• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental API for testing of tf.data."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import gen_experimental_dataset_ops
24
25
26def assert_next(transformations):
27  """A transformation that asserts which transformations happen next.
28
29  Transformations should be referred to by their base name, not including
30  version suffix. For example, use "Batch" instead of "BatchV2". "Batch" will
31  match any of "Batch", "BatchV1", "BatchV2", etc.
32
33  Args:
34    transformations: A `tf.string` vector `tf.Tensor` identifying the
35      transformations that are expected to happen next.
36
37  Returns:
38    A `Dataset` transformation function, which can be passed to
39    `tf.data.Dataset.apply`.
40  """
41
42  def _apply_fn(dataset):
43    """Function from `Dataset` to `Dataset` that applies the transformation."""
44    return _AssertNextDataset(dataset, transformations)
45
46  return _apply_fn
47
48
49def non_serializable():
50  """A non-serializable identity transformation.
51
52  Returns:
53    A `Dataset` transformation function, which can be passed to
54    `tf.data.Dataset.apply`.
55  """
56
57  def _apply_fn(dataset):
58    """Function from `Dataset` to `Dataset` that applies the transformation."""
59    return _NonSerializableDataset(dataset)
60
61  return _apply_fn
62
63
64def sleep(sleep_microseconds):
65  """Sleeps for `sleep_microseconds` before producing each input element.
66
67  Args:
68    sleep_microseconds: The number of microseconds to sleep before producing an
69      input element.
70
71  Returns:
72    A `Dataset` transformation function, which can be passed to
73    `tf.data.Dataset.apply`.
74  """
75
76  def _apply_fn(dataset):
77    return _SleepDataset(dataset, sleep_microseconds)
78
79  return _apply_fn
80
81
82class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
83  """A `Dataset` that asserts which transformations happen next."""
84
85  def __init__(self, input_dataset, transformations):
86    """See `assert_next()` for details."""
87    self._input_dataset = input_dataset
88    if transformations is None:
89      raise ValueError("At least one transformation should be specified")
90    self._transformations = ops.convert_to_tensor(
91        transformations, dtype=dtypes.string, name="transformations")
92    variant_tensor = (
93        gen_experimental_dataset_ops.experimental_assert_next_dataset(
94            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
95            self._transformations,
96            **self._flat_structure))
97    super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
98
99
100class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
101  """A `Dataset` that performs non-serializable identity transformation."""
102
103  def __init__(self, input_dataset):
104    """See `non_serializable()` for details."""
105    self._input_dataset = input_dataset
106    variant_tensor = (
107        gen_experimental_dataset_ops.experimental_non_serializable_dataset(
108            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
109            **self._flat_structure))
110    super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
111
112
113class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
114  """A `Dataset` that sleeps before producing each upstream element."""
115
116  def __init__(self, input_dataset, sleep_microseconds):
117    self._input_dataset = input_dataset
118    self._sleep_microseconds = sleep_microseconds
119    variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
120        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
121        self._sleep_microseconds,
122        **self._flat_structure)
123    super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
124