• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities for tf.data functionality."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import re
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import structure
25from tensorflow.python.eager import context
26from tensorflow.python.framework import combinations
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import gen_dataset_ops
33from tensorflow.python.ops import gen_experimental_dataset_ops
34from tensorflow.python.ops import tensor_array_ops
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.platform import test
37
38
39def default_test_combinations():
40  """Returns the default test combinations for tf.data tests."""
41  return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"])
42
43
44def eager_only_combinations():
45  """Returns the default test combinations for eager mode only tf.data tests."""
46  return combinations.combine(tf_api_version=[1, 2], mode="eager")
47
48
49def graph_only_combinations():
50  """Returns the default test combinations for graph mode only tf.data tests."""
51  return combinations.combine(tf_api_version=[1, 2], mode="graph")
52
53
54def v1_only_combinations():
55  """Returns the default test combinations for v1 only tf.data tests."""
56  return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
57
58
59def v2_only_combinations():
60  """Returns the default test combinations for v2 only tf.data tests."""
61  return combinations.combine(tf_api_version=2, mode=["eager", "graph"])
62
63
64def v2_eager_only_combinations():
65  """Returns the default test combinations for v2 eager only tf.data tests."""
66  return combinations.combine(tf_api_version=2, mode="eager")
67
68
69class DatasetTestBase(test.TestCase):
70  """Base class for dataset tests."""
71
72  def assert_op_cancelled(self, op):
73    with self.assertRaises(errors.CancelledError):
74      self.evaluate(op)
75
76  def assertValuesEqual(self, expected, actual):
77    """Asserts that two values are equal."""
78    if isinstance(expected, dict):
79      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
80      for k in expected.keys():
81        self.assertValuesEqual(expected[k], actual[k])
82    elif sparse_tensor.is_sparse(expected):
83      self.assertAllEqual(expected.indices, actual.indices)
84      self.assertAllEqual(expected.values, actual.values)
85      self.assertAllEqual(expected.dense_shape, actual.dense_shape)
86    else:
87      self.assertAllEqual(expected, actual)
88
89  def getNext(self, dataset, requires_initialization=False, shared_name=None):
90    """Returns a callable that returns the next element of the dataset.
91
92    Example use:
93    ```python
94    # In both graph and eager modes
95    dataset = ...
96    get_next = self.getNext(dataset)
97    result = self.evaluate(get_next())
98    ```
99
100    Args:
101      dataset: A dataset whose elements will be returned.
102      requires_initialization: Indicates that when the test is executed in graph
103        mode, it should use an initializable iterator to iterate through the
104        dataset (e.g. when it contains stateful nodes). Defaults to False.
105      shared_name: (Optional.) If non-empty, the returned iterator will be
106        shared under the given name across multiple sessions that share the same
107        devices (e.g. when using a remote server).
108    Returns:
109      A callable that returns the next element of `dataset`. Any `TensorArray`
110      objects `dataset` outputs are stacked.
111    """
112    def ta_wrapper(gn):
113      def _wrapper():
114        r = gn()
115        if isinstance(r, tensor_array_ops.TensorArray):
116          return r.stack()
117        else:
118          return r
119      return _wrapper
120
121    # Create an anonymous iterator if we are in eager-mode or are graph inside
122    # of a tf.function.
123    if context.executing_eagerly() or ops.inside_function():
124      iterator = iter(dataset)
125      return ta_wrapper(iterator._next_internal)  # pylint: disable=protected-access
126    else:
127      if requires_initialization:
128        iterator = dataset_ops.make_initializable_iterator(dataset, shared_name)
129        self.evaluate(iterator.initializer)
130      else:
131        iterator = dataset_ops.make_one_shot_iterator(dataset)
132      get_next = iterator.get_next()
133      return ta_wrapper(lambda: get_next)
134
135  def _compareOutputToExpected(self, result_values, expected_values,
136                               assert_items_equal):
137    if assert_items_equal:
138      # TODO(shivaniagrawal): add support for nested elements containing sparse
139      # tensors when needed.
140      self.assertItemsEqual(result_values, expected_values)
141      return
142    for i in range(len(result_values)):
143      nest.assert_same_structure(result_values[i], expected_values[i])
144      for result_value, expected_value in zip(
145          nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
146        self.assertValuesEqual(expected_value, result_value)
147
148  def getDatasetOutput(self, dataset, requires_initialization=False):
149    get_next = self.getNext(
150        dataset, requires_initialization=requires_initialization)
151    results = []
152    while True:
153      try:
154        results.append(self.evaluate(get_next()))
155      except errors.OutOfRangeError:
156        break
157    return results
158
159  def assertDatasetProduces(self,
160                            dataset,
161                            expected_output=None,
162                            expected_shapes=None,
163                            expected_error=None,
164                            requires_initialization=False,
165                            num_test_iterations=1,
166                            assert_items_equal=False,
167                            expected_error_iter=1):
168    """Asserts that a dataset produces the expected output / error.
169
170    Args:
171      dataset: A dataset to check for the expected output / error.
172      expected_output: A list of elements that the dataset is expected to
173        produce.
174      expected_shapes: A list of TensorShapes which is expected to match
175        output_shapes of dataset.
176      expected_error: A tuple `(type, predicate)` identifying the expected error
177        `dataset` should raise. The `type` should match the expected exception
178        type, while `predicate` should either be 1) a unary function that inputs
179        the raised exception and returns a boolean indicator of success or 2) a
180        regular expression that is expected to match the error message
181        partially.
182      requires_initialization: Indicates that when the test is executed in graph
183        mode, it should use an initializable iterator to iterate through the
184        dataset (e.g. when it contains stateful nodes). Defaults to False.
185      num_test_iterations: Number of times `dataset` will be iterated. Defaults
186        to 1.
187      assert_items_equal: Tests expected_output has (only) the same elements
188        regardless of order.
189      expected_error_iter: How many times to iterate before expecting an error,
190        if an error is expected.
191    """
192    self.assertTrue(
193        expected_error is not None or expected_output is not None,
194        "Exactly one of expected_output or expected error should be provided.")
195    if expected_error:
196      self.assertTrue(
197          expected_output is None,
198          "Exactly one of expected_output or expected error should be provided."
199      )
200      with self.assertRaisesWithPredicateMatch(expected_error[0],
201                                               expected_error[1]):
202        get_next = self.getNext(
203            dataset, requires_initialization=requires_initialization)
204        for _ in range(expected_error_iter):
205          self.evaluate(get_next())
206      return
207    if expected_shapes:
208      self.assertEqual(expected_shapes,
209                       dataset_ops.get_legacy_output_shapes(dataset))
210    self.assertGreater(num_test_iterations, 0)
211    for _ in range(num_test_iterations):
212      get_next = self.getNext(
213          dataset, requires_initialization=requires_initialization)
214      result = []
215      for _ in range(len(expected_output)):
216        result.append(self.evaluate(get_next()))
217      self._compareOutputToExpected(result, expected_output, assert_items_equal)
218      with self.assertRaises(errors.OutOfRangeError):
219        self.evaluate(get_next())
220      with self.assertRaises(errors.OutOfRangeError):
221        self.evaluate(get_next())
222
223  def assertDatasetsEqual(self, dataset1, dataset2):
224    """Checks that datasets are equal. Supports both graph and eager mode."""
225    self.assertTrue(
226        structure.are_compatible(
227            dataset_ops.get_structure(dataset1),
228            dataset_ops.get_structure(dataset2)))
229
230    flattened_types = nest.flatten(
231        dataset_ops.get_legacy_output_types(dataset1))
232
233    next1 = self.getNext(dataset1)
234    next2 = self.getNext(dataset2)
235
236    while True:
237      try:
238        op1 = self.evaluate(next1())
239      except errors.OutOfRangeError:
240        with self.assertRaises(errors.OutOfRangeError):
241          self.evaluate(next2())
242        break
243      op2 = self.evaluate(next2())
244
245      op1 = nest.flatten(op1)
246      op2 = nest.flatten(op2)
247      assert len(op1) == len(op2)
248      for i in range(len(op1)):
249        if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]):
250          self.assertValuesEqual(op1[i], op2[i])
251        elif flattened_types[i] == dtypes.string:
252          self.assertAllEqual(op1[i], op2[i])
253        else:
254          self.assertAllClose(op1[i], op2[i])
255
256  def assertDatasetsRaiseSameError(self,
257                                   dataset1,
258                                   dataset2,
259                                   exception_class,
260                                   replacements=None):
261    """Checks that datasets raise the same error on the first get_next call."""
262    if replacements is None:
263      replacements = []
264    next1 = self.getNext(dataset1)
265    next2 = self.getNext(dataset2)
266    try:
267      self.evaluate(next1())
268      raise ValueError(
269          "Expected dataset to raise an error of type %s, but it did not." %
270          repr(exception_class))
271    except exception_class as e:
272      expected_message = e.message
273      for old, new, count in replacements:
274        expected_message = expected_message.replace(old, new, count)
275      # Check that the first segment of the error messages are the same.
276      with self.assertRaisesRegexp(exception_class,
277                                   re.escape(expected_message)):
278        self.evaluate(next2())
279
280  def structuredDataset(self, dataset_structure, shape=None,
281                        dtype=dtypes.int64):
282    """Returns a singleton dataset with the given structure."""
283    if shape is None:
284      shape = []
285    if dataset_structure is None:
286      return dataset_ops.Dataset.from_tensors(
287          array_ops.zeros(shape, dtype=dtype))
288    else:
289      return dataset_ops.Dataset.zip(
290          tuple([
291              self.structuredDataset(substructure, shape, dtype)
292              for substructure in dataset_structure
293          ]))
294
295  def graphRoundTrip(self, dataset, allow_stateful=False):
296    """Converts a dataset to a graph and back."""
297    graph = gen_dataset_ops.dataset_to_graph(
298        dataset._variant_tensor, allow_stateful=allow_stateful)  # pylint: disable=protected-access
299    return dataset_ops.from_variant(
300        gen_experimental_dataset_ops.dataset_from_graph(graph),
301        dataset.element_spec)
302
303  def structuredElement(self, element_structure, shape=None,
304                        dtype=dtypes.int64):
305    """Returns an element with the given structure."""
306    if shape is None:
307      shape = []
308    if element_structure is None:
309      return array_ops.zeros(shape, dtype=dtype)
310    else:
311      return tuple([
312          self.structuredElement(substructure, shape, dtype)
313          for substructure in element_structure
314      ])
315
316  def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements):
317    """Tests whether a dataset produces its elements deterministically.
318
319    `dataset_fn` takes a delay_ms argument, which tells it how long to delay
320    production of the first dataset element. This gives us a way to trigger
321    out-of-order production of dataset elements.
322
323    Args:
324      dataset_fn: A function taking a delay_ms argument.
325      expect_determinism: Whether to expect deterministic ordering.
326      expected_elements: The elements expected to be produced by the dataset,
327        assuming the dataset produces elements in deterministic order.
328    """
329    if expect_determinism:
330      dataset = dataset_fn(100)
331      actual = self.getDatasetOutput(dataset)
332      self.assertAllEqual(expected_elements, actual)
333      return
334
335    # We consider the test a success if it succeeds under any delay_ms. The
336    # delay_ms needed to observe non-deterministic ordering varies across
337    # test machines. Usually 10 or 100 milliseconds is enough, but on slow
338    # machines it could take longer.
339    for delay_ms in [10, 100, 1000, 20000]:
340      dataset = dataset_fn(delay_ms)
341      actual = self.getDatasetOutput(dataset)
342      self.assertCountEqual(expected_elements, actual)
343      for i in range(len(actual)):
344        if actual[i] != expected_elements[i]:
345          return
346    self.fail("Failed to observe nondeterministic ordering")
347