1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.ops import iterator_ops 22from tensorflow.python.eager import function 23from tensorflow.python.framework import device as framework_device 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import functional_ops 29from tensorflow.python.ops import gen_dataset_ops 30from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.util.tf_export import tf_export 33 34 35@tf_export("data.experimental.prefetch_to_device") 36def prefetch_to_device(device, buffer_size=None): 37 """A transformation that prefetches dataset values to the given `device`. 38 39 NOTE: Although the transformation creates a `tf.data.Dataset`, the 40 transformation must be the final `Dataset` in the input pipeline. 41 42 Args: 43 device: A string. The name of a device to which elements will be prefetched. 44 buffer_size: (Optional.) The number of elements to buffer on `device`. 45 Defaults to an automatically chosen value. 46 47 Returns: 48 A `Dataset` transformation function, which can be passed to 49 `tf.data.Dataset.apply`. 50 """ 51 def _apply_fn(dataset): 52 return dataset.apply( 53 copy_to_device(target_device=device)).prefetch(buffer_size) 54 55 return _apply_fn 56 57 58@tf_export("data.experimental.copy_to_device") 59def copy_to_device(target_device, source_device="/cpu:0"): 60 """A transformation that copies dataset elements to the given `target_device`. 61 62 Args: 63 target_device: The name of a device to which elements will be copied. 64 source_device: The original device on which `input_dataset` will be placed. 65 66 Returns: 67 A `Dataset` transformation function, which can be passed to 68 `tf.data.Dataset.apply`. 69 """ 70 71 def _apply_fn(dataset): 72 options = dataset_ops.Options() 73 options.experimental_optimization.apply_default_optimizations = False 74 options.experimental_optimization.autotune = False 75 return _CopyToDeviceDataset( 76 dataset, target_device=target_device, 77 source_device=source_device).with_options(options) 78 79 return _apply_fn 80 81 82# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate 83# all inputs to the Op are in host memory, thereby avoiding some unnecessary 84# Sends and Recvs. 85class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): 86 """A `Dataset` that copies elements to another device.""" 87 88 def __init__(self, input_dataset, target_device, source_device="/cpu:0"): 89 """Constructs a _CopyToDeviceDataset. 90 91 Args: 92 input_dataset: `Dataset` to be copied 93 target_device: The name of the device to which elements would be copied. 94 source_device: Device where input_dataset would be placed. 95 """ 96 self._input_dataset = input_dataset 97 self._target_device = target_device 98 spec = framework_device.DeviceSpec().from_string(self._target_device) 99 self._is_gpu_target = (spec.device_type == "GPU") 100 self._source_device_string = source_device 101 self._source_device = ops.convert_to_tensor(source_device) 102 103 wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( 104 self._input_dataset._variant_tensor) # pylint: disable=protected-access 105 106 @function.defun() 107 def _init_func(): 108 """Creates an iterator for the input dataset. 109 110 Returns: 111 A `string` tensor that encapsulates the iterator created. 112 """ 113 ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) 114 resource = gen_dataset_ops.anonymous_iterator( 115 **dataset_ops.flat_structure(self._input_dataset)) 116 with ops.control_dependencies( 117 [gen_dataset_ops.make_iterator(ds_variant, resource)]): 118 return gen_dataset_ops.iterator_to_string_handle(resource) 119 120 init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access 121 122 @function.defun() 123 def _remote_init_func(): 124 return functional_ops.remote_call( 125 target=self._source_device, 126 args=init_func_concrete.captured_inputs, 127 Tout=[dtypes.string], 128 f=init_func_concrete) 129 130 self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access 131 self._init_captured_args = self._init_func.captured_inputs 132 133 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 134 def _next_func(string_handle): 135 """Calls get_next for created iterator. 136 137 Args: 138 string_handle: An iterator string handle created by _init_func 139 Returns: 140 The elements generated from `input_dataset` 141 """ 142 with ops.device(self._source_device_string): 143 iterator = iterator_ops.Iterator.from_string_handle( 144 string_handle, 145 dataset_ops.get_legacy_output_types(self), 146 dataset_ops.get_legacy_output_shapes(self), 147 dataset_ops.get_legacy_output_classes(self)) 148 return self._element_structure._to_tensor_list(iterator.get_next()) # pylint: disable=protected-access 149 150 next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access 151 152 @function.defun_with_attributes( 153 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 154 attributes={"experimental_ints_on_device": True}) 155 def _remote_next_func(string_handle): 156 return functional_ops.remote_call( 157 target=self._source_device, 158 args=[string_handle] + 159 next_func_concrete.captured_inputs, 160 Tout=self._input_dataset._element_structure._flat_types, # pylint: disable=protected-access 161 f=next_func_concrete) 162 163 self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access 164 self._next_captured_args = self._next_func.captured_inputs 165 166 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 167 def _finalize_func(string_handle): 168 """Destroys the iterator resource created. 169 170 Args: 171 string_handle: An iterator string handle created by _init_func 172 Returns: 173 Tensor constant 0 174 """ 175 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 176 string_handle, 177 **dataset_ops.flat_structure(self._input_dataset)) 178 with ops.control_dependencies([ 179 resource_variable_ops.destroy_resource_op( 180 iterator_resource, ignore_lookup_error=True)]): 181 return array_ops.constant(0, dtypes.int64) 182 183 finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access 184 185 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 186 def _remote_finalize_func(string_handle): 187 return functional_ops.remote_call( 188 target=self._source_device, 189 args=[string_handle] + finalize_func_concrete.captured_inputs, 190 Tout=[dtypes.int64], 191 f=finalize_func_concrete) 192 193 self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access 194 ) 195 self._finalize_captured_args = self._finalize_func.captured_inputs 196 197 g = ops.get_default_graph() 198 self._init_func.add_to_graph(g) 199 self._next_func.add_to_graph(g) 200 self._finalize_func.add_to_graph(g) 201 # pylint: enable=protected-scope 202 203 with ops.device(self._target_device): 204 variant_tensor = gen_dataset_ops.generator_dataset( 205 self._init_captured_args, 206 self._next_captured_args, 207 self._finalize_captured_args, 208 init_func=self._init_func, 209 next_func=self._next_func, 210 finalize_func=self._finalize_func, 211 **dataset_ops.flat_structure(self._input_dataset)) 212 super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor) 213 214 # The one_shot_iterator implementation needs a 0 arg _make_dataset function 215 # that thereby captures all the inputs required to create the dataset. Since 216 # there are strings that are inputs to the GeneratorDataset which can't be 217 # placed on a GPU, this fails for the GPU case. Therefore, disabling it for 218 # GPU 219 def make_one_shot_iterator(self): 220 if self._is_gpu_target: 221 raise ValueError("Cannot create a one shot iterator when using " 222 "`tf.data.experimental.copy_to_device()` on GPU. Please " 223 "use `Dataset.make_initializable_iterator()` instead.") 224 else: 225 return super(_CopyToDeviceDataset, self).make_one_shot_iterator() 226 227 228class _MapOnGpuDataset(dataset_ops.UnaryDataset): 229 """A `Dataset` that maps a function over elements in its using a GPU.""" 230 231 def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): 232 """See `Dataset.map()` for details.""" 233 self._input_dataset = input_dataset 234 self._use_inter_op_parallelism = use_inter_op_parallelism 235 236 self._map_func = dataset_ops.StructuredFunctionWrapper( 237 map_func, 238 self._transformation_name(), 239 dataset=input_dataset, 240 defun_kwargs={"experimental_ints_on_device": True}) 241 variant_tensor = ged_ops.experimental_map_dataset( 242 self._input_dataset._variant_tensor, # pylint: disable=protected-access 243 self._map_func.function.captured_inputs, 244 f=self._map_func.function, 245 use_inter_op_parallelism=self._use_inter_op_parallelism, 246 **dataset_ops.flat_structure(self)) 247 super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor) 248 249 def _functions(self): 250 return [self._map_func] 251 252 @property 253 def _element_structure(self): 254 return self._map_func.output_structure 255 256 def _transformation_name(self): 257 return "map_on_gpu()" 258 259 260def map_on_gpu(map_func): 261 """Maps `map_func` across the elements of this dataset. 262 263 NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs 264 `map_func` on GPU. It must be used after applying the 265 `tf.data.experimental.copy_to_device` transformation with a GPU device 266 argument. 267 268 Args: 269 map_func: A function mapping a nested structure of tensors (having shapes 270 and types defined by `self.output_shapes` and `self.output_types`) to 271 another nested structure of tensors. 272 273 Returns: 274 A `Dataset` transformation function, which can be passed to 275 `tf.data.Dataset.apply`. 276 """ 277 278 def _apply_fn(dataset): 279 return _MapOnGpuDataset(dataset, map_func) 280 281 return _apply_fn 282