1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.ops import iterator_ops 22from tensorflow.python.data.ops import options as options_lib 23from tensorflow.python.data.util import structure 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import composite_tensor 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import type_spec 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import functional_ops 35from tensorflow.python.ops import gen_dataset_ops 36from tensorflow.python.ops import resource_variable_ops 37 38 39class _PerDeviceGenerator(dataset_ops.DatasetV2): 40 """A `dummy` generator dataset.""" 41 42 def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, 43 source_device, element_spec): 44 self._element_spec = element_spec 45 46 multi_device_iterator_string_handle = ( 47 gen_dataset_ops.multi_device_iterator_to_string_handle( 48 multi_device_iterator_resource)) 49 50 # TODO(b/124254153): Enable autograph once the overhead is low enough. 51 @function.defun(autograph=False) # Pure graph code. 52 def _init_func(): 53 return multi_device_iterator_string_handle 54 55 init_func_concrete = _init_func.get_concrete_function() 56 57 # TODO(b/124254153): Enable autograph once the overhead is low enough. 58 @function.defun(autograph=False) # Pure graph code. 59 def _remote_init_func(): 60 return functional_ops.remote_call( 61 target=source_device, 62 args=init_func_concrete.captured_inputs, 63 Tout=[dtypes.string], 64 f=init_func_concrete) 65 66 self._init_func = _remote_init_func.get_concrete_function() 67 self._init_captured_args = self._init_func.captured_inputs 68 69 # TODO(b/124254153): Enable autograph once the overhead is low enough. 70 @function.defun( 71 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 72 autograph=False) # Pure graph code. 73 def _next_func(string_handle): 74 # pylint: disable=protected-access 75 multi_device_iterator = ( 76 gen_dataset_ops.multi_device_iterator_from_string_handle( 77 string_handle=string_handle, 78 output_types=structure.get_flat_tensor_types(self._element_spec), 79 output_shapes=structure.get_flat_tensor_shapes( 80 self._element_spec))) 81 return gen_dataset_ops.multi_device_iterator_get_next_from_shard( 82 multi_device_iterator=multi_device_iterator, 83 shard_num=shard_num, 84 incarnation_id=incarnation_id, 85 output_types=structure.get_flat_tensor_types(self._element_spec), 86 output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) 87 88 next_func_concrete = _next_func.get_concrete_function() 89 90 # TODO(b/124254153): Enable autograph once the overhead is low enough. 91 @function.defun_with_attributes( 92 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 93 attributes={"experimental_ints_on_device": True}, 94 autograph=False) # Pure graph code. 95 def _remote_next_func(string_handle): 96 return functional_ops.remote_call( 97 target=source_device, 98 args=[string_handle] + next_func_concrete.captured_inputs, 99 Tout=structure.get_flat_tensor_types(self._element_spec), 100 f=next_func_concrete) 101 102 self._next_func = _remote_next_func.get_concrete_function() 103 self._next_captured_args = self._next_func.captured_inputs 104 105 self._incarnation_id_index = -1 106 for i, arg in enumerate(self._next_captured_args): 107 if arg is incarnation_id: 108 self._incarnation_id_index = i 109 110 # TODO(b/124254153): Enable autograph once the overhead is low enough. 111 @function.defun( 112 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 113 autograph=False) # Pure graph code. 114 def _finalize_func(unused_string_handle): 115 return array_ops.constant(0, dtypes.int64) 116 117 finalize_func_concrete = _finalize_func.get_concrete_function() 118 119 # TODO(b/124254153): Enable autograph once the overhead is low enough. 120 @function.defun( 121 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 122 autograph=False) # Pure graph code. 123 def _remote_finalize_func(string_handle): 124 return functional_ops.remote_call( 125 target=source_device, 126 args=[string_handle] + finalize_func_concrete.captured_inputs, 127 Tout=[dtypes.int64], 128 f=finalize_func_concrete) 129 130 self._finalize_func = _remote_finalize_func.get_concrete_function() 131 self._finalize_captured_args = self._finalize_func.captured_inputs 132 133 variant_tensor = gen_dataset_ops.generator_dataset( 134 self._init_captured_args, 135 self._next_captured_args, 136 self._finalize_captured_args, 137 init_func=self._init_func, 138 next_func=self._next_func, 139 finalize_func=self._finalize_func, 140 **self._flat_structure) 141 super(_PerDeviceGenerator, self).__init__(variant_tensor) 142 143 def _inputs(self): 144 # TODO(b/116506223): Determine which datasets should be used as inputs here. 145 return [] 146 147 @property 148 def element_spec(self): 149 return self._element_spec 150 151 152class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2): 153 """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id. 154 155 Re-uses the functions from the provided per_device_dataset and just switches 156 out the function argument corresponding to the incarnation_id. 157 """ 158 159 def __init__(self, per_device_dataset, incarnation_id): 160 # pylint: disable=protected-access 161 self._element_spec = per_device_dataset.element_spec 162 self._init_func = per_device_dataset._init_func 163 self._init_captured_args = self._init_func.captured_inputs 164 165 self._next_func = per_device_dataset._next_func 166 self._next_captured_args = per_device_dataset._next_captured_args 167 # The captured arguments to the next_func are string_handle, incarnation_id. 168 # We update the incarnation id to the new one. 169 self._next_captured_args[ 170 per_device_dataset._incarnation_id_index] = incarnation_id 171 172 self._finalize_func = per_device_dataset._finalize_func 173 self._finalize_captured_args = per_device_dataset._finalize_captured_args 174 175 variant_tensor = gen_dataset_ops.generator_dataset( 176 self._init_captured_args, 177 self._next_captured_args, 178 self._finalize_captured_args, 179 init_func=self._init_func, 180 next_func=self._next_func, 181 finalize_func=self._finalize_func, 182 **self._flat_structure) 183 super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor) 184 185 def _inputs(self): 186 # TODO(b/116506223): Determine which datasets should be used as inputs here. 187 return [] 188 189 @property 190 def element_spec(self): 191 return self._element_spec 192 193 194def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size, 195 experimental_slack): 196 """Uses _prototype_device_datasets[i] to build a dataset for the device.""" 197 ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id) 198 if prefetch_buffer_size > 0: 199 if experimental_slack: 200 ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1) 201 else: 202 ds = ds.prefetch(prefetch_buffer_size) 203 return ds 204 205 206class MultiDeviceIterator(object): 207 """An iterator over multiple devices.""" 208 209 def __init__(self, 210 dataset, 211 devices, 212 max_buffer_size=1, 213 prefetch_buffer_size=1, 214 source_device="/cpu:0"): 215 """Constructs a MultiDeviceIterator. 216 217 Args: 218 dataset: The input dataset to be iterated over. 219 devices: The list of devices to fetch data to. 220 max_buffer_size: Maximum size of the host side per device buffer to keep. 221 prefetch_buffer_size: if > 0, then we setup a buffer on each device to 222 prefetch into. 223 source_device: The host device to place the `dataset` on. In order to 224 prevent deadlocks, if the prefetch_buffer_size is greater than the 225 max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. 226 """ 227 options = options_lib.Options() 228 options.experimental_distribute.num_devices = len(devices) 229 dataset = dataset.with_options(options) 230 self._dataset = dataset._apply_debug_options() # pylint: disable=protected-access 231 self._experimental_slack = dataset.options().experimental_slack 232 self._devices = devices 233 self._source_device = source_device 234 self._source_device_tensor = ops.convert_to_tensor(source_device) 235 self._max_buffer_size = max_buffer_size 236 self._prefetch_buffer_size = prefetch_buffer_size 237 238 if self._prefetch_buffer_size > self._max_buffer_size: 239 self._max_buffer_size = self._prefetch_buffer_size 240 241 # Create the MultiDeviceIterator. 242 with ops.device(self._source_device): 243 # TODO(b/121378567): Get rid of this shared_name hack. 244 shared_name = "" 245 if context.executing_eagerly(): 246 shared_name = context.shared_name() 247 self._multi_device_iterator_resource = ( 248 gen_dataset_ops.multi_device_iterator( 249 devices=self._devices, 250 shared_name=shared_name, 251 container="", 252 **self._dataset._flat_structure)) # pylint: disable=protected-access 253 if context.executing_eagerly(): 254 # Delete the resource when this object is deleted 255 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 256 handle=self._multi_device_iterator_resource, 257 handle_device=self._source_device) 258 259 # The incarnation ID is used to ensure consistency between the per-device 260 # iterators and the multi-device iterator. 261 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 262 self._dataset._variant_tensor, # pylint: disable=protected-access 263 self._multi_device_iterator_resource, 264 max_buffer_size=self._max_buffer_size) 265 266 self._prototype_device_datasets = [] 267 for i, device in enumerate(self._devices): 268 with ops.device(device): 269 ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource, 270 self._incarnation_id, 271 self._source_device_tensor, 272 self._dataset.element_spec) 273 self._prototype_device_datasets.append(ds) 274 275 # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to 276 # initialize the device side of the pipeline. This would allow the 277 # MultiDeviceIterator to choose, for example, to move some transformations 278 # into the device side from its input. It might be useful in rewriting. 279 # Create the per device iterators. 280 self._device_iterators = [] 281 for i, device in enumerate(self._devices): 282 with ops.device(device): 283 ds = _create_device_dataset(self._prototype_device_datasets[i], 284 self._incarnation_id, 285 self._prefetch_buffer_size, 286 self._experimental_slack) 287 if context.executing_eagerly(): 288 self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds)) 289 else: 290 self._device_iterators.append( 291 dataset_ops.make_initializable_iterator(ds)) 292 293 if not context.executing_eagerly(): 294 device_iterator_initializers = [ 295 iterator.initializer for iterator in self._device_iterators 296 ] 297 self._initializer = control_flow_ops.group(*device_iterator_initializers) 298 299 def _create_device_dataset(self, i): 300 """Uses _prototype_device_datasets[i] to build a dataset for the device.""" 301 ds = self._prototype_device_datasets[i] 302 ds = _ReincarnatedPerDeviceGenerator(ds, self._incarnation_id) 303 if self._prefetch_buffer_size > 0: 304 if self._experimental_slack: 305 ds = dataset_ops.PrefetchDataset( 306 ds, self._prefetch_buffer_size, slack_period=1) 307 else: 308 ds = ds.prefetch(self._prefetch_buffer_size) 309 return ds 310 311 def get_next(self, device=None): 312 """Returns the next element given a `device`, else returns all in a list.""" 313 if device is not None: 314 index = self._devices.index(device) 315 return self._device_iterators[index].get_next() 316 317 result = [] 318 for i, device in enumerate(self._devices): 319 with ops.device(device): 320 result.append(self._device_iterators[i].get_next()) 321 return result 322 323 def get_next_as_optional(self): 324 result = [] 325 for i, device in enumerate(self._devices): 326 with ops.device(device): 327 result.append(self._device_iterators[i].get_next_as_optional()) 328 return result 329 330 @property 331 def initializer(self): 332 if context.executing_eagerly(): 333 return control_flow_ops.no_op() 334 return self._initializer 335 336 def _eager_reset(self): 337 """Resets the MultiDeviceIterator in eager mode.""" 338 if not ops.executing_eagerly_outside_functions(): 339 raise ValueError("Eager reset is only supported in eager mode.") 340 # pylint: disable=protected-access 341 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 342 self._dataset._variant_tensor, 343 self._multi_device_iterator_resource, 344 max_buffer_size=self._max_buffer_size) 345 for i, device in enumerate(self._devices): 346 with ops.device(device): 347 ds = _create_device_dataset(self._prototype_device_datasets[i], 348 self._incarnation_id, 349 self._prefetch_buffer_size, 350 self._experimental_slack) 351 # Reset the device iterator resources with the new dataset. 352 ds_variant = ds._variant_tensor 353 gen_dataset_ops.make_iterator( 354 ds_variant, self._device_iterators[i]._iterator_resource) 355 356 @property 357 def element_spec(self): 358 return self._dataset.element_spec 359 360 361class MultiDeviceIteratorResourceDeleter(object): 362 """An object which cleans up a Multi Device Iterator resource. 363 364 An alternative to defining a __del__ method on an object. Even if the parent 365 object is part of a reference cycle, the cycle will be collectible. 366 """ 367 368 __slots__ = [ 369 "_deleter", "_multi_device_iterator", "_iterators", "_device", 370 "_eager_mode" 371 ] 372 373 def __init__(self, multi_device_iterator, iterators, device, deleter): 374 self._deleter = deleter 375 self._multi_device_iterator = multi_device_iterator 376 self._iterators = iterators 377 self._device = device 378 self._eager_mode = context.executing_eagerly() 379 380 def __del__(self): 381 with ops.device(self._device): 382 # Make sure the resource is deleted in the same mode as it was created in. 383 # We pass in the iterator handles as inputs to the op to make sure that 384 # this op runs after all the iterators are deleted. 385 if self._eager_mode: 386 with context.eager_mode(): 387 gen_dataset_ops.delete_multi_device_iterator( 388 multi_device_iterator=self._multi_device_iterator, 389 iterators=self._iterators, 390 deleter=self._deleter) 391 else: 392 with context.graph_mode(): 393 gen_dataset_ops.delete_multi_device_iterator( 394 multi_device_iterator=self._multi_device_iterator, 395 iterators=self._iterators, 396 deleter=self._deleter) 397 398 399class MultiDeviceIteratorSpec(type_spec.TypeSpec): 400 """Type specification for `OwnedMultiDeviceIterator`.""" 401 402 __slots__ = ["_devices", "_source_device", "_element_spec"] 403 404 def __init__(self, devices, source_device, element_spec): 405 self._devices = devices 406 self._source_device = source_device 407 self._element_spec = element_spec 408 409 @property 410 def value_type(self): 411 return OwnedMultiDeviceIterator 412 413 def _serialize(self): 414 return (tuple(self._devices), self._source_device, self._element_spec) 415 416 @property 417 def _component_specs(self): 418 specs = [ 419 tensor_spec.TensorSpec([], dtypes.resource), 420 tensor_spec.TensorSpec([], dtypes.variant) 421 ] 422 for _ in range(len(self._devices)): 423 specs.append(iterator_ops.IteratorSpec(self._element_spec)) 424 return specs 425 426 def _to_components(self, value): 427 # pylint: disable=protected-access 428 c = [value._multi_device_iterator_resource, value._deleter] 429 c.extend(value._device_iterators) 430 return c 431 432 def _from_components(self, components): 433 return OwnedMultiDeviceIterator( 434 dataset=None, 435 devices=self._devices, 436 source_device=self._source_device, 437 components=components, 438 element_spec=self._element_spec) 439 440 @staticmethod 441 def from_value(value): 442 # pylint: disable=protected-access 443 return MultiDeviceIteratorSpec( 444 value._devices, 445 value._source_device, 446 value.element_spec) 447 448 449class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor): 450 """An iterator over multiple devices. 451 452 The multi-device iterator resource created through `OwnedMultiDeviceIterator` 453 is owned by the Python object and the life time of the underlying resource is 454 tied to the life time of the `OwnedMultiDeviceIterator` object. This makes 455 `OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of 456 tf.functions. 457 """ 458 459 def __init__(self, 460 dataset=None, 461 devices=None, 462 max_buffer_size=1, 463 prefetch_buffer_size=1, 464 source_device="/cpu:0", 465 components=None, 466 element_spec=None): 467 """Constructs an owned MultiDeviceIterator object. 468 469 Args: 470 dataset: The input dataset to be iterated over. 471 devices: The list of devices to fetch data to. 472 max_buffer_size: Maximum size of the host side per device buffer to keep. 473 prefetch_buffer_size: if > 0, then we setup a buffer on each device to 474 prefetch into. 475 source_device: The host device to place the `dataset` on. In order to 476 prevent deadlocks, if the prefetch_buffer_size is greater than the 477 max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. 478 components: Tensor components to construct the MultiDeviceIterator from. 479 element_spec: A (nested) structure of `tf.TypeSpec` objects that 480 represents the type specification of elements of the iterator. 481 482 Raises: 483 RuntimeError: If executed in graph mode or outside of function building 484 mode. 485 """ 486 if not context.executing_eagerly() and not ops.inside_function(): 487 raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of " 488 "tf.function or when eager execution is enabled.") 489 if devices is None: 490 raise ValueError("`devices` must be provided") 491 error_message = "Either `dataset` or both `components` and " 492 "`element_spec` need to be provided." 493 494 if dataset is None: 495 if (components is None or element_spec is None): 496 raise ValueError(error_message) 497 self._element_spec = element_spec 498 self._devices = devices 499 self._source_device = source_device 500 self._multi_device_iterator_resource = components[0] 501 self._deleter = components[1] 502 self._device_iterators = components[2:] 503 iterator_handles = [] 504 for it in self._device_iterators: 505 iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access 506 else: 507 if (components is not None or element_spec is not None): 508 raise ValueError(error_message) 509 options = options_lib.Options() 510 options.experimental_distribute.num_devices = len(devices) 511 dataset = dataset.with_options(options) 512 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 513 self._element_spec = dataset.element_spec 514 experimental_slack = dataset.options().experimental_slack 515 self._devices = devices 516 self._source_device = source_device 517 source_device_tensor = ops.convert_to_tensor(self._source_device) 518 519 if prefetch_buffer_size > max_buffer_size: 520 max_buffer_size = prefetch_buffer_size 521 522 # Create the MultiDeviceIterator. 523 with ops.device(self._source_device): 524 self._multi_device_iterator_resource, self._deleter = ( 525 gen_dataset_ops.anonymous_multi_device_iterator( 526 devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access 527 528 # The incarnation ID is used to ensure consistency between the 529 # per-device iterators and the multi-device iterator. 530 incarnation_id = gen_dataset_ops.multi_device_iterator_init( 531 dataset._variant_tensor, # pylint: disable=protected-access 532 self._multi_device_iterator_resource, 533 max_buffer_size=max_buffer_size) 534 535 prototype_device_datasets = [] 536 for i, device in enumerate(self._devices): 537 with ops.device(device): 538 ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource, 539 incarnation_id, source_device_tensor, 540 dataset.element_spec) 541 prototype_device_datasets.append(ds) 542 543 # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to 544 # initialize the device side of the pipeline. This would allow the 545 # MultiDeviceIterator to choose, for example, to move some transformations 546 # into the device side from its input. It might be useful in rewriting. 547 # Create the per device iterators. 548 self._device_iterators = [] 549 iterator_handles = [] 550 for i, device in enumerate(self._devices): 551 with ops.device(device): 552 ds = _create_device_dataset(prototype_device_datasets[i], 553 incarnation_id, prefetch_buffer_size, 554 experimental_slack) 555 iterator = iter(ds) 556 self._device_iterators.append(iterator) 557 iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access 558 559 self._resource_deleter = MultiDeviceIteratorResourceDeleter( 560 multi_device_iterator=self._multi_device_iterator_resource, 561 iterators=iterator_handles, 562 device=self._source_device, 563 deleter=self._deleter) 564 565 def get_next(self, device=None): 566 """Returns the next element given a `device`, else returns all in a list.""" 567 if device is not None: 568 index = self._devices.index(device) 569 return self._device_iterators[index].get_next() 570 571 result = [] 572 for i, device in enumerate(self._devices): 573 with ops.device(device): 574 result.append(self._device_iterators[i].get_next()) 575 return result 576 577 def __iter__(self): 578 return self 579 580 def next(self): 581 return self.__next__() 582 583 def __next__(self): 584 try: 585 return self.get_next() 586 except errors.OutOfRangeError: 587 raise StopIteration 588 589 def get_next_as_optional(self): 590 result = [] 591 for i, device in enumerate(self._devices): 592 with ops.device(device): 593 result.append(self._device_iterators[i].get_next_as_optional()) 594 return result 595 596 @property 597 def element_spec(self): 598 return self._element_spec 599 600 @property 601 def _type_spec(self): 602 return MultiDeviceIteratorSpec(self._devices, self._source_device, 603 self._element_spec) 604