• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A tf.distribute.Strategy for running on a single device."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import device_util
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.distribute import distribute_utils
24from tensorflow.python.distribute import input_lib
25from tensorflow.python.distribute import numpy_dataset
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.util import nest
31from tensorflow.python.util.tf_export import tf_export
32
33
34# TODO(josh11b): Do we wrap values in types to generate errors if you are
35# doing something that won't work with other DistributionStrategy
36# implementations?
37
38
39@tf_export("distribute.OneDeviceStrategy", v1=[])
40class OneDeviceStrategy(distribute_lib.Strategy):
41  """A distribution strategy for running on a single device.
42
43  Using this strategy will place any variables created in its scope on the
44  specified device. Input distributed through this strategy will be
45  prefetched to the specified device. Moreover, any functions called via
46  `strategy.run` will also be placed on the specified device
47  as well.
48
49  Typical usage of this strategy could be testing your code with the
50  tf.distribute.Strategy API before switching to other strategies which
51  actually distribute to multiple devices/machines.
52
53  For example:
54  ```
55  strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
56
57  with strategy.scope():
58    v = tf.Variable(1.0)
59    print(v.device)  # /job:localhost/replica:0/task:0/device:GPU:0
60
61  def step_fn(x):
62    return x * 2
63
64  result = 0
65  for i in range(10):
66    result += strategy.run(step_fn, args=(i,))
67  print(result)  # 90
68  ```
69  """
70
71  def __init__(self, device):
72    """Creates a `OneDeviceStrategy`.
73
74    Args:
75      device: Device string identifier for the device on which the variables
76        should be placed. See class docs for more details on how the device is
77        used. Examples: "/cpu:0", "/gpu:0", "/device:CPU:0", "/device:GPU:0"
78    """
79    super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device))
80    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
81        "OneDeviceStrategy")
82
83  def experimental_distribute_dataset(self, dataset, options=None):  # pylint: disable=useless-super-delegation
84    """Distributes a tf.data.Dataset instance provided via dataset.
85
86    In this case, there is only one device, so this is only a thin wrapper
87    around the input dataset. It will, however, prefetch the input data to the
88    specified device. The returned distributed dataset can be iterated over
89    similar to how regular datasets can.
90
91    NOTE: Currently, the user cannot add any more transformations to a
92    distributed dataset.
93
94    Example:
95    ```
96    strategy = tf.distribute.OneDeviceStrategy()
97    dataset = tf.data.Dataset.range(10).batch(2)
98    dist_dataset = strategy.experimental_distribute_dataset(dataset)
99    for x in dist_dataset:
100      print(x)  # [0, 1], [2, 3],...
101    ```
102    Args:
103      dataset: `tf.data.Dataset` to be prefetched to device.
104      options: `tf.distribute.InputOptions` used to control options on how this
105        dataset is distributed.
106    Returns:
107      A "distributed `Dataset`" that the caller can iterate over.
108    """
109    return super(OneDeviceStrategy, self).experimental_distribute_dataset(
110        dataset, options)
111
112  def distribute_datasets_from_function(
113      self,
114      dataset_fn,  # pylint: disable=useless-super-delegation
115      options=None):
116    """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
117
118    `dataset_fn` will be called once for each worker in the strategy. In this
119    case, we only have one worker and one device so `dataset_fn` is called
120    once.
121
122    The `dataset_fn` should take an `tf.distribute.InputContext` instance where
123    information about batching and input replication can be accessed:
124
125    ```
126    def dataset_fn(input_context):
127      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
128      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
129      return d.shard(
130          input_context.num_input_pipelines, input_context.input_pipeline_id)
131
132    inputs = strategy.distribute_datasets_from_function(dataset_fn)
133
134    for batch in inputs:
135      replica_results = strategy.run(replica_fn, args=(batch,))
136    ```
137
138    IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
139    per-replica batch size, unlike `experimental_distribute_dataset`, which uses
140    the global batch size.  This may be computed using
141    `input_context.get_per_replica_batch_size`.
142
143    Args:
144      dataset_fn: A function taking a `tf.distribute.InputContext` instance and
145        returning a `tf.data.Dataset`.
146      options: `tf.distribute.InputOptions` used to control options on how this
147        dataset is distributed.
148
149    Returns:
150      A "distributed `Dataset`", which the caller can iterate over like regular
151      datasets.
152    """
153    return super(OneDeviceStrategy,
154                 self).distribute_datasets_from_function(dataset_fn, options)
155
156  def experimental_local_results(self, value):  # pylint: disable=useless-super-delegation
157    """Returns the list of all local per-replica values contained in `value`.
158
159    In `OneDeviceStrategy`, the `value` is always expected to be a single
160    value, so the result is just the value in a tuple.
161
162    Args:
163      value: A value returned by `experimental_run()`, `run()`,
164        `extended.call_for_each_replica()`, or a variable created in `scope`.
165
166    Returns:
167      A tuple of values contained in `value`. If `value` represents a single
168      value, this returns `(value,).`
169    """
170    return super(OneDeviceStrategy, self).experimental_local_results(value)
171
172  def run(self, fn, args=(), kwargs=None, options=None):  # pylint: disable=useless-super-delegation
173    """Run `fn` on each replica, with the given arguments.
174
175    In `OneDeviceStrategy`, `fn` is simply called within a device scope for the
176    given device, with the provided arguments.
177
178    Args:
179      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
180      args: (Optional) Positional arguments to `fn`.
181      kwargs: (Optional) Keyword arguments to `fn`.
182      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
183        the options to run `fn`.
184
185    Returns:
186      Return value from running `fn`.
187    """
188    return super(OneDeviceStrategy, self).run(fn, args, kwargs, options)
189
190  def reduce(self, reduce_op, value, axis):  # pylint: disable=useless-super-delegation
191    """Reduce `value` across replicas.
192
193    In `OneDeviceStrategy`, there is only one replica, so if axis=None, value
194    is simply returned. If axis is specified as something other than None,
195    such as axis=0, value is reduced along that axis and returned.
196
197    Example:
198    ```
199    t = tf.range(10)
200
201    result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=None).numpy()
202    # result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
203
204    result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=0).numpy()
205    # result: 45
206    ```
207
208    Args:
209      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
210        be combined.
211      value: A "per replica" value, e.g. returned by `run` to
212        be combined into a single tensor.
213      axis: Specifies the dimension to reduce along within each
214        replica's tensor. Should typically be set to the batch dimension, or
215        `None` to only reduce across replicas (e.g. if the tensor has no batch
216        dimension).
217
218    Returns:
219      A `Tensor`.
220    """
221    return super(OneDeviceStrategy, self).reduce(reduce_op, value, axis)
222
223  def scope(self):  # pylint: disable=useless-super-delegation
224    """Returns a context manager selecting this Strategy as current.
225
226    Inside a `with strategy.scope():` code block, this thread
227    will use a variable creator set by `strategy`, and will
228    enter its "cross-replica context".
229
230    In `OneDeviceStrategy`, all variables created inside `strategy.scope()`
231    will be on `device` specified at strategy construction time.
232    See example in the docs for this class.
233
234    Returns:
235      A context manager to use for creating variables with this strategy.
236    """
237    return super(OneDeviceStrategy, self).scope()
238
239
240@tf_export(v1=["distribute.OneDeviceStrategy"])  # pylint: disable=empty-docstring
241class OneDeviceStrategyV1(distribute_lib.StrategyV1):
242
243  __doc__ = OneDeviceStrategy.__doc__.replace(
244      "For example:\n  ```",
245      "For example:\n  ```\n  tf.enable_eager_execution()")
246
247  def __init__(self, device):
248    super(OneDeviceStrategyV1, self).__init__(OneDeviceExtended(self, device))
249    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
250        "OneDeviceStrategy")
251  __init__.__doc__ = OneDeviceStrategy.__init__.__doc__
252
253
254# TODO(josh11b): Switch to V2 after callers have been updated to only V2 APIs.
255class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
256  """Implementation of OneDeviceStrategy."""
257
258  def __init__(self, container_strategy, device):
259    super(OneDeviceExtended, self).__init__(container_strategy)
260    self._device = device_util.resolve(device)
261    self._input_device = device_util.get_host_for_device(self._device)
262
263  def _input_workers_with_options(self, options=None):
264    if not options or options.experimental_fetch_to_device:
265      return input_lib.InputWorkers([(self._input_device, (self._device,))])
266    else:
267      return input_lib.InputWorkers([(self._input_device,
268                                      (self._input_device,))])
269
270  @property
271  def _input_workers(self):
272    return self._input_workers_with_options()
273
274  def _create_variable(self, next_creator, **kwargs):
275    colocate_with = kwargs.pop("colocate_with", None)
276    if colocate_with is None:
277      with ops.device(self._device):
278        return next_creator(**kwargs)
279    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
280      with ops.device(colocate_with.device):
281        return next_creator(**kwargs)
282    else:
283      with ops.colocate_with(colocate_with):
284        return next_creator(**kwargs)
285
286  def _validate_colocate_with_variable(self, colocate_with_variable):
287    distribute_utils.validate_colocate(colocate_with_variable, self)
288
289  def _make_dataset_iterator(self, dataset):
290    """Make iterator from dataset without splitting the batch."""
291    # Note that split_batch_by argument is not passed because it is always 1 in
292    # this strategy, and adding it adds unnecessary overhead to the dataset.
293    return input_lib.DatasetIterator(dataset, self._input_workers,
294                                     self._container_strategy())
295
296  def _make_input_fn_iterator(
297      self,
298      input_fn,
299      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
300    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
301                                           [distribute_lib.InputContext()],
302                                           self._container_strategy())
303
304  def _experimental_make_numpy_dataset(self, numpy_input, session):
305    return numpy_dataset.one_host_numpy_dataset(
306        numpy_input, numpy_dataset.SingleDevice(self._input_device), session)
307
308  def _broadcast_to(self, tensor, destinations):
309    del destinations
310    return tensor
311
312  def _experimental_distribute_dataset(self, dataset, options):
313    # Note that split_batch_by argument is not passed because it is always 1 in
314    # this strategy, and adding it adds unnecessary overhead to the dataset.
315    if (options and options.experimental_replication_mode ==
316        distribute_lib.InputReplicationMode.PER_REPLICA):
317      raise NotImplementedError(
318          "InputReplicationMode.PER_REPLICA "
319          "is only supported in  "
320          "`experimental_distribute_datasets_from_function`."
321      )
322    return input_lib.get_distributed_dataset(
323        dataset,
324        self._input_workers_with_options(options),
325        self._container_strategy(),
326        options=options)
327
328  def _distribute_datasets_from_function(self, dataset_fn, options):
329    if (options and options.experimental_replication_mode ==
330        distribute_lib.InputReplicationMode.PER_REPLICA):
331      raise NotImplementedError(
332          "InputReplicationMode.PER_REPLICA "
333          "is only supported in "
334          "`experimental_distribute_datasets_from_function` "
335          "of tf.distribute.MirroredStrategy")
336    return input_lib.get_distributed_datasets_from_function(
337        dataset_fn,
338        self._input_workers_with_options(options),
339        [distribute_lib.InputContext()],
340        self._container_strategy(),
341        options=options)
342
343  def _experimental_distribute_values_from_function(self, value_fn):
344    # TODO(b/137795644): This should return a PerReplica value but other
345    # methods like run in OneDeviceStrategy need to be modified
346    # to do the same.
347    return value_fn(distribute_lib.ValueContext())
348
349  # TODO(priyag): Deal with OutOfRange errors  once b/111349762 is fixed.
350  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
351                                          initial_loop_values=None):
352    if initial_loop_values is None:
353      initial_loop_values = {}
354    initial_loop_values = nest.flatten(initial_loop_values)
355
356    ctx = input_lib.MultiStepContext()
357    def body(i, *args):
358      """A wrapper around `fn` to create the while loop body."""
359      del args
360      fn_result = fn(ctx, iterator.get_next())
361      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
362      with ops.control_dependencies([fn_result]):
363        return [i + 1] + flat_last_step_outputs
364
365    # We capture the control_flow_context at this point, before we run `fn`
366    # inside a while_loop. This is useful in cases where we might need to exit
367    # these contexts and get back to the outer context to do some things, for
368    # e.g. create an op which should be evaluated only once at the end of the
369    # loop on the host. One such usage is in creating metrics' value op.
370    self._outer_control_flow_context = (
371        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
372
373    # TODO(priyag): Use max_iterations instead of an explicit counter.
374    cond = lambda i, *args: i < iterations
375    i = constant_op.constant(0)
376    loop_result = control_flow_ops.while_loop(
377        cond, body, [i] + initial_loop_values, name="",
378        parallel_iterations=1, back_prop=False, swap_memory=False,
379        return_same_structure=True)
380    del self._outer_control_flow_context
381
382    ctx.run_op = control_flow_ops.group(loop_result)
383
384    # Convert the last_step_outputs from a list to the original dict structure
385    # of last_step_outputs.
386    last_step_tensor_outputs = loop_result[1:]
387    last_step_tensor_outputs_dict = nest.pack_sequence_as(
388        ctx.last_step_outputs, last_step_tensor_outputs)
389
390    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
391    return ctx
392
393  def _call_for_each_replica(self, fn, args, kwargs):
394    strategy = self._container_strategy()
395    with ops.device(self._device), _OneDeviceReplicaContext(strategy):
396      return fn(*args, **kwargs)
397
398  def _reduce_to(self, reduce_op, value, destinations, options):
399    del reduce_op, destinations, options
400    return value
401
402  def _gather_to_implementation(self, value, destinations, axis, options):
403    del destinations, axis, options
404    return value
405
406  def _update(self, var, fn, args, kwargs, group):
407    # The implementations of _update() and _update_non_slot() are identical
408    # except _update() passes `var` as the first argument to `fn()`.
409    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
410
411  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
412    del colocate_with
413    with ops.device(self._device), distribute_lib.UpdateContext(self._device):
414      result = fn(*args, **kwargs)
415      if group:
416        return result
417      else:
418        return nest.map_structure(self._local_results, result)
419
420  def read_var(self, replica_local_var):
421    """Read the aggregate value of a replica-local variable."""
422    return array_ops.identity(replica_local_var)
423
424  def _local_results(self, value):
425    return (value,)
426
427  def value_container(self, value):
428    return value
429
430  def _in_multi_worker_mode(self):
431    """Whether this strategy indicates working in multi-worker settings."""
432    return False
433
434  @property
435  def _num_replicas_in_sync(self):
436    return 1
437
438  @property
439  def worker_devices(self):
440    return (self._device,)
441
442  @property
443  def parameter_devices(self):
444    return (self._device,)
445
446  def non_slot_devices(self, var_list):
447    del var_list
448    return (self._device,)
449
450  @property
451  def experimental_should_init(self):
452    return True
453
454  @property
455  def experimental_between_graph(self):
456    return False
457
458  @property
459  def should_checkpoint(self):
460    return True
461
462  @property
463  def should_save_summary(self):
464    return True
465
466  # TODO(priyag): Delete this once all strategies use global batch size.
467  @property
468  def _global_batch_size(self):
469    """Global and per-replica batching are equivalent for OneDeviceStrategy."""
470    return True
471
472  @property
473  def _support_per_replica_values(self):
474    return False
475
476  def _get_local_replica_id(self, replica_id_in_sync_group):
477    return replica_id_in_sync_group
478
479
480class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
481  """ReplicaContext for OneDeviceStrategy."""
482
483  def __init__(self, strategy):
484    distribute_lib.ReplicaContext.__init__(
485        self, strategy, replica_id_in_sync_group=0)
486
487  @property
488  def devices(self):
489    return self._strategy.extended.worker_devices
490