• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.eager import function
23from tensorflow.python.framework import device as framework_device
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import functional_ops
29from tensorflow.python.ops import gen_dataset_ops
30from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.util.tf_export import tf_export
33
34
35@tf_export("data.experimental.prefetch_to_device")
36def prefetch_to_device(device, buffer_size=None):
37  """A transformation that prefetches dataset values to the given `device`.
38
39  NOTE: Although the transformation creates a `tf.data.Dataset`, the
40  transformation must be the final `Dataset` in the input pipeline.
41
42  Args:
43    device: A string. The name of a device to which elements will be prefetched.
44    buffer_size: (Optional.) The number of elements to buffer on `device`.
45      Defaults to an automatically chosen value.
46
47  Returns:
48    A `Dataset` transformation function, which can be passed to
49    `tf.data.Dataset.apply`.
50  """
51  def _apply_fn(dataset):
52    return dataset.apply(
53        copy_to_device(target_device=device)).prefetch(buffer_size)
54
55  return _apply_fn
56
57
58@tf_export("data.experimental.copy_to_device")
59def copy_to_device(target_device, source_device="/cpu:0"):
60  """A transformation that copies dataset elements to the given `target_device`.
61
62  Args:
63    target_device: The name of a device to which elements will be copied.
64    source_device: The original device on which `input_dataset` will be placed.
65
66  Returns:
67    A `Dataset` transformation function, which can be passed to
68    `tf.data.Dataset.apply`.
69  """
70
71  def _apply_fn(dataset):
72    options = dataset_ops.Options()
73    options.experimental_optimization.apply_default_optimizations = False
74    options.experimental_optimization.autotune = False
75    return _CopyToDeviceDataset(
76        dataset, target_device=target_device,
77        source_device=source_device).with_options(options)
78
79  return _apply_fn
80
81
82# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
83# all inputs to the Op are in host memory, thereby avoiding some unnecessary
84# Sends and Recvs.
85class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
86  """A `Dataset` that copies elements to another device."""
87
88  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
89    """Constructs a _CopyToDeviceDataset.
90
91    Args:
92      input_dataset: `Dataset` to be copied
93      target_device: The name of the device to which elements would be copied.
94      source_device: Device where input_dataset would be placed.
95    """
96    self._input_dataset = input_dataset
97    self._target_device = target_device
98    spec = framework_device.DeviceSpec().from_string(self._target_device)
99    self._is_gpu_target = (spec.device_type == "GPU")
100    self._source_device_string = source_device
101    self._source_device = ops.convert_to_tensor(source_device)
102
103    wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
104        self._input_dataset._variant_tensor)  # pylint: disable=protected-access
105
106    @function.defun()
107    def _init_func():
108      """Creates an iterator for the input dataset.
109
110      Returns:
111        A `string` tensor that encapsulates the iterator created.
112      """
113      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
114      resource = gen_dataset_ops.anonymous_iterator(
115          **dataset_ops.flat_structure(self._input_dataset))
116      with ops.control_dependencies(
117          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
118        return gen_dataset_ops.iterator_to_string_handle(resource)
119
120    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access
121
122    @function.defun()
123    def _remote_init_func():
124      return functional_ops.remote_call(
125          target=self._source_device,
126          args=init_func_concrete.captured_inputs,
127          Tout=[dtypes.string],
128          f=init_func_concrete)
129
130    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
131    self._init_captured_args = self._init_func.captured_inputs
132
133    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
134    def _next_func(string_handle):
135      """Calls get_next for created iterator.
136
137      Args:
138        string_handle: An iterator string handle created by _init_func
139      Returns:
140        The elements generated from `input_dataset`
141      """
142      with ops.device(self._source_device_string):
143        iterator = iterator_ops.Iterator.from_string_handle(
144            string_handle,
145            dataset_ops.get_legacy_output_types(self),
146            dataset_ops.get_legacy_output_shapes(self),
147            dataset_ops.get_legacy_output_classes(self))
148      return self._element_structure._to_tensor_list(iterator.get_next())  # pylint: disable=protected-access
149
150    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access
151
152    @function.defun_with_attributes(
153        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
154        attributes={"experimental_ints_on_device": True})
155    def _remote_next_func(string_handle):
156      return functional_ops.remote_call(
157          target=self._source_device,
158          args=[string_handle] +
159          next_func_concrete.captured_inputs,
160          Tout=self._input_dataset._element_structure._flat_types,  # pylint: disable=protected-access
161          f=next_func_concrete)
162
163    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
164    self._next_captured_args = self._next_func.captured_inputs
165
166    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
167    def _finalize_func(string_handle):
168      """Destroys the iterator resource created.
169
170      Args:
171        string_handle: An iterator string handle created by _init_func
172      Returns:
173        Tensor constant 0
174      """
175      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
176          string_handle,
177          **dataset_ops.flat_structure(self._input_dataset))
178      with ops.control_dependencies([
179          resource_variable_ops.destroy_resource_op(
180              iterator_resource, ignore_lookup_error=True)]):
181        return array_ops.constant(0, dtypes.int64)
182
183    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access
184
185    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
186    def _remote_finalize_func(string_handle):
187      return functional_ops.remote_call(
188          target=self._source_device,
189          args=[string_handle] + finalize_func_concrete.captured_inputs,
190          Tout=[dtypes.int64],
191          f=finalize_func_concrete)
192
193    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
194    )
195    self._finalize_captured_args = self._finalize_func.captured_inputs
196
197    g = ops.get_default_graph()
198    self._init_func.add_to_graph(g)
199    self._next_func.add_to_graph(g)
200    self._finalize_func.add_to_graph(g)
201    # pylint: enable=protected-scope
202
203    with ops.device(self._target_device):
204      variant_tensor = gen_dataset_ops.generator_dataset(
205          self._init_captured_args,
206          self._next_captured_args,
207          self._finalize_captured_args,
208          init_func=self._init_func,
209          next_func=self._next_func,
210          finalize_func=self._finalize_func,
211          **dataset_ops.flat_structure(self._input_dataset))
212    super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
213
214  # The one_shot_iterator implementation needs a 0 arg _make_dataset function
215  # that thereby captures all the inputs required to create the dataset. Since
216  # there are strings that are inputs to the GeneratorDataset which can't be
217  # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
218  # GPU
219  def make_one_shot_iterator(self):
220    if self._is_gpu_target:
221      raise ValueError("Cannot create a one shot iterator when using "
222                       "`tf.data.experimental.copy_to_device()` on GPU. Please "
223                       "use `Dataset.make_initializable_iterator()` instead.")
224    else:
225      return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
226
227
228class _MapOnGpuDataset(dataset_ops.UnaryDataset):
229  """A `Dataset` that maps a function over elements in its using a GPU."""
230
231  def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
232    """See `Dataset.map()` for details."""
233    self._input_dataset = input_dataset
234    self._use_inter_op_parallelism = use_inter_op_parallelism
235
236    self._map_func = dataset_ops.StructuredFunctionWrapper(
237        map_func,
238        self._transformation_name(),
239        dataset=input_dataset,
240        defun_kwargs={"experimental_ints_on_device": True})
241    variant_tensor = ged_ops.experimental_map_dataset(
242        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
243        self._map_func.function.captured_inputs,
244        f=self._map_func.function,
245        use_inter_op_parallelism=self._use_inter_op_parallelism,
246        **dataset_ops.flat_structure(self))
247    super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
248
249  def _functions(self):
250    return [self._map_func]
251
252  @property
253  def _element_structure(self):
254    return self._map_func.output_structure
255
256  def _transformation_name(self):
257    return "map_on_gpu()"
258
259
260def map_on_gpu(map_func):
261  """Maps `map_func` across the elements of this dataset.
262
263  NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
264  `map_func` on GPU. It must be used after applying the
265  `tf.data.experimental.copy_to_device` transformation with a GPU device
266  argument.
267
268  Args:
269    map_func: A function mapping a nested structure of tensors (having shapes
270      and types defined by `self.output_shapes` and `self.output_types`) to
271      another nested structure of tensors.
272
273  Returns:
274    A `Dataset` transformation function, which can be passed to
275    `tf.data.Dataset.apply`.
276  """
277
278  def _apply_fn(dataset):
279    return _MapOnGpuDataset(dataset, map_func)
280
281  return _apply_fn
282