• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities for tf.data functionality."""
16import os
17import random
18import re
19
20from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
21from tensorflow.python.data.experimental.ops import random_access
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import structure
25from tensorflow.python.eager import context
26from tensorflow.python.framework import combinations
27from tensorflow.python.framework import config
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import gen_dataset_ops
35from tensorflow.python.ops import gen_experimental_dataset_ops
36from tensorflow.python.ops import lookup_ops
37from tensorflow.python.ops import tensor_array_ops
38from tensorflow.python.ops.ragged import ragged_tensor
39from tensorflow.python.platform import test
40
41
42def default_test_combinations():
43  """Returns the default test combinations for tf.data tests."""
44  return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"])
45
46
47def eager_only_combinations():
48  """Returns the default test combinations for eager mode only tf.data tests."""
49  return combinations.combine(tf_api_version=[1, 2], mode="eager")
50
51
52def graph_only_combinations():
53  """Returns the default test combinations for graph mode only tf.data tests."""
54  return combinations.combine(tf_api_version=[1, 2], mode="graph")
55
56
57def v1_only_combinations():
58  """Returns the default test combinations for v1 only tf.data tests."""
59  return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
60
61
62def v2_only_combinations():
63  """Returns the default test combinations for v2 only tf.data tests."""
64  return combinations.combine(tf_api_version=2, mode=["eager", "graph"])
65
66
67def v2_eager_only_combinations():
68  """Returns the default test combinations for v2 eager only tf.data tests."""
69  return combinations.combine(tf_api_version=2, mode="eager")
70
71
72class DatasetTestBase(test.TestCase):
73  """Base class for dataset tests."""
74
75  def assert_op_cancelled(self, op):
76    with self.assertRaises(errors.CancelledError):
77      self.evaluate(op)
78
79  def assertValuesEqual(self, expected, actual):
80    """Asserts that two values are equal."""
81    if isinstance(expected, dict):
82      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
83      for k in expected.keys():
84        self.assertValuesEqual(expected[k], actual[k])
85    elif sparse_tensor.is_sparse(expected):
86      self.assertAllEqual(expected.indices, actual.indices)
87      self.assertAllEqual(expected.values, actual.values)
88      self.assertAllEqual(expected.dense_shape, actual.dense_shape)
89    else:
90      self.assertAllEqual(expected, actual)
91
92  def getNext(self, dataset, requires_initialization=False, shared_name=None):
93    """Returns a callable that returns the next element of the dataset.
94
95    Example use:
96    ```python
97    # In both graph and eager modes
98    dataset = ...
99    get_next = self.getNext(dataset)
100    result = self.evaluate(get_next())
101    ```
102
103    Args:
104      dataset: A dataset whose elements will be returned.
105      requires_initialization: Indicates that when the test is executed in graph
106        mode, it should use an initializable iterator to iterate through the
107        dataset (e.g. when it contains stateful nodes). Defaults to False.
108      shared_name: (Optional.) If non-empty, the returned iterator will be
109        shared under the given name across multiple sessions that share the same
110        devices (e.g. when using a remote server).
111    Returns:
112      A callable that returns the next element of `dataset`. Any `TensorArray`
113      objects `dataset` outputs are stacked.
114    """
115    def ta_wrapper(gn):
116      def _wrapper():
117        r = gn()
118        if isinstance(r, tensor_array_ops.TensorArray):
119          return r.stack()
120        else:
121          return r
122      return _wrapper
123
124    # Create an anonymous iterator if we are in eager-mode or are graph inside
125    # of a tf.function.
126    if context.executing_eagerly() or ops.inside_function():
127      iterator = iter(dataset)
128      return ta_wrapper(iterator._next_internal)  # pylint: disable=protected-access
129    else:
130      if requires_initialization:
131        iterator = dataset_ops.make_initializable_iterator(dataset, shared_name)
132        self.evaluate(iterator.initializer)
133      else:
134        iterator = dataset_ops.make_one_shot_iterator(dataset)
135      get_next = iterator.get_next()
136      return ta_wrapper(lambda: get_next)
137
138  def _compareOutputToExpected(self, result_values, expected_values,
139                               assert_items_equal):
140    if assert_items_equal:
141      # TODO(shivaniagrawal): add support for nested elements containing sparse
142      # tensors when needed.
143      self.assertItemsEqual(result_values, expected_values)
144      return
145    for i in range(len(result_values)):
146      nest.assert_same_structure(result_values[i], expected_values[i])
147      for result_value, expected_value in zip(
148          nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
149        self.assertValuesEqual(expected_value, result_value)
150
151  def getDatasetOutput(self, dataset, requires_initialization=False):
152    get_next = self.getNext(
153        dataset, requires_initialization=requires_initialization)
154    return self.getIteratorOutput(get_next)
155
156  def getIteratorOutput(self, get_next):
157    """Evaluates `get_next` until end of input, returning the results."""
158    results = []
159    while True:
160      try:
161        results.append(self.evaluate(get_next()))
162      except errors.OutOfRangeError:
163        break
164    return results
165
166  def assertDatasetProduces(self,
167                            dataset,
168                            expected_output=None,
169                            expected_shapes=None,
170                            expected_error=None,
171                            requires_initialization=False,
172                            num_test_iterations=1,
173                            assert_items_equal=False,
174                            expected_error_iter=1):
175    """Asserts that a dataset produces the expected output / error.
176
177    Args:
178      dataset: A dataset to check for the expected output / error.
179      expected_output: A list of elements that the dataset is expected to
180        produce.
181      expected_shapes: A list of TensorShapes which is expected to match
182        output_shapes of dataset.
183      expected_error: A tuple `(type, predicate)` identifying the expected error
184        `dataset` should raise. The `type` should match the expected exception
185        type, while `predicate` should either be 1) a unary function that inputs
186        the raised exception and returns a boolean indicator of success or 2) a
187        regular expression that is expected to match the error message
188        partially.
189      requires_initialization: Indicates that when the test is executed in graph
190        mode, it should use an initializable iterator to iterate through the
191        dataset (e.g. when it contains stateful nodes). Defaults to False.
192      num_test_iterations: Number of times `dataset` will be iterated. Defaults
193        to 1.
194      assert_items_equal: Tests expected_output has (only) the same elements
195        regardless of order.
196      expected_error_iter: How many times to iterate before expecting an error,
197        if an error is expected.
198    """
199    self.assertTrue(
200        expected_error is not None or expected_output is not None,
201        "Exactly one of expected_output or expected error should be provided.")
202    if expected_error:
203      self.assertTrue(
204          expected_output is None,
205          "Exactly one of expected_output or expected error should be provided."
206      )
207      with self.assertRaisesWithPredicateMatch(expected_error[0],
208                                               expected_error[1]):
209        get_next = self.getNext(
210            dataset, requires_initialization=requires_initialization)
211        for _ in range(expected_error_iter):
212          self.evaluate(get_next())
213      return
214    if expected_shapes:
215      self.assertEqual(expected_shapes,
216                       dataset_ops.get_legacy_output_shapes(dataset))
217    self.assertGreater(num_test_iterations, 0)
218    for _ in range(num_test_iterations):
219      get_next = self.getNext(
220          dataset, requires_initialization=requires_initialization)
221      result = []
222      for _ in range(len(expected_output)):
223        try:
224          result.append(self.evaluate(get_next()))
225        except errors.OutOfRangeError:
226          raise AssertionError(
227              "Dataset ended early, producing %d elements out of %d. "
228              "Dataset output: %s" %
229              (len(result), len(expected_output), str(result)))
230      self._compareOutputToExpected(result, expected_output, assert_items_equal)
231      with self.assertRaises(errors.OutOfRangeError):
232        self.evaluate(get_next())
233      with self.assertRaises(errors.OutOfRangeError):
234        self.evaluate(get_next())
235
236  def assertDatasetsEqual(self, dataset1, dataset2):
237    """Checks that datasets are equal. Supports both graph and eager mode."""
238    self.assertTrue(
239        structure.are_compatible(
240            dataset_ops.get_structure(dataset1),
241            dataset_ops.get_structure(dataset2)))
242
243    flattened_types = nest.flatten(
244        dataset_ops.get_legacy_output_types(dataset1))
245
246    next1 = self.getNext(dataset1)
247    next2 = self.getNext(dataset2)
248
249    while True:
250      try:
251        op1 = self.evaluate(next1())
252      except errors.OutOfRangeError:
253        with self.assertRaises(errors.OutOfRangeError):
254          self.evaluate(next2())
255        break
256      op2 = self.evaluate(next2())
257
258      op1 = nest.flatten(op1)
259      op2 = nest.flatten(op2)
260      assert len(op1) == len(op2)
261      for i in range(len(op1)):
262        if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]):
263          self.assertValuesEqual(op1[i], op2[i])
264        elif flattened_types[i] == dtypes.string:
265          self.assertAllEqual(op1[i], op2[i])
266        else:
267          self.assertAllClose(op1[i], op2[i])
268
269  def assertDatasetsRaiseSameError(self,
270                                   dataset1,
271                                   dataset2,
272                                   exception_class,
273                                   replacements=None):
274    """Checks that datasets raise the same error on the first get_next call."""
275    if replacements is None:
276      replacements = []
277    next1 = self.getNext(dataset1)
278    next2 = self.getNext(dataset2)
279    try:
280      self.evaluate(next1())
281      raise ValueError(
282          "Expected dataset to raise an error of type %s, but it did not." %
283          repr(exception_class))
284    except exception_class as e:
285      expected_message = e.message
286      for old, new, count in replacements:
287        expected_message = expected_message.replace(old, new, count)
288      # Check that the first segment of the error messages are the same.
289      with self.assertRaisesRegexp(exception_class,
290                                   re.escape(expected_message)):
291        self.evaluate(next2())
292
293  def structuredDataset(self, dataset_structure, shape=None,
294                        dtype=dtypes.int64):
295    """Returns a singleton dataset with the given structure."""
296    if shape is None:
297      shape = []
298    if dataset_structure is None:
299      return dataset_ops.Dataset.from_tensors(
300          array_ops.zeros(shape, dtype=dtype))
301    else:
302      return dataset_ops.Dataset.zip(
303          tuple([
304              self.structuredDataset(substructure, shape, dtype)
305              for substructure in dataset_structure
306          ]))
307
308  def verifyRandomAccess(self, dataset, expected):
309    self.verifyRandomAccessInfiniteCardinality(dataset, expected)
310    with self.assertRaises(errors.OutOfRangeError):
311      self.evaluate(random_access.at(dataset, index=len(expected)))
312
313  def verifyRandomAccessInfiniteCardinality(self, dataset, expected):
314    """Tests randomly accessing elements of a dataset."""
315    # Tests accessing the elements in a shuffled order with repeats.
316    len_expected = len(expected)
317    indices = list(range(len_expected)) * 2
318    random.shuffle(indices)
319    for i in indices:
320      self.assertAllEqual(expected[i],
321                          self.evaluate(random_access.at(dataset, i)))
322
323    # Tests accessing the elements in order.
324    indices = set(sorted(indices))
325    for i in indices:
326      self.assertAllEqual(expected[i],
327                          self.evaluate(random_access.at(dataset, i)))
328
329  def textFileInitializer(self, vals):
330    file = os.path.join(self.get_temp_dir(), "text_file_initializer")
331    with open(file, "w") as f:
332      f.write("\n".join(str(v) for v in vals) + "\n")
333    return lookup_ops.TextFileInitializer(file, dtypes.int64,
334                                          lookup_ops.TextFileIndex.LINE_NUMBER,
335                                          dtypes.int64,
336                                          lookup_ops.TextFileIndex.WHOLE_LINE)
337
338  def keyValueTensorInitializer(self, vals):
339    keys_tensor = constant_op.constant(
340        list(range(len(vals))), dtype=dtypes.int64)
341    vals_tensor = constant_op.constant(vals)
342    return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor)
343
344  def datasetInitializer(self, vals):
345    keys = dataset_ops.Dataset.range(len(vals))
346    values = dataset_ops.Dataset.from_tensor_slices(vals)
347    ds = dataset_ops.Dataset.zip((keys, values))
348    return data_lookup_ops.DatasetInitializer(ds)
349
350  def lookupTableInitializer(self, init_source, vals):
351    """Returns a lookup table initializer for the given source and values.
352
353    Args:
354      init_source: One of ["textfile", "keyvalue", "dataset"], indicating what
355        type of initializer to use.
356      vals: The initializer values. The keys will be `range(len(vals))`.
357    """
358    if init_source == "textfile":
359      return self.textFileInitializer(vals)
360    elif init_source == "keyvaluetensor":
361      return self.keyValueTensorInitializer(vals)
362    elif init_source == "dataset":
363      return self.datasetInitializer(vals)
364    else:
365      raise ValueError("Unrecognized init_source: " + init_source)
366
367  def graphRoundTrip(self, dataset, allow_stateful=False):
368    """Converts a dataset to a graph and back."""
369    graph = gen_dataset_ops.dataset_to_graph(
370        dataset._variant_tensor, allow_stateful=allow_stateful)  # pylint: disable=protected-access
371    return dataset_ops.from_variant(
372        gen_experimental_dataset_ops.dataset_from_graph(graph),
373        dataset.element_spec)
374
375  def structuredElement(self, element_structure, shape=None,
376                        dtype=dtypes.int64):
377    """Returns an element with the given structure."""
378    if shape is None:
379      shape = []
380    if element_structure is None:
381      return array_ops.zeros(shape, dtype=dtype)
382    else:
383      return tuple([
384          self.structuredElement(substructure, shape, dtype)
385          for substructure in element_structure
386      ])
387
388  def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements):
389    """Tests whether a dataset produces its elements deterministically.
390
391    `dataset_fn` takes a delay_ms argument, which tells it how long to delay
392    production of the first dataset element. This gives us a way to trigger
393    out-of-order production of dataset elements.
394
395    Args:
396      dataset_fn: A function taking a delay_ms argument.
397      expect_determinism: Whether to expect deterministic ordering.
398      expected_elements: The elements expected to be produced by the dataset,
399        assuming the dataset produces elements in deterministic order.
400    """
401    if expect_determinism:
402      dataset = dataset_fn(100)
403      actual = self.getDatasetOutput(dataset)
404      self.assertAllEqual(expected_elements, actual)
405      return
406
407    # We consider the test a success if it succeeds under any delay_ms. The
408    # delay_ms needed to observe non-deterministic ordering varies across
409    # test machines. Usually 10 or 100 milliseconds is enough, but on slow
410    # machines it could take longer.
411    for delay_ms in [10, 100, 1000, 20000, 100000]:
412      dataset = dataset_fn(delay_ms)
413      actual = self.getDatasetOutput(dataset)
414      self.assertCountEqual(expected_elements, actual)
415      for i in range(len(actual)):
416        if actual[i] != expected_elements[i]:
417          return
418    self.fail("Failed to observe nondeterministic ordering")
419
420  def configureDevicesForMultiDeviceTest(self, num_devices):
421    """Configures number of logical devices for multi-device tests.
422
423    It returns a list of device names. If invoked in GPU-enabled runtime, the
424    last device name will be for a GPU device. Otherwise, all device names will
425    be for a CPU device.
426
427    Args:
428      num_devices: The number of devices to configure.
429
430    Returns:
431      A list of device names to use for a multi-device test.
432    """
433    cpus = config.list_physical_devices("CPU")
434    gpus = config.list_physical_devices("GPU")
435    config.set_logical_device_configuration(cpus[0], [
436        context.LogicalDeviceConfiguration() for _ in range(num_devices)
437    ])
438    devices = ["/device:CPU:" + str(i) for i in range(num_devices - 1)]
439    if gpus:
440      devices.append("/device:GPU:0")
441    else:
442      devices.append("/device:CPU:" + str(num_devices - 1))
443    return devices
444