1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Important value classes relevant to `ClusterCoordinator`. 17 18This is currently under development and the API is subject to change. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import enum 26import threading 27 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.data.ops.options import ExternalStatePolicy 30from tensorflow.python.distribute import input_lib 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import function as tf_function 34from tensorflow.python.framework import composite_tensor 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import type_spec as type_spec_lib 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import gen_dataset_ops 40from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.util import nest 43from tensorflow.python.util.tf_export import tf_export 44 45 46class RemoteValueStatus(enum.Enum): 47 """The status of a `RemoteValue` object. 48 49 A `RemoteValue` object can have three states: 50 1) not ready: no value, no non-retryable error and not aborted; 51 2) aborted: i.e. the execution of function was aborted because of task 52 failure, but can be retried; 53 3) ready: i.e. has value or has non-tryable error; 54 55 The initial state of a `RemoteValue` is "not ready". When its corresponding 56 closure has 57 been executed at least once, it will become aborted or ready. The state 58 transitions are: 59 1) not ready -> 2) aborted: 60 when the corresponding closure is aborted due to worker failure, and the 61 worker failure is not immediately handled. 62 1) not ready -> 3) ready: 63 when the corresponding closure has been executed successfully. 64 2) aborted -> 3) ready: 65 when the `RemoteValue` is rebuilt by rerunning the corresponding closure 66 and the closure has been executed successfully. 67 3) ready -> 2) aborted: 68 when the corresponding closure had been executed successfully but later 69 the corresponding remote worker failed. This is currently only implemented 70 for resource `RemoteValue` like iterators. 71 """ 72 NOT_READY = "NOT_READY" 73 ABORTED = "ABORTED" 74 READY = "READY" 75 76 77@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[]) 78class RemoteValue(object): 79 """An asynchronously available value of a scheduled function. 80 81 This class is used as the return value of 82 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where 83 the underlying value becomes available at a later time once the function has 84 been executed. 85 86 Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to 87 a subsequent function scheduled with 88 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is 89 currently not supported. 90 91 Example: 92 93 ```python 94 strategy = tf.distribute.experimental.ParameterServerStrategy( 95 cluster_resolver=...) 96 coordinator = ( 97 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)) 98 99 with strategy.scope(): 100 v1 = tf.Variable(initial_value=0.0) 101 v2 = tf.Variable(initial_value=1.0) 102 103 @tf.function 104 def worker_fn(): 105 v1.assign_add(0.1) 106 v2.assign_sub(0.2) 107 return v1.read_value() / v2.read_value() 108 109 result = coordinator.schedule(worker_fn) 110 # Note that `fetch()` gives the actual result instead of a `tf.Tensor`. 111 assert result.fetch() == 0.125 112 113 for _ in range(10): 114 # `worker_fn` will be run on arbitrary workers that are available. The 115 # `result` value will be available later. 116 result = coordinator.schedule(worker_fn) 117 ``` 118 """ 119 120 def fetch(self): 121 """Wait for the result of `RemoteValue` and return the numpy result. 122 123 This makes the value concrete by copying the remote value to local. 124 125 Returns: 126 The numpy array structure of the actual output of the `tf.function` 127 associated with this `RemoteValue`, previously returned by a 128 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call. 129 This can be a single value, or a structure of values, depending on the 130 output of the `tf.function`. 131 132 Raises: 133 tf.errors.CancelledError: If the function that produces this `RemoteValue` 134 is aborted or cancelled due to failure. 135 """ 136 raise NotImplementedError("Must be implemented in subclasses.") 137 138 def get(self): 139 """Wait for the result of `RemoteValue` and return the tensor result. 140 141 This makes the value concrete by copying the remote tensor to local. 142 143 Returns: 144 The actual output (in the form of `tf.Tensor`s) of the `tf.function` 145 associated with this `RemoteValue`, previously returned by a 146 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call. 147 This can be a single Tensor, or a structure of Tensors, depending on the 148 output of the `tf.function`. 149 150 Raises: 151 tf.errors.CancelledError: If the function that produces this `RemoteValue` 152 is aborted or cancelled due to failure. 153 """ 154 raise NotImplementedError("Must be implemented in subclasses.") 155 156 157# TODO(yuefengz): create an implementation for resource RemoteValue which needs 158# to remember the closure object while a normal RemoteValue doesn't. 159class RemoteValueImpl(RemoteValue): 160 """Implementation of `RemoteValue`.""" 161 162 def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called 163 """Initializes a `RemoteValueImpl`. 164 165 Args: 166 closure: The closure from which the `RemoteValue` is created. 167 type_spec: The type spec for this `RemoteValue` which is used to trace 168 functions that take this `RemoteValue` as input. 169 """ 170 self._closure = closure 171 self._type_spec = type_spec 172 self._values = None 173 self._has_fetched_to_local = False 174 self._has_fetched_to_local_lock = threading.Lock() 175 self._fetched_tensors = None 176 self._error = None 177 self._status_available_event = threading.Event() 178 self._status = RemoteValueStatus.NOT_READY 179 180 def _set_aborted(self): 181 self._status = RemoteValueStatus.ABORTED 182 self._values = None 183 self._error = None 184 185 # Wake up any waiting thread and clear the event. 186 self._status_available_event.set() 187 188 def _rebuild_on(self, worker): 189 self._status_available_event.clear() 190 # TODO(yuefengz): we may need to rebuild its inputs as well. 191 self._closure.execute_on(worker) 192 193 def _set_values(self, tensors): 194 self._status = RemoteValueStatus.READY 195 self._values = tensors 196 self._error = None 197 self._status_available_event.set() 198 199 def _set_error(self, exception): 200 self._status = RemoteValueStatus.READY 201 self._values = None 202 self._error = exception 203 self._status_available_event.set() 204 205 def _get_values(self): 206 self._status_available_event.wait() 207 return self._values 208 209 def _get_error(self): 210 self._status_available_event.wait() 211 return self._error 212 213 def _wait_and_maybe_error(self): 214 self._status_available_event.wait() 215 if self._status is RemoteValueStatus.ABORTED: 216 raise errors.CancelledError( 217 None, None, 218 "The corresponding function is aborted. Please reschedule the " 219 "function.") 220 if self._error is not None: 221 raise self._error 222 223 def fetch(self): 224 # TODO(rchao): Discuss the possibility of letting users perform `numpy` 225 # themselves at API graduation. 226 return nest.map_structure( 227 lambda x: x.numpy() if hasattr(x, "numpy") else x, self.get()) 228 229 def get(self): 230 self._wait_and_maybe_error() 231 232 with self._has_fetched_to_local_lock: 233 if not self._has_fetched_to_local: 234 235 def copy_tensor(composite_tensor_obj): 236 """Copy a remote tensor to local (coordinator).""" 237 if isinstance(composite_tensor_obj, input_lib.DistributedIterator): 238 # A DistributedIterator cannot be copied to local; users should not 239 # access that anyway. 240 return composite_tensor_obj 241 242 with ops.device("/job:%s" % context.get_server_def().job_name): 243 # Copying to local (the coordinator) with `tf.device`. 244 return array_ops.identity(composite_tensor_obj) 245 246 if self._values is not None: 247 # When `self._values` is `None`, it indicates the associated function 248 # does not have a return value. 249 self._fetched_tensors = nest.map_structure(copy_tensor, self._values) 250 self._has_fetched_to_local = True 251 252 return self._fetched_tensors 253 254 255@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[]) 256class PerWorkerValues(composite_tensor.CompositeTensor): 257 """A container that holds a list of values, one value per worker. 258 259 `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection 260 of values, where each of the values is located on its corresponding worker, 261 and upon being used as one of the `args` or `kwargs` of 262 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the 263 value specific to a worker will be passed into the function being executed at 264 that corresponding worker. 265 266 Currently, the only supported path to create an object of 267 `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling 268 `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned 269 distributed dataset instance. The mechanism to create a custom 270 `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported. 271 """ 272 273 def __init__(self, values): 274 for v in values: 275 if not isinstance(v, RemoteValue): 276 raise AssertionError( 277 "`PerWorkerValues` should only take `RemoteValue`s.") 278 self._values = tuple(values) 279 280 @property 281 def _type_spec(self): 282 return PerWorkerValuesTypeSpec( 283 self._values[0]._type_spec, # pylint: disable=protected-access 284 type(self)) 285 286 287class PerWorkerValuesTypeSpec(type_spec_lib.TypeSpec): 288 """TypeSpec for PerWorkerValues. 289 290 It only support tracing a function using a PerWorkerValues. 291 """ 292 293 def __init__(self, value_spec, descendant_type): 294 assert value_spec 295 self._value_spec = value_spec 296 self._descendant_type = descendant_type 297 298 def _serialize(self): 299 return (self._value_spec,) 300 301 @property 302 def value_type(self): 303 return self._descendant_type 304 305 def most_specific_compatible_type(self, other): 306 raise NotImplementedError( 307 "most_specific_compatible_type is not implemented") 308 309 @property 310 def _component_specs(self): 311 return self._value_spec 312 313 def _to_components(self, value): 314 return self._value_spec 315 316 def _from_components(self, value): 317 return value 318 319 320class PerWorkerDatasetFromDatasetFunction(object): 321 """Represents worker-distributed datasets created from dataset function.""" 322 323 def __init__(self, dataset_fn, coordinator): 324 """Makes an iterable from datasets created by the given function. 325 326 Args: 327 dataset_fn: A function that returns a `Dataset`. 328 coordinator: a `ClusterCoordinator` object, used to create dataset 329 resources. 330 """ 331 332 def disallow_variable_creation(next_creator, **kwargs): 333 raise ValueError("Creating variables in `dataset_fn` is not allowed.") 334 335 if isinstance(dataset_fn, def_function.Function): 336 with variable_scope.variable_creator_scope(disallow_variable_creation): 337 dataset_fn = dataset_fn.get_concrete_function() 338 elif not isinstance(dataset_fn, tf_function.ConcreteFunction): 339 with variable_scope.variable_creator_scope(disallow_variable_creation): 340 dataset_fn = def_function.function(dataset_fn).get_concrete_function() 341 self._dataset_fn = dataset_fn 342 self._coordinator = coordinator 343 self._element_spec = None 344 345 def __iter__(self): 346 # We would like users to create iterators outside `tf.function`s so that we 347 # can track them. 348 if (not context.executing_eagerly() or 349 ops.get_default_graph().building_function): 350 raise RuntimeError( 351 "__iter__() is not supported inside of tf.function or in graph mode.") 352 353 def _create_per_worker_iterator(): 354 dataset = self._dataset_fn() 355 return iter(dataset) 356 357 # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple 358 # times, for the same object it should only create and register resource 359 # once. Using object id to distinguish different iterator resources. 360 per_worker_iterator = self._coordinator._create_per_worker_resources( 361 _create_per_worker_iterator) 362 363 # Setting type_spec of each RemoteValue so that functions taking these 364 # RemoteValues as inputs can be traced. 365 for iterator_remote_value in per_worker_iterator._values: 366 iterator_remote_value._type_spec = ( 367 input_lib.get_iterator_spec_from_dataset( 368 self._coordinator.strategy, self._dataset_fn.structured_outputs)) 369 370 return PerWorkerDistributedIterator(per_worker_iterator._values) 371 372 @property 373 def element_spec(self): 374 """The type specification of an element of this dataset. 375 376 This property is subject to change without notice. 377 """ 378 if not isinstance(self._dataset_fn, tf_function.ConcreteFunction): 379 raise NotImplementedError( 380 "`element_spec` is not supported when the `dataset_fn` is not " 381 "a `ConcreteFunction`.") 382 return self._dataset_fn.structured_outputs.element_spec 383 384 385def serialize_dataset_to_graph(dataset): 386 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 387 graph_def = gen_dataset_ops.dataset_to_graph_v2( 388 dataset._variant_tensor, # pylint: disable=protected-access 389 external_state_policy=ExternalStatePolicy.WARN.value, 390 strip_device_assignment=True) 391 return graph_def 392 393 394class _RemoteDataset(dataset_ops.DatasetSource): 395 """Creates a dataset given a graph def.""" 396 397 def __init__(self, graph_def, element_spec): 398 self._elem_spec = element_spec 399 variant_tensor = ged_ops.dataset_from_graph(graph_def) 400 super(_RemoteDataset, self).__init__(variant_tensor) 401 402 @property 403 def element_spec(self): 404 return self._elem_spec 405 406 407def deserialize_dataset_from_graph(graph_def, element_spec): 408 return _RemoteDataset(graph_def, element_spec) 409 410 411class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction): 412 """Represents worker-distributed datasets created from a dataset.""" 413 414 def __init__(self, dataset, coordinator): 415 """Makes an iterable from datasets created by the given dataset. 416 417 It creates a dataset_fn which deserializes a dataset from a graph under the 418 hood. 419 420 Args: 421 dataset: A tf.data.Dataset, a DistributedDataset or a 422 DistributedDatasetsFromFunction 423 coordinator: a `ClusterCoordinator` object, used to create dataset 424 resources. 425 """ 426 if isinstance(dataset, input_lib.DistributedDataset): 427 original_dataset = dataset._original_dataset 428 serialized = serialize_dataset_to_graph(original_dataset) 429 430 def dataset_fn(): 431 deserialized = deserialize_dataset_from_graph( 432 serialized, original_dataset.element_spec) 433 dataset.build(dataset_to_replace=deserialized) 434 return dataset 435 elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction): 436 def dataset_fn(): 437 dataset.build() 438 return dataset 439 elif isinstance(dataset, dataset_ops.Dataset): 440 serialized = serialize_dataset_to_graph(dataset) 441 442 def dataset_fn(): 443 return deserialize_dataset_from_graph(serialized, dataset.element_spec) 444 else: 445 raise ValueError("Unexpected dataset type!") 446 447 super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator) 448 449 450def get_per_worker_dataset(dataset_or_dataset_fn, coordinator): 451 if callable(dataset_or_dataset_fn): 452 return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn, 453 coordinator) 454 else: 455 return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator) 456 457 458class PerWorkerDistributedIterator(PerWorkerValues): 459 """Distributed iterator for `ClusterCoordinator`.""" 460 461 def __next__(self): 462 return self.get_next() 463 464 def get_next(self, name=None): 465 """Returns the next input from the iterator for all replicas.""" 466 raise NotImplementedError("Iterating over an `AsyncDistributedIterator` " 467 "is not supported right now.") 468