• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing utilities used by tf.distribute.Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import abc
22
23import contextlib
24import threading
25from tensorflow.python.distribute import tpu_values as tpu_values_lib
26from tensorflow.python.distribute import values as values_lib
27from tensorflow.python.eager import context
28from tensorflow.python.eager import tape
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import variable_scope as vs
34from tensorflow.python.util import nest
35
36
37def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
38  """Makes a nest per-replica into a nest of PerReplica/Mirrored values.
39
40  Args:
41    values: Values to regroup
42    wrap_class: Class that `values` be wrapped in.
43    always_wrap: Always wrap the `values` in `wrap_class` even if the values
44        are the same except for DistributeVariable.
45  Returns:
46    Wrapped `values`.
47  """
48  v0 = values[0]
49
50  if isinstance(v0, list):
51    for v in values[1:]:
52      assert isinstance(v, list)
53      assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
54                                 (len(v), len(v0), v, v0))
55    return [
56        regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
57        for i in range(len(v0))
58    ]
59
60  if isinstance(v0, tuple):
61    for v in values[1:]:
62      assert isinstance(v, tuple)
63      assert len(v) == len(v0)
64    regrouped_tuple = tuple(
65        regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
66        for i in range(len(v0)))
67    if hasattr(v0, "_fields"):
68      # This tuple is in fact a namedtuple! Create a new namedtuple instance
69      # and initialize it with the regrouped values:
70      assert hasattr(v0, "_make")
71      return v0._make(regrouped_tuple)
72    else:
73      return regrouped_tuple
74
75  if isinstance(v0, abc.Mapping):
76    v0keys = v0.keys()
77    for v in values[1:]:
78      assert isinstance(v, abc.Mapping), ("v[0]: %r  v[i]: %r" % (v0, v))
79      assert set(v.keys()) == set(v0keys), ("v[0].keys: %s  v[i].keys: %s" %
80                                            (set(v0keys), set(v.keys())))
81    # Use the actual type in case it is a class inherited from a dict.
82    return type(v0)({
83        key: regroup(tuple(v[key] for v in values),
84                     wrap_class, always_wrap)
85        for key in v0keys
86    })
87
88  # If exactly the same object across all devices, return it unwrapped.
89  same_id = True
90  for v in values[1:]:
91    if v is not v0:
92      same_id = False
93      break
94  # Consider three cases where same_id is true:
95  # * If v0 is a DistributedVariable (a MirroredVariable or
96  #   SyncOnReadVariable, and same_id means it is the same across all
97  #   devices), we want to return it. We check DistributedVariable
98  #   specifically since it can look like it has a
99  #   _distributed_container member since its members do.
100  if same_id and isinstance(v0, values_lib.DistributedVariable):
101    return v0
102  # * If v0 is a member of a distributed variable, in which case
103  #   hasattr(v0, "_distributed_container") is true, we want to
104  #   return the DistributedVariable that contains it using the
105  #   _distributed_container logic below. This case can trigger
106  #   same_id when there is only one device.
107  # * In any other situation, same_id means we return v0 unless `always_wrap` is
108  #   true.
109  if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
110    return v0
111
112  # Detect the case where each device has a parallel component of the
113  # same MirroredVariable (or SyncOnReadVariable). In this case we
114  # want to return the containing MirroredVariable, after a bunch of
115  # sanity checking. In particular, each component should have the
116  # same container, and the devices of the variables should match the
117  # keys of the per-replica dictionary.
118  if hasattr(v0, "_distributed_container"):
119    # pylint: disable=protected-access
120    assert not isinstance(v0, values_lib.MirroredVariable), (
121        "ids = %s, values = %s" % ([id(v) for v in values], values))
122    distributed_container = v0._distributed_container()
123    assert distributed_container is not None
124    for v in values[1:]:
125      assert distributed_container is v._distributed_container()
126    return distributed_container
127  # pylint: enable=protected-access
128
129  return wrap_class(values)
130
131
132def select_replica(replica_id, structured):
133  """Specialize a nest of regular & per-replica values for one replica."""
134
135  def _get(x):
136    # `DistributedValues` would be sliced according to replica unless it is a
137    # `DistributedVariable` because `DistributedVariable` can be handled
138    # directly in the replica context.
139    if (isinstance(x, values_lib.DistributedVariable) or
140        not isinstance(x, values_lib.DistributedValues)):
141      return x
142    else:
143      return x.values[replica_id]
144
145  return nest.map_structure(_get, structured)
146
147
148def select_replica_mirrored(replica_id, structured):
149  """Specialize a nest of regular & mirrored values for one replica."""
150  assert_mirrored(structured)
151  return select_replica(replica_id, structured)
152
153
154def assert_mirrored(structured):
155  """Raises if the structured is not composed of mirrored or regular values."""
156
157  def _assert_mirrored(x):
158    if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x):
159      raise TypeError(
160          "Expected value to be mirrored across replicas: %s in %s." %
161          (x, structured))
162
163  nest.map_structure(_assert_mirrored, structured)
164
165
166def update_regroup(extended, updates, group):
167  """Regroup for an update, with dependencies to ensure all updates execute."""
168  if not group:
169    regrouped = regroup(updates, values_lib.Mirrored)
170    return nest.map_structure(extended._local_results, regrouped)  # pylint: disable=protected-access
171
172  def _make_grouped_mirrored(values):
173    """Convert per-replica list `values` into Mirrored type with grouping."""
174    if len(values) == 1:
175      return values_lib.Mirrored(values)
176
177    # Make sure we run all updates. Without this, something like
178    # session.run(extended.update(...)) may only update one replica.
179    g = control_flow_ops.group(values)
180
181    # If values is just ops, the grouping is enough. Everything in values
182    # should have the same type, since we expect every replica to be performing
183    # the same computation.
184    if not all(tensor_util.is_tf_type(v) for v in values):
185      return g
186
187    # Otherwise we need tensors with the same values as `values`, but
188    # that have a dependency on `g`.
189    with_dep = []
190    for v in values:
191      with ops.device(v.device), ops.control_dependencies([g]):
192        with_dep.append(array_ops.identity(v))
193
194    return values_lib.Mirrored(with_dep)
195
196  return regroup(updates, _make_grouped_mirrored)
197
198
199def value_container(val):
200  """Returns the container that this per-replica `value` belongs to.
201
202  Args:
203    val: A value returned by `call_for_each_replica()` or a variable created in
204      `scope()`.
205
206  Returns:
207    A container that `value` belongs to.
208    If value does not belong to any container (including the case of
209    container having been destroyed), returns the value itself.
210  """
211  if (hasattr(val, "_distributed_container") and
212      # DistributedVariable has _distributed_container defined
213      # but we don't want to return it.
214      not isinstance(val, values_lib.DistributedVariable)):
215    container = val._distributed_container()  # pylint: disable=protected-access
216    if container is not None:
217      return container
218  return val
219
220
221def is_distributed_variable(v):
222  """Determine if a variable is ds variable or TPU mirrored variable."""
223  return isinstance(v, values_lib.DistributedVariable)
224
225
226def _validate_colocate_extended(v, extended):
227  variable_strategy = v._distribute_strategy  # pylint: disable=protected-access
228  if variable_strategy.extended is not extended:
229    raise ValueError(
230        "`colocate_vars_with` must only be passed a variable created in this "
231        "tf.distribute.Strategy.scope(), not %s created in scope: %s" %
232        (v, variable_strategy))
233
234
235def validate_colocate_distributed_variable(v, extended):
236  if not isinstance(v, values_lib.DistributedVariable):
237    raise ValueError(
238        "`colocate_vars_with` must only be passed a variable created in this "
239        "tf.distribute.Strategy.scope(), not: %r" % (v,))
240  _validate_colocate_extended(v, extended)
241
242
243def validate_colocate(v, extended):
244  if not hasattr(v, "_distribute_strategy"):
245    raise ValueError(
246        "`colocate_vars_with` must only be passed a variable created in this "
247        "tf.distribute.Strategy.scope(), not: %r" % (v,))
248  _validate_colocate_extended(v, extended)
249
250
251# Variable creation function for sync strategies.
252def _validate_synchronization(kwargs):
253  """Validate that given synchronization value is valid."""
254  synchronization = kwargs.get("synchronization",
255                               vs.VariableSynchronization.AUTO)
256  if synchronization == vs.VariableSynchronization.NONE:
257    raise ValueError(
258        "`NONE` variable synchronization mode is not supported with "
259        "tf.distribute strategy. Please change the `synchronization` for "
260        "variable: " + str(kwargs["name"]))
261  if synchronization not in (vs.VariableSynchronization.ON_READ,
262                             vs.VariableSynchronization.ON_WRITE,
263                             vs.VariableSynchronization.AUTO):
264    raise ValueError(
265        "Invalid variable synchronization mode: %s for variable: %s" %
266        (synchronization, kwargs["name"]))
267  if synchronization == vs.VariableSynchronization.AUTO:
268    return vs.VariableSynchronization.ON_WRITE
269  return synchronization
270
271
272def _validate_aggregation(kwargs):
273  aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE)
274
275  if aggregation not in (vs.VariableAggregation.NONE,
276                         vs.VariableAggregation.SUM,
277                         vs.VariableAggregation.MEAN,
278                         vs.VariableAggregation.ONLY_FIRST_REPLICA):
279    raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
280                     (aggregation, kwargs["name"]))
281  return aggregation
282
283
284def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
285                             policy_mapping, **kwargs):
286  """Create distributed variables with given synchronization and aggregation."""
287  # Figure out what collections this variable should be added to.
288  # We'll add the MirroredVariable to those collections instead.
289  var_collections = kwargs.pop("collections", None)
290  if var_collections is None:
291    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
292  kwargs["collections"] = []
293
294  synchronization = _validate_synchronization(kwargs)
295  # Update synchronization in kwargs in case it's AUTO, which is converted to
296  # ON_WRITE.
297  kwargs["synchronization"] = synchronization
298  aggregation = _validate_aggregation(kwargs)
299  use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
300
301  # Ignore user-specified caching device, not needed for mirrored variables.
302  kwargs.pop("caching_device", None)
303
304  # TODO(josh11b,apassos): It would be better if variable initialization
305  # was never recorded on the tape instead of having to do this manually
306  # here.
307  with tape.stop_recording():
308    value_list = real_mirrored_creator(**kwargs)
309    # MirroredVariable is recreated during saved_model loading, and its
310    # component variables (value_list) will have None initializer. We
311    # set their initializers to no_op so that consumer like
312    # `global_variables_initializer` wouldn't complain, as it groups all
313    # variables' initializers thus all variables have to have initializers.
314    for v in value_list:
315      # pylint:disable=protected-access
316      if hasattr(v, "_initializer_op") and v._initializer_op is None:
317        v._initializer_op = control_flow_ops.no_op()
318      # pylint:enable=protected-access
319    if use_var_policy:
320      var_policy_cls = policy_mapping.get(synchronization)
321      var_policy = var_policy_cls(aggregation=aggregation)
322      var_cls = class_mapping.get("VariableClass")
323      result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
324    else:
325      var_cls = class_mapping.get(synchronization)
326      result = var_cls(strategy, value_list, aggregation)
327
328  # Add the wrapped variable to the requested collections.
329  # The handling of eager mode and the global step matches
330  # ResourceVariable._init_from_args().
331  if not context.executing_eagerly():
332    g = ops.get_default_graph()
333    # If "trainable" is True, next_creator() will add the member variables
334    # to the TRAINABLE_VARIABLES collection, so we manually remove
335    # them and replace with the MirroredVariable. We can't set
336    # "trainable" to False for next_creator() since that causes functions
337    # like implicit_gradients to skip those variables.
338    if kwargs.get("trainable", True):
339      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
340      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
341      for value in value_list:
342        for i, trainable_variable in enumerate(l):
343          if value is trainable_variable:
344            del l[i]
345            break
346
347    g.add_to_collections(var_collections, result)
348  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
349    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
350
351  return result
352
353
354# Utility functions
355# Return True if the Value is Mirrored or the Variable is replicated and kept in
356# sync.
357def is_mirrored(val):
358  if isinstance(val, values_lib.DistributedVariable):
359    if val._policy:  # pylint: disable=protected-access
360      return val._policy._is_mirrored()  # pylint: disable=protected-access
361  return isinstance(val, values_lib.Mirrored)
362
363
364def is_sync_on_read(val):
365  if isinstance(val, values_lib.DistributedVariable):
366    if val._policy:  # pylint: disable=protected-access
367      return not val._policy._is_mirrored()  # pylint: disable=protected-access
368  return not isinstance(val, values_lib.Mirrored)
369
370
371class CachingScopeLocal(threading.local):
372  """Class for maintaining thread local state for caching scope."""
373
374  def __init__(self):
375    super(CachingScopeLocal, self).__init__()
376    self.new_cache_scope_count = 0
377    self.cache_scope_exited_count = 0
378
379  def enter_scope(self):
380    self.new_cache_scope_count += 1
381
382  def exit_scope(self):
383    self.cache_scope_exited_count += 1
384
385  def in_caching_scope(self):
386    return self.new_cache_scope_count > self.cache_scope_exited_count
387
388
389caching_scope_local = CachingScopeLocal()
390
391
392@contextlib.contextmanager
393def cache_variable_reads():
394  """Scope for caching variable reads for AggregatingVariable.
395
396  The variable reads for AggregatingVariable inside this scope are cached. i.e.
397  the first read of variable reads the value from possibly remote handle, but
398  subsequent reads are returned using local cached value.
399
400  For example:
401  strategy = ParameterServerStrategy...
402  with strategy.scope():
403    # Variable v is of AggregatingVariable type with actual variable residing
404    # on PS.
405    v = tf.Variable(1.0)
406
407  with distribute_utils.cache_variable_reads():
408    v.read_value()  # Reads value 1.0
409    v.assign(constant_op.constant(5.0))  # v changes to 5.0
410    t1 = v.read_value()
411    t2 = v.read_value()  # Both t1 & t2 return cached value 1.0 from local CPU.
412
413  Notes about cache_variable_reads scope:
414  1. Nesting of scope cache_variable_reads() is not supported
415  2. And when caching scope is enabled, the thread enabling the cache and
416    mirrored_run._MirroredReplicaThread threads spawned from it will have
417    caching enabled.
418
419  Yields:
420    A context for caching variables.
421  """
422
423  try:
424    if caching_scope_local.in_caching_scope():
425      # There is nested cache scope, which is not supported.
426      raise ValueError("cache_variable_reads scope cannot be nested")
427    caching_scope_local.enter_scope()
428    yield
429  finally:
430    caching_scope_local.exit_scope()
431
432
433# The following mapping indicates the policy that you must use for a given
434# variable `synchronization` and `aggregation` pair.
435# OnWritePolicy is used for:
436# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
437# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
438# OnReadPolicy is used for:
439# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
440VARIABLE_POLICY_MAPPING = {
441    vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
442    vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
443}
444
445VARIABLE_CLASS_MAPPING = {
446    "VariableClass": values_lib.DistributedVariable,
447    vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
448    vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
449}
450
451TPU_VARIABLE_POLICY_MAPPING = {
452    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
453    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
454}
455
456TPU_VARIABLE_CLASS_MAPPING = {
457    "VariableClass": tpu_values_lib.TPUDistributedVariable,
458    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
459    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,
460}
461