• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Important value classes relevant to `ClusterCoordinator`.
17
18This is currently under development and the API is subject to change.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import enum
26import threading
27
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.data.ops.options import ExternalStatePolicy
30from tensorflow.python.distribute import input_lib
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import function as tf_function
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import type_spec as type_spec_lib
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import gen_dataset_ops
40from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
44
45
46class RemoteValueStatus(enum.Enum):
47  """The status of a `RemoteValue` object.
48
49  A `RemoteValue` object can have three states:
50    1) not ready: no value, no non-retryable error and not aborted;
51    2) aborted: i.e. the execution of function was aborted because of task
52       failure, but can be retried;
53    3) ready: i.e. has value or has non-tryable error;
54
55  The initial state of a `RemoteValue` is "not ready". When its corresponding
56  closure has
57  been executed at least once, it will become aborted or ready. The state
58  transitions are:
59    1) not ready -> 2) aborted:
60      when the corresponding closure is aborted due to worker failure, and the
61      worker failure is not immediately handled.
62    1) not ready -> 3) ready:
63      when the corresponding closure has been executed successfully.
64    2) aborted -> 3) ready:
65      when the `RemoteValue` is rebuilt by rerunning the corresponding closure
66      and the closure has been executed successfully.
67    3) ready -> 2) aborted:
68      when the corresponding closure had been executed successfully but later
69      the corresponding remote worker failed. This is currently only implemented
70      for resource `RemoteValue` like iterators.
71  """
72  NOT_READY = "NOT_READY"
73  ABORTED = "ABORTED"
74  READY = "READY"
75
76
77@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[])
78class RemoteValue(object):
79  """An asynchronously available value of a scheduled function.
80
81  This class is used as the return value of
82  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where
83  the underlying value becomes available at a later time once the function has
84  been executed.
85
86  Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to
87  a subsequent function scheduled with
88  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is
89  currently not supported.
90
91  Example:
92
93  ```python
94  strategy = tf.distribute.experimental.ParameterServerStrategy(
95      cluster_resolver=...)
96  coordinator = (
97      tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
98
99  with strategy.scope():
100    v1 = tf.Variable(initial_value=0.0)
101    v2 = tf.Variable(initial_value=1.0)
102
103  @tf.function
104  def worker_fn():
105    v1.assign_add(0.1)
106    v2.assign_sub(0.2)
107    return v1.read_value() / v2.read_value()
108
109  result = coordinator.schedule(worker_fn)
110  # Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
111  assert result.fetch() == 0.125
112
113  for _ in range(10):
114    # `worker_fn` will be run on arbitrary workers that are available. The
115    # `result` value will be available later.
116    result = coordinator.schedule(worker_fn)
117  ```
118  """
119
120  def fetch(self):
121    """Wait for the result of `RemoteValue` and return the numpy result.
122
123    This makes the value concrete by copying the remote value to local.
124
125    Returns:
126      The numpy array structure of the actual output of the `tf.function`
127      associated with this `RemoteValue`, previously returned by a
128      `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
129      This can be a single value, or a structure of values, depending on the
130      output of the `tf.function`.
131
132    Raises:
133      tf.errors.CancelledError: If the function that produces this `RemoteValue`
134        is aborted or cancelled due to failure.
135    """
136    raise NotImplementedError("Must be implemented in subclasses.")
137
138  def get(self):
139    """Wait for the result of `RemoteValue` and return the tensor result.
140
141    This makes the value concrete by copying the remote tensor to local.
142
143    Returns:
144      The actual output (in the form of `tf.Tensor`s) of the `tf.function`
145      associated with this `RemoteValue`, previously returned by a
146      `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
147      This can be a single Tensor, or a structure of Tensors, depending on the
148      output of the `tf.function`.
149
150    Raises:
151      tf.errors.CancelledError: If the function that produces this `RemoteValue`
152        is aborted or cancelled due to failure.
153    """
154    raise NotImplementedError("Must be implemented in subclasses.")
155
156
157# TODO(yuefengz): create an implementation for resource RemoteValue which needs
158# to remember the closure object while a normal RemoteValue doesn't.
159class RemoteValueImpl(RemoteValue):
160  """Implementation of `RemoteValue`."""
161
162  def __init__(self, closure, type_spec):  # pylint: disable=super-init-not-called
163    """Initializes a `RemoteValueImpl`.
164
165    Args:
166      closure: The closure from which the `RemoteValue` is created.
167      type_spec: The type spec for this `RemoteValue` which is used to trace
168        functions that take this `RemoteValue` as input.
169    """
170    self._closure = closure
171    self._type_spec = type_spec
172    self._values = None
173    self._has_fetched_to_local = False
174    self._has_fetched_to_local_lock = threading.Lock()
175    self._fetched_tensors = None
176    self._error = None
177    self._status_available_event = threading.Event()
178    self._status = RemoteValueStatus.NOT_READY
179
180  def _set_aborted(self):
181    self._status = RemoteValueStatus.ABORTED
182    self._values = None
183    self._error = None
184
185    # Wake up any waiting thread and clear the event.
186    self._status_available_event.set()
187
188  def _rebuild_on(self, worker):
189    self._status_available_event.clear()
190    # TODO(yuefengz): we may need to rebuild its inputs as well.
191    self._closure.execute_on(worker)
192
193  def _set_values(self, tensors):
194    self._status = RemoteValueStatus.READY
195    self._values = tensors
196    self._error = None
197    self._status_available_event.set()
198
199  def _set_error(self, exception):
200    self._status = RemoteValueStatus.READY
201    self._values = None
202    self._error = exception
203    self._status_available_event.set()
204
205  def _get_values(self):
206    self._status_available_event.wait()
207    return self._values
208
209  def _get_error(self):
210    self._status_available_event.wait()
211    return self._error
212
213  def _wait_and_maybe_error(self):
214    self._status_available_event.wait()
215    if self._status is RemoteValueStatus.ABORTED:
216      raise errors.CancelledError(
217          None, None,
218          "The corresponding function is aborted. Please reschedule the "
219          "function.")
220    if self._error is not None:
221      raise self._error
222
223  def fetch(self):
224    # TODO(rchao): Discuss the possibility of letting users perform `numpy`
225    # themselves at API graduation.
226    return nest.map_structure(
227        lambda x: x.numpy() if hasattr(x, "numpy") else x, self.get())
228
229  def get(self):
230    self._wait_and_maybe_error()
231
232    with self._has_fetched_to_local_lock:
233      if not self._has_fetched_to_local:
234
235        def copy_tensor(composite_tensor_obj):
236          """Copy a remote tensor to local (coordinator)."""
237          if isinstance(composite_tensor_obj, input_lib.DistributedIterator):
238            # A DistributedIterator cannot be copied to local; users should not
239            # access that anyway.
240            return composite_tensor_obj
241
242          with ops.device("/job:%s" % context.get_server_def().job_name):
243            # Copying to local (the coordinator) with `tf.device`.
244            return array_ops.identity(composite_tensor_obj)
245
246        if self._values is not None:
247          # When `self._values` is `None`, it indicates the associated function
248          # does not have a return value.
249          self._fetched_tensors = nest.map_structure(copy_tensor, self._values)
250        self._has_fetched_to_local = True
251
252    return self._fetched_tensors
253
254
255@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
256class PerWorkerValues(composite_tensor.CompositeTensor):
257  """A container that holds a list of values, one value per worker.
258
259  `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
260  of values, where each of the values is located on its corresponding worker,
261  and upon being used as one of the `args` or `kwargs` of
262  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
263  value specific to a worker will be passed into the function being executed at
264  that corresponding worker.
265
266  Currently, the only supported path to create an object of
267  `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
268  `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
269  distributed dataset instance. The mechanism to create a custom
270  `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
271  """
272
273  def __init__(self, values):
274    for v in values:
275      if not isinstance(v, RemoteValue):
276        raise AssertionError(
277            "`PerWorkerValues` should only take `RemoteValue`s.")
278    self._values = tuple(values)
279
280  @property
281  def _type_spec(self):
282    return PerWorkerValuesTypeSpec(
283        self._values[0]._type_spec,  # pylint: disable=protected-access
284        type(self))
285
286
287class PerWorkerValuesTypeSpec(type_spec_lib.TypeSpec):
288  """TypeSpec for PerWorkerValues.
289
290  It only support tracing a function using a PerWorkerValues.
291  """
292
293  def __init__(self, value_spec, descendant_type):
294    assert value_spec
295    self._value_spec = value_spec
296    self._descendant_type = descendant_type
297
298  def _serialize(self):
299    return (self._value_spec,)
300
301  @property
302  def value_type(self):
303    return self._descendant_type
304
305  def most_specific_compatible_type(self, other):
306    raise NotImplementedError(
307        "most_specific_compatible_type is not implemented")
308
309  @property
310  def _component_specs(self):
311    return self._value_spec
312
313  def _to_components(self, value):
314    return self._value_spec
315
316  def _from_components(self, value):
317    return value
318
319
320class PerWorkerDatasetFromDatasetFunction(object):
321  """Represents worker-distributed datasets created from dataset function."""
322
323  def __init__(self, dataset_fn, coordinator):
324    """Makes an iterable from datasets created by the given function.
325
326    Args:
327      dataset_fn: A function that returns a `Dataset`.
328      coordinator: a `ClusterCoordinator` object, used to create dataset
329        resources.
330    """
331
332    def disallow_variable_creation(next_creator, **kwargs):
333      raise ValueError("Creating variables in `dataset_fn` is not allowed.")
334
335    if isinstance(dataset_fn, def_function.Function):
336      with variable_scope.variable_creator_scope(disallow_variable_creation):
337        dataset_fn = dataset_fn.get_concrete_function()
338    elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
339      with variable_scope.variable_creator_scope(disallow_variable_creation):
340        dataset_fn = def_function.function(dataset_fn).get_concrete_function()
341    self._dataset_fn = dataset_fn
342    self._coordinator = coordinator
343    self._element_spec = None
344
345  def __iter__(self):
346    # We would like users to create iterators outside `tf.function`s so that we
347    # can track them.
348    if (not context.executing_eagerly() or
349        ops.get_default_graph().building_function):
350      raise RuntimeError(
351          "__iter__() is not supported inside of tf.function or in graph mode.")
352
353    def _create_per_worker_iterator():
354      dataset = self._dataset_fn()
355      return iter(dataset)
356
357    # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
358    # times, for the same object it should only create and register resource
359    # once. Using object id to distinguish different iterator resources.
360    per_worker_iterator = self._coordinator._create_per_worker_resources(
361        _create_per_worker_iterator)
362
363    # Setting type_spec of each RemoteValue so that functions taking these
364    # RemoteValues as inputs can be traced.
365    for iterator_remote_value in per_worker_iterator._values:
366      iterator_remote_value._type_spec = (
367          input_lib.get_iterator_spec_from_dataset(
368              self._coordinator.strategy, self._dataset_fn.structured_outputs))
369
370    return PerWorkerDistributedIterator(per_worker_iterator._values)
371
372  @property
373  def element_spec(self):
374    """The type specification of an element of this dataset.
375
376    This property is subject to change without notice.
377    """
378    if not isinstance(self._dataset_fn, tf_function.ConcreteFunction):
379      raise NotImplementedError(
380          "`element_spec` is not supported when the `dataset_fn` is not "
381          "a `ConcreteFunction`.")
382    return self._dataset_fn.structured_outputs.element_spec
383
384
385def serialize_dataset_to_graph(dataset):
386  dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
387  graph_def = gen_dataset_ops.dataset_to_graph_v2(
388      dataset._variant_tensor,  # pylint: disable=protected-access
389      external_state_policy=ExternalStatePolicy.WARN.value,
390      strip_device_assignment=True)
391  return graph_def
392
393
394class _RemoteDataset(dataset_ops.DatasetSource):
395  """Creates a dataset given a graph def."""
396
397  def __init__(self, graph_def, element_spec):
398    self._elem_spec = element_spec
399    variant_tensor = ged_ops.dataset_from_graph(graph_def)
400    super(_RemoteDataset, self).__init__(variant_tensor)
401
402  @property
403  def element_spec(self):
404    return self._elem_spec
405
406
407def deserialize_dataset_from_graph(graph_def, element_spec):
408  return _RemoteDataset(graph_def, element_spec)
409
410
411class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
412  """Represents worker-distributed datasets created from a dataset."""
413
414  def __init__(self, dataset, coordinator):
415    """Makes an iterable from datasets created by the given dataset.
416
417    It creates a dataset_fn which deserializes a dataset from a graph under the
418    hood.
419
420    Args:
421      dataset: A tf.data.Dataset, a DistributedDataset or a
422        DistributedDatasetsFromFunction
423      coordinator: a `ClusterCoordinator` object, used to create dataset
424        resources.
425    """
426    if isinstance(dataset, input_lib.DistributedDataset):
427      original_dataset = dataset._original_dataset
428      serialized = serialize_dataset_to_graph(original_dataset)
429
430      def dataset_fn():
431        deserialized = deserialize_dataset_from_graph(
432            serialized, original_dataset.element_spec)
433        dataset.build(dataset_to_replace=deserialized)
434        return dataset
435    elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
436      def dataset_fn():
437        dataset.build()
438        return dataset
439    elif isinstance(dataset, dataset_ops.Dataset):
440      serialized = serialize_dataset_to_graph(dataset)
441
442      def dataset_fn():
443        return deserialize_dataset_from_graph(serialized, dataset.element_spec)
444    else:
445      raise ValueError("Unexpected dataset type!")
446
447    super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)
448
449
450def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
451  if callable(dataset_or_dataset_fn):
452    return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
453                                               coordinator)
454  else:
455    return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)
456
457
458class PerWorkerDistributedIterator(PerWorkerValues):
459  """Distributed iterator for `ClusterCoordinator`."""
460
461  def __next__(self):
462    return self.get_next()
463
464  def get_next(self, name=None):
465    """Returns the next input from the iterator for all replicas."""
466    raise NotImplementedError("Iterating over an `AsyncDistributedIterator` "
467                              "is not supported right now.")
468