• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing TPU distributed values.
16
17Note that the tests are in values_test.py .
18
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import contextlib
26
27from tensorflow.python.distribute import packed_distributed_variable as packed
28from tensorflow.python.distribute import tpu_util
29from tensorflow.python.distribute import values
30from tensorflow.python.distribute import values_util
31from tensorflow.python.eager import context
32from tensorflow.python.eager import tape
33from tensorflow.python.framework import ops
34from tensorflow.python.ops import gen_resource_variable_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import variable_scope
37
38
39@contextlib.contextmanager
40def _maybe_enter_graph(tensor):
41  # Note: might have an eager tensor but not be executing eagerly when
42  # building functions.
43  if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
44      ops.has_default_graph()):
45    yield
46  else:
47    with tensor.graph.as_default():
48      yield
49
50
51@contextlib.contextmanager
52def _maybe_on_device(var):
53  # Add a device scope for packed variables.
54  if isinstance(var, packed.PackedVarAndDevice):
55    with ops.device(var.device):
56      yield
57  else:
58    yield
59
60
61def _make_raw_assign_fn(raw_assign_fn):  # pylint: disable=missing-docstring
62
63  def assign_fn(var, value, use_locking=False, name=None, read_value=True):  # pylint: disable=missing-docstring
64    del use_locking  # Unused.
65
66    handle = var.handle
67    with _maybe_enter_graph(handle), _maybe_on_device(var):
68      op = raw_assign_fn(
69          handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
70      with ops.control_dependencies([op]):
71        return var._read_variable_op() if read_value else op  # pylint: disable=protected-access
72
73  return assign_fn
74
75
76_scatter_error_msg = ("{op_name} is only supported for distributed "
77                      "variable (variable created within certain "
78                      "`tf.distribute.Strategy` scope) with NONE "
79                      " aggregation, got: {aggregation}.")
80
81
82def _make_raw_scatter_xxx_fn(raw_scatter_xxx_fn):
83  """Wrap `raw_scatter_xxx_fn` so that it can be called w/ and w/o packed handle."""
84
85  def scatter_xxx_fn(var, sparse_delta, use_locking=False, name=None):  # pylint: disable=missing-docstring
86    del use_locking  # Unused.
87
88    handle = var.handle
89    with _maybe_enter_graph(handle), _maybe_on_device(var):
90      op = raw_scatter_xxx_fn(
91          handle,
92          sparse_delta.indices,
93          ops.convert_to_tensor(sparse_delta.values, var.dtype),
94          name=name)
95      with ops.control_dependencies([op]):
96        return var._read_variable_op()  # pylint: disable=protected-access
97
98  return scatter_xxx_fn
99
100
101class TPUVariableMixin(object):
102  """Mixin for TPU variables."""
103
104  def __init__(self, *args, **kwargs):
105    super(TPUVariableMixin, self).__init__(*args, **kwargs)
106
107    # Handle ID is needed for `get_replicated_var_handle` to cache the variables
108    # correctly since in eager mode different variables can have the same name.
109    if ops.executing_eagerly_outside_functions():
110      self._handle_id = self._common_name + "_" + str(id(self._primary))
111    else:
112      self._handle_id = self._common_name
113
114  def __getattr__(self, name):
115    if tpu_util.enclosing_tpu_context() is None:
116      return super(TPUVariableMixin, self).__getattr__(name)
117    else:
118      raise AttributeError(
119          f"`TPUVariableMixin.{name}` not accessible within a TPU context.")
120
121  def get(self):
122    if tpu_util.enclosing_tpu_context() is None:
123      return super(TPUVariableMixin, self).get()
124    else:
125      raise NotImplementedError(
126          "`TPUVariableMixin.get()` is not supported within a TPU context.")
127
128  def _get_as_operand(self):
129    return self.read_value()
130
131  def _is_mirrored(self):
132    raise NotImplementedError(
133        "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
134
135  @property
136  def handle(self):
137    """The handle by which this variable can be accessed."""
138    # If we're in a tpu.rewrite(), return the replicated handle.
139    tpu_context = tpu_util.enclosing_tpu_context()
140    if tpu_context is None or context.executing_eagerly():
141      var = self._get_on_device_or_primary()
142      if isinstance(var, packed.PackedVarAndDevice):
143        return var.on_device_handle()
144      else:
145        return var.handle
146    else:
147      is_packed = self._packed_var is not None
148      val = self._values
149      if is_packed:
150        val = [self._packed_var]
151
152      return tpu_context.get_replicated_var_handle(self._handle_id, val,
153                                                   self._is_mirrored(),
154                                                   is_packed)
155
156  @property
157  def device(self):
158    return self.handle.device
159
160  def _read_variable_op(self):
161    """Reads the value of this variable."""
162    if self.trainable:
163      tape.variable_accessed(self)
164
165    handle = self.handle
166    if getattr(handle, "is_packed", False):
167      # Add a device scope for a packed variable handle.
168      with ops.device(self._get_on_device_or_primary().device):
169        return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
170    else:
171      return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
172
173  def read_value(self):
174    if tpu_util.enclosing_tpu_context() is None:
175      return super(TPUVariableMixin, self).read_value()
176    else:
177      return self._read_variable_op()
178
179  def value(self):
180    if tpu_util.enclosing_tpu_context() is None:
181      return super(TPUVariableMixin, self).value()
182    else:
183      return self._read_variable_op()
184
185  def _as_graph_element(self):
186    if tpu_util.enclosing_tpu_context() is None:
187      return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
188    else:
189      return None
190
191  @property
192  def op(self):
193    if values_util.is_saving_non_distributed():
194      return self._primary.op
195    return values.DistributedVarOp(self._primary.op.name,
196                                   self._primary.op.graph,
197                                   self._primary.op.traceback,
198                                   self._primary.op.type)
199
200  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
201    """Converts a variable to a tensor."""
202    # pylint: disable=protected-access
203    if tpu_util.enclosing_tpu_context() is None:
204      return super(TPUVariableMixin, self)._dense_var_to_tensor(
205          dtype=dtype, name=name, as_ref=as_ref)
206    # pylint: enable=protected-access
207    elif dtype is not None and dtype != self.dtype:
208      return math_ops.cast(self.read_value(), dtype)
209    else:
210      return self.handle if as_ref else self.read_value()
211
212
213class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
214  """DistributedVariable subclass for TPUStrategy."""
215
216  def _is_mirrored(self):
217    return self._policy._is_mirrored()  # pylint: disable=protected-access
218
219  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
220    if values_util.is_saving_non_distributed():
221      return self._primary.assign_sub(value, use_locking, name, read_value)
222    return self._policy.assign_sub(
223        self, value, use_locking=use_locking, name=name, read_value=read_value)
224
225  def assign_add(self, value, use_locking=False, name=None, read_value=True):
226    if values_util.is_saving_non_distributed():
227      return self._primary.assign_add(value, use_locking, name, read_value)
228    return self._policy.assign_add(
229        self, value, use_locking=use_locking, name=name, read_value=read_value)
230
231  def assign(self, value, use_locking=False, name=None, read_value=True):
232    if values_util.is_saving_non_distributed():
233      return self._primary.assign(value, use_locking, name, read_value)
234    return self._policy.assign(
235        self, value, use_locking=use_locking, name=name, read_value=read_value)
236
237  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
238    if values_util.is_saving_non_distributed():
239      return self._primary.scatter_sub(sparse_delta, use_locking, name)
240    return self._policy.scatter_sub(
241        self, sparse_delta, use_locking=use_locking, name=name)
242
243  def scatter_add(self, sparse_delta, use_locking=False, name=None):
244    if values_util.is_saving_non_distributed():
245      return self._primary.scatter_add(sparse_delta, use_locking, name)
246    return self._policy.scatter_add(
247        self, sparse_delta, use_locking=use_locking, name=name)
248
249  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
250    if values_util.is_saving_non_distributed():
251      return self._primary.scatter_mul(sparse_delta, use_locking, name)
252    return self._policy.scatter_mul(
253        self, sparse_delta, use_locking=use_locking, name=name)
254
255  def scatter_div(self, sparse_delta, use_locking=False, name=None):
256    if values_util.is_saving_non_distributed():
257      return self._primary.scatter_div(sparse_delta, use_locking, name)
258    return self._policy.scatter_div(
259        self, sparse_delta, use_locking=use_locking, name=name)
260
261  def scatter_min(self, sparse_delta, use_locking=False, name=None):
262    if values_util.is_saving_non_distributed():
263      return self._primary.scatter_min(sparse_delta, use_locking, name)
264    return self._policy.scatter_min(
265        self, sparse_delta, use_locking=use_locking, name=name)
266
267  def scatter_max(self, sparse_delta, use_locking=False, name=None):
268    if values_util.is_saving_non_distributed():
269      return self._primary.scatter_max(sparse_delta, use_locking, name)
270    return self._policy.scatter_max(
271        self, sparse_delta, use_locking=use_locking, name=name)
272
273  def scatter_update(self, sparse_delta, use_locking=False, name=None):
274    if values_util.is_saving_non_distributed():
275      return self._primary.scatter_update(sparse_delta, use_locking, name)
276    return self._policy.scatter_update(
277        self, sparse_delta, use_locking=use_locking, name=name)
278
279
280class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
281  """Holds a map from replica to TPU variables whose values are kept in sync."""
282
283  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
284    if (tpu_util.enclosing_tpu_context() and
285        self.aggregation == variable_scope.VariableAggregation.NONE):
286      return _make_raw_assign_fn(
287          gen_resource_variable_ops.assign_sub_variable_op)(
288              self,
289              value=value,
290              use_locking=use_locking,
291              name=name,
292              read_value=read_value)
293    return assign_sub(
294        self, value, use_locking=use_locking, name=name, read_value=read_value)
295
296  def assign_add(self, value, use_locking=False, name=None, read_value=True):
297    if (tpu_util.enclosing_tpu_context() and
298        self.aggregation == variable_scope.VariableAggregation.NONE):
299      return _make_raw_assign_fn(
300          gen_resource_variable_ops.assign_add_variable_op)(
301              self,
302              value=value,
303              use_locking=use_locking,
304              name=name,
305              read_value=read_value)
306    return assign_add(
307        self, value, use_locking=use_locking, name=name, read_value=read_value)
308
309  def assign(self, value, use_locking=False, name=None, read_value=True):
310    if (tpu_util.enclosing_tpu_context() and
311        self.aggregation == variable_scope.VariableAggregation.NONE):
312      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
313          self,
314          value=value,
315          use_locking=use_locking,
316          name=name,
317          read_value=read_value)
318    return assign(
319        self, value, use_locking=use_locking, name=name, read_value=read_value)
320
321  def scatter_sub(self, *args, **kwargs):
322    if values_util.is_saving_non_distributed():
323      return self._primary.scatter_sub(*args, **kwargs)
324    raise NotImplementedError
325
326  def scatter_add(self, *args, **kwargs):
327    if values_util.is_saving_non_distributed():
328      return self._primary.scatter_add(*args, **kwargs)
329    raise NotImplementedError
330
331  def scatter_max(self, *args, **kwargs):
332    if values_util.is_saving_non_distributed():
333      return self._primary.scatter_max(*args, **kwargs)
334    raise NotImplementedError
335
336  def scatter_min(self, *args, **kwargs):
337    if values_util.is_saving_non_distributed():
338      return self._primary.scatter_min(*args, **kwargs)
339    raise NotImplementedError
340
341  def scatter_mul(self, *args, **kwargs):
342    if values_util.is_saving_non_distributed():
343      return self._primary.scatter_mul(*args, **kwargs)
344    raise NotImplementedError
345
346  def scatter_div(self, *args, **kwargs):
347    if values_util.is_saving_non_distributed():
348      return self._primary.scatter_div(*args, **kwargs)
349    raise NotImplementedError
350
351  def scatter_update(self, *args, **kwargs):
352    if values_util.is_saving_non_distributed():
353      return self._primary.scatter_update(*args, **kwargs)
354    raise NotImplementedError
355
356  def _is_mirrored(self):
357    return True
358
359
360class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
361  """Holds a map from replica to variables whose values are reduced on save."""
362
363  def assign_sub(self, *args, **kwargs):
364    if tpu_util.enclosing_tpu_context() is None:
365      return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
366    else:
367      return _make_raw_assign_fn(
368          gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
369                                                            **kwargs)
370
371  def assign_add(self, *args, **kwargs):
372    if tpu_util.enclosing_tpu_context() is None:
373      return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
374    else:
375      return _make_raw_assign_fn(
376          gen_resource_variable_ops.assign_add_variable_op)(self, *args,
377                                                            **kwargs)
378
379  def assign(self, *args, **kwargs):
380    if tpu_util.enclosing_tpu_context() is None:
381      return values.SyncOnReadVariable.assign(self, *args, **kwargs)
382    else:
383      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
384          self, *args, **kwargs)
385
386  def _is_mirrored(self):
387    return False
388
389
390# Common method between OnWrite and Mirrored variables.
391def assign_sub(var, value, use_locking=False, name=None, read_value=True):
392  assign_sub_fn = _make_raw_assign_fn(
393      gen_resource_variable_ops.assign_sub_variable_op)
394  return var._update(  # pylint: disable=protected-access
395      update_fn=assign_sub_fn,
396      value=value,
397      use_locking=use_locking,
398      name=name,
399      read_value=read_value)
400
401
402def assign_add(var, value, use_locking=False, name=None, read_value=True):
403  assign_add_fn = _make_raw_assign_fn(
404      gen_resource_variable_ops.assign_add_variable_op)
405  return var._update(  # pylint: disable=protected-access
406      update_fn=assign_add_fn,
407      value=value,
408      use_locking=use_locking,
409      name=name,
410      read_value=read_value)
411
412
413def assign(var, value, use_locking=False, name=None, read_value=True):
414  assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)
415  return var._update(  # pylint: disable=protected-access
416      update_fn=assign_fn,
417      value=value,
418      use_locking=use_locking,
419      name=name,
420      read_value=read_value)
421
422
423class TPUOnWritePolicy(values.OnWritePolicy):
424  """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
425
426  This policy is created when `synchronization` is set to
427  `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`.
428  """
429
430  def assign_sub(self,
431                 var,
432                 value,
433                 use_locking=False,
434                 name=None,
435                 read_value=True):
436    if (tpu_util.enclosing_tpu_context() and
437        var.aggregation == variable_scope.VariableAggregation.NONE):
438      return _make_raw_assign_fn(
439          gen_resource_variable_ops.assign_sub_variable_op)(
440              var,
441              value=value,
442              use_locking=use_locking,
443              name=name,
444              read_value=read_value)
445    return assign_sub(
446        var, value, use_locking=use_locking, name=name, read_value=read_value)
447
448  def assign_add(self,
449                 var,
450                 value,
451                 use_locking=False,
452                 name=None,
453                 read_value=True):
454    if (tpu_util.enclosing_tpu_context() and
455        var.aggregation == variable_scope.VariableAggregation.NONE):
456      return _make_raw_assign_fn(
457          gen_resource_variable_ops.assign_add_variable_op)(
458              var,
459              value=value,
460              use_locking=use_locking,
461              name=name,
462              read_value=read_value)
463    return assign_add(
464        var, value, use_locking=use_locking, name=name, read_value=read_value)
465
466  def assign(self, var, value, use_locking=False, name=None, read_value=True):
467    if (tpu_util.enclosing_tpu_context() and
468        var.aggregation == variable_scope.VariableAggregation.NONE):
469      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
470          var,
471          value=value,
472          use_locking=use_locking,
473          name=name,
474          read_value=read_value)
475    return assign(
476        var, value, use_locking=use_locking, name=name, read_value=read_value)
477
478  def _scatter_xxx(self,
479                   raw_scater_xxx_fn,
480                   op_name,
481                   var,
482                   sparse_delta,
483                   use_locking=False,
484                   name=None):
485    scater_xxx_fn = _make_raw_scatter_xxx_fn(raw_scater_xxx_fn)
486    if tpu_util.enclosing_tpu_context():
487      if self._aggregation != variable_scope.VariableAggregation.NONE:
488        raise NotImplementedError(
489            _scatter_error_msg.format(
490                op_name=op_name, aggregation=self._aggregation))
491      return scater_xxx_fn(
492          var, sparse_delta=sparse_delta, use_locking=use_locking, name=name)
493    else:
494      return var._update(  # pylint: disable=protected-access
495          update_fn=scater_xxx_fn,
496          value=sparse_delta,
497          use_locking=use_locking,
498          name=name)
499
500  def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
501    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_sub,
502                             "scatter_sub", var, sparse_delta, use_locking,
503                             name)
504
505  def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
506    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_add,
507                             "scatter_add", var, sparse_delta, use_locking,
508                             name)
509
510  def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
511    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_max,
512                             "scatter_max", var, sparse_delta, use_locking,
513                             name)
514
515  def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
516    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_min,
517                             "scatter_min", var, sparse_delta, use_locking,
518                             name)
519
520  def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
521    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_mul,
522                             "scatter_mul", var, sparse_delta, use_locking,
523                             name)
524
525  def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
526    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_div,
527                             "scatter_div", var, sparse_delta, use_locking,
528                             name)
529
530  def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
531    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_update,
532                             "scatter_update", var, sparse_delta, use_locking,
533                             name)
534
535  def _is_mirrored(self):
536    return True
537
538
539class TPUOnReadPolicy(values.OnReadPolicy):
540  """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
541
542  This policy is created when `synchronization` is set to
543  `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
544  values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
545  `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
546  scope.
547  """
548
549  def assign_sub(self, var, *args, **kwargs):
550    if tpu_util.enclosing_tpu_context() is None:
551      return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
552    else:
553      return _make_raw_assign_fn(
554          gen_resource_variable_ops.assign_sub_variable_op)(var, *args,
555                                                            **kwargs)
556
557  def assign_add(self, var, *args, **kwargs):
558    if tpu_util.enclosing_tpu_context() is None:
559      return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
560    else:
561      return _make_raw_assign_fn(
562          gen_resource_variable_ops.assign_add_variable_op)(var, *args,
563                                                            **kwargs)
564
565  def assign(self, var, *args, **kwargs):
566    if tpu_util.enclosing_tpu_context() is None:
567      return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
568    else:
569      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
570          var, *args, **kwargs)
571
572  def _is_mirrored(self):
573    return False
574
575  def scatter_sub(self, *args, **kwargs):
576    raise NotImplementedError
577
578  def scatter_add(self, *args, **kwargs):
579    raise NotImplementedError
580
581  def scatter_max(self, *args, **kwargs):
582    raise NotImplementedError
583
584  def scatter_min(self, *args, **kwargs):
585    raise NotImplementedError
586
587  def scatter_mul(self, *args, **kwargs):
588    raise NotImplementedError
589
590  def scatter_div(self, *args, **kwargs):
591    raise NotImplementedError
592
593  def scatter_update(self, *args, **kwargs):
594    raise NotImplementedError
595