1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.ops import iterator_ops 22from tensorflow.python.data.util import structure 23from tensorflow.python.eager import context 24from tensorflow.python.eager import function 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.framework import type_spec 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import functional_ops 34from tensorflow.python.ops import gen_dataset_ops 35from tensorflow.python.ops import resource_variable_ops 36 37 38class _PerDeviceGenerator(dataset_ops.DatasetV2): 39 """A `dummy` generator dataset.""" 40 41 def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, 42 source_device, element_spec): 43 self._element_spec = element_spec 44 45 multi_device_iterator_string_handle = ( 46 gen_dataset_ops.multi_device_iterator_to_string_handle( 47 multi_device_iterator_resource)) 48 49 # TODO(b/124254153): Enable autograph once the overhead is low enough. 50 @function.defun(autograph=False) # Pure graph code. 51 def _init_func(): 52 return multi_device_iterator_string_handle 53 54 init_func_concrete = _init_func.get_concrete_function() 55 56 # TODO(b/124254153): Enable autograph once the overhead is low enough. 57 @function.defun(autograph=False) # Pure graph code. 58 def _remote_init_func(): 59 return functional_ops.remote_call( 60 target=source_device, 61 args=init_func_concrete.captured_inputs, 62 Tout=[dtypes.string], 63 f=init_func_concrete) 64 65 self._init_func = _remote_init_func.get_concrete_function() 66 self._init_captured_args = self._init_func.captured_inputs 67 68 # TODO(b/124254153): Enable autograph once the overhead is low enough. 69 @function.defun( 70 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 71 autograph=False) # Pure graph code. 72 def _next_func(string_handle): 73 # pylint: disable=protected-access 74 multi_device_iterator = ( 75 gen_dataset_ops.multi_device_iterator_from_string_handle( 76 string_handle=string_handle, 77 output_types=structure.get_flat_tensor_types(self._element_spec), 78 output_shapes=structure.get_flat_tensor_shapes( 79 self._element_spec))) 80 return gen_dataset_ops.multi_device_iterator_get_next_from_shard( 81 multi_device_iterator=multi_device_iterator, 82 shard_num=shard_num, 83 incarnation_id=incarnation_id, 84 output_types=structure.get_flat_tensor_types(self._element_spec), 85 output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) 86 87 next_func_concrete = _next_func.get_concrete_function() 88 89 # TODO(b/124254153): Enable autograph once the overhead is low enough. 90 @function.defun_with_attributes( 91 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 92 attributes={"experimental_ints_on_device": True}, 93 autograph=False) # Pure graph code. 94 def _remote_next_func(string_handle): 95 return functional_ops.remote_call( 96 target=source_device, 97 args=[string_handle] + next_func_concrete.captured_inputs, 98 Tout=structure.get_flat_tensor_types(self._element_spec), 99 f=next_func_concrete) 100 101 self._next_func = _remote_next_func.get_concrete_function() 102 self._next_captured_args = self._next_func.captured_inputs 103 104 self._incarnation_id_index = -1 105 for i, arg in enumerate(self._next_captured_args): 106 if arg is incarnation_id: 107 self._incarnation_id_index = i 108 109 # TODO(b/124254153): Enable autograph once the overhead is low enough. 110 @function.defun( 111 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 112 autograph=False) # Pure graph code. 113 def _finalize_func(unused_string_handle): 114 return array_ops.constant(0, dtypes.int64) 115 116 finalize_func_concrete = _finalize_func.get_concrete_function() 117 118 # TODO(b/124254153): Enable autograph once the overhead is low enough. 119 @function.defun( 120 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 121 autograph=False) # Pure graph code. 122 def _remote_finalize_func(string_handle): 123 return functional_ops.remote_call( 124 target=source_device, 125 args=[string_handle] + finalize_func_concrete.captured_inputs, 126 Tout=[dtypes.int64], 127 f=finalize_func_concrete) 128 129 self._finalize_func = _remote_finalize_func.get_concrete_function() 130 self._finalize_captured_args = self._finalize_func.captured_inputs 131 132 variant_tensor = gen_dataset_ops.generator_dataset( 133 self._init_captured_args, 134 self._next_captured_args, 135 self._finalize_captured_args, 136 init_func=self._init_func, 137 next_func=self._next_func, 138 finalize_func=self._finalize_func, 139 **self._flat_structure) 140 super(_PerDeviceGenerator, self).__init__(variant_tensor) 141 142 def _inputs(self): 143 # TODO(b/116506223): Determine which datasets should be used as inputs here. 144 return [] 145 146 @property 147 def element_spec(self): 148 return self._element_spec 149 150 151class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2): 152 """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id. 153 154 Re-uses the functions from the provided per_device_dataset and just switches 155 out the function argument corresponding to the incarnation_id. 156 """ 157 158 def __init__(self, per_device_dataset, incarnation_id): 159 # pylint: disable=protected-access 160 self._element_spec = per_device_dataset.element_spec 161 self._init_func = per_device_dataset._init_func 162 self._init_captured_args = self._init_func.captured_inputs 163 164 self._next_func = per_device_dataset._next_func 165 self._next_captured_args = per_device_dataset._next_captured_args 166 # The captured arguments to the next_func are string_handle, incarnation_id. 167 # We update the incarnation id to the new one. 168 self._next_captured_args[ 169 per_device_dataset._incarnation_id_index] = incarnation_id 170 171 self._finalize_func = per_device_dataset._finalize_func 172 self._finalize_captured_args = per_device_dataset._finalize_captured_args 173 174 variant_tensor = gen_dataset_ops.generator_dataset( 175 self._init_captured_args, 176 self._next_captured_args, 177 self._finalize_captured_args, 178 init_func=self._init_func, 179 next_func=self._next_func, 180 finalize_func=self._finalize_func, 181 **self._flat_structure) 182 super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor) 183 184 def _inputs(self): 185 # TODO(b/116506223): Determine which datasets should be used as inputs here. 186 return [] 187 188 @property 189 def element_spec(self): 190 return self._element_spec 191 192 193def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size, 194 experimental_slack): 195 """Uses _prototype_device_datasets[i] to build a dataset for the device.""" 196 ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id) 197 if prefetch_buffer_size > 0: 198 if experimental_slack: 199 ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1) 200 else: 201 ds = ds.prefetch(prefetch_buffer_size) 202 # TODO(jsimsa): Enable auto-tuning and optimizations when supported for 203 # non-CPU devices. 204 options = dataset_ops.Options() 205 options.experimental_optimization.apply_default_optimizations = False 206 options.experimental_optimization.autotune = False 207 ds = ds.with_options(options) 208 return ds 209 210 211class MultiDeviceIterator(object): 212 """An iterator over multiple devices.""" 213 214 def __init__(self, 215 dataset, 216 devices, 217 max_buffer_size=1, 218 prefetch_buffer_size=1, 219 source_device="/cpu:0"): 220 """Constructs a MultiDeviceIterator. 221 222 Args: 223 dataset: The input dataset to be iterated over. 224 devices: The list of devices to fetch data to. 225 max_buffer_size: Maximum size of the host side per device buffer to keep. 226 prefetch_buffer_size: if > 0, then we setup a buffer on each device to 227 prefetch into. 228 source_device: The host device to place the `dataset` on. In order to 229 prevent deadlocks, if the prefetch_buffer_size is greater than the 230 max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. 231 """ 232 options = dataset_ops.Options() 233 options.experimental_distribute.num_devices = len(devices) 234 dataset = dataset.with_options(options) 235 self._dataset = dataset._apply_options() # pylint: disable=protected-access 236 self._experimental_slack = dataset.options().experimental_slack 237 self._devices = devices 238 self._source_device = source_device 239 self._source_device_tensor = ops.convert_to_tensor(source_device) 240 self._max_buffer_size = max_buffer_size 241 self._prefetch_buffer_size = prefetch_buffer_size 242 243 if self._prefetch_buffer_size > self._max_buffer_size: 244 self._max_buffer_size = self._prefetch_buffer_size 245 246 # Create the MultiDeviceIterator. 247 with ops.device(self._source_device): 248 # TODO(b/121378567): Get rid of this shared_name hack. 249 shared_name = "" 250 if context.executing_eagerly(): 251 shared_name = context.shared_name() 252 self._multi_device_iterator_resource = ( 253 gen_dataset_ops.multi_device_iterator( 254 devices=self._devices, 255 shared_name=shared_name, 256 container="", 257 **self._dataset._flat_structure)) # pylint: disable=protected-access 258 if context.executing_eagerly(): 259 # Delete the resource when this object is deleted 260 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 261 handle=self._multi_device_iterator_resource, 262 handle_device=self._source_device) 263 264 # The incarnation ID is used to ensure consistency between the per-device 265 # iterators and the multi-device iterator. 266 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 267 self._dataset._variant_tensor, # pylint: disable=protected-access 268 self._multi_device_iterator_resource, 269 max_buffer_size=self._max_buffer_size) 270 271 self._prototype_device_datasets = [] 272 for i, device in enumerate(self._devices): 273 with ops.device(device): 274 ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource, 275 self._incarnation_id, 276 self._source_device_tensor, 277 self._dataset.element_spec) 278 self._prototype_device_datasets.append(ds) 279 280 # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to 281 # initialize the device side of the pipeline. This would allow the 282 # MultiDeviceIterator to choose, for example, to move some transformations 283 # into the device side from its input. It might be useful in rewriting. 284 # Create the per device iterators. 285 self._device_iterators = [] 286 for i, device in enumerate(self._devices): 287 with ops.device(device): 288 ds = _create_device_dataset(self._prototype_device_datasets[i], 289 self._incarnation_id, 290 self._prefetch_buffer_size, 291 self._experimental_slack) 292 if context.executing_eagerly(): 293 self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds)) 294 else: 295 self._device_iterators.append( 296 dataset_ops.make_initializable_iterator(ds)) 297 298 if not context.executing_eagerly(): 299 device_iterator_initializers = [ 300 iterator.initializer for iterator in self._device_iterators 301 ] 302 self._initializer = control_flow_ops.group(*device_iterator_initializers) 303 304 def _create_device_dataset(self, i): 305 """Uses _prototype_device_datasets[i] to build a dataset for the device.""" 306 ds = self._prototype_device_datasets[i] 307 ds = _ReincarnatedPerDeviceGenerator(ds, self._incarnation_id) 308 if self._prefetch_buffer_size > 0: 309 if self._experimental_slack: 310 ds = dataset_ops.PrefetchDataset( 311 ds, self._prefetch_buffer_size, slack_period=1) 312 else: 313 ds = ds.prefetch(self._prefetch_buffer_size) 314 # TODO(jsimsa): Enable auto-tuning and optimizations when supported for 315 # non-CPU devices. 316 options = dataset_ops.Options() 317 options.experimental_optimization.apply_default_optimizations = False 318 options.experimental_optimization.autotune = False 319 ds = ds.with_options(options) 320 return ds 321 322 def get_next(self, device=None): 323 """Returns the next element given a `device`, else returns all in a list.""" 324 if device is not None: 325 index = self._devices.index(device) 326 return self._device_iterators[index].get_next() 327 328 result = [] 329 for i, device in enumerate(self._devices): 330 with ops.device(device): 331 result.append(self._device_iterators[i].get_next()) 332 return result 333 334 def get_next_as_optional(self): 335 result = [] 336 for i, device in enumerate(self._devices): 337 with ops.device(device): 338 result.append(self._device_iterators[i].get_next_as_optional()) 339 return result 340 341 @property 342 def initializer(self): 343 if context.executing_eagerly(): 344 return control_flow_ops.no_op() 345 return self._initializer 346 347 def _eager_reset(self): 348 """Resets the MultiDeviceIterator in eager mode.""" 349 if not ops.executing_eagerly_outside_functions(): 350 raise ValueError("Eager reset is only supported in eager mode.") 351 # pylint: disable=protected-access 352 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 353 self._dataset._variant_tensor, 354 self._multi_device_iterator_resource, 355 max_buffer_size=self._max_buffer_size) 356 for i, device in enumerate(self._devices): 357 with ops.device(device): 358 ds = _create_device_dataset(self._prototype_device_datasets[i], 359 self._incarnation_id, 360 self._prefetch_buffer_size, 361 self._experimental_slack) 362 # Reset the device iterator resources with the new dataset. 363 ds_variant = ds._variant_tensor 364 gen_dataset_ops.make_iterator( 365 ds_variant, self._device_iterators[i]._iterator_resource) 366 367 @property 368 def element_spec(self): 369 return self._dataset.element_spec 370 371 372class MultiDeviceIteratorResourceDeleter(object): 373 """An object which cleans up a Multi Device Iterator resource. 374 375 An alternative to defining a __del__ method on an object. Even if the parent 376 object is part of a reference cycle, the cycle will be collectible. 377 """ 378 379 __slots__ = [ 380 "_deleter", "_multi_device_iterator", "_iterators", "_device", 381 "_eager_mode" 382 ] 383 384 def __init__(self, multi_device_iterator, iterators, device, deleter): 385 self._deleter = deleter 386 self._multi_device_iterator = multi_device_iterator 387 self._iterators = iterators 388 self._device = device 389 self._eager_mode = context.executing_eagerly() 390 391 def __del__(self): 392 with ops.device(self._device): 393 # Make sure the resource is deleted in the same mode as it was created in. 394 # We pass in the iterator handles as inputs to the op to make sure that 395 # this op runs after all the iterators are deleted. 396 if self._eager_mode: 397 with context.eager_mode(): 398 gen_dataset_ops.delete_multi_device_iterator( 399 multi_device_iterator=self._multi_device_iterator, 400 iterators=self._iterators, 401 deleter=self._deleter) 402 else: 403 with context.graph_mode(): 404 gen_dataset_ops.delete_multi_device_iterator( 405 multi_device_iterator=self._multi_device_iterator, 406 iterators=self._iterators, 407 deleter=self._deleter) 408 409 410class MultiDeviceIteratorSpec(type_spec.TypeSpec): 411 """Type specification for `OwnedMultiDeviceIterator`.""" 412 413 __slots__ = ["_devices", "_source_device", "_element_spec"] 414 415 def __init__(self, devices, source_device, element_spec): 416 self._devices = devices 417 self._source_device = source_device 418 self._element_spec = element_spec 419 420 @property 421 def value_type(self): 422 return OwnedMultiDeviceIterator 423 424 def _serialize(self): 425 return (tuple(self._devices), self._source_device, self._element_spec) 426 427 @property 428 def _component_specs(self): 429 specs = [ 430 tensor_spec.TensorSpec([], dtypes.resource), 431 tensor_spec.TensorSpec([], dtypes.variant) 432 ] 433 for _ in range(len(self._devices)): 434 specs.append(iterator_ops.IteratorSpec(self._element_spec)) 435 return specs 436 437 def _to_components(self, value): 438 # pylint: disable=protected-access 439 c = [value._multi_device_iterator_resource, value._deleter] 440 c.extend(value._device_iterators) 441 return c 442 443 def _from_components(self, components): 444 return OwnedMultiDeviceIterator( 445 dataset=None, 446 devices=self._devices, 447 source_device=self._source_device, 448 components=components, 449 element_spec=self._element_spec) 450 451 @staticmethod 452 def from_value(value): 453 # pylint: disable=protected-access 454 return MultiDeviceIteratorSpec( 455 value._devices, 456 value._source_device, 457 value.element_spec) 458 459 460class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor): 461 """An iterator over multiple devices. 462 463 The multi-device iterator resource created through `OwnedMultiDeviceIterator` 464 is owned by the Python object and the life time of the underlying resource is 465 tied to the life time of the `OwnedMultiDeviceIterator` object. This makes 466 `OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of 467 tf.functions. 468 """ 469 470 def __init__(self, 471 dataset=None, 472 devices=None, 473 max_buffer_size=1, 474 prefetch_buffer_size=1, 475 source_device="/cpu:0", 476 components=None, 477 element_spec=None): 478 """Constructs an owned MultiDeviceIterator object. 479 480 Args: 481 dataset: The input dataset to be iterated over. 482 devices: The list of devices to fetch data to. 483 max_buffer_size: Maximum size of the host side per device buffer to keep. 484 prefetch_buffer_size: if > 0, then we setup a buffer on each device to 485 prefetch into. 486 source_device: The host device to place the `dataset` on. In order to 487 prevent deadlocks, if the prefetch_buffer_size is greater than the 488 max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. 489 components: Tensor components to construct the MultiDeviceIterator from. 490 element_spec: A nested structure of `TypeSpec` objects that 491 represents the type specification of elements of the iterator. 492 493 Raises: 494 RuntimeError: If executed in graph mode or outside of function building 495 mode. 496 """ 497 if not context.executing_eagerly() and not ops.inside_function(): 498 raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of " 499 "tf.function or when eager execution is enabled.") 500 if devices is None: 501 raise ValueError("`devices` must be provided") 502 error_message = "Either `dataset` or both `components` and " 503 "`element_spec` need to be provided." 504 505 if dataset is None: 506 if (components is None or element_spec is None): 507 raise ValueError(error_message) 508 self._element_spec = element_spec 509 self._devices = devices 510 self._source_device = source_device 511 self._multi_device_iterator_resource = components[0] 512 self._deleter = components[1] 513 self._device_iterators = components[2:] 514 iterator_handles = [] 515 for it in self._device_iterators: 516 iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access 517 else: 518 if (components is not None or element_spec is not None): 519 raise ValueError(error_message) 520 options = dataset_ops.Options() 521 options.experimental_distribute.num_devices = len(devices) 522 dataset = dataset.with_options(options) 523 dataset = dataset._apply_options() # pylint: disable=protected-access 524 self._element_spec = dataset.element_spec 525 experimental_slack = dataset.options().experimental_slack 526 self._devices = devices 527 self._source_device = source_device 528 source_device_tensor = ops.convert_to_tensor(self._source_device) 529 530 if prefetch_buffer_size > max_buffer_size: 531 max_buffer_size = prefetch_buffer_size 532 533 # Create the MultiDeviceIterator. 534 with ops.device(self._source_device): 535 self._multi_device_iterator_resource, self._deleter = ( 536 gen_dataset_ops.anonymous_multi_device_iterator( 537 devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access 538 539 # The incarnation ID is used to ensure consistency between the 540 # per-device iterators and the multi-device iterator. 541 incarnation_id = gen_dataset_ops.multi_device_iterator_init( 542 dataset._variant_tensor, # pylint: disable=protected-access 543 self._multi_device_iterator_resource, 544 max_buffer_size=max_buffer_size) 545 546 prototype_device_datasets = [] 547 for i, device in enumerate(self._devices): 548 with ops.device(device): 549 ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource, 550 incarnation_id, source_device_tensor, 551 dataset.element_spec) 552 prototype_device_datasets.append(ds) 553 554 # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to 555 # initialize the device side of the pipeline. This would allow the 556 # MultiDeviceIterator to choose, for example, to move some transformations 557 # into the device side from its input. It might be useful in rewriting. 558 # Create the per device iterators. 559 self._device_iterators = [] 560 iterator_handles = [] 561 for i, device in enumerate(self._devices): 562 with ops.device(device): 563 ds = _create_device_dataset(prototype_device_datasets[i], 564 incarnation_id, prefetch_buffer_size, 565 experimental_slack) 566 iterator = iter(ds) 567 self._device_iterators.append(iterator) 568 iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access 569 570 self._resource_deleter = MultiDeviceIteratorResourceDeleter( 571 multi_device_iterator=self._multi_device_iterator_resource, 572 iterators=iterator_handles, 573 device=self._source_device, 574 deleter=self._deleter) 575 576 def get_next(self, device=None): 577 """Returns the next element given a `device`, else returns all in a list.""" 578 if device is not None: 579 index = self._devices.index(device) 580 return self._device_iterators[index].get_next() 581 582 result = [] 583 for i, device in enumerate(self._devices): 584 with ops.device(device): 585 result.append(self._device_iterators[i].get_next()) 586 return result 587 588 def __iter__(self): 589 return self 590 591 def next(self): 592 return self.__next__() 593 594 def __next__(self): 595 try: 596 return self.get_next() 597 except errors.OutOfRangeError: 598 raise StopIteration 599 600 def get_next_as_optional(self): 601 result = [] 602 for i, device in enumerate(self._devices): 603 with ops.device(device): 604 result.append(self._device_iterators[i].get_next_as_optional()) 605 return result 606 607 @property 608 def element_spec(self): 609 return self._element_spec 610 611 @property 612 def _type_spec(self): 613 return MultiDeviceIteratorSpec(self._devices, self._source_device, 614 self._element_spec) 615