1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.ops import iterator_ops 22from tensorflow.python.data.util import structure 23from tensorflow.python.eager import function 24from tensorflow.python.framework import device as framework_device 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import functional_ops 30from tensorflow.python.ops import gen_dataset_ops 31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 32from tensorflow.python.ops import resource_variable_ops 33from tensorflow.python.util.tf_export import tf_export 34 35 36@tf_export("data.experimental.prefetch_to_device") 37def prefetch_to_device(device, buffer_size=None): 38 """A transformation that prefetches dataset values to the given `device`. 39 40 NOTE: Although the transformation creates a `tf.data.Dataset`, the 41 transformation must be the final `Dataset` in the input pipeline. 42 43 For example, 44 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 45 >>> dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/cpu:0")) 46 >>> for element in dataset: 47 ... print(f'Tensor {element} is on device {element.device}') 48 Tensor 1 is on device /job:localhost/replica:0/task:0/device:CPU:0 49 Tensor 2 is on device /job:localhost/replica:0/task:0/device:CPU:0 50 Tensor 3 is on device /job:localhost/replica:0/task:0/device:CPU:0 51 52 Args: 53 device: A string. The name of a device to which elements will be prefetched. 54 buffer_size: (Optional.) The number of elements to buffer on `device`. 55 Defaults to an automatically chosen value. 56 57 Returns: 58 A `Dataset` transformation function, which can be passed to 59 `tf.data.Dataset.apply`. 60 """ 61 def _apply_fn(dataset): 62 return dataset.apply( 63 copy_to_device(target_device=device)).prefetch(buffer_size) 64 65 return _apply_fn 66 67 68@tf_export("data.experimental.copy_to_device") 69def copy_to_device(target_device, source_device="/cpu:0"): 70 """A transformation that copies dataset elements to the given `target_device`. 71 72 Args: 73 target_device: The name of a device to which elements will be copied. 74 source_device: The original device on which `input_dataset` will be placed. 75 76 Returns: 77 A `Dataset` transformation function, which can be passed to 78 `tf.data.Dataset.apply`. 79 """ 80 81 def _apply_fn(dataset): 82 return _CopyToDeviceDataset( 83 dataset, target_device=target_device, source_device=source_device) 84 85 return _apply_fn 86 87 88# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate 89# all inputs to the Op are in host memory, thereby avoiding some unnecessary 90# Sends and Recvs. 91class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): 92 """A `Dataset` that copies elements to another device.""" 93 94 def __init__(self, input_dataset, target_device, source_device="/cpu:0"): 95 """Constructs a _CopyToDeviceDataset. 96 97 Args: 98 input_dataset: `Dataset` to be copied 99 target_device: The name of the device to which elements would be copied. 100 source_device: Device where input_dataset would be placed. 101 """ 102 self._input_dataset = input_dataset._apply_debug_options() # pylint: disable=protected-access 103 self._target_device = target_device 104 spec = framework_device.DeviceSpec().from_string(self._target_device) 105 self._is_gpu_target = (spec.device_type == "GPU") 106 self._source_device_string = source_device 107 self._source_device = ops.convert_to_tensor(source_device) 108 109 wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( 110 self._input_dataset._variant_tensor) # pylint: disable=protected-access 111 112 @function.defun() 113 def _init_func(): 114 """Creates an iterator for the input dataset. 115 116 Returns: 117 A `string` tensor that encapsulates the iterator created. 118 """ 119 ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) 120 resource = gen_dataset_ops.anonymous_iterator( 121 **self._input_dataset._flat_structure) # pylint: disable=protected-access 122 with ops.control_dependencies( 123 [gen_dataset_ops.make_iterator(ds_variant, resource)]): 124 return gen_dataset_ops.iterator_to_string_handle(resource) 125 126 init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access 127 128 @function.defun() 129 def _remote_init_func(): 130 return functional_ops.remote_call( 131 target=self._source_device, 132 args=init_func_concrete.captured_inputs, 133 Tout=[dtypes.string], 134 f=init_func_concrete) 135 136 self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access 137 self._init_captured_args = self._init_func.captured_inputs 138 139 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 140 def _next_func(string_handle): 141 """Calls get_next for created iterator. 142 143 Args: 144 string_handle: An iterator string handle created by _init_func 145 Returns: 146 The elements generated from `input_dataset` 147 """ 148 with ops.device(self._source_device_string): 149 iterator = iterator_ops.Iterator.from_string_handle( 150 string_handle, 151 dataset_ops.get_legacy_output_types(self), 152 dataset_ops.get_legacy_output_shapes(self), 153 dataset_ops.get_legacy_output_classes(self)) 154 return structure.to_tensor_list(self.element_spec, iterator.get_next()) 155 156 next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access 157 158 @function.defun_with_attributes( 159 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 160 attributes={"experimental_ints_on_device": True}) 161 def _remote_next_func(string_handle): 162 return functional_ops.remote_call( 163 target=self._source_device, 164 args=[string_handle] + next_func_concrete.captured_inputs, 165 Tout=self._input_dataset._flat_types, # pylint: disable=protected-access 166 f=next_func_concrete) 167 168 self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access 169 self._next_captured_args = self._next_func.captured_inputs 170 171 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 172 def _finalize_func(string_handle): 173 """Destroys the iterator resource created. 174 175 Args: 176 string_handle: An iterator string handle created by _init_func 177 Returns: 178 Tensor constant 0 179 """ 180 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 181 string_handle, 182 **self._input_dataset._flat_structure) # pylint: disable=protected-access 183 with ops.control_dependencies([ 184 resource_variable_ops.destroy_resource_op( 185 iterator_resource, ignore_lookup_error=True)]): 186 return array_ops.constant(0, dtypes.int64) 187 188 finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access 189 190 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 191 def _remote_finalize_func(string_handle): 192 return functional_ops.remote_call( 193 target=self._source_device, 194 args=[string_handle] + finalize_func_concrete.captured_inputs, 195 Tout=[dtypes.int64], 196 f=finalize_func_concrete) 197 198 self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access 199 ) 200 self._finalize_captured_args = self._finalize_func.captured_inputs 201 202 g = ops.get_default_graph() 203 self._init_func.add_to_graph(g) 204 self._next_func.add_to_graph(g) 205 self._finalize_func.add_to_graph(g) 206 # pylint: enable=protected-scope 207 208 with ops.device(self._target_device): 209 variant_tensor = gen_dataset_ops.generator_dataset( 210 self._init_captured_args, 211 self._next_captured_args, 212 self._finalize_captured_args, 213 init_func=self._init_func, 214 next_func=self._next_func, 215 finalize_func=self._finalize_func, 216 **self._input_dataset._flat_structure) # pylint: disable=protected-access 217 super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor) 218 219 # The one_shot_iterator implementation needs a 0 arg _make_dataset function 220 # that thereby captures all the inputs required to create the dataset. Since 221 # there are strings that are inputs to the GeneratorDataset which can't be 222 # placed on a GPU, this fails for the GPU case. Therefore, disabling it for 223 # GPU 224 def make_one_shot_iterator(self): 225 if self._is_gpu_target: 226 raise ValueError("Cannot create a one shot iterator when using " 227 "`tf.data.experimental.copy_to_device()` on GPU. Please " 228 "use `Dataset.make_initializable_iterator()` instead.") 229 else: 230 return super(_CopyToDeviceDataset, self).make_one_shot_iterator() 231 232 233class _MapOnGpuDataset(dataset_ops.UnaryDataset): 234 """A `Dataset` that maps a function over elements in its using a GPU.""" 235 236 def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): 237 """See `Dataset.map()` for details.""" 238 self._input_dataset = input_dataset 239 self._use_inter_op_parallelism = use_inter_op_parallelism 240 241 self._map_func = dataset_ops.StructuredFunctionWrapper( 242 map_func, 243 self._transformation_name(), 244 dataset=input_dataset, 245 defun_kwargs={"experimental_ints_on_device": True}) 246 variant_tensor = ged_ops.experimental_map_dataset( 247 self._input_dataset._variant_tensor, # pylint: disable=protected-access 248 self._map_func.function.captured_inputs, 249 f=self._map_func.function, 250 use_inter_op_parallelism=self._use_inter_op_parallelism, 251 **self._flat_structure) 252 super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor) 253 254 def _functions(self): 255 return [self._map_func] 256 257 @property 258 def element_spec(self): 259 return self._map_func.output_structure 260 261 def _transformation_name(self): 262 return "map_on_gpu()" 263 264 265def map_on_gpu(map_func): 266 """Maps `map_func` across the elements of this dataset. 267 268 NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs 269 `map_func` on GPU. It must be used after applying the 270 `tf.data.experimental.copy_to_device` transformation with a GPU device 271 argument. 272 273 Args: 274 map_func: A function mapping a nested structure of tensors (having shapes 275 and types defined by `self.output_shapes` and `self.output_types`) to 276 another nested structure of tensors. 277 278 Returns: 279 A `Dataset` transformation function, which can be passed to 280 `tf.data.Dataset.apply`. 281 """ 282 283 def _apply_fn(dataset): 284 return _MapOnGpuDataset(dataset, map_func) 285 286 return _apply_fn 287