• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities for tf.data functionality."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import re
22
23from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.data.util import nest
26from tensorflow.python.data.util import structure
27from tensorflow.python.eager import context
28from tensorflow.python.framework import combinations
29from tensorflow.python.framework import config
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import gen_dataset_ops
37from tensorflow.python.ops import gen_experimental_dataset_ops
38from tensorflow.python.ops import lookup_ops
39from tensorflow.python.ops import tensor_array_ops
40from tensorflow.python.ops.ragged import ragged_tensor
41from tensorflow.python.platform import test
42
43
44def default_test_combinations():
45  """Returns the default test combinations for tf.data tests."""
46  return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"])
47
48
49def eager_only_combinations():
50  """Returns the default test combinations for eager mode only tf.data tests."""
51  return combinations.combine(tf_api_version=[1, 2], mode="eager")
52
53
54def graph_only_combinations():
55  """Returns the default test combinations for graph mode only tf.data tests."""
56  return combinations.combine(tf_api_version=[1, 2], mode="graph")
57
58
59def v1_only_combinations():
60  """Returns the default test combinations for v1 only tf.data tests."""
61  return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
62
63
64def v2_only_combinations():
65  """Returns the default test combinations for v2 only tf.data tests."""
66  return combinations.combine(tf_api_version=2, mode=["eager", "graph"])
67
68
69def v2_eager_only_combinations():
70  """Returns the default test combinations for v2 eager only tf.data tests."""
71  return combinations.combine(tf_api_version=2, mode="eager")
72
73
74class DatasetTestBase(test.TestCase):
75  """Base class for dataset tests."""
76
77  def assert_op_cancelled(self, op):
78    with self.assertRaises(errors.CancelledError):
79      self.evaluate(op)
80
81  def assertValuesEqual(self, expected, actual):
82    """Asserts that two values are equal."""
83    if isinstance(expected, dict):
84      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
85      for k in expected.keys():
86        self.assertValuesEqual(expected[k], actual[k])
87    elif sparse_tensor.is_sparse(expected):
88      self.assertAllEqual(expected.indices, actual.indices)
89      self.assertAllEqual(expected.values, actual.values)
90      self.assertAllEqual(expected.dense_shape, actual.dense_shape)
91    else:
92      self.assertAllEqual(expected, actual)
93
94  def getNext(self, dataset, requires_initialization=False, shared_name=None):
95    """Returns a callable that returns the next element of the dataset.
96
97    Example use:
98    ```python
99    # In both graph and eager modes
100    dataset = ...
101    get_next = self.getNext(dataset)
102    result = self.evaluate(get_next())
103    ```
104
105    Args:
106      dataset: A dataset whose elements will be returned.
107      requires_initialization: Indicates that when the test is executed in graph
108        mode, it should use an initializable iterator to iterate through the
109        dataset (e.g. when it contains stateful nodes). Defaults to False.
110      shared_name: (Optional.) If non-empty, the returned iterator will be
111        shared under the given name across multiple sessions that share the same
112        devices (e.g. when using a remote server).
113    Returns:
114      A callable that returns the next element of `dataset`. Any `TensorArray`
115      objects `dataset` outputs are stacked.
116    """
117    def ta_wrapper(gn):
118      def _wrapper():
119        r = gn()
120        if isinstance(r, tensor_array_ops.TensorArray):
121          return r.stack()
122        else:
123          return r
124      return _wrapper
125
126    # Create an anonymous iterator if we are in eager-mode or are graph inside
127    # of a tf.function.
128    if context.executing_eagerly() or ops.inside_function():
129      iterator = iter(dataset)
130      return ta_wrapper(iterator._next_internal)  # pylint: disable=protected-access
131    else:
132      if requires_initialization:
133        iterator = dataset_ops.make_initializable_iterator(dataset, shared_name)
134        self.evaluate(iterator.initializer)
135      else:
136        iterator = dataset_ops.make_one_shot_iterator(dataset)
137      get_next = iterator.get_next()
138      return ta_wrapper(lambda: get_next)
139
140  def _compareOutputToExpected(self, result_values, expected_values,
141                               assert_items_equal):
142    if assert_items_equal:
143      # TODO(shivaniagrawal): add support for nested elements containing sparse
144      # tensors when needed.
145      self.assertItemsEqual(result_values, expected_values)
146      return
147    for i in range(len(result_values)):
148      nest.assert_same_structure(result_values[i], expected_values[i])
149      for result_value, expected_value in zip(
150          nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
151        self.assertValuesEqual(expected_value, result_value)
152
153  def getDatasetOutput(self, dataset, requires_initialization=False):
154    get_next = self.getNext(
155        dataset, requires_initialization=requires_initialization)
156    return self.getIteratorOutput(get_next)
157
158  def getIteratorOutput(self, get_next):
159    """Evaluates `get_next` until end of input, returning the results."""
160    results = []
161    while True:
162      try:
163        results.append(self.evaluate(get_next()))
164      except errors.OutOfRangeError:
165        break
166    return results
167
168  def assertDatasetProduces(self,
169                            dataset,
170                            expected_output=None,
171                            expected_shapes=None,
172                            expected_error=None,
173                            requires_initialization=False,
174                            num_test_iterations=1,
175                            assert_items_equal=False,
176                            expected_error_iter=1):
177    """Asserts that a dataset produces the expected output / error.
178
179    Args:
180      dataset: A dataset to check for the expected output / error.
181      expected_output: A list of elements that the dataset is expected to
182        produce.
183      expected_shapes: A list of TensorShapes which is expected to match
184        output_shapes of dataset.
185      expected_error: A tuple `(type, predicate)` identifying the expected error
186        `dataset` should raise. The `type` should match the expected exception
187        type, while `predicate` should either be 1) a unary function that inputs
188        the raised exception and returns a boolean indicator of success or 2) a
189        regular expression that is expected to match the error message
190        partially.
191      requires_initialization: Indicates that when the test is executed in graph
192        mode, it should use an initializable iterator to iterate through the
193        dataset (e.g. when it contains stateful nodes). Defaults to False.
194      num_test_iterations: Number of times `dataset` will be iterated. Defaults
195        to 1.
196      assert_items_equal: Tests expected_output has (only) the same elements
197        regardless of order.
198      expected_error_iter: How many times to iterate before expecting an error,
199        if an error is expected.
200    """
201    self.assertTrue(
202        expected_error is not None or expected_output is not None,
203        "Exactly one of expected_output or expected error should be provided.")
204    if expected_error:
205      self.assertTrue(
206          expected_output is None,
207          "Exactly one of expected_output or expected error should be provided."
208      )
209      with self.assertRaisesWithPredicateMatch(expected_error[0],
210                                               expected_error[1]):
211        get_next = self.getNext(
212            dataset, requires_initialization=requires_initialization)
213        for _ in range(expected_error_iter):
214          self.evaluate(get_next())
215      return
216    if expected_shapes:
217      self.assertEqual(expected_shapes,
218                       dataset_ops.get_legacy_output_shapes(dataset))
219    self.assertGreater(num_test_iterations, 0)
220    for _ in range(num_test_iterations):
221      get_next = self.getNext(
222          dataset, requires_initialization=requires_initialization)
223      result = []
224      for _ in range(len(expected_output)):
225        try:
226          result.append(self.evaluate(get_next()))
227        except errors.OutOfRangeError:
228          raise AssertionError(
229              "Dataset ended early, producing %d elements out of %d. "
230              "Dataset output: %s" %
231              (len(result), len(expected_output), str(result)))
232      self._compareOutputToExpected(result, expected_output, assert_items_equal)
233      with self.assertRaises(errors.OutOfRangeError):
234        self.evaluate(get_next())
235      with self.assertRaises(errors.OutOfRangeError):
236        self.evaluate(get_next())
237
238  def assertDatasetsEqual(self, dataset1, dataset2):
239    """Checks that datasets are equal. Supports both graph and eager mode."""
240    self.assertTrue(
241        structure.are_compatible(
242            dataset_ops.get_structure(dataset1),
243            dataset_ops.get_structure(dataset2)))
244
245    flattened_types = nest.flatten(
246        dataset_ops.get_legacy_output_types(dataset1))
247
248    next1 = self.getNext(dataset1)
249    next2 = self.getNext(dataset2)
250
251    while True:
252      try:
253        op1 = self.evaluate(next1())
254      except errors.OutOfRangeError:
255        with self.assertRaises(errors.OutOfRangeError):
256          self.evaluate(next2())
257        break
258      op2 = self.evaluate(next2())
259
260      op1 = nest.flatten(op1)
261      op2 = nest.flatten(op2)
262      assert len(op1) == len(op2)
263      for i in range(len(op1)):
264        if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]):
265          self.assertValuesEqual(op1[i], op2[i])
266        elif flattened_types[i] == dtypes.string:
267          self.assertAllEqual(op1[i], op2[i])
268        else:
269          self.assertAllClose(op1[i], op2[i])
270
271  def assertDatasetsRaiseSameError(self,
272                                   dataset1,
273                                   dataset2,
274                                   exception_class,
275                                   replacements=None):
276    """Checks that datasets raise the same error on the first get_next call."""
277    if replacements is None:
278      replacements = []
279    next1 = self.getNext(dataset1)
280    next2 = self.getNext(dataset2)
281    try:
282      self.evaluate(next1())
283      raise ValueError(
284          "Expected dataset to raise an error of type %s, but it did not." %
285          repr(exception_class))
286    except exception_class as e:
287      expected_message = e.message
288      for old, new, count in replacements:
289        expected_message = expected_message.replace(old, new, count)
290      # Check that the first segment of the error messages are the same.
291      with self.assertRaisesRegexp(exception_class,
292                                   re.escape(expected_message)):
293        self.evaluate(next2())
294
295  def structuredDataset(self, dataset_structure, shape=None,
296                        dtype=dtypes.int64):
297    """Returns a singleton dataset with the given structure."""
298    if shape is None:
299      shape = []
300    if dataset_structure is None:
301      return dataset_ops.Dataset.from_tensors(
302          array_ops.zeros(shape, dtype=dtype))
303    else:
304      return dataset_ops.Dataset.zip(
305          tuple([
306              self.structuredDataset(substructure, shape, dtype)
307              for substructure in dataset_structure
308          ]))
309
310  def textFileInitializer(self, vals):
311    file = os.path.join(self.get_temp_dir(), "text_file_initializer")
312    with open(file, "w") as f:
313      f.write("\n".join(str(v) for v in vals) + "\n")
314    return lookup_ops.TextFileInitializer(file, dtypes.int64,
315                                          lookup_ops.TextFileIndex.LINE_NUMBER,
316                                          dtypes.int64,
317                                          lookup_ops.TextFileIndex.WHOLE_LINE)
318
319  def keyValueTensorInitializer(self, vals):
320    keys_tensor = constant_op.constant(
321        list(range(len(vals))), dtype=dtypes.int64)
322    vals_tensor = constant_op.constant(vals)
323    return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor)
324
325  def datasetInitializer(self, vals):
326    keys = dataset_ops.Dataset.range(len(vals))
327    values = dataset_ops.Dataset.from_tensor_slices(vals)
328    ds = dataset_ops.Dataset.zip((keys, values))
329    return data_lookup_ops.DatasetInitializer(ds)
330
331  def lookupTableInitializer(self, init_source, vals):
332    """Returns a lookup table initializer for the given source and values.
333
334    Args:
335      init_source: One of ["textfile", "keyvalue", "dataset"], indicating what
336        type of initializer to use.
337      vals: The initializer values. The keys will be `range(len(vals))`.
338    """
339    if init_source == "textfile":
340      return self.textFileInitializer(vals)
341    elif init_source == "keyvaluetensor":
342      return self.keyValueTensorInitializer(vals)
343    elif init_source == "dataset":
344      return self.datasetInitializer(vals)
345    else:
346      raise ValueError("Unrecognized init_source: " + init_source)
347
348  def graphRoundTrip(self, dataset, allow_stateful=False):
349    """Converts a dataset to a graph and back."""
350    graph = gen_dataset_ops.dataset_to_graph(
351        dataset._variant_tensor, allow_stateful=allow_stateful)  # pylint: disable=protected-access
352    return dataset_ops.from_variant(
353        gen_experimental_dataset_ops.dataset_from_graph(graph),
354        dataset.element_spec)
355
356  def structuredElement(self, element_structure, shape=None,
357                        dtype=dtypes.int64):
358    """Returns an element with the given structure."""
359    if shape is None:
360      shape = []
361    if element_structure is None:
362      return array_ops.zeros(shape, dtype=dtype)
363    else:
364      return tuple([
365          self.structuredElement(substructure, shape, dtype)
366          for substructure in element_structure
367      ])
368
369  def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements):
370    """Tests whether a dataset produces its elements deterministically.
371
372    `dataset_fn` takes a delay_ms argument, which tells it how long to delay
373    production of the first dataset element. This gives us a way to trigger
374    out-of-order production of dataset elements.
375
376    Args:
377      dataset_fn: A function taking a delay_ms argument.
378      expect_determinism: Whether to expect deterministic ordering.
379      expected_elements: The elements expected to be produced by the dataset,
380        assuming the dataset produces elements in deterministic order.
381    """
382    if expect_determinism:
383      dataset = dataset_fn(100)
384      actual = self.getDatasetOutput(dataset)
385      self.assertAllEqual(expected_elements, actual)
386      return
387
388    # We consider the test a success if it succeeds under any delay_ms. The
389    # delay_ms needed to observe non-deterministic ordering varies across
390    # test machines. Usually 10 or 100 milliseconds is enough, but on slow
391    # machines it could take longer.
392    for delay_ms in [10, 100, 1000, 20000, 100000]:
393      dataset = dataset_fn(delay_ms)
394      actual = self.getDatasetOutput(dataset)
395      self.assertCountEqual(expected_elements, actual)
396      for i in range(len(actual)):
397        if actual[i] != expected_elements[i]:
398          return
399    self.fail("Failed to observe nondeterministic ordering")
400
401  def configureDevicesForMultiDeviceTest(self, num_devices):
402    """Configures number of logical devices for multi-device tests.
403
404    It returns a list of device names. If invoked in GPU-enabled runtime, the
405    last device name will be for a GPU device. Otherwise, all device names will
406    be for a CPU device.
407
408    Args:
409      num_devices: The number of devices to configure.
410
411    Returns:
412      A list of device names to use for a multi-device test.
413    """
414    cpus = config.list_physical_devices("CPU")
415    gpus = config.list_physical_devices("GPU")
416    config.set_logical_device_configuration(cpus[0], [
417        context.LogicalDeviceConfiguration() for _ in range(num_devices)
418    ])
419    devices = ["/device:CPU:" + str(i) for i in range(num_devices - 1)]
420    if gpus:
421      devices.append("/device:GPU:0")
422    else:
423      devices.append("/device:CPU:" + str(num_devices - 1))
424    return devices
425