1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utilities for tf.data functionality.""" 16import os 17import random 18import re 19 20from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops 21from tensorflow.python.data.experimental.ops import random_access 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.util import nest 24from tensorflow.python.data.util import structure 25from tensorflow.python.eager import context 26from tensorflow.python.framework import combinations 27from tensorflow.python.framework import config 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import gen_dataset_ops 35from tensorflow.python.ops import gen_experimental_dataset_ops 36from tensorflow.python.ops import lookup_ops 37from tensorflow.python.ops import tensor_array_ops 38from tensorflow.python.ops.ragged import ragged_tensor 39from tensorflow.python.platform import test 40 41 42def default_test_combinations(): 43 """Returns the default test combinations for tf.data tests.""" 44 return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"]) 45 46 47def eager_only_combinations(): 48 """Returns the default test combinations for eager mode only tf.data tests.""" 49 return combinations.combine(tf_api_version=[1, 2], mode="eager") 50 51 52def graph_only_combinations(): 53 """Returns the default test combinations for graph mode only tf.data tests.""" 54 return combinations.combine(tf_api_version=[1, 2], mode="graph") 55 56 57def v1_only_combinations(): 58 """Returns the default test combinations for v1 only tf.data tests.""" 59 return combinations.combine(tf_api_version=1, mode=["eager", "graph"]) 60 61 62def v2_only_combinations(): 63 """Returns the default test combinations for v2 only tf.data tests.""" 64 return combinations.combine(tf_api_version=2, mode=["eager", "graph"]) 65 66 67def v2_eager_only_combinations(): 68 """Returns the default test combinations for v2 eager only tf.data tests.""" 69 return combinations.combine(tf_api_version=2, mode="eager") 70 71 72class DatasetTestBase(test.TestCase): 73 """Base class for dataset tests.""" 74 75 def assert_op_cancelled(self, op): 76 with self.assertRaises(errors.CancelledError): 77 self.evaluate(op) 78 79 def assertValuesEqual(self, expected, actual): 80 """Asserts that two values are equal.""" 81 if isinstance(expected, dict): 82 self.assertItemsEqual(list(expected.keys()), list(actual.keys())) 83 for k in expected.keys(): 84 self.assertValuesEqual(expected[k], actual[k]) 85 elif sparse_tensor.is_sparse(expected): 86 self.assertAllEqual(expected.indices, actual.indices) 87 self.assertAllEqual(expected.values, actual.values) 88 self.assertAllEqual(expected.dense_shape, actual.dense_shape) 89 else: 90 self.assertAllEqual(expected, actual) 91 92 def getNext(self, dataset, requires_initialization=False, shared_name=None): 93 """Returns a callable that returns the next element of the dataset. 94 95 Example use: 96 ```python 97 # In both graph and eager modes 98 dataset = ... 99 get_next = self.getNext(dataset) 100 result = self.evaluate(get_next()) 101 ``` 102 103 Args: 104 dataset: A dataset whose elements will be returned. 105 requires_initialization: Indicates that when the test is executed in graph 106 mode, it should use an initializable iterator to iterate through the 107 dataset (e.g. when it contains stateful nodes). Defaults to False. 108 shared_name: (Optional.) If non-empty, the returned iterator will be 109 shared under the given name across multiple sessions that share the same 110 devices (e.g. when using a remote server). 111 Returns: 112 A callable that returns the next element of `dataset`. Any `TensorArray` 113 objects `dataset` outputs are stacked. 114 """ 115 def ta_wrapper(gn): 116 def _wrapper(): 117 r = gn() 118 if isinstance(r, tensor_array_ops.TensorArray): 119 return r.stack() 120 else: 121 return r 122 return _wrapper 123 124 # Create an anonymous iterator if we are in eager-mode or are graph inside 125 # of a tf.function. 126 if context.executing_eagerly() or ops.inside_function(): 127 iterator = iter(dataset) 128 return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access 129 else: 130 if requires_initialization: 131 iterator = dataset_ops.make_initializable_iterator(dataset, shared_name) 132 self.evaluate(iterator.initializer) 133 else: 134 iterator = dataset_ops.make_one_shot_iterator(dataset) 135 get_next = iterator.get_next() 136 return ta_wrapper(lambda: get_next) 137 138 def _compareOutputToExpected(self, result_values, expected_values, 139 assert_items_equal): 140 if assert_items_equal: 141 # TODO(shivaniagrawal): add support for nested elements containing sparse 142 # tensors when needed. 143 self.assertItemsEqual(result_values, expected_values) 144 return 145 for i in range(len(result_values)): 146 nest.assert_same_structure(result_values[i], expected_values[i]) 147 for result_value, expected_value in zip( 148 nest.flatten(result_values[i]), nest.flatten(expected_values[i])): 149 self.assertValuesEqual(expected_value, result_value) 150 151 def getDatasetOutput(self, dataset, requires_initialization=False): 152 get_next = self.getNext( 153 dataset, requires_initialization=requires_initialization) 154 return self.getIteratorOutput(get_next) 155 156 def getIteratorOutput(self, get_next): 157 """Evaluates `get_next` until end of input, returning the results.""" 158 results = [] 159 while True: 160 try: 161 results.append(self.evaluate(get_next())) 162 except errors.OutOfRangeError: 163 break 164 return results 165 166 def assertDatasetProduces(self, 167 dataset, 168 expected_output=None, 169 expected_shapes=None, 170 expected_error=None, 171 requires_initialization=False, 172 num_test_iterations=1, 173 assert_items_equal=False, 174 expected_error_iter=1): 175 """Asserts that a dataset produces the expected output / error. 176 177 Args: 178 dataset: A dataset to check for the expected output / error. 179 expected_output: A list of elements that the dataset is expected to 180 produce. 181 expected_shapes: A list of TensorShapes which is expected to match 182 output_shapes of dataset. 183 expected_error: A tuple `(type, predicate)` identifying the expected error 184 `dataset` should raise. The `type` should match the expected exception 185 type, while `predicate` should either be 1) a unary function that inputs 186 the raised exception and returns a boolean indicator of success or 2) a 187 regular expression that is expected to match the error message 188 partially. 189 requires_initialization: Indicates that when the test is executed in graph 190 mode, it should use an initializable iterator to iterate through the 191 dataset (e.g. when it contains stateful nodes). Defaults to False. 192 num_test_iterations: Number of times `dataset` will be iterated. Defaults 193 to 1. 194 assert_items_equal: Tests expected_output has (only) the same elements 195 regardless of order. 196 expected_error_iter: How many times to iterate before expecting an error, 197 if an error is expected. 198 """ 199 self.assertTrue( 200 expected_error is not None or expected_output is not None, 201 "Exactly one of expected_output or expected error should be provided.") 202 if expected_error: 203 self.assertTrue( 204 expected_output is None, 205 "Exactly one of expected_output or expected error should be provided." 206 ) 207 with self.assertRaisesWithPredicateMatch(expected_error[0], 208 expected_error[1]): 209 get_next = self.getNext( 210 dataset, requires_initialization=requires_initialization) 211 for _ in range(expected_error_iter): 212 self.evaluate(get_next()) 213 return 214 if expected_shapes: 215 self.assertEqual(expected_shapes, 216 dataset_ops.get_legacy_output_shapes(dataset)) 217 self.assertGreater(num_test_iterations, 0) 218 for _ in range(num_test_iterations): 219 get_next = self.getNext( 220 dataset, requires_initialization=requires_initialization) 221 result = [] 222 for _ in range(len(expected_output)): 223 try: 224 result.append(self.evaluate(get_next())) 225 except errors.OutOfRangeError: 226 raise AssertionError( 227 "Dataset ended early, producing %d elements out of %d. " 228 "Dataset output: %s" % 229 (len(result), len(expected_output), str(result))) 230 self._compareOutputToExpected(result, expected_output, assert_items_equal) 231 with self.assertRaises(errors.OutOfRangeError): 232 self.evaluate(get_next()) 233 with self.assertRaises(errors.OutOfRangeError): 234 self.evaluate(get_next()) 235 236 def assertDatasetsEqual(self, dataset1, dataset2): 237 """Checks that datasets are equal. Supports both graph and eager mode.""" 238 self.assertTrue( 239 structure.are_compatible( 240 dataset_ops.get_structure(dataset1), 241 dataset_ops.get_structure(dataset2))) 242 243 flattened_types = nest.flatten( 244 dataset_ops.get_legacy_output_types(dataset1)) 245 246 next1 = self.getNext(dataset1) 247 next2 = self.getNext(dataset2) 248 249 while True: 250 try: 251 op1 = self.evaluate(next1()) 252 except errors.OutOfRangeError: 253 with self.assertRaises(errors.OutOfRangeError): 254 self.evaluate(next2()) 255 break 256 op2 = self.evaluate(next2()) 257 258 op1 = nest.flatten(op1) 259 op2 = nest.flatten(op2) 260 assert len(op1) == len(op2) 261 for i in range(len(op1)): 262 if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]): 263 self.assertValuesEqual(op1[i], op2[i]) 264 elif flattened_types[i] == dtypes.string: 265 self.assertAllEqual(op1[i], op2[i]) 266 else: 267 self.assertAllClose(op1[i], op2[i]) 268 269 def assertDatasetsRaiseSameError(self, 270 dataset1, 271 dataset2, 272 exception_class, 273 replacements=None): 274 """Checks that datasets raise the same error on the first get_next call.""" 275 if replacements is None: 276 replacements = [] 277 next1 = self.getNext(dataset1) 278 next2 = self.getNext(dataset2) 279 try: 280 self.evaluate(next1()) 281 raise ValueError( 282 "Expected dataset to raise an error of type %s, but it did not." % 283 repr(exception_class)) 284 except exception_class as e: 285 expected_message = e.message 286 for old, new, count in replacements: 287 expected_message = expected_message.replace(old, new, count) 288 # Check that the first segment of the error messages are the same. 289 with self.assertRaisesRegexp(exception_class, 290 re.escape(expected_message)): 291 self.evaluate(next2()) 292 293 def structuredDataset(self, dataset_structure, shape=None, 294 dtype=dtypes.int64): 295 """Returns a singleton dataset with the given structure.""" 296 if shape is None: 297 shape = [] 298 if dataset_structure is None: 299 return dataset_ops.Dataset.from_tensors( 300 array_ops.zeros(shape, dtype=dtype)) 301 else: 302 return dataset_ops.Dataset.zip( 303 tuple([ 304 self.structuredDataset(substructure, shape, dtype) 305 for substructure in dataset_structure 306 ])) 307 308 def verifyRandomAccess(self, dataset, expected): 309 self.verifyRandomAccessInfiniteCardinality(dataset, expected) 310 with self.assertRaises(errors.OutOfRangeError): 311 self.evaluate(random_access.at(dataset, index=len(expected))) 312 313 def verifyRandomAccessInfiniteCardinality(self, dataset, expected): 314 """Tests randomly accessing elements of a dataset.""" 315 # Tests accessing the elements in a shuffled order with repeats. 316 len_expected = len(expected) 317 indices = list(range(len_expected)) * 2 318 random.shuffle(indices) 319 for i in indices: 320 self.assertAllEqual(expected[i], 321 self.evaluate(random_access.at(dataset, i))) 322 323 # Tests accessing the elements in order. 324 indices = set(sorted(indices)) 325 for i in indices: 326 self.assertAllEqual(expected[i], 327 self.evaluate(random_access.at(dataset, i))) 328 329 def textFileInitializer(self, vals): 330 file = os.path.join(self.get_temp_dir(), "text_file_initializer") 331 with open(file, "w") as f: 332 f.write("\n".join(str(v) for v in vals) + "\n") 333 return lookup_ops.TextFileInitializer(file, dtypes.int64, 334 lookup_ops.TextFileIndex.LINE_NUMBER, 335 dtypes.int64, 336 lookup_ops.TextFileIndex.WHOLE_LINE) 337 338 def keyValueTensorInitializer(self, vals): 339 keys_tensor = constant_op.constant( 340 list(range(len(vals))), dtype=dtypes.int64) 341 vals_tensor = constant_op.constant(vals) 342 return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor) 343 344 def datasetInitializer(self, vals): 345 keys = dataset_ops.Dataset.range(len(vals)) 346 values = dataset_ops.Dataset.from_tensor_slices(vals) 347 ds = dataset_ops.Dataset.zip((keys, values)) 348 return data_lookup_ops.DatasetInitializer(ds) 349 350 def lookupTableInitializer(self, init_source, vals): 351 """Returns a lookup table initializer for the given source and values. 352 353 Args: 354 init_source: One of ["textfile", "keyvalue", "dataset"], indicating what 355 type of initializer to use. 356 vals: The initializer values. The keys will be `range(len(vals))`. 357 """ 358 if init_source == "textfile": 359 return self.textFileInitializer(vals) 360 elif init_source == "keyvaluetensor": 361 return self.keyValueTensorInitializer(vals) 362 elif init_source == "dataset": 363 return self.datasetInitializer(vals) 364 else: 365 raise ValueError("Unrecognized init_source: " + init_source) 366 367 def graphRoundTrip(self, dataset, allow_stateful=False): 368 """Converts a dataset to a graph and back.""" 369 graph = gen_dataset_ops.dataset_to_graph( 370 dataset._variant_tensor, allow_stateful=allow_stateful) # pylint: disable=protected-access 371 return dataset_ops.from_variant( 372 gen_experimental_dataset_ops.dataset_from_graph(graph), 373 dataset.element_spec) 374 375 def structuredElement(self, element_structure, shape=None, 376 dtype=dtypes.int64): 377 """Returns an element with the given structure.""" 378 if shape is None: 379 shape = [] 380 if element_structure is None: 381 return array_ops.zeros(shape, dtype=dtype) 382 else: 383 return tuple([ 384 self.structuredElement(substructure, shape, dtype) 385 for substructure in element_structure 386 ]) 387 388 def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements): 389 """Tests whether a dataset produces its elements deterministically. 390 391 `dataset_fn` takes a delay_ms argument, which tells it how long to delay 392 production of the first dataset element. This gives us a way to trigger 393 out-of-order production of dataset elements. 394 395 Args: 396 dataset_fn: A function taking a delay_ms argument. 397 expect_determinism: Whether to expect deterministic ordering. 398 expected_elements: The elements expected to be produced by the dataset, 399 assuming the dataset produces elements in deterministic order. 400 """ 401 if expect_determinism: 402 dataset = dataset_fn(100) 403 actual = self.getDatasetOutput(dataset) 404 self.assertAllEqual(expected_elements, actual) 405 return 406 407 # We consider the test a success if it succeeds under any delay_ms. The 408 # delay_ms needed to observe non-deterministic ordering varies across 409 # test machines. Usually 10 or 100 milliseconds is enough, but on slow 410 # machines it could take longer. 411 for delay_ms in [10, 100, 1000, 20000, 100000]: 412 dataset = dataset_fn(delay_ms) 413 actual = self.getDatasetOutput(dataset) 414 self.assertCountEqual(expected_elements, actual) 415 for i in range(len(actual)): 416 if actual[i] != expected_elements[i]: 417 return 418 self.fail("Failed to observe nondeterministic ordering") 419 420 def configureDevicesForMultiDeviceTest(self, num_devices): 421 """Configures number of logical devices for multi-device tests. 422 423 It returns a list of device names. If invoked in GPU-enabled runtime, the 424 last device name will be for a GPU device. Otherwise, all device names will 425 be for a CPU device. 426 427 Args: 428 num_devices: The number of devices to configure. 429 430 Returns: 431 A list of device names to use for a multi-device test. 432 """ 433 cpus = config.list_physical_devices("CPU") 434 gpus = config.list_physical_devices("GPU") 435 config.set_logical_device_configuration(cpus[0], [ 436 context.LogicalDeviceConfiguration() for _ in range(num_devices) 437 ]) 438 devices = ["/device:CPU:" + str(i) for i in range(num_devices - 1)] 439 if gpus: 440 devices.append("/device:GPU:0") 441 else: 442 devices.append("/device:CPU:" + str(num_devices - 1)) 443 return devices 444