1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed inputs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import sys 22 23import six 24 25from tensorflow.python.data.experimental.ops import batching 26from tensorflow.python.data.experimental.ops import distribute 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.data.ops import multi_device_iterator_ops 29from tensorflow.python.distribute import device_util 30from tensorflow.python.distribute import distribution_strategy_context 31from tensorflow.python.distribute import input_ops 32from tensorflow.python.distribute import reduce_util 33from tensorflow.python.distribute import values 34from tensorflow.python.eager import context 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import device as tf_device 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import errors 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import sparse_tensor 41from tensorflow.python.framework import tensor_shape 42from tensorflow.python.framework import tensor_util 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops.ragged import ragged_tensor 47from tensorflow.python.util import nest 48from tensorflow.python.util.deprecation import deprecated 49 50 51def get_distributed_dataset(dataset, 52 input_workers, 53 strategy, 54 split_batch_by=None, 55 input_context=None): 56 """Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. 57 58 This is a common function that is used by all strategies to return the right 59 tf.data.Dataset wrapped instance depending on the `dataset` argument type. 60 61 Args: 62 dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance. 63 input_workers: an InputWorkers object which specifies devices on which 64 iterators should be created. 65 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 66 handle last partial batch. 67 split_batch_by: Optional integer. If present, we "split" each batch of the 68 dataset by `split_batch_by` value. 69 input_context: `InputContext` for sharding. Only pass this in for between 70 graph multi-worker cases where there is only one `input_worker`. In 71 these cases, we will shard based on the `input_pipeline_id` and 72 `num_input_pipelines` in the `InputContext`. 73 74 Returns: 75 A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. 76 """ 77 if isinstance(dataset, dataset_ops.DatasetV1): 78 return DistributedDatasetV1( 79 dataset, 80 input_workers, 81 strategy, 82 split_batch_by=split_batch_by, 83 input_context=input_context) 84 else: 85 return DistributedDataset( 86 dataset, 87 input_workers, 88 strategy, 89 split_batch_by=split_batch_by, 90 input_context=input_context) 91 92 93def get_distributed_datasets_from_function(dataset_fn, 94 input_workers, 95 input_contexts, 96 strategy): 97 """Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. 98 99 This is a common function that is used by all strategies to return the right 100 tf.data.Dataset wrapped instance depending on if we are in graph or eager 101 mode. 102 103 Args: 104 dataset_fn: a function that returns a tf.data.DatasetV1 or tf.data.DatasetV2 105 instance. 106 input_workers: an InputWorkers object which specifies devices on which 107 iterators should be created. 108 input_contexts: A list of `InputContext` instances to be passed to call(s) 109 to `dataset_fn`. Length and order should match worker order in 110 `worker_device_pairs`. 111 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 112 handle last partial batch. 113 114 Returns: 115 A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. 116 """ 117 if ops.executing_eagerly_outside_functions(): 118 return DistributedDatasetsFromFunction( 119 dataset_fn, 120 input_workers, 121 input_contexts, 122 strategy) 123 else: 124 return DistributedDatasetsFromFunctionV1( 125 dataset_fn, 126 input_workers, 127 input_contexts, 128 strategy) 129 130 131class InputWorkers(object): 132 """A 1-to-many mapping from input worker devices to compute devices.""" 133 134 def __init__(self, worker_device_pairs): 135 """Initialize an `InputWorkers` object. 136 137 Args: 138 worker_device_pairs: A sequence of pairs: 139 `(input device, a tuple of compute devices fed by that input device)`. 140 """ 141 self._input_worker_devices = tuple(d for d, _ in worker_device_pairs) 142 self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f) 143 for _, f in worker_device_pairs) 144 145 @property 146 def num_workers(self): 147 return len(self._input_worker_devices) 148 149 @property 150 def worker_devices(self): 151 return self._input_worker_devices 152 153 def compute_devices_for_worker(self, worker_index): 154 return self._fed_devices[worker_index] 155 156 def __repr__(self): 157 devices = self.worker_devices 158 debug_repr = ",\n".join(" %d %s: %s" % 159 (i, devices[i], self._fed_devices[i]) 160 for i in range(len(devices))) 161 return "%s:{\n%s}" % (self.__class__.__name__, debug_repr) 162 163 164def _get_next_as_optional(iterator, strategy, name=None): 165 """Returns an empty dataset indicator and the next input from the iterator.""" 166 replicas = [] 167 worker_has_values = [] 168 worker_devices = [] 169 for i, worker in enumerate(iterator._input_workers.worker_devices): # pylint: disable=protected-access 170 if name is not None: 171 d = tf_device.DeviceSpec.from_string(worker) 172 new_name = "%s_%s_%d" % (name, d.job, d.task) 173 else: 174 new_name = None 175 176 with ops.device(worker): 177 worker_has_value, next_element = ( 178 iterator._iterators[i].get_next_as_list(new_name)) # pylint: disable=protected-access 179 # Collective all-reduce requires explict devices for inputs. 180 with ops.device("/cpu:0"): 181 # Converting to integers for all-reduce. 182 worker_has_value = math_ops.cast(worker_has_value, dtypes.int32) 183 worker_devices.append(worker_has_value.device) 184 worker_has_values.append(worker_has_value) 185 # Make `replicas` a flat list of values across all replicas. 186 replicas.append(next_element) 187 188 # Run an all-reduce to see whether any worker has values. 189 # TODO(b/131423105): we should be able to short-cut the all-reduce in some 190 # cases. 191 if getattr(strategy.extended, "_support_per_replica_values", True): 192 # Slight hack: `reduce` expects a `PerReplica`, so we pass it one, even 193 # though it doesn't actually have a value per replica. 194 worker_has_values = values.PerReplica(worker_has_values) 195 global_has_value = strategy.reduce( 196 reduce_util.ReduceOp.SUM, worker_has_values, axis=None) 197 else: 198 assert len(worker_has_values) == 1 199 global_has_value = worker_has_values[0] 200 global_has_value = array_ops.reshape( 201 math_ops.cast(global_has_value, dtypes.bool), []) 202 return global_has_value, replicas 203 204 205class DistributedIterator(object): 206 """Common implementation for all input iterators.""" 207 208 def __init__(self, input_workers, iterators, strategy): 209 static_shape = True 210 for iterator in iterators: 211 if not isinstance(iterator, _SingleWorkerDatasetIterator): 212 continue 213 flattened_shapes = nest.flatten(iterator.output_shapes) 214 for output_shape in flattened_shapes: 215 if not output_shape.is_fully_defined(): 216 static_shape = False 217 break 218 219 # TODO(b/133073708): we currently need a flag to control the usage because 220 # there is a performance difference between get_next() and 221 # get_next_as_optional(). And we only enable get_next_as_optional when the 222 # output shapes are not static. 223 # 224 # TODO(yuefengz): Currently `experimental_enable_get_next_as_optional` is 225 # always set to False in CollectiveAllReduceStrategy. We want to have a way 226 # to distinguish multi workers/single worker between graph, so we can enable 227 # the behavior in single worker case. 228 # 229 # TODO(rxsang): We want to always enable the get_next_as_optional behavior 230 # when user passed input_fn instead of dataset. 231 if getattr( 232 strategy.extended, "experimental_enable_get_next_as_optional", False): 233 self._enable_get_next_as_optional = not static_shape 234 else: 235 self._enable_get_next_as_optional = False 236 237 assert isinstance(input_workers, InputWorkers) 238 if not input_workers.worker_devices: 239 raise ValueError("Should have at least one worker for input iterator.") 240 241 self._iterators = iterators 242 self._input_workers = input_workers 243 self._strategy = strategy 244 245 def next(self): 246 return self.__next__() 247 248 def __next__(self): 249 try: 250 return self.get_next() 251 except errors.OutOfRangeError: 252 raise StopIteration 253 254 def __iter__(self): 255 return self 256 257 def get_next(self, name=None): 258 """Returns the next input from the iterator for all replicas.""" 259 if not self._enable_get_next_as_optional: 260 replicas = [] 261 for i, worker in enumerate(self._input_workers.worker_devices): 262 if name is not None: 263 d = tf_device.DeviceSpec.from_string(worker) 264 new_name = "%s_%s_%d" % (name, d.job, d.task) 265 else: 266 new_name = None 267 with ops.device(worker): 268 # Make `replicas` a flat list of values across all replicas. 269 replicas.extend( 270 self._iterators[i].get_next_as_list_static_shapes(new_name)) 271 return values.regroup(replicas) 272 273 out_of_range_replicas = [] 274 def out_of_range_fn(worker_index, device): 275 """This function will throw an OutOfRange error.""" 276 # As this will be only called when there is no data left, so calling 277 # get_next() will trigger an OutOfRange error. 278 data = self._iterators[worker_index].get_next(device) 279 out_of_range_replicas.append(data) 280 return data 281 282 global_has_value, replicas = _get_next_as_optional(self, self._strategy) 283 results = [] 284 for i, worker in enumerate(self._input_workers.worker_devices): 285 with ops.device(worker): 286 devices = self._input_workers.compute_devices_for_worker(i) 287 for j, device in enumerate(devices): 288 with ops.device(device): 289 # pylint: disable=undefined-loop-variable 290 # pylint: disable=cell-var-from-loop 291 # It is fine for the lambda to capture variables from the loop as 292 # the lambda is executed in the loop as well. 293 result = control_flow_ops.cond( 294 global_has_value, 295 lambda: replicas[i][j], 296 lambda: out_of_range_fn(i, device), 297 strict=True, 298 ) 299 # pylint: enable=cell-var-from-loop 300 # pylint: enable=undefined-loop-variable 301 results.append(result) 302 replicas = results 303 304 # Some dimensions in `replicas` will become unknown after we conditionally 305 # return the real tensors or the dummy tensors. We fix the input shapes by 306 # using the shapes from `out_of_range_replicas` because it is calling 307 # get_next() inside. 308 flattened_replicas = nest.flatten(replicas) 309 for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): 310 for target, source in zip( 311 nest.flatten(flattened_replicas[i], expand_composites=True), 312 nest.flatten(replica_data, expand_composites=True)): 313 target.set_shape(source.get_shape()) 314 # `SparseTensor` shape is not determined by the shape of its component 315 # tensors. Rather, its shape depends on a tensor's values. 316 if sparse_tensor.is_sparse(replica_data) and replica_data.get_shape(): 317 dense_shape = replica_data.get_shape() 318 with ops.device(flattened_replicas[i].op.device): 319 # For partially defined shapes, fill in missing values from tensor. 320 if not dense_shape.is_fully_defined(): 321 dense_shape = array_ops.stack([ 322 flattened_replicas[i].dense_shape[j] if dim is None else dim 323 for j, dim in enumerate(dense_shape.as_list()) 324 ]) 325 flattened_replicas[i] = sparse_tensor.SparseTensor( 326 indices=flattened_replicas[i].indices, 327 values=flattened_replicas[i].values, 328 dense_shape=dense_shape) 329 replicas = nest.pack_sequence_as(replicas, flattened_replicas) 330 331 return values.regroup(replicas) 332 333 # We need a private initializer method for re-initializing multidevice 334 # iterators when used with Keras training loops. If we don't reinitialize the 335 # iterator we run into memory leak issues (b/123315763). 336 @property 337 def _initializer(self): 338 init_ops = [] 339 for it in self._iterators: 340 init_ops.extend(it.initialize()) 341 return control_flow_ops.group(init_ops) 342 343 @property 344 def element_spec(self): 345 """The type specification of an element of this iterator.""" 346 return self._element_spec 347 348 349class DistributedIteratorV1(DistributedIterator): 350 """Input Iterator for tf.data.DatasetV1.""" 351 352 @deprecated(None, "Use the iterator's `initializer` property instead.") 353 def initialize(self): 354 """Initialze underlying iterators. 355 356 Returns: 357 A list of any initializer ops that should be run. 358 """ 359 return super(DistributedIteratorV1, self)._initializer 360 361 @property 362 def initializer(self): 363 """Returns a list of ops that initialize the iterator.""" 364 return self.initialize() 365 366 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 367 @property 368 def output_classes(self): 369 return self._iterators[0].output_classes 370 371 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 372 @property 373 def output_shapes(self): 374 return self._iterators[0].output_shapes 375 376 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 377 @property 378 def output_types(self): 379 return self._iterators[0].output_types 380 381 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 382 def get_iterator(self, worker): 383 for i, w in enumerate(self._input_workers.worker_devices): 384 if worker == w: 385 return self._iterators[i] 386 return None 387 388 389class _IterableInput(object): 390 """Base class for iterable inputs for distribution strategies.""" 391 392 def __init__(self, input_workers): 393 assert isinstance(input_workers, InputWorkers) 394 self._input_workers = input_workers 395 396 def __iter__(self): 397 raise NotImplementedError("must be implemented in descendants") 398 399 def reduce(self, initial_state, reduce_fn): 400 """Execute a `reduce_fn` over all the elements of the input.""" 401 iterator = iter(self) 402 has_data, data = _get_next_as_optional(iterator, self._strategy) 403 404 def cond(has_data, data, state): 405 del data, state # Unused. 406 return has_data 407 408 def loop_body(has_data, data, state): 409 """Executes `reduce_fn` in a loop till the dataset is empty.""" 410 del has_data # Unused. 411 # data is list of lists here. where each list corresponds to one worker. 412 # TODO(b/130570614): Add support for the multiworker and TPU pods use 413 # case. 414 if self._input_workers.num_workers == 1: 415 data = data[0] 416 else: 417 raise ValueError("Dataset iteration within a tf.function is" 418 " not supported for multiple workers.") 419 state = reduce_fn(state, values.regroup(data)) 420 has_data, data = _get_next_as_optional(iterator, self._strategy) 421 return has_data, data, state 422 423 has_data, data, final_state = control_flow_ops.while_loop( 424 cond, loop_body, [has_data, data, initial_state], parallel_iterations=1) 425 return final_state 426 427 428class DistributedDataset(_IterableInput): 429 """Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices.""" 430 431 def __init__(self, 432 dataset, 433 input_workers, 434 strategy, 435 split_batch_by=None, 436 input_context=None): 437 """Distribute the dataset on all workers. 438 439 If `split_batch_by` is not None, we "split" each batch of the dataset by 440 `split_batch_by` value. 441 442 Args: 443 dataset: `tf.data.Dataset` that will be used as the input source. 444 input_workers: an `InputWorkers` object. 445 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 446 handle last partial batch. 447 split_batch_by: Optional integer. If present, we "split" each batch of the 448 dataset by `split_batch_by` value. 449 input_context: `InputContext` for sharding. Only pass this in for between 450 graph multi-worker cases where there is only one `input_worker`. In 451 these cases, we will shard based on the `input_pipeline_id` and 452 `num_input_pipelines` in the `InputContext`. 453 """ 454 super(DistributedDataset, self).__init__(input_workers=input_workers) 455 456 # We clone and shard the dataset on each worker. The current setup tries to 457 # shard the dataset by files if possible so that each worker sees a 458 # different subset of files. If that is not possible, will attempt to shard 459 # the final input such that each worker will run the entire preprocessing 460 # pipeline and only receive its own shard of the dataset. 461 if split_batch_by: 462 try: 463 # pylint: disable=protected-access 464 with ops.colocate_with(dataset._variant_tensor): 465 dataset = distribute._RebatchDataset(dataset, split_batch_by) 466 # Add a prefetch to pipeline rebatching for performance. 467 # TODO(rachelim): Instead of inserting an extra prefetch stage here, 468 # leverage static graph rewrites to insert _RebatchDataset before 469 # the final `prefetch` if it exists. 470 dataset = dataset.prefetch(split_batch_by) 471 except errors.InvalidArgumentError as e: 472 if "without encountering a batch" in str(e): 473 six.reraise( 474 ValueError, 475 ValueError( 476 "Call the `batch` method on the input Dataset in order to be " 477 "able to split your input across {} replicas.\n Please " 478 "the tf.distribute.Strategy guide. {}".format( 479 split_batch_by, e)), 480 sys.exc_info()[2]) 481 else: 482 raise 483 484 # TODO(b/138745411): Remove once stateful transformations are supported. 485 options = dataset_ops.Options() 486 options.experimental_distribute._make_stateless = True # pylint: disable=protected-access 487 dataset = dataset.with_options(options) 488 489 self._cloned_datasets = [] 490 if input_context: 491 # Between-graph where we rely on the input_context for sharding 492 assert input_workers.num_workers == 1 493 dataset = input_ops.auto_shard_dataset(dataset, 494 input_context.num_input_pipelines, 495 input_context.input_pipeline_id) 496 self._cloned_datasets.append(dataset) 497 else: 498 replicated_ds = distribute.replicate(dataset, 499 input_workers.worker_devices) 500 for i, worker in enumerate(input_workers.worker_devices): 501 with ops.device(worker): 502 cloned_dataset = replicated_ds[worker] 503 cloned_dataset = cloned_dataset.with_options(dataset.options()) 504 cloned_dataset = input_ops.auto_shard_dataset( 505 cloned_dataset, len(input_workers.worker_devices), i) 506 self._cloned_datasets.append(cloned_dataset) 507 508 self._input_workers = input_workers 509 self._strategy = strategy 510 self._element_spec = _create_distributed_tensor_spec(self._strategy, 511 dataset.element_spec) # pylint: disable=protected-access 512 513 def __iter__(self): 514 if not (context.executing_eagerly() or 515 ops.get_default_graph().building_function): 516 raise RuntimeError("__iter__() is only supported inside of tf.function " 517 "or when eager execution is enabled.") 518 519 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 520 self._input_workers) 521 iterator = DistributedIterator(self._input_workers, worker_iterators, 522 self._strategy) 523 iterator._element_spec = self.element_spec # pylint: disable=protected-access 524 return iterator 525 526 @property 527 def element_spec(self): 528 """The type specification of an element of this dataset.""" 529 return self._element_spec 530 531 532class DistributedDatasetV1(DistributedDataset): 533 """Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices.""" 534 535 def __init__(self, 536 dataset, 537 input_workers, 538 strategy, 539 split_batch_by=None, 540 input_context=None): 541 self._input_workers = input_workers 542 super(DistributedDatasetV1, self).__init__( 543 dataset, 544 input_workers, 545 strategy, 546 split_batch_by=split_batch_by, 547 input_context=input_context) 548 549 def make_one_shot_iterator(self): 550 """Get a one time use iterator for DistributedDatasetV1. 551 552 Note: This API is deprecated. Please use `for ... in dataset:` to iterate 553 over the dataset or `iter` to create an iterator. 554 555 Returns: 556 A DistributedIteratorV1 instance. 557 """ 558 return self._make_one_shot_iterator() 559 560 def _make_one_shot_iterator(self): 561 """Get an iterator for DistributedDatasetV1.""" 562 # Graph mode with one shot iterator is disabled because we have to call 563 # `initialize` on the iterator which is only required if we are using a 564 # tf.distribute strategy. 565 if not context.executing_eagerly(): 566 raise ValueError("Cannot create a one shot iterator. Please use " 567 "`make_initializable_iterator()` instead.") 568 return self._get_iterator() 569 570 def make_initializable_iterator(self): 571 """Get an initializable iterator for DistributedDatasetV1. 572 573 Note: This API is deprecated. Please use 574 `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an 575 initializable iterator. 576 577 Returns: 578 A DistributedIteratorV1 instance. 579 """ 580 return self._make_initializable_iterator() 581 582 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument 583 """Get an initializable iterator for DistributedDatasetV1.""" 584 # Eager mode generates already initialized iterators. Hence we cannot create 585 # an initializable iterator. 586 if context.executing_eagerly(): 587 raise ValueError("Cannot create initializable iterator in Eager mode. " 588 "Please use `iter()` instead.") 589 return self._get_iterator() 590 591 def _get_iterator(self): 592 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 593 self._input_workers) 594 iterator = DistributedIteratorV1(self._input_workers, worker_iterators, 595 self._strategy) 596 iterator._element_spec = self.element_spec # pylint: disable=protected-access 597 return iterator 598 599 600# TODO(priyag): Add other replication modes. 601class DistributedDatasetsFromFunction(_IterableInput): 602 """Inputs created from dataset function.""" 603 604 def __init__(self, dataset_fn, input_workers, input_contexts, strategy): 605 """Makes an iterable from datasets created by the given function. 606 607 Args: 608 dataset_fn: A function that returns a `Dataset` given an `InputContext`. 609 input_workers: an `InputWorkers` object. 610 input_contexts: A list of `InputContext` instances to be passed to call(s) 611 to `dataset_fn`. Length and order should match worker order in 612 `worker_device_pairs`. 613 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 614 handle last partial batch. 615 """ 616 super(DistributedDatasetsFromFunction, self).__init__( 617 input_workers=input_workers) 618 619 if input_workers.num_workers != len(input_contexts): 620 raise ValueError( 621 "Number of input workers (%d) is not same as number of " 622 "input_contexts (%d)" % 623 (input_workers.num_workers, len(input_contexts))) 624 625 self._dataset_fn = dataset_fn 626 self._input_workers = input_workers 627 self._input_contexts = input_contexts 628 self._strategy = strategy 629 self._element_spec = None 630 631 def __iter__(self): 632 if not (context.executing_eagerly() or 633 ops.get_default_graph().building_function): 634 raise RuntimeError("__iter__() is only supported inside of tf.function " 635 "or when eager execution is enabled.") 636 637 iterators, element_spec = _create_iterators_per_worker_with_input_context( 638 self._input_contexts, self._input_workers, self._dataset_fn) 639 iterator = DistributedIterator(self._input_workers, iterators, 640 self._strategy) 641 self._element_spec = _create_distributed_tensor_spec(self._strategy, 642 element_spec) 643 iterator._element_spec = self._element_spec # pylint: disable=protected-access 644 return iterator 645 646 @property 647 def element_spec(self): 648 """The type specification of an element of this dataset.""" 649 if self._element_spec is None: 650 raise ValueError("You must create an iterator before calling " 651 "`element_spec` on the distributed dataset or iterator. " 652 "This is because the dataset function is not called " 653 "before an iterator is created.") 654 655 return self._element_spec 656 657 658class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): 659 """Inputs created from dataset function.""" 660 661 def _make_initializable_iterator(self, shared_name=None): 662 """Get an initializable iterator for DistributedDatasetsFromFunctionV1.""" 663 del shared_name # Unused 664 # Eager mode generates already initialized iterators. Hence we cannot create 665 # an initializable iterator. 666 if context.executing_eagerly(): 667 raise ValueError("Cannot create initializable iterator in Eager mode. " 668 "Please use `iter()` instead.") 669 return self._get_iterator() 670 671 def _make_one_shot_iterator(self): 672 """Get an iterator for iterating over DistributedDatasetsFromFunctionV1.""" 673 # Graph mode with one shot iterator is disabled because we have to call 674 # `initialize` on the iterator which is only required if we are using a 675 # tf.distribute strategy. 676 if not context.executing_eagerly(): 677 raise ValueError("Cannot create a one shot iterator. Please use " 678 "`make_initializable_iterator()` instead.") 679 return self._get_iterator() 680 681 def _get_iterator(self): 682 iterators, element_spec = _create_iterators_per_worker_with_input_context( 683 self._input_contexts, self._input_workers, self._dataset_fn) 684 iterator = DistributedIteratorV1(self._input_workers, iterators, 685 self._strategy) 686 self._element_spec = _create_distributed_tensor_spec(self._strategy, 687 element_spec) 688 iterator._element_spec = self._element_spec # pylint: disable=protected-access 689 return iterator 690 691 692# TODO(anjalisridhar): This class will be soon be removed in favor of newer 693# APIs. 694class InputFunctionIterator(DistributedIteratorV1): 695 """Iterator created from input function.""" 696 697 def __init__(self, input_fn, input_workers, input_contexts, strategy): 698 """Make an iterator for input provided via an input function. 699 700 Currently implements PER_WORKER mode, in which the `input_fn` is called 701 once on each worker. 702 703 TODO(priyag): Add other replication modes. 704 705 Args: 706 input_fn: Input function that returns a `tf.data.Dataset` object. 707 input_workers: an `InputWorkers` object. 708 input_contexts: A list of `InputContext` instances to be passed to call(s) 709 to `input_fn`. Length and order should match worker order in 710 `worker_device_pairs`. 711 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 712 handle last partial batch. 713 """ 714 assert isinstance(input_workers, InputWorkers) 715 if input_workers.num_workers != len(input_contexts): 716 raise ValueError( 717 "Number of input workers (%d) is not same as number of " 718 "input_contexts (%d)" % 719 (input_workers.num_workers, len(input_contexts))) 720 721 iterators = [] 722 for i, ctx in enumerate(input_contexts): 723 worker = input_workers.worker_devices[i] 724 with ops.device(worker): 725 result = input_fn(ctx) 726 devices = input_workers.compute_devices_for_worker(i) 727 if isinstance(result, dataset_ops.DatasetV2): 728 iterator = _SingleWorkerDatasetIterator(result, worker, devices) 729 elif callable(result): 730 iterator = _SingleWorkerCallableIterator(result, worker, devices) 731 else: 732 raise ValueError( 733 "input_fn must return a tf.data.Dataset or a callable.") 734 iterators.append(iterator) 735 736 super(InputFunctionIterator, self).__init__(input_workers, iterators, 737 strategy) 738 739 740# TODO(anjalisridhar): This class will soon be removed and users should move 741# to using DistributedIterator. 742class DatasetIterator(DistributedIteratorV1): 743 """Iterator created from input dataset.""" 744 745 def __init__(self, 746 dataset, 747 input_workers, 748 strategy, 749 split_batch_by=None, 750 input_context=None): 751 """Make an iterator for the dataset on given devices. 752 753 If `split_batch_by` is not None, we "split" each batch of the 754 dataset by `split_batch_by` value. 755 756 Args: 757 dataset: `tf.data.Dataset` that will be used as the input source. 758 input_workers: an `InputWorkers` object. 759 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 760 handle last partial batch. 761 split_batch_by: Optional integer. If present, we "split" each batch of the 762 dataset by `split_batch_by` value. 763 input_context: `InputContext` for sharding. Only pass this in for between 764 graph multi-worker cases where there is only one `input_worker`. In 765 these cases, we will shard based on the `input_pipeline_id` and 766 `num_input_pipelines` in the `InputContext`. 767 """ 768 dist_dataset = DistributedDatasetV1( 769 dataset, 770 input_workers, 771 strategy, 772 split_batch_by=split_batch_by, 773 input_context=input_context) 774 worker_iterators = _create_iterators_per_worker( 775 dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access 776 super(DatasetIterator, self).__init__( 777 input_workers, 778 worker_iterators, # pylint: disable=protected-access 779 strategy) 780 self._element_spec = dist_dataset.element_spec 781 782 783def _dummy_tensor_fn(value_structure): 784 """A function to create dummy tensors from `value_structure`.""" 785 786 def create_dummy_tensor(type_spec): 787 """Create a dummy tensor with possible batch dimensions set to 0.""" 788 if isinstance(type_spec, ragged_tensor.RaggedTensorSpec): 789 # Splice out the ragged dimensions. 790 # pylint: disable=protected-access 791 feature_shape = type_spec._shape[:1].concatenate( 792 type_spec._shape[(1 + type_spec._ragged_rank):]) 793 feature_type = type_spec._dtype 794 # pylint: enable=protected-access 795 else: 796 feature_shape = type_spec.shape 797 feature_type = type_spec.dtype 798 # Ideally we should set the batch dimension to 0, however as in 799 # DistributionStrategy we don't know the batch dimension, we try to 800 # guess it as much as possible. If the feature has unknown dimensions, we 801 # will set them to 0. If the feature shape is already static, we guess the 802 # first dimension as batch dimension and set it to 0. 803 dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()] 804 if feature_shape else []) 805 if dims and (isinstance(type_spec, ragged_tensor.RaggedTensorSpec) or 806 feature_shape.is_fully_defined()): 807 dims[0] = tensor_shape.Dimension(0) 808 809 if isinstance(type_spec, sparse_tensor.SparseTensorSpec): 810 return sparse_tensor.SparseTensor( 811 values=array_ops.zeros(0, feature_type), 812 indices=array_ops.zeros((0, len(dims)), dtypes.int64), 813 dense_shape=dims) 814 815 # Create the dummy tensor. 816 dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) 817 if isinstance(type_spec, ragged_tensor.RaggedTensorSpec): 818 # Reinsert the ragged dimensions with size 0. 819 # pylint: disable=protected-access 820 row_splits = array_ops.zeros(1, type_spec._row_splits_dtype) 821 dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits( 822 dummy_tensor, (row_splits,) * type_spec._ragged_rank, validate=False) 823 # pylint: enable=protected-access 824 return dummy_tensor 825 826 return nest.map_structure(create_dummy_tensor, value_structure) 827 828 829class _SingleWorkerDatasetIterator(object): 830 """Iterator for a single `tf.data.Dataset`.""" 831 832 def __init__(self, dataset, worker, devices): 833 """Create iterator for the `dataset` to fetch data to worker's `devices` . 834 835 `MultiDeviceIterator` is used to prefetch input to the devices on the 836 given worker. 837 838 Args: 839 dataset: A `tf.data.Dataset` instance. 840 worker: Worker on which ops should be created. 841 devices: Distribute data from `dataset` to these devices. 842 """ 843 self._dataset = dataset 844 self._worker = worker 845 self._devices = devices 846 self._make_iterator() 847 848 def _make_iterator(self): 849 """Make appropriate iterator on the dataset.""" 850 with ops.device(self._worker): 851 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 852 self._dataset, self._devices) 853 854 def get_next(self, device, name=None): 855 """Get next element for the given device.""" 856 del name 857 with ops.device(self._worker): 858 return self._iterator.get_next(device) 859 860 def get_next_as_list_static_shapes(self, name=None): 861 """Get next element from the underlying iterator. 862 863 Runs the iterator get_next() within a device scope. Since this doesn't use 864 get_next_as_optional(), is is considerably faster than get_next_as_list() 865 (but can only be used when the shapes are static). 866 867 Args: 868 name: not used. 869 870 Returns: 871 A list consisting of the next data from each device. 872 """ 873 del name 874 with ops.device(self._worker): 875 return self._iterator.get_next() 876 877 def get_next_as_list(self, name=None): 878 """Get next element from underlying iterator. 879 880 If there is no data left, a list of dummy tensors with possible batch 881 dimensions set to 0 will be returned. Use of get_next_as_optional() and 882 extra logic adds overhead compared to get_next_as_list_static_shapes(), but 883 allows us to handle non-static shapes. 884 885 Args: 886 name: not used. 887 888 Returns: 889 A boolean tensor indicates whether there is any data in next element and 890 the real data as the next element or a list of dummy tensors if no data 891 left. 892 """ 893 del name 894 with ops.device(self._worker): 895 data_list = self._iterator.get_next_as_optional() 896 result = [] 897 for i, data in enumerate(data_list): 898 # Place the condition op in the same device as the data so the data 899 # doesn't need to be sent back to the worker. 900 with ops.device(self._devices[i]): 901 # As MultiDeviceIterator will fetch data in order, so we only need to 902 # check if the first replica has value to see whether there is data 903 # left for this single worker. 904 if i == 0: 905 worker_has_value = data.has_value() 906 907 # pylint: disable=unnecessary-lambda 908 # pylint: disable=cell-var-from-loop 909 real_data = control_flow_ops.cond( 910 data.has_value(), 911 lambda: data.get_value(), 912 lambda: _dummy_tensor_fn(data.value_structure), 913 strict=True, 914 ) 915 result.append(real_data) 916 # pylint: enable=cell-var-from-loop 917 # pylint: enable=unnecessary-lambda 918 919 return worker_has_value, result 920 921 def initialize(self): 922 """Initialze underlying iterator. 923 924 In eager execution, this simply recreates the underlying iterator. 925 In graph execution, it returns the initializer ops for the underlying 926 iterator. 927 928 Returns: 929 A list of any initializer ops that should be run. 930 """ 931 if ops.executing_eagerly_outside_functions(): 932 self._iterator._eager_reset() # pylint: disable=protected-access 933 return [] 934 else: 935 return [self._iterator.initializer] 936 937 @property 938 def output_classes(self): 939 return dataset_ops.get_legacy_output_classes(self._iterator) 940 941 @property 942 def output_shapes(self): 943 return dataset_ops.get_legacy_output_shapes(self._iterator) 944 945 @property 946 def output_types(self): 947 return dataset_ops.get_legacy_output_types(self._iterator) 948 949 950class _SingleWorkerCallableIterator(object): 951 """Iterator for a single tensor-returning callable.""" 952 953 def __init__(self, fn, worker, devices): 954 self._fn = fn 955 self._worker = worker 956 self._devices = devices 957 958 def get_next(self, device, name=None): 959 """Get next element for the given device from the callable.""" 960 del device, name 961 with ops.device(self._worker): 962 return self._fn() 963 964 def get_next_as_list_static_shapes(self, name=None): 965 """Get next element from the callable.""" 966 del name 967 with ops.device(self._worker): 968 data_list = [self._fn() for _ in self._devices] 969 return data_list 970 971 def get_next_as_list(self, name=None): 972 """Get next element from the callable.""" 973 del name 974 with ops.device(self._worker): 975 data_list = [self._fn() for _ in self._devices] 976 return constant_op.constant(True), data_list 977 978 def initialize(self): 979 # TODO(petebu) Should this throw an exception instead? 980 return [] 981 982 983def _create_iterators_per_worker(worker_datasets, input_workers): 984 """Create a multidevice iterator on each of the workers.""" 985 assert isinstance(input_workers, InputWorkers) 986 987 assert len(worker_datasets) == len(input_workers.worker_devices) 988 iterators = [] 989 for i, worker in enumerate(input_workers.worker_devices): 990 with ops.device(worker): 991 worker_devices = input_workers.compute_devices_for_worker(i) 992 iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, 993 worker_devices) 994 iterators.append(iterator) 995 return iterators 996 997 998def _create_iterators_per_worker_with_input_context(input_contexts, 999 input_workers, 1000 dataset_fn): 1001 """Create a multidevice iterator per workers given a dataset function.""" 1002 iterators = [] 1003 for i, ctx in enumerate(input_contexts): 1004 worker = input_workers.worker_devices[i] 1005 with ops.device(worker): 1006 dataset = dataset_fn(ctx) 1007 # TODO(b/138745411): Remove once stateful transformations are supported. 1008 options = dataset_ops.Options() 1009 options.experimental_distribute._make_stateless = True # pylint: disable=protected-access 1010 dataset = dataset.with_options(options) 1011 devices = input_workers.compute_devices_for_worker(i) 1012 iterator = _SingleWorkerDatasetIterator(dataset, worker, devices) 1013 iterators.append(iterator) 1014 return iterators, dataset.element_spec 1015 1016 1017# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 1018def _get_batched_dataset(d): 1019 """Get the batched dataset from `d`.""" 1020 # pylint: disable=protected-access 1021 if isinstance(d, dataset_ops.DatasetV1Adapter): 1022 d = d._dataset 1023 1024 if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): 1025 return d 1026 elif isinstance(d, (dataset_ops.PrefetchDataset, 1027 dataset_ops._OptionsDataset)): 1028 return _get_batched_dataset(d._input_dataset) 1029 1030 raise ValueError( 1031 "Unable to get batched dataset from the input dataset. `batch` " 1032 "`map_and_batch` need to be the last operations on the dataset. " 1033 "The batch operations can be followed by a prefetch.") 1034 1035 1036def _get_batched_dataset_attributes(d): 1037 """Get `batch_size`, `drop_remainder` of dataset.""" 1038 # pylint: disable=protected-access 1039 assert isinstance(d, 1040 (dataset_ops.BatchDataset, batching._MapAndBatchDataset)) 1041 if isinstance(d, dataset_ops.BatchDataset): 1042 batch_size = d._batch_size 1043 drop_remainder = d._drop_remainder 1044 elif isinstance(d, batching._MapAndBatchDataset): 1045 batch_size = d._batch_size_t 1046 drop_remainder = d._drop_remainder_t 1047 # pylint: enable=protected-access 1048 1049 if tensor_util.is_tensor(batch_size): 1050 batch_size = tensor_util.constant_value(batch_size) 1051 1052 if tensor_util.is_tensor(drop_remainder): 1053 drop_remainder = tensor_util.constant_value(drop_remainder) 1054 1055 return batch_size, drop_remainder 1056 1057 1058# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 1059def _get_dataset_attributes(dataset): 1060 """Get the underlying attributes from the dataset object.""" 1061 # pylint: disable=protected-access 1062 1063 # First, get batch_size and drop_remainder from the dataset. We need 1064 # to walk back the dataset creation process and find the batched version in 1065 # order to get the attributes. 1066 batched_dataset = _get_batched_dataset(dataset) 1067 batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset) 1068 1069 # Second, prefetch buffer should be get from the original dataset. 1070 prefetch_buffer = None 1071 if isinstance(dataset, dataset_ops.PrefetchDataset): 1072 prefetch_buffer = dataset._buffer_size 1073 elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) 1074 and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): 1075 prefetch_buffer = dataset._dataset._buffer_size 1076 1077 return batch_size, drop_remainder, prefetch_buffer 1078 1079 1080class MultiStepContext(object): 1081 """A context object that can be used to capture things when running steps. 1082 1083 This context object is useful when running multiple steps at a time using the 1084 `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step 1085 function to specify which outputs to emit at what frequency. Currently it 1086 supports capturing output from the last step, as well as capturing non tensor 1087 outputs. In the future it will be augmented to support other use cases such 1088 as output each N steps. 1089 """ 1090 1091 def __init__(self): 1092 """Initialize an output context. 1093 1094 Returns: 1095 A context object. 1096 """ 1097 self._last_step_outputs = {} 1098 self._last_step_outputs_reduce_ops = {} 1099 self._non_tensor_outputs = {} 1100 1101 @property 1102 def last_step_outputs(self): 1103 """A dictionary consisting of outputs to be captured on last step. 1104 1105 Keys in the dictionary are names of tensors to be captured, as specified 1106 when `set_last_step_output` is called. 1107 Values in the dictionary are the tensors themselves. If 1108 `set_last_step_output` was called with a `reduce_op` for this output, 1109 then the value is the reduced value. 1110 1111 Returns: 1112 A dictionary with last step outputs. 1113 """ 1114 return self._last_step_outputs 1115 1116 def _set_last_step_outputs(self, outputs): 1117 """Replace the entire dictionary of last step outputs.""" 1118 if not isinstance(outputs, dict): 1119 raise ValueError("Need a dictionary to set last_step_outputs.") 1120 self._last_step_outputs = outputs 1121 1122 def set_last_step_output(self, name, output, reduce_op=None): 1123 """Set `output` with `name` to be outputted from the last step. 1124 1125 Args: 1126 name: String, name to identify the output. Doesn't need to match tensor 1127 name. 1128 output: The tensors that should be outputted with `name`. See below for 1129 actual types supported. 1130 reduce_op: Reduction method to use to reduce outputs from multiple 1131 replicas. Required if `set_last_step_output` is called in a replica 1132 context. Optional in cross_replica_context. 1133 When present, the outputs from all the replicas are reduced using the 1134 current distribution strategy's `reduce` method. Hence, the type of 1135 `output` must be what's supported by the corresponding `reduce` method. 1136 For e.g. if using MirroredStrategy and reduction is set, output 1137 must be a `PerReplica` value. 1138 The reduce method is also recorded in a dictionary 1139 `_last_step_outputs_reduce_ops` for later interpreting of the 1140 outputs as already reduced or not. 1141 """ 1142 if distribution_strategy_context.in_cross_replica_context(): 1143 self._last_step_outputs_reduce_ops[name] = reduce_op 1144 if reduce_op is None: 1145 self._last_step_outputs[name] = output 1146 else: 1147 distribution = distribution_strategy_context.get_strategy() 1148 self._last_step_outputs[name] = distribution.reduce(reduce_op, output, 1149 axis=None) 1150 else: 1151 assert reduce_op is not None 1152 def merge_fn(distribution, value): 1153 self._last_step_outputs[name] = distribution.reduce(reduce_op, value, 1154 axis=None) 1155 # Setting this inside the `merge_fn` because all replicas share the same 1156 # context object, so it's more robust to set it only once (even if all 1157 # the replicas are trying to set the same value). 1158 self._last_step_outputs_reduce_ops[name] = reduce_op 1159 1160 distribution_strategy_context.get_replica_context().merge_call( 1161 merge_fn, args=(output,)) 1162 1163 @property 1164 def non_tensor_outputs(self): 1165 """A dictionary consisting of any non tensor outputs to be captured.""" 1166 return self._non_tensor_outputs 1167 1168 def set_non_tensor_output(self, name, output): 1169 """Set `output` with `name` to be captured as a non tensor output.""" 1170 if distribution_strategy_context.in_cross_replica_context(): 1171 self._non_tensor_outputs[name] = output 1172 else: 1173 def merge_fn(distribution, value): 1174 # NOTE(priyag): For non tensor outputs, we simply return all the values 1175 # in a list as reduction doesn't make sense on non tensors. 1176 self._non_tensor_outputs[name] = ( 1177 distribution.experimental_local_results(value)) 1178 distribution_strategy_context.get_replica_context().merge_call( 1179 merge_fn, args=(output,)) 1180 1181 1182def _create_distributed_tensor_spec(strategy, tensor_spec): 1183 """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`. 1184 1185 Args: 1186 strategy: The given `tf.distribute` strategy. 1187 tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the 1188 shape should be None if you have partial batches. 1189 1190 Returns: 1191 A `tf.TypeSpec` that matches the values produced by a given strategy. This 1192 can be a `tf.TensorSpec` or a `PerRelicaSpec`. 1193 """ 1194 num_replicas = len(strategy.extended.worker_devices) 1195 1196 # If the number of devices used in the strategy is just 1 then we return 1197 # the tensor_spec as is. 1198 if num_replicas == 1: 1199 return tensor_spec 1200 1201 # If the number of devices is greater than 1 then we assume the input to 1202 # tf.function is a per replica type. 1203 def _get_value_per_replica(tensor_spec_per_input): 1204 value_specs = [tensor_spec_per_input for _ in range(num_replicas)] 1205 return values.PerReplicaSpec(*value_specs) 1206 1207 return nest.map_structure(_get_value_per_replica, tensor_spec) 1208 1209