1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.ops import iterator_ops 22from tensorflow.python.eager import context 23from tensorflow.python.eager import function 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import functional_ops 30from tensorflow.python.ops import gen_dataset_ops 31from tensorflow.python.ops import resource_variable_ops 32 33 34class _PerDeviceGenerator(dataset_ops.DatasetV2): 35 """A `dummy` generator dataset.""" 36 37 def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, 38 source_device, element_structure): 39 self._structure = element_structure 40 41 multi_device_iterator_string_handle = ( 42 gen_dataset_ops.multi_device_iterator_to_string_handle( 43 multi_device_iterator_resource)) 44 45 # TODO(b/124254153): Enable autograph once the overhead is low enough. 46 @function.defun(autograph=False) # Pure graph code. 47 def _init_func(): 48 return multi_device_iterator_string_handle 49 50 init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access 51 52 # TODO(b/124254153): Enable autograph once the overhead is low enough. 53 @function.defun(autograph=False) # Pure graph code. 54 def _remote_init_func(): 55 return functional_ops.remote_call( 56 target=source_device, 57 args=init_func_concrete.captured_inputs, 58 Tout=[dtypes.string], 59 f=init_func_concrete) 60 61 self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access 62 self._init_captured_args = self._init_func.captured_inputs 63 64 # TODO(b/124254153): Enable autograph once the overhead is low enough. 65 @function.defun( 66 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 67 autograph=False) # Pure graph code. 68 def _next_func(string_handle): 69 # pylint: disable=protected-access 70 multi_device_iterator = ( 71 gen_dataset_ops.multi_device_iterator_from_string_handle( 72 string_handle=string_handle, 73 output_types=self._structure._flat_types, 74 output_shapes=self._structure._flat_shapes)) 75 return gen_dataset_ops.multi_device_iterator_get_next_from_shard( 76 multi_device_iterator=multi_device_iterator, 77 shard_num=shard_num, 78 incarnation_id=incarnation_id, 79 output_types=self._structure._flat_types, 80 output_shapes=self._structure._flat_shapes) 81 82 next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access 83 84 # TODO(b/124254153): Enable autograph once the overhead is low enough. 85 @function.defun_with_attributes( 86 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 87 attributes={"experimental_ints_on_device": True}, 88 autograph=False) # Pure graph code. 89 def _remote_next_func(string_handle): 90 return functional_ops.remote_call( 91 target=source_device, 92 args=[string_handle] + next_func_concrete.captured_inputs, 93 Tout=self._structure._flat_types, # pylint: disable=protected-access 94 f=next_func_concrete) 95 96 self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access 97 self._next_captured_args = self._next_func.captured_inputs 98 99 self._incarnation_id_index = -1 100 for i, arg in enumerate(self._next_captured_args): 101 if arg == incarnation_id: 102 self._incarnation_id_index = i 103 104 # TODO(b/124254153): Enable autograph once the overhead is low enough. 105 @function.defun( 106 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 107 autograph=False) # Pure graph code. 108 def _finalize_func(unused_string_handle): 109 return array_ops.constant(0, dtypes.int64) 110 111 finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access 112 113 # TODO(b/124254153): Enable autograph once the overhead is low enough. 114 @function.defun( 115 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 116 autograph=False) # Pure graph code. 117 def _remote_finalize_func(string_handle): 118 return functional_ops.remote_call( 119 target=source_device, 120 args=[string_handle] + finalize_func_concrete.captured_inputs, 121 Tout=[dtypes.int64], 122 f=finalize_func_concrete) 123 124 self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access 125 ) 126 self._finalize_captured_args = self._finalize_func.captured_inputs 127 128 variant_tensor = gen_dataset_ops.generator_dataset( 129 self._init_captured_args, 130 self._next_captured_args, 131 self._finalize_captured_args, 132 init_func=self._init_func, 133 next_func=self._next_func, 134 finalize_func=self._finalize_func, 135 **dataset_ops.flat_structure(self)) 136 super(_PerDeviceGenerator, self).__init__(variant_tensor) 137 138 def _inputs(self): 139 # TODO(b/116506223): Determine which datasets should be used as inputs here. 140 return [] 141 142 @property 143 def _element_structure(self): 144 return self._structure 145 146 147class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2): 148 """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id. 149 150 Re-uses the functions from the provided per_device_dataset and just switches 151 out the function argument corresponding to the incarnation_id. 152 """ 153 154 def __init__(self, per_device_dataset, incarnation_id): 155 # pylint: disable=protected-access 156 self._structure = per_device_dataset._structure 157 158 self._init_func = per_device_dataset._init_func 159 self._init_captured_args = self._init_func.captured_inputs 160 161 self._next_func = per_device_dataset._next_func 162 self._next_captured_args = per_device_dataset._next_captured_args 163 # The captured arguments to the next_func are string_handle, incarnation_id. 164 # We update the incarnation id to the new one. 165 self._next_captured_args[ 166 per_device_dataset._incarnation_id_index] = incarnation_id 167 168 self._finalize_func = per_device_dataset._finalize_func 169 self._finalize_captured_args = per_device_dataset._finalize_captured_args 170 171 variant_tensor = gen_dataset_ops.generator_dataset( 172 self._init_captured_args, 173 self._next_captured_args, 174 self._finalize_captured_args, 175 init_func=self._init_func, 176 next_func=self._next_func, 177 finalize_func=self._finalize_func, 178 **dataset_ops.flat_structure(self)) 179 super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor) 180 181 def _inputs(self): 182 # TODO(b/116506223): Determine which datasets should be used as inputs here. 183 return [] 184 185 @property 186 def _element_structure(self): 187 return self._structure 188 189 190class MultiDeviceIterator(object): 191 """An iterator over multiple devices.""" 192 193 def __init__(self, 194 dataset, 195 devices, 196 max_buffer_size=1, 197 prefetch_buffer_size=1, 198 source_device="/cpu:0"): 199 """Constructs a MultiDeviceIterator. 200 201 Args: 202 dataset: The input dataset to be iterated over. 203 devices: The list of devices to fetch data to. 204 max_buffer_size: Maximum size of the host side per device buffer to keep. 205 prefetch_buffer_size: if > 1, then we setup a buffer on each device 206 to prefetch into. 207 source_device: The host device to place the `dataset` on. 208 209 In order to prevent deadlocks, if the prefetch_buffer_size is greater 210 than the max_buffer_size, we set the max_buffer_size to 211 prefetch_buffer_size. 212 213 Raises: 214 RuntimeError: If run in Eager mode. 215 """ 216 self._dataset = dataset._apply_options() # pylint: disable=protected-access 217 self._devices = devices 218 self._source_device = source_device 219 self._source_device_tensor = ops.convert_to_tensor(source_device) 220 self._max_buffer_size = max_buffer_size 221 self._prefetch_buffer_size = prefetch_buffer_size 222 223 if self._prefetch_buffer_size > self._max_buffer_size: 224 self._max_buffer_size = self._prefetch_buffer_size 225 226 # Create the MultiDeviceIterator. 227 with ops.device(self._source_device): 228 # TODO(b/121378567): Get rid of this shared_name hack. 229 shared_name = "" 230 if context.executing_eagerly(): 231 shared_name = context.shared_name() 232 self._multi_device_iterator_resource = ( 233 gen_dataset_ops.multi_device_iterator( 234 devices=self._devices, 235 shared_name=shared_name, 236 container="", 237 **dataset_ops.flat_structure(self._dataset))) 238 if context.executing_eagerly(): 239 # Delete the resource when this object is deleted 240 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 241 handle=self._multi_device_iterator_resource, 242 handle_device=self._source_device) 243 244 # The incarnation ID is used to ensure consistency between the per-device 245 # iterators and the multi-device iterator. 246 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 247 self._dataset._variant_tensor, # pylint: disable=protected-access 248 self._multi_device_iterator_resource, 249 max_buffer_size=self._max_buffer_size) 250 251 self._prototype_device_datasets = [] 252 for i, device in enumerate(self._devices): 253 with ops.device(device): 254 ds = _PerDeviceGenerator( 255 i, self._multi_device_iterator_resource, self._incarnation_id, 256 self._source_device_tensor, self._dataset._element_structure) # pylint: disable=protected-access 257 self._prototype_device_datasets.append(ds) 258 259 # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to 260 # initialize the device side of the pipeline. This would allow the 261 # MultiDeviceIterator to choose, for example, to move some transformations 262 # into the device side from its input. It might be useful in rewriting. 263 # Create the per device iterators. 264 self._device_iterators = [] 265 for i, device in enumerate(self._devices): 266 with ops.device(device): 267 ds = self._create_device_dataset(i) 268 if context.executing_eagerly(): 269 self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds)) 270 else: 271 self._device_iterators.append( 272 dataset_ops.make_initializable_iterator(ds)) 273 274 if not context.executing_eagerly(): 275 device_iterator_initializers = [ 276 iterator.initializer for iterator in self._device_iterators 277 ] 278 self._initializer = control_flow_ops.group(*device_iterator_initializers) 279 280 def _create_device_dataset(self, i): 281 """Uses _prototype_device_datasets[i] to build a dataset for the device.""" 282 ds = self._prototype_device_datasets[i] 283 ds = _ReincarnatedPerDeviceGenerator(ds, self._incarnation_id) 284 if self._prefetch_buffer_size > 0: 285 ds = ds.prefetch(self._prefetch_buffer_size) 286 # TODO(jsimsa): Enable auto-tuning and optimizations when supported for 287 # non-CPU devices. 288 options = dataset_ops.Options() 289 options.experimental_optimization.apply_default_optimizations = False 290 options.experimental_optimization.autotune = False 291 ds = ds.with_options(options) 292 return ds 293 294 def get_next(self, device=None): 295 """Returns the next element given a `device`, else returns all in a list.""" 296 if device is not None: 297 index = self._devices.index(device) 298 return self._device_iterators[index].get_next() 299 300 result = [] 301 for i, device in enumerate(self._devices): 302 with ops.device(device): 303 result.append(self._device_iterators[i].get_next()) 304 return result 305 306 def get_next_as_optional(self): 307 result = [] 308 for i, device in enumerate(self._devices): 309 with ops.device(device): 310 result.append(iterator_ops.get_next_as_optional( 311 self._device_iterators[i])) 312 return result 313 314 @property 315 def initializer(self): 316 if context.executing_eagerly(): 317 return control_flow_ops.no_op() 318 return self._initializer 319 320 def _eager_reset(self): 321 """Resets the MultiDeviceIterator in eager mode.""" 322 if not context.executing_eagerly(): 323 raise ValueError("Eager reset is only supported in eager mode.") 324 # pylint: disable=protected-access 325 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 326 self._dataset._variant_tensor, 327 self._multi_device_iterator_resource, 328 max_buffer_size=self._max_buffer_size) 329 for i, device in enumerate(self._devices): 330 with ops.device(device): 331 ds = self._create_device_dataset(i) 332 # Reset the device iterator resources with the new dataset. 333 ds_variant = ds._variant_tensor 334 gen_dataset_ops.make_iterator( 335 ds_variant, self._device_iterators[i]._iterator_resource) 336 337 @property 338 def _element_structure(self): 339 return dataset_ops.get_structure(self._dataset) 340