• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values for PS."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import weakref
23
24import numpy as np
25
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import distribute_utils
28from tensorflow.python.distribute import distribution_strategy_context as ds_context
29from tensorflow.python.distribute import values
30from tensorflow.python.distribute import values_util
31from tensorflow.python.eager import context
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import variable_scope as vs
36from tensorflow.python.training.tracking import base as trackable
37from tensorflow.python.types import core
38
39
40# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
41class AggregatingVariable(resource_variable_ops.BaseResourceVariable,
42                          core.Tensor):
43  """A wrapper around a variable that aggregates updates across replicas."""
44
45  def __init__(self, strategy, v, aggregation):
46    self._distribute_strategy = strategy
47    self._v = v
48    # NOTE: We don't use "_distributed_container" here because we don't want
49    # to trigger that code path in regroup().
50    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
51    self._aggregation = aggregation
52
53  def __deepcopy__(self, memo):
54    """Perform a deepcopy of the `AggregatingVariable`.
55
56    Unlike the deepcopy of a regular tf.Variable, this keeps the original
57    strategy and devices of the `AggregatingVariable`.  To avoid confusion
58    with the behavior of deepcopy on a regular `Variable` (which does
59    copy into new devices), we only allow a deepcopy of a `AggregatingVariable`
60    within its originating strategy scope.
61
62    Args:
63      memo: The memoization object for `deepcopy`.
64
65    Returns:
66      A deep copy of the current `AggregatingVariable`.
67
68    Raises:
69      RuntimeError: If trying to deepcopy into a different strategy.
70    """
71    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
72      v = copy.deepcopy(self._v, memo)
73
74    copied_variable = type(self)(
75        strategy=self._distribute_strategy,
76        v=v,
77        aggregation=self._aggregation)
78
79    memo[id(self)] = copied_variable
80
81    return copied_variable
82
83  def get(self):
84    return self._v
85
86  @property
87  def distribute_strategy(self):
88    return self._distribute_strategy
89
90  def __getattr__(self, name):
91    return getattr(self._v, name)
92
93  def _assign_func(self, *args, **kwargs):
94    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
95      f = kwargs.pop("f")
96      if ds_context.in_cross_replica_context():
97        if distribute_lib.get_update_replica_id() is not None:
98          # We are calling an assign function in an update context.
99          return f(self._v, *args, **kwargs)
100
101        # We are calling an assign function in cross replica context, wrap it in
102        # an update call.
103        return self._distribute_strategy.extended.update(
104            self, f, args=args, kwargs=kwargs)
105      else:
106        replica_context = ds_context.get_replica_context()
107        assert replica_context
108        # We are calling an assign function in replica context.
109        # We reduce the value we want to assign/add/sub. More details about how
110        # we handle the different use cases can be found in the _reduce method.
111        # We call the function with the reduced value.
112        if self._aggregation == vs.VariableAggregation.NONE:
113          raise ValueError(
114              values_util.aggregation_error_msg.format(
115                  variable_type="AggregatingVariable"))
116
117        def merge_fn(strategy,
118                     value,
119                     use_locking=False,
120                     name=None,
121                     read_value=True):
122          v = values_util.apply_aggregation(strategy, value, self._aggregation,
123                                            self)
124          if name and isinstance(name, values.PerReplica):
125            name = name.values[0]
126          return strategy.extended.update(
127              self,
128              f,
129              args=(v,),
130              kwargs={
131                  "use_locking": use_locking,
132                  "name": name,
133                  "read_value": read_value
134              })
135        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
136
137  def assign_sub(self, *args, **kwargs):
138    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
139    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
140
141  def assign_add(self, *args, **kwargs):
142    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
143    return self._assign_func(f=assign_add_fn, *args, **kwargs)
144
145  def assign(self, *args, **kwargs):
146    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
147    return self._assign_func(f=assign_fn, *args, **kwargs)
148
149  @property
150  def initializer(self):
151    return self._v.initializer
152
153  def initialized_value(self):
154    return self._v.initialized_value()
155
156  @property
157  def initial_value(self):
158    return self._v.initial_value
159
160  @property
161  def op(self):
162    return self._v.op
163
164  def value(self):
165    return self._v.value()
166
167  def read_value(self):
168    return self._v.read_value()
169
170  def sparse_read(self, indices, name=None):
171    return self._v.sparse_read(indices, name=name)
172
173  def eval(self, session=None):
174    return self._v.eval(session)
175
176  @property
177  def graph(self):
178    return self._v.graph
179
180  @property
181  def device(self):
182    return self._v.device
183
184  @property
185  def shape(self):
186    return self._v.shape
187
188  @property
189  def aggregation(self):
190    return self._aggregation
191
192  @property
193  def synchronization(self):
194    return self._v.synchronization
195
196  @property
197  def name(self):
198    return self._v.name
199
200  @property
201  def trainable(self):
202    return self._v.trainable
203
204  @property
205  def dtype(self):
206    return self._v.dtype
207
208  # TODO(josh11b): Test saving & restoring.
209  def _gather_saveables_for_checkpoint(self):
210    if isinstance(self._v, CachingVariable):
211      return self._v._gather_saveables_for_checkpoint()  # pylint:disable=protected-access
212    return {trackable.VARIABLE_VALUE_KEY: self._v}
213
214  def _map_resources(self, save_options):
215    """For implementing `Trackable`."""
216    # By delegating this method to the wrapped variable, SavedModel with
217    # AggregatingVariable are identical to SavedModel with normal variables.
218    obj_map, resource_map = self._v._map_resources(save_options)  # pylint:disable=protected-access
219    obj_map[self] = obj_map[self._v]
220    return obj_map, resource_map
221
222  # pylint: disable=multiple-statements
223  def __add__(self, o):
224    return self._v + o
225
226  def __radd__(self, o):
227    return o + self._v
228
229  def __sub__(self, o):
230    return self._v - o
231
232  def __rsub__(self, o):
233    return o - self._v
234
235  def __mul__(self, o):
236    return self._v * o
237
238  def __rmul__(self, o):
239    return o * self._v
240
241  def __truediv__(self, o):
242    return self._v / o
243
244  def __rtruediv__(self, o):
245    return o / self._v
246
247  def __floordiv__(self, o):
248    return self._v // o
249
250  def __rfloordiv__(self, o):
251    return o // self._v
252
253  def __mod__(self, o):
254    return self._v % o
255
256  def __rmod__(self, o):
257    return o % self._v
258
259  def __lt__(self, o):
260    return self._v < o
261
262  def __le__(self, o):
263    return self._v <= o
264
265  def __gt__(self, o):
266    return self._v > o
267
268  def __ge__(self, o):
269    return self._v >= o
270
271  def __and__(self, o):
272    return self._v & o
273
274  def __rand__(self, o):
275    return o & self._v
276
277  def __or__(self, o):
278    return self._v | o
279
280  def __ror__(self, o):
281    return o | self._v
282
283  def __xor__(self, o):
284    return self._v ^ o
285
286  def __rxor__(self, o):
287    return o ^ self._v
288
289  def __getitem__(self, o):
290    return self._v[o]
291
292  def __pow__(self, o, modulo=None):
293    return pow(self._v, o, modulo)
294
295  def __rpow__(self, o):
296    return pow(o, self._v)
297
298  def __invert__(self):
299    return ~self._v
300
301  def __neg__(self):
302    return -self._v
303
304  def __abs__(self):
305    return abs(self._v)
306
307  def __div__(self, o):
308    try:
309      return self._v.__div__(o)
310    except AttributeError:
311      # See https://docs.python.org/3/library/constants.html#NotImplemented
312      return NotImplemented
313
314  def __rdiv__(self, o):
315    try:
316      return self._v.__rdiv__(o)
317    except AttributeError:
318      # See https://docs.python.org/3/library/constants.html#NotImplemented
319      return NotImplemented
320
321  def __matmul__(self, o):
322    try:
323      return self._v.__matmul__(o)
324    except AttributeError:
325      # See https://docs.python.org/3/library/constants.html#NotImplemented
326      return NotImplemented
327
328  def __rmatmul__(self, o):
329    try:
330      return self._v.__rmatmul__(o)
331    except AttributeError:
332      # See https://docs.python.org/3/library/constants.html#NotImplemented
333      return NotImplemented
334
335  def __str__(self):
336    return str(self._v)
337
338  def __repr__(self):
339    return repr(self._v)
340
341  def _should_act_as_resource_variable(self):
342    """Pass resource_variable_ops.is_resource_variable check."""
343    pass
344
345  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
346    return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
347
348
349class CachingVariable(resource_variable_ops.BaseResourceVariable, core.Tensor):
350  """A wrapper around a variable that caches read value locally."""
351
352  def __init__(self, v):
353    self._v = v
354    self._cache = None
355    self._current_new_cache_scope_count = 0
356
357  def get(self):
358    return self._v
359
360  def __getattr__(self, name):
361    return getattr(self._v, name)
362
363  def read_value(self):
364    if distribute_utils.caching_scope_local.in_caching_scope():
365      return self.cached_read_value()
366    return self._v.read_value()
367
368  def sparse_read(self, indices, name=None):
369    return self._v.sparse_read(indices, name=name)
370
371  def cached_read_value(self):
372    if (distribute_utils.caching_scope_local.new_cache_scope_count >
373        self._current_new_cache_scope_count):
374      self._current_new_cache_scope_count += 1
375      self._cache = None
376
377    with ops.device("CPU:0"):
378      if self._cache is not None:
379        return self._cache
380      else:
381        self._cache = array_ops.identity(self._v)
382        return self._cache
383
384  def assign_sub(self, *args, **kwargs):
385    return self._v.assign_sub(*args, **kwargs)
386
387  def assign_add(self, *args, **kwargs):
388    return self._v.assign_add(*args, **kwargs)
389
390  def assign(self, *args, **kwargs):
391    return self._v.assign(*args, **kwargs)
392
393  @property
394  def initializer(self):
395    return self._v.initializer
396
397  def initialized_value(self):
398    return self._v.initialized_value()
399
400  @property
401  def initial_value(self):
402    return self._v.initial_value
403
404  @property
405  def op(self):
406    return self._v.op
407
408  def value(self):
409    if distribute_utils.caching_scope_local.in_caching_scope():
410      return self.cached_read_value()
411    return self._v.value()
412
413  def eval(self, session=None):
414    return self._v.eval(session)
415
416  @property
417  def graph(self):
418    return self._v.graph
419
420  @property
421  def device(self):
422    return self._v.device
423
424  @property
425  def shape(self):
426    return self._v.shape
427
428  @property
429  def synchronization(self):
430    return self._v.synchronization
431
432  @property
433  def name(self):
434    return self._v.name
435
436  @property
437  def trainable(self):
438    return self._v.trainable
439
440  @property
441  def dtype(self):
442    return self._v.dtype
443
444  @property
445  def constraint(self):
446    return self._v.constraint
447
448  def __array__(self):
449    return np.asarray(self.numpy())
450
451  def __complex__(self):
452    return complex(self.value().numpy())
453
454  def __int__(self):
455    return int(self.value().numpy())
456
457  def __float__(self):
458    return float(self.value().numpy())
459
460  def numpy(self):
461    if context.executing_eagerly():
462      return self.read_value().numpy()
463    else:
464      raise NotImplementedError(
465          "numpy() is only available when eager execution is enabled.")
466
467  def __str__(self):
468    return str(self._v)
469
470  def __repr__(self):
471    return repr(self._v)
472
473  def _should_act_as_resource_variable(self):
474    """Pass resource_variable_ops.is_resource_variable check."""
475    pass
476
477  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
478    if distribute_utils.caching_scope_local.in_caching_scope():
479      return self.cached_read_value()
480    return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=False)  # pylint: disable=protected-access
481
482  @classmethod
483  def _overload_overloadable_operators(cls):
484    """Register overloads for all operators."""
485    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
486      # Overloading __eq__ or __ne__ does not work as expected.
487      if operator == "__eq__" or operator == "__ne__":
488        continue
489      cls._tensor_overload_operator(operator)
490
491  @classmethod
492  def _tensor_overload_operator(cls, operator):
493    """Delegate an operator overload to `ops.Tensor`."""
494    tensor_operator = getattr(ops.Tensor, operator)
495
496    def _operator(v, *args, **kwargs):
497      return tensor_operator(v.value(), *args, **kwargs)  # pylint: disable=protected-access
498    setattr(cls, operator, _operator)
499
500  def _gather_saveables_for_checkpoint(self):
501    return {trackable.VARIABLE_VALUE_KEY: self._v}
502
503  def _map_resources(self, save_options):
504    """For implementing `Trackable`."""
505    # By delegating this method to the wrapped variable, SavedModel with
506    # AggregatingVariable are identical to SavedModel with normal variables.
507    obj_map, resource_map = self._v._map_resources(save_options)  # pylint:disable=protected-access
508    obj_map[self] = obj_map[self._v]
509    return obj_map, resource_map
510
511
512# Register a conversion function which reads the value of the variable,
513# allowing instances of the class to be used as tensors.
514def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
515  return var._dense_var_to_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
516
517
518ops.register_tensor_conversion_function(AggregatingVariable,
519                                        _tensor_conversion_aggregate)
520
521
522# Register a conversion function which reads the value of the variable,
523# allowing instances of the class to be used as tensors.
524def _tensor_conversion_caching(var, dtype=None, name=None, as_ref=False):
525  return var._dense_var_to_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
526
527
528ops.register_tensor_conversion_function(CachingVariable,
529                                        _tensor_conversion_caching)
530
531CachingVariable._overload_overloadable_operators()  # pylint: disable=protected-access
532