• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.data.util import structure
23from tensorflow.python.eager import function
24from tensorflow.python.framework import device as framework_device
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import functional_ops
30from tensorflow.python.ops import gen_dataset_ops
31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36@tf_export("data.experimental.prefetch_to_device")
37def prefetch_to_device(device, buffer_size=None):
38  """A transformation that prefetches dataset values to the given `device`.
39
40  NOTE: Although the transformation creates a `tf.data.Dataset`, the
41  transformation must be the final `Dataset` in the input pipeline.
42
43  For example,
44  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
45  >>> dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/cpu:0"))
46  >>> for element in dataset:
47  ...   print(f'Tensor {element} is on device {element.device}')
48  Tensor 1 is on device /job:localhost/replica:0/task:0/device:CPU:0
49  Tensor 2 is on device /job:localhost/replica:0/task:0/device:CPU:0
50  Tensor 3 is on device /job:localhost/replica:0/task:0/device:CPU:0
51
52  Args:
53    device: A string. The name of a device to which elements will be prefetched.
54    buffer_size: (Optional.) The number of elements to buffer on `device`.
55      Defaults to an automatically chosen value.
56
57  Returns:
58    A `Dataset` transformation function, which can be passed to
59    `tf.data.Dataset.apply`.
60  """
61  def _apply_fn(dataset):
62    return dataset.apply(
63        copy_to_device(target_device=device)).prefetch(buffer_size)
64
65  return _apply_fn
66
67
68@tf_export("data.experimental.copy_to_device")
69def copy_to_device(target_device, source_device="/cpu:0"):
70  """A transformation that copies dataset elements to the given `target_device`.
71
72  Args:
73    target_device: The name of a device to which elements will be copied.
74    source_device: The original device on which `input_dataset` will be placed.
75
76  Returns:
77    A `Dataset` transformation function, which can be passed to
78    `tf.data.Dataset.apply`.
79  """
80
81  def _apply_fn(dataset):
82    return _CopyToDeviceDataset(
83        dataset, target_device=target_device, source_device=source_device)
84
85  return _apply_fn
86
87
88# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
89# all inputs to the Op are in host memory, thereby avoiding some unnecessary
90# Sends and Recvs.
91class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
92  """A `Dataset` that copies elements to another device."""
93
94  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
95    """Constructs a _CopyToDeviceDataset.
96
97    Args:
98      input_dataset: `Dataset` to be copied
99      target_device: The name of the device to which elements would be copied.
100      source_device: Device where input_dataset would be placed.
101    """
102    self._input_dataset = input_dataset._apply_debug_options()  # pylint: disable=protected-access
103    self._target_device = target_device
104    spec = framework_device.DeviceSpec().from_string(self._target_device)
105    self._is_gpu_target = (spec.device_type == "GPU")
106    self._source_device_string = source_device
107    self._source_device = ops.convert_to_tensor(source_device)
108
109    wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
110        self._input_dataset._variant_tensor)  # pylint: disable=protected-access
111
112    @function.defun()
113    def _init_func():
114      """Creates an iterator for the input dataset.
115
116      Returns:
117        A `string` tensor that encapsulates the iterator created.
118      """
119      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
120      resource = gen_dataset_ops.anonymous_iterator(
121          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
122      with ops.control_dependencies(
123          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
124        return gen_dataset_ops.iterator_to_string_handle(resource)
125
126    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access
127
128    @function.defun()
129    def _remote_init_func():
130      return functional_ops.remote_call(
131          target=self._source_device,
132          args=init_func_concrete.captured_inputs,
133          Tout=[dtypes.string],
134          f=init_func_concrete)
135
136    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
137    self._init_captured_args = self._init_func.captured_inputs
138
139    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
140    def _next_func(string_handle):
141      """Calls get_next for created iterator.
142
143      Args:
144        string_handle: An iterator string handle created by _init_func
145      Returns:
146        The elements generated from `input_dataset`
147      """
148      with ops.device(self._source_device_string):
149        iterator = iterator_ops.Iterator.from_string_handle(
150            string_handle,
151            dataset_ops.get_legacy_output_types(self),
152            dataset_ops.get_legacy_output_shapes(self),
153            dataset_ops.get_legacy_output_classes(self))
154      return structure.to_tensor_list(self.element_spec, iterator.get_next())
155
156    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access
157
158    @function.defun_with_attributes(
159        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
160        attributes={"experimental_ints_on_device": True})
161    def _remote_next_func(string_handle):
162      return functional_ops.remote_call(
163          target=self._source_device,
164          args=[string_handle] + next_func_concrete.captured_inputs,
165          Tout=self._input_dataset._flat_types,  # pylint: disable=protected-access
166          f=next_func_concrete)
167
168    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
169    self._next_captured_args = self._next_func.captured_inputs
170
171    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
172    def _finalize_func(string_handle):
173      """Destroys the iterator resource created.
174
175      Args:
176        string_handle: An iterator string handle created by _init_func
177      Returns:
178        Tensor constant 0
179      """
180      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
181          string_handle,
182          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
183      with ops.control_dependencies([
184          resource_variable_ops.destroy_resource_op(
185              iterator_resource, ignore_lookup_error=True)]):
186        return array_ops.constant(0, dtypes.int64)
187
188    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access
189
190    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
191    def _remote_finalize_func(string_handle):
192      return functional_ops.remote_call(
193          target=self._source_device,
194          args=[string_handle] + finalize_func_concrete.captured_inputs,
195          Tout=[dtypes.int64],
196          f=finalize_func_concrete)
197
198    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
199    )
200    self._finalize_captured_args = self._finalize_func.captured_inputs
201
202    g = ops.get_default_graph()
203    self._init_func.add_to_graph(g)
204    self._next_func.add_to_graph(g)
205    self._finalize_func.add_to_graph(g)
206    # pylint: enable=protected-scope
207
208    with ops.device(self._target_device):
209      variant_tensor = gen_dataset_ops.generator_dataset(
210          self._init_captured_args,
211          self._next_captured_args,
212          self._finalize_captured_args,
213          init_func=self._init_func,
214          next_func=self._next_func,
215          finalize_func=self._finalize_func,
216          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
217    super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
218
219  # The one_shot_iterator implementation needs a 0 arg _make_dataset function
220  # that thereby captures all the inputs required to create the dataset. Since
221  # there are strings that are inputs to the GeneratorDataset which can't be
222  # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
223  # GPU
224  def make_one_shot_iterator(self):
225    if self._is_gpu_target:
226      raise ValueError("Cannot create a one shot iterator when using "
227                       "`tf.data.experimental.copy_to_device()` on GPU. Please "
228                       "use `Dataset.make_initializable_iterator()` instead.")
229    else:
230      return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
231
232
233class _MapOnGpuDataset(dataset_ops.UnaryDataset):
234  """A `Dataset` that maps a function over elements in its using a GPU."""
235
236  def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
237    """See `Dataset.map()` for details."""
238    self._input_dataset = input_dataset
239    self._use_inter_op_parallelism = use_inter_op_parallelism
240
241    self._map_func = dataset_ops.StructuredFunctionWrapper(
242        map_func,
243        self._transformation_name(),
244        dataset=input_dataset,
245        defun_kwargs={"experimental_ints_on_device": True})
246    variant_tensor = ged_ops.experimental_map_dataset(
247        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
248        self._map_func.function.captured_inputs,
249        f=self._map_func.function,
250        use_inter_op_parallelism=self._use_inter_op_parallelism,
251        **self._flat_structure)
252    super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
253
254  def _functions(self):
255    return [self._map_func]
256
257  @property
258  def element_spec(self):
259    return self._map_func.output_structure
260
261  def _transformation_name(self):
262    return "map_on_gpu()"
263
264
265def map_on_gpu(map_func):
266  """Maps `map_func` across the elements of this dataset.
267
268  NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
269  `map_func` on GPU. It must be used after applying the
270  `tf.data.experimental.copy_to_device` transformation with a GPU device
271  argument.
272
273  Args:
274    map_func: A function mapping a nested structure of tensors (having shapes
275      and types defined by `self.output_shapes` and `self.output_types`) to
276      another nested structure of tensors.
277
278  Returns:
279    A `Dataset` transformation function, which can be passed to
280    `tf.data.Dataset.apply`.
281  """
282
283  def _apply_fn(dataset):
284    return _MapOnGpuDataset(dataset, map_func)
285
286  return _apply_fn
287