1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed inputs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.data.experimental.ops import batching 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.ops import multi_device_iterator_ops 24from tensorflow.python.data.util import structure 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribution_strategy_context 27from tensorflow.python.distribute import input_ops 28from tensorflow.python.distribute import values 29from tensorflow.python.eager import context 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import device as tf_device 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.util import nest 39 40 41class InputWorkers(object): 42 """A 1-to-many mapping from input worker devices to compute devices.""" 43 44 def __init__(self, device_map, worker_device_pairs=None, logical_device=0): 45 """Initialize an `InputWorkers` object. 46 47 Args: 48 device_map: A `DeviceMap` with the computation devices fed by the 49 input workers. 50 worker_device_pairs: A sequence of pairs: 51 `(input device, a tuple of compute devices fed by that input device)`. 52 logical_device: The logical device of `device_map` to feed. 53 """ 54 self._device_map = device_map 55 self._logical_device = logical_device 56 if worker_device_pairs is None: 57 worker_device_pairs = (( 58 device_util.canonicalize("/device:CPU:0"), 59 device_map.logical_to_actual_devices(logical_device)),) 60 self._input_worker_devices = tuple(d for d, _ in worker_device_pairs) 61 self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f) 62 for _, f in worker_device_pairs) 63 flattened = tuple(d for l in self._fed_devices for d in l) 64 assert (flattened == 65 device_map.logical_to_actual_devices(logical_device)), ( 66 "flattened: %s logical device %d: %s" % 67 (flattened, logical_device, 68 device_map.logical_to_actual_devices(logical_device))) 69 70 @property 71 def device_map(self): 72 return self._device_map 73 74 @property 75 def logical_device(self): 76 return self._logical_device 77 78 @property 79 def num_workers(self): 80 return len(self._input_worker_devices) 81 82 @property 83 def worker_devices(self): 84 return self._input_worker_devices 85 86 def compute_devices_for_worker(self, worker_index): 87 return self._fed_devices[worker_index] 88 89 def __repr__(self): 90 devices = self.worker_devices 91 debug_repr = ",\n".join(" %d %s: %s" % 92 (i, devices[i], self._fed_devices[i]) 93 for i in range(len(devices))) 94 return "%s:{\n%s\n device_map: %s}" % ( 95 self.__class__.__name__, debug_repr, self._device_map) 96 97 98class InputIterator(object): 99 """An input iterator, intended to be passed to `DistributionStrategy.run`.""" 100 101 def get_next(self): 102 """Returns the next inputs for all replicas.""" 103 raise NotImplementedError("must be implemented in descendants") 104 105 def initialize(self): 106 """Initialize the underlying input dataset, when applicable. 107 108 In eager mode, this will create a new iterator and return it. 109 In graph mode, this will initialize the same underlying iterator(s). 110 111 Users are required to call this if 112 - This iterator was returned from a call to `make_input_fn_iterator` with an 113 input function that returns a dataset. 114 - Or this iterator was returned from a call to `make_dataset_iterator`. 115 116 Returns: 117 A list of initialization ops to be executed. 118 """ 119 raise NotImplementedError("must be implemented in descendants") 120 121 122class InputIteratorImpl(InputIterator): 123 """Common implementation for all input iterators.""" 124 125 def __init__(self, input_workers, iterators): 126 assert isinstance(input_workers, InputWorkers) 127 if not input_workers.worker_devices: 128 raise ValueError("Should have at least one worker for input iterator.") 129 130 self._iterators = iterators 131 self._input_workers = input_workers 132 133 def get_next(self, name=None): 134 """Returns the next input from the iterator for all replicas.""" 135 replicas = [] 136 worker_has_values = [] 137 for i, worker in enumerate(self._input_workers.worker_devices): 138 if name is not None: 139 d = tf_device.DeviceSpec.from_string(worker) 140 new_name = "%s_%s_%d" % (name, d.job, d.task) 141 else: 142 new_name = None 143 with ops.device(worker): 144 worker_has_value, next_element = ( 145 self._iterators[i].get_next_as_list(new_name)) 146 worker_has_values.append(worker_has_value) 147 # Make `replicas` a flat list of values across all replicas. 148 replicas.append(next_element) 149 150 out_of_range_replicas = [] 151 152 def out_of_range_fn(worker_index, device): 153 """This function will throw an OutOfRange error.""" 154 # As this will be only called when there is no data left, so calling 155 # get_next() will trigger an OutOfRange error. 156 data = self._iterators[worker_index].get_next(device) 157 out_of_range_replicas.append(data) 158 return data 159 160 # `global_has_value` indicates whether there is data in this global batch. 161 # We do a all-reduce across all the workers in the multi-worker case. 162 # TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy. 163 if len(worker_has_values) > 1: 164 with ops.device(self._input_workers.compute_devices_for_worker(0)[0]): 165 # Place the tf.reduce_any op in device 0 to minimize communication 166 # cost. 167 # TODO(b/128545270): Investigate why placing it on worker 0 will cause 168 # the entire data to copy back from device to host. 169 global_has_value = math_ops.reduce_any(worker_has_values) 170 else: 171 global_has_value = worker_has_values[0] 172 173 results = [] 174 for i, worker in enumerate(self._input_workers.worker_devices): 175 with ops.device(worker): 176 devices = self._input_workers.compute_devices_for_worker(i) 177 for j, device in enumerate(devices): 178 with ops.device(device): 179 # pylint: disable=undefined-loop-variable 180 # pylint: disable=cell-var-from-loop 181 # It is fine for the lambda to capture variables from the loop as 182 # the lambda is executed in the loop as well. 183 result = control_flow_ops.cond(global_has_value, 184 lambda: replicas[i][j], 185 lambda: out_of_range_fn(i, device)) 186 # pylint: enable=cell-var-from-loop 187 # pylint: enable=undefined-loop-variable 188 results.append(result) 189 replicas = results 190 191 # Some dimensions in `replicas` will become unknown after we conditionally 192 # return the real tensors or the dummy tensors. We fix the input shapes by 193 # using the shapes from `out_of_range_replicas` because it is calling 194 # get_next() inside. 195 flattened_replicas = nest.flatten(replicas) 196 for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): 197 flattened_replicas[i].set_shape(replica_data.get_shape()) 198 replicas = nest.pack_sequence_as(replicas, flattened_replicas) 199 200 return values.regroup(self._input_workers.device_map, replicas) 201 202 def initialize(self): 203 """Initialze underlying iterators. 204 205 Returns: 206 A list of any initializer ops that should be run. 207 """ 208 init_ops = [] 209 for it in self._iterators: 210 init_ops.extend(it.initialize()) 211 return init_ops 212 213 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 214 @property 215 def output_classes(self): 216 return self._iterators[0].output_classes 217 218 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 219 @property 220 def output_shapes(self): 221 return self._iterators[0].output_shapes 222 223 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 224 @property 225 def output_types(self): 226 return self._iterators[0].output_types 227 228 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 229 def get_iterator(self, worker): 230 for i, w in enumerate(self._input_workers.worker_devices): 231 if worker == w: 232 return self._iterators[i] 233 return None 234 235 236class InputFunctionIterator(InputIteratorImpl): 237 """Iterator created from input function.""" 238 239 def __init__(self, input_fn, input_workers, input_contexts): 240 """Make an iterator for input provided via an input function. 241 242 Currently implements PER_WORKER mode, in which the `input_fn` is called 243 once on each worker. 244 245 TODO(priyag): Add other replication modes. 246 247 Args: 248 input_fn: Input function that returns a `tf.data.Dataset` object. 249 input_workers: an `InputWorkers` object. 250 input_contexts: A list of `InputContext` instances to be passed to call(s) 251 to `input_fn`. Length and order should match worker order in 252 `worker_device_pairs`. 253 """ 254 assert isinstance(input_workers, InputWorkers) 255 if input_workers.num_workers != len(input_contexts): 256 raise ValueError( 257 "Number of input workers (%d) is not same as number of " 258 "input_contexts (%d)" % 259 (input_workers.num_workers, len(input_contexts))) 260 261 iterators = [] 262 for i, ctx in enumerate(input_contexts): 263 worker = input_workers.worker_devices[i] 264 with ops.device(worker): 265 result = input_fn(ctx) 266 devices = input_workers.compute_devices_for_worker(i) 267 if isinstance(result, dataset_ops.DatasetV2): 268 iterator = _SingleWorkerDatasetIterator(result, worker, devices) 269 elif callable(result): 270 iterator = _SingleWorkerCallableIterator(result, worker, devices) 271 else: 272 raise ValueError( 273 "input_fn must return a tf.data.Dataset or a callable.") 274 iterators.append(iterator) 275 276 super(InputFunctionIterator, self).__init__(input_workers, iterators) 277 278 279class DatasetIterator(InputIteratorImpl): 280 """Iterator created from input dataset.""" 281 282 def __init__(self, dataset, input_workers, split_batch_by=None): 283 """Make an iterator for the dataset on given devices. 284 285 If `split_batch_by` is not None, we "split" each batch of the 286 dataset by `split_batch_by` value. To achieve this, we first unbatch the 287 input dataset and then rebatch it with the per replica batch size that is 288 calculated using `global_batch_size // split_batch_by`. 289 The currently supported datasets are as follows: 290 `dataset.batch()` is the last operation on the dataset OR 291 `dataset.apply(map_and_batch)` is the last operation on the dataset OR 292 `dataset.batch().prefetch()` are the last 2 operations on the dataset OR 293 `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. 294 295 TODO(priyag): Support multi worker / host cases properly by cloning 296 and sharding the dataset on each worker. Current setup will only work in 297 some cases, such as in-graph multi worker GPU case. If the input pipeline 298 has random shuffling (with a different seed on each worker), each worker 299 will see random input from the same overall dataset in each step. Otherwise, 300 each worker will see the same input in each step. 301 302 Args: 303 dataset: `tf.data.Dataset` that will be used as the input source. 304 input_workers: an `InputWorkers` object. 305 split_batch_by: Optional integer. If present, we "split" each batch of the 306 dataset by `split_batch_by` value. 307 """ 308 assert isinstance(input_workers, InputWorkers) 309 if split_batch_by: 310 dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access 311 312 iterators = [] 313 for i, worker in enumerate(input_workers.worker_devices): 314 with ops.device(worker): 315 worker_devices = input_workers.compute_devices_for_worker(i) 316 cloned_dataset = dataset 317 if not context.executing_eagerly(): 318 cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access 319 cloned_dataset = cloned_dataset.with_options(dataset.options()) 320 iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, 321 worker_devices) 322 iterators.append(iterator) 323 324 self._element_structure = dataset._element_structure # pylint: disable=protected-access 325 326 super(DatasetIterator, self).__init__(input_workers, iterators) 327 328 329def _dummy_tensor_fn(value_structure): 330 """A function to create dummy tensors from `value_structure`.""" 331 332 def create_dummy_tensor(feature_shape, feature_type): 333 """Create a dummy tensor with possible batch dimensions set to 0.""" 334 335 # Ideally we should set the batch dimension to 0, however as in 336 # DistributionStrategy we don't know the batch dimension, we try to 337 # guess it as much as possible. If the feature has unknown dimensions, we 338 # will set them to 0. If the feature shape is already static, we guess the 339 # first dimension as batch dimension and set it to 0. 340 dims = [] 341 for dim in feature_shape.dims: 342 if dim.value is None: 343 dims.append(tensor_shape.Dimension(0)) 344 else: 345 dims.append(dim) 346 if feature_shape.is_fully_defined() and dims: 347 dims[0] = tensor_shape.Dimension(0) 348 349 # Create the dummy tensor. 350 dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) 351 return dummy_tensor 352 353 result = [] 354 # pylint: disable=protected-access 355 for feature_shape, feature_type in zip(value_structure._flat_shapes, 356 value_structure._flat_types): 357 result.append(create_dummy_tensor(feature_shape, feature_type)) 358 359 if isinstance(value_structure, structure.NestedStructure): 360 result = nest.pack_sequence_as(value_structure._nested_structure, result) 361 else: 362 result = result[0] 363 # pylint: enable=protected-access 364 365 return result 366 367 368class _SingleWorkerDatasetIterator(object): 369 """Iterator for a single `tf.data.Dataset`.""" 370 371 def __init__(self, dataset, worker, devices): 372 """Create iterator for the `dataset` to fetch data to worker's `devices` . 373 374 `MultiDeviceIterator` is used to prefetch input to the devices on the 375 given worker. 376 377 Args: 378 dataset: A `tf.data.Dataset` instance. 379 worker: Worker on which ops should be created. 380 devices: Distribute data from `dataset` to these devices. 381 """ 382 self._dataset = dataset 383 self._worker = worker 384 self._devices = devices 385 self._make_iterator() 386 387 def _make_iterator(self): 388 """Make appropriate iterator on the dataset.""" 389 with ops.device(self._worker): 390 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 391 self._dataset, self._devices) 392 393 def get_next(self, device, name=None): 394 """Get next element for the given device.""" 395 del name 396 with ops.device(self._worker): 397 return self._iterator.get_next(device) 398 399 def get_next_as_list(self, name=None): 400 """Get next element from underlying iterator. 401 402 If there is no data left, a list of dummy tensors with possible batch 403 dimensions set to 0 will be returned. 404 405 Args: 406 name: not used. 407 408 Returns: 409 A boolean tensor indicates whether there is any data in next element and 410 the real data as the next element or a list of dummy tensors if no data 411 left. 412 """ 413 del name 414 with ops.device(self._worker): 415 data_list = self._iterator.get_next_as_optional() 416 result = [] 417 for i, data in enumerate(data_list): 418 # Place the condition op in the same device as the data so the data 419 # doesn't need to be sent back to the worker. 420 with ops.device(self._devices[i]): 421 # As MultiDeviceIterator will fetch data in order, so we only need to 422 # check if the first replica has value to see whether there is data 423 # left for this single worker. 424 if i == 0: 425 worker_has_value = data.has_value() 426 427 # pylint: disable=unnecessary-lambda 428 # pylint: disable=cell-var-from-loop 429 real_data = control_flow_ops.cond( 430 data.has_value(), 431 lambda: data.get_value(), 432 lambda: _dummy_tensor_fn(data.value_structure)) 433 result.append(real_data) 434 # pylint: enable=cell-var-from-loop 435 # pylint: enable=unnecessary-lambda 436 437 return worker_has_value, result 438 439 def initialize(self): 440 """Initialze underlying iterator. 441 442 In eager execution, this simply recreates the underlying iterator. 443 In graph execution, it returns the initializer ops for the underlying 444 iterator. 445 446 Returns: 447 A list of any initializer ops that should be run. 448 """ 449 if context.executing_eagerly(): 450 self._iterator._eager_reset() # pylint: disable=protected-access 451 return [] 452 else: 453 return [self._iterator.initializer] 454 455 @property 456 def output_classes(self): 457 return dataset_ops.get_legacy_output_classes(self._iterator) 458 459 @property 460 def output_shapes(self): 461 return dataset_ops.get_legacy_output_shapes(self._iterator) 462 463 @property 464 def output_types(self): 465 return dataset_ops.get_legacy_output_types(self._iterator) 466 467 468class _SingleWorkerCallableIterator(object): 469 """Iterator for a single tensor-returning callable.""" 470 471 def __init__(self, fn, worker, devices): 472 self._fn = fn 473 self._worker = worker 474 self._devices = devices 475 476 def get_next(self, device, name=None): 477 """Get next element for the given device from the callable.""" 478 del device, name 479 with ops.device(self._worker): 480 return self._fn() 481 482 def get_next_as_list(self, name=None): 483 """Get next element from the callable.""" 484 del name 485 with ops.device(self._worker): 486 data_list = [self._fn() for _ in self._devices] 487 return constant_op.constant(True), data_list 488 489 def initialize(self): 490 # TODO(petebu) Should this throw an exception instead? 491 return [] 492 493 494# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 495def _get_batched_dataset(d): 496 """Get the batched dataset from `d`.""" 497 # pylint: disable=protected-access 498 if isinstance(d, dataset_ops.DatasetV1Adapter): 499 d = d._dataset 500 501 if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): 502 return d 503 elif isinstance(d, (dataset_ops.PrefetchDataset, 504 dataset_ops._OptionsDataset)): 505 return _get_batched_dataset(d._input_dataset) 506 507 raise ValueError( 508 "Unable to get batched dataset from the input dataset. `batch` " 509 "`map_and_batch` need to be the last operations on the dataset. " 510 "The batch operations can be followed by a prefetch.") 511 512 513def _get_batched_dataset_attributes(d): 514 """Get `batch_size`, `drop_remainder` of dataset.""" 515 # pylint: disable=protected-access 516 assert isinstance(d, 517 (dataset_ops.BatchDataset, batching._MapAndBatchDataset)) 518 if isinstance(d, dataset_ops.BatchDataset): 519 batch_size = d._batch_size 520 drop_remainder = d._drop_remainder 521 elif isinstance(d, batching._MapAndBatchDataset): 522 batch_size = d._batch_size_t 523 drop_remainder = d._drop_remainder_t 524 # pylint: enable=protected-access 525 526 if tensor_util.is_tensor(batch_size): 527 batch_size = tensor_util.constant_value(batch_size) 528 529 if tensor_util.is_tensor(drop_remainder): 530 drop_remainder = tensor_util.constant_value(drop_remainder) 531 532 return batch_size, drop_remainder 533 534 535# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 536def _get_dataset_attributes(dataset): 537 """Get the underlying attributes from the dataset object.""" 538 # pylint: disable=protected-access 539 540 # First, get batch_size and drop_remainder from the dataset. We need 541 # to walk back the dataset creation process and find the batched version in 542 # order to get the attributes. 543 batched_dataset = _get_batched_dataset(dataset) 544 batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset) 545 546 # Second, prefetch buffer should be get from the original dataset. 547 prefetch_buffer = None 548 if isinstance(dataset, dataset_ops.PrefetchDataset): 549 prefetch_buffer = dataset._buffer_size 550 elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) 551 and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): 552 prefetch_buffer = dataset._dataset._buffer_size 553 554 return batch_size, drop_remainder, prefetch_buffer 555 556 557class MultiStepContext(object): 558 """A context object that can be used to capture things when running steps. 559 560 This context object is useful when running multiple steps at a time using the 561 `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step 562 function to specify which outputs to emit at what frequency. Currently it 563 supports capturing output from the last step, as well as capturing non tensor 564 outputs. In the future it will be augmented to support other use cases such 565 as output each N steps. 566 """ 567 568 def __init__(self): 569 """Initialize an output context. 570 571 Returns: 572 A context object. 573 """ 574 self._last_step_outputs = {} 575 self._last_step_outputs_reduce_ops = {} 576 self._non_tensor_outputs = {} 577 578 @property 579 def last_step_outputs(self): 580 """A dictionary consisting of outputs to be captured on last step. 581 582 Keys in the dictionary are names of tensors to be captured, as specified 583 when `set_last_step_output` is called. 584 Values in the dictionary are the tensors themselves. If 585 `set_last_step_output` was called with a `reduce_op` for this output, 586 then the value is the reduced value. 587 588 Returns: 589 A dictionary with last step outputs. 590 """ 591 return self._last_step_outputs 592 593 def _set_last_step_outputs(self, outputs): 594 """Replace the entire dictionary of last step outputs.""" 595 if not isinstance(outputs, dict): 596 raise ValueError("Need a dictionary to set last_step_outputs.") 597 self._last_step_outputs = outputs 598 599 def set_last_step_output(self, name, output, reduce_op=None): 600 """Set `output` with `name` to be outputted from the last step. 601 602 Args: 603 name: String, name to identify the output. Doesn't need to match tensor 604 name. 605 output: The tensors that should be outputted with `name`. See below for 606 actual types supported. 607 reduce_op: Reduction method to use to reduce outputs from multiple 608 replicas. Required if `set_last_step_output` is called in a replica 609 context. Optional in cross_replica_context. 610 When present, the outputs from all the replicas are reduced using the 611 current distribution strategy's `reduce` method. Hence, the type of 612 `output` must be what's supported by the corresponding `reduce` method. 613 For e.g. if using MirroredStrategy and reduction is set, output 614 must be a `PerReplica` value. 615 The reduce method is also recorded in a dictionary 616 `_last_step_outputs_reduce_ops` for later interpreting of the 617 outputs as already reduced or not. 618 """ 619 if distribution_strategy_context.in_cross_replica_context(): 620 self._last_step_outputs_reduce_ops[name] = reduce_op 621 if reduce_op is None: 622 self._last_step_outputs[name] = output 623 else: 624 distribution = distribution_strategy_context.get_strategy() 625 self._last_step_outputs[name] = distribution.reduce(reduce_op, output) 626 else: 627 assert reduce_op is not None 628 def merge_fn(distribution, value): 629 self._last_step_outputs[name] = distribution.reduce(reduce_op, value) 630 # Setting this inside the `merge_fn` because all replicas share the same 631 # context object, so it's more robust to set it only once (even if all 632 # the replicas are trying to set the same value). 633 self._last_step_outputs_reduce_ops[name] = reduce_op 634 635 distribution_strategy_context.get_replica_context().merge_call( 636 merge_fn, args=(output,)) 637 638 @property 639 def non_tensor_outputs(self): 640 """A dictionary consisting of any non tensor outputs to be captured.""" 641 return self._non_tensor_outputs 642 643 def set_non_tensor_output(self, name, output): 644 """Set `output` with `name` to be captured as a non tensor output.""" 645 if distribution_strategy_context.in_cross_replica_context(): 646 self._non_tensor_outputs[name] = output 647 else: 648 def merge_fn(distribution, value): 649 # NOTE(priyag): For non tensor outputs, we simply return all the values 650 # in a list as reduction doesn't make sense on non tensors. 651 self._non_tensor_outputs[name] = ( 652 distribution.experimental_local_results(value)) 653 distribution_strategy_context.get_replica_context().merge_call( 654 merge_fn, args=(output,)) 655