1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utilities for tf.data functionality.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import re 22 23from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.data.util import nest 26from tensorflow.python.data.util import structure 27from tensorflow.python.eager import context 28from tensorflow.python.framework import combinations 29from tensorflow.python.framework import config 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import sparse_tensor 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import gen_dataset_ops 37from tensorflow.python.ops import gen_experimental_dataset_ops 38from tensorflow.python.ops import lookup_ops 39from tensorflow.python.ops import tensor_array_ops 40from tensorflow.python.ops.ragged import ragged_tensor 41from tensorflow.python.platform import test 42 43 44def default_test_combinations(): 45 """Returns the default test combinations for tf.data tests.""" 46 return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"]) 47 48 49def eager_only_combinations(): 50 """Returns the default test combinations for eager mode only tf.data tests.""" 51 return combinations.combine(tf_api_version=[1, 2], mode="eager") 52 53 54def graph_only_combinations(): 55 """Returns the default test combinations for graph mode only tf.data tests.""" 56 return combinations.combine(tf_api_version=[1, 2], mode="graph") 57 58 59def v1_only_combinations(): 60 """Returns the default test combinations for v1 only tf.data tests.""" 61 return combinations.combine(tf_api_version=1, mode=["eager", "graph"]) 62 63 64def v2_only_combinations(): 65 """Returns the default test combinations for v2 only tf.data tests.""" 66 return combinations.combine(tf_api_version=2, mode=["eager", "graph"]) 67 68 69def v2_eager_only_combinations(): 70 """Returns the default test combinations for v2 eager only tf.data tests.""" 71 return combinations.combine(tf_api_version=2, mode="eager") 72 73 74class DatasetTestBase(test.TestCase): 75 """Base class for dataset tests.""" 76 77 def assert_op_cancelled(self, op): 78 with self.assertRaises(errors.CancelledError): 79 self.evaluate(op) 80 81 def assertValuesEqual(self, expected, actual): 82 """Asserts that two values are equal.""" 83 if isinstance(expected, dict): 84 self.assertItemsEqual(list(expected.keys()), list(actual.keys())) 85 for k in expected.keys(): 86 self.assertValuesEqual(expected[k], actual[k]) 87 elif sparse_tensor.is_sparse(expected): 88 self.assertAllEqual(expected.indices, actual.indices) 89 self.assertAllEqual(expected.values, actual.values) 90 self.assertAllEqual(expected.dense_shape, actual.dense_shape) 91 else: 92 self.assertAllEqual(expected, actual) 93 94 def getNext(self, dataset, requires_initialization=False, shared_name=None): 95 """Returns a callable that returns the next element of the dataset. 96 97 Example use: 98 ```python 99 # In both graph and eager modes 100 dataset = ... 101 get_next = self.getNext(dataset) 102 result = self.evaluate(get_next()) 103 ``` 104 105 Args: 106 dataset: A dataset whose elements will be returned. 107 requires_initialization: Indicates that when the test is executed in graph 108 mode, it should use an initializable iterator to iterate through the 109 dataset (e.g. when it contains stateful nodes). Defaults to False. 110 shared_name: (Optional.) If non-empty, the returned iterator will be 111 shared under the given name across multiple sessions that share the same 112 devices (e.g. when using a remote server). 113 Returns: 114 A callable that returns the next element of `dataset`. Any `TensorArray` 115 objects `dataset` outputs are stacked. 116 """ 117 def ta_wrapper(gn): 118 def _wrapper(): 119 r = gn() 120 if isinstance(r, tensor_array_ops.TensorArray): 121 return r.stack() 122 else: 123 return r 124 return _wrapper 125 126 # Create an anonymous iterator if we are in eager-mode or are graph inside 127 # of a tf.function. 128 if context.executing_eagerly() or ops.inside_function(): 129 iterator = iter(dataset) 130 return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access 131 else: 132 if requires_initialization: 133 iterator = dataset_ops.make_initializable_iterator(dataset, shared_name) 134 self.evaluate(iterator.initializer) 135 else: 136 iterator = dataset_ops.make_one_shot_iterator(dataset) 137 get_next = iterator.get_next() 138 return ta_wrapper(lambda: get_next) 139 140 def _compareOutputToExpected(self, result_values, expected_values, 141 assert_items_equal): 142 if assert_items_equal: 143 # TODO(shivaniagrawal): add support for nested elements containing sparse 144 # tensors when needed. 145 self.assertItemsEqual(result_values, expected_values) 146 return 147 for i in range(len(result_values)): 148 nest.assert_same_structure(result_values[i], expected_values[i]) 149 for result_value, expected_value in zip( 150 nest.flatten(result_values[i]), nest.flatten(expected_values[i])): 151 self.assertValuesEqual(expected_value, result_value) 152 153 def getDatasetOutput(self, dataset, requires_initialization=False): 154 get_next = self.getNext( 155 dataset, requires_initialization=requires_initialization) 156 return self.getIteratorOutput(get_next) 157 158 def getIteratorOutput(self, get_next): 159 """Evaluates `get_next` until end of input, returning the results.""" 160 results = [] 161 while True: 162 try: 163 results.append(self.evaluate(get_next())) 164 except errors.OutOfRangeError: 165 break 166 return results 167 168 def assertDatasetProduces(self, 169 dataset, 170 expected_output=None, 171 expected_shapes=None, 172 expected_error=None, 173 requires_initialization=False, 174 num_test_iterations=1, 175 assert_items_equal=False, 176 expected_error_iter=1): 177 """Asserts that a dataset produces the expected output / error. 178 179 Args: 180 dataset: A dataset to check for the expected output / error. 181 expected_output: A list of elements that the dataset is expected to 182 produce. 183 expected_shapes: A list of TensorShapes which is expected to match 184 output_shapes of dataset. 185 expected_error: A tuple `(type, predicate)` identifying the expected error 186 `dataset` should raise. The `type` should match the expected exception 187 type, while `predicate` should either be 1) a unary function that inputs 188 the raised exception and returns a boolean indicator of success or 2) a 189 regular expression that is expected to match the error message 190 partially. 191 requires_initialization: Indicates that when the test is executed in graph 192 mode, it should use an initializable iterator to iterate through the 193 dataset (e.g. when it contains stateful nodes). Defaults to False. 194 num_test_iterations: Number of times `dataset` will be iterated. Defaults 195 to 1. 196 assert_items_equal: Tests expected_output has (only) the same elements 197 regardless of order. 198 expected_error_iter: How many times to iterate before expecting an error, 199 if an error is expected. 200 """ 201 self.assertTrue( 202 expected_error is not None or expected_output is not None, 203 "Exactly one of expected_output or expected error should be provided.") 204 if expected_error: 205 self.assertTrue( 206 expected_output is None, 207 "Exactly one of expected_output or expected error should be provided." 208 ) 209 with self.assertRaisesWithPredicateMatch(expected_error[0], 210 expected_error[1]): 211 get_next = self.getNext( 212 dataset, requires_initialization=requires_initialization) 213 for _ in range(expected_error_iter): 214 self.evaluate(get_next()) 215 return 216 if expected_shapes: 217 self.assertEqual(expected_shapes, 218 dataset_ops.get_legacy_output_shapes(dataset)) 219 self.assertGreater(num_test_iterations, 0) 220 for _ in range(num_test_iterations): 221 get_next = self.getNext( 222 dataset, requires_initialization=requires_initialization) 223 result = [] 224 for _ in range(len(expected_output)): 225 try: 226 result.append(self.evaluate(get_next())) 227 except errors.OutOfRangeError: 228 raise AssertionError( 229 "Dataset ended early, producing %d elements out of %d. " 230 "Dataset output: %s" % 231 (len(result), len(expected_output), str(result))) 232 self._compareOutputToExpected(result, expected_output, assert_items_equal) 233 with self.assertRaises(errors.OutOfRangeError): 234 self.evaluate(get_next()) 235 with self.assertRaises(errors.OutOfRangeError): 236 self.evaluate(get_next()) 237 238 def assertDatasetsEqual(self, dataset1, dataset2): 239 """Checks that datasets are equal. Supports both graph and eager mode.""" 240 self.assertTrue( 241 structure.are_compatible( 242 dataset_ops.get_structure(dataset1), 243 dataset_ops.get_structure(dataset2))) 244 245 flattened_types = nest.flatten( 246 dataset_ops.get_legacy_output_types(dataset1)) 247 248 next1 = self.getNext(dataset1) 249 next2 = self.getNext(dataset2) 250 251 while True: 252 try: 253 op1 = self.evaluate(next1()) 254 except errors.OutOfRangeError: 255 with self.assertRaises(errors.OutOfRangeError): 256 self.evaluate(next2()) 257 break 258 op2 = self.evaluate(next2()) 259 260 op1 = nest.flatten(op1) 261 op2 = nest.flatten(op2) 262 assert len(op1) == len(op2) 263 for i in range(len(op1)): 264 if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]): 265 self.assertValuesEqual(op1[i], op2[i]) 266 elif flattened_types[i] == dtypes.string: 267 self.assertAllEqual(op1[i], op2[i]) 268 else: 269 self.assertAllClose(op1[i], op2[i]) 270 271 def assertDatasetsRaiseSameError(self, 272 dataset1, 273 dataset2, 274 exception_class, 275 replacements=None): 276 """Checks that datasets raise the same error on the first get_next call.""" 277 if replacements is None: 278 replacements = [] 279 next1 = self.getNext(dataset1) 280 next2 = self.getNext(dataset2) 281 try: 282 self.evaluate(next1()) 283 raise ValueError( 284 "Expected dataset to raise an error of type %s, but it did not." % 285 repr(exception_class)) 286 except exception_class as e: 287 expected_message = e.message 288 for old, new, count in replacements: 289 expected_message = expected_message.replace(old, new, count) 290 # Check that the first segment of the error messages are the same. 291 with self.assertRaisesRegexp(exception_class, 292 re.escape(expected_message)): 293 self.evaluate(next2()) 294 295 def structuredDataset(self, dataset_structure, shape=None, 296 dtype=dtypes.int64): 297 """Returns a singleton dataset with the given structure.""" 298 if shape is None: 299 shape = [] 300 if dataset_structure is None: 301 return dataset_ops.Dataset.from_tensors( 302 array_ops.zeros(shape, dtype=dtype)) 303 else: 304 return dataset_ops.Dataset.zip( 305 tuple([ 306 self.structuredDataset(substructure, shape, dtype) 307 for substructure in dataset_structure 308 ])) 309 310 def textFileInitializer(self, vals): 311 file = os.path.join(self.get_temp_dir(), "text_file_initializer") 312 with open(file, "w") as f: 313 f.write("\n".join(str(v) for v in vals) + "\n") 314 return lookup_ops.TextFileInitializer(file, dtypes.int64, 315 lookup_ops.TextFileIndex.LINE_NUMBER, 316 dtypes.int64, 317 lookup_ops.TextFileIndex.WHOLE_LINE) 318 319 def keyValueTensorInitializer(self, vals): 320 keys_tensor = constant_op.constant( 321 list(range(len(vals))), dtype=dtypes.int64) 322 vals_tensor = constant_op.constant(vals) 323 return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor) 324 325 def datasetInitializer(self, vals): 326 keys = dataset_ops.Dataset.range(len(vals)) 327 values = dataset_ops.Dataset.from_tensor_slices(vals) 328 ds = dataset_ops.Dataset.zip((keys, values)) 329 return data_lookup_ops.DatasetInitializer(ds) 330 331 def lookupTableInitializer(self, init_source, vals): 332 """Returns a lookup table initializer for the given source and values. 333 334 Args: 335 init_source: One of ["textfile", "keyvalue", "dataset"], indicating what 336 type of initializer to use. 337 vals: The initializer values. The keys will be `range(len(vals))`. 338 """ 339 if init_source == "textfile": 340 return self.textFileInitializer(vals) 341 elif init_source == "keyvaluetensor": 342 return self.keyValueTensorInitializer(vals) 343 elif init_source == "dataset": 344 return self.datasetInitializer(vals) 345 else: 346 raise ValueError("Unrecognized init_source: " + init_source) 347 348 def graphRoundTrip(self, dataset, allow_stateful=False): 349 """Converts a dataset to a graph and back.""" 350 graph = gen_dataset_ops.dataset_to_graph( 351 dataset._variant_tensor, allow_stateful=allow_stateful) # pylint: disable=protected-access 352 return dataset_ops.from_variant( 353 gen_experimental_dataset_ops.dataset_from_graph(graph), 354 dataset.element_spec) 355 356 def structuredElement(self, element_structure, shape=None, 357 dtype=dtypes.int64): 358 """Returns an element with the given structure.""" 359 if shape is None: 360 shape = [] 361 if element_structure is None: 362 return array_ops.zeros(shape, dtype=dtype) 363 else: 364 return tuple([ 365 self.structuredElement(substructure, shape, dtype) 366 for substructure in element_structure 367 ]) 368 369 def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements): 370 """Tests whether a dataset produces its elements deterministically. 371 372 `dataset_fn` takes a delay_ms argument, which tells it how long to delay 373 production of the first dataset element. This gives us a way to trigger 374 out-of-order production of dataset elements. 375 376 Args: 377 dataset_fn: A function taking a delay_ms argument. 378 expect_determinism: Whether to expect deterministic ordering. 379 expected_elements: The elements expected to be produced by the dataset, 380 assuming the dataset produces elements in deterministic order. 381 """ 382 if expect_determinism: 383 dataset = dataset_fn(100) 384 actual = self.getDatasetOutput(dataset) 385 self.assertAllEqual(expected_elements, actual) 386 return 387 388 # We consider the test a success if it succeeds under any delay_ms. The 389 # delay_ms needed to observe non-deterministic ordering varies across 390 # test machines. Usually 10 or 100 milliseconds is enough, but on slow 391 # machines it could take longer. 392 for delay_ms in [10, 100, 1000, 20000, 100000]: 393 dataset = dataset_fn(delay_ms) 394 actual = self.getDatasetOutput(dataset) 395 self.assertCountEqual(expected_elements, actual) 396 for i in range(len(actual)): 397 if actual[i] != expected_elements[i]: 398 return 399 self.fail("Failed to observe nondeterministic ordering") 400 401 def configureDevicesForMultiDeviceTest(self, num_devices): 402 """Configures number of logical devices for multi-device tests. 403 404 It returns a list of device names. If invoked in GPU-enabled runtime, the 405 last device name will be for a GPU device. Otherwise, all device names will 406 be for a CPU device. 407 408 Args: 409 num_devices: The number of devices to configure. 410 411 Returns: 412 A list of device names to use for a multi-device test. 413 """ 414 cpus = config.list_physical_devices("CPU") 415 gpus = config.list_physical_devices("GPU") 416 config.set_logical_device_configuration(cpus[0], [ 417 context.LogicalDeviceConfiguration() for _ in range(num_devices) 418 ]) 419 devices = ["/device:CPU:" + str(i) for i in range(num_devices - 1)] 420 if gpus: 421 devices.append("/device:GPU:0") 422 else: 423 devices.append("/device:CPU:" + str(num_devices - 1)) 424 return devices 425