1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for training utility functions.""" 16 17import functools 18import multiprocessing.pool 19import time 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.data.ops import options as options_lib 26from tensorflow.python.data.ops import readers 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.keras import backend 33from tensorflow.python.keras import keras_parameterized 34from tensorflow.python.keras import testing_utils 35from tensorflow.python.keras.engine import keras_tensor 36from tensorflow.python.keras.engine import training_utils_v1 37from tensorflow.python.ops.ragged import ragged_tensor 38from tensorflow.python.ops.ragged import ragged_tensor_value 39from tensorflow.python.platform import test 40from tensorflow.python.platform import tf_logging as logging 41 42 43class ModelInputsTest(test.TestCase): 44 45 def test_single_thing(self): 46 a = np.ones(10) 47 model_inputs = training_utils_v1.ModelInputs(a) 48 self.assertEqual(['input_1'], model_inputs.get_input_names()) 49 vals = model_inputs.get_symbolic_inputs() 50 self.assertTrue(tensor_util.is_tf_type(vals)) 51 vals = model_inputs.get_symbolic_inputs(return_single_as_list=True) 52 self.assertEqual(1, len(vals)) 53 self.assertTrue(tensor_util.is_tf_type(vals[0])) 54 self.assertEqual(backend.floatx(), vals[0].dtype) 55 56 def test_single_thing_eager(self): 57 if not context.executing_eagerly(): 58 self.skipTest('Run in eager mode only.') 59 a = np.ones(10, dtype=np.int32) 60 model_inputs = training_utils_v1.ModelInputs(a) 61 self.assertEqual(['input_1'], model_inputs.get_input_names()) 62 val = model_inputs.get_symbolic_inputs() 63 self.assertIsInstance(val, keras_tensor.KerasTensor) 64 vals = model_inputs.get_symbolic_inputs(return_single_as_list=True) 65 self.assertEqual(1, len(vals)) 66 self.assertIsInstance(vals[0], keras_tensor.KerasTensor) 67 self.assertEqual(dtypes.int32, vals[0].dtype) 68 69 def test_list(self): 70 a = [np.ones(10), np.ones(20)] 71 model_inputs = training_utils_v1.ModelInputs(a) 72 self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names()) 73 vals = model_inputs.get_symbolic_inputs() 74 self.assertTrue(tensor_util.is_tf_type(vals[0])) 75 self.assertTrue(tensor_util.is_tf_type(vals[1])) 76 77 def test_list_eager(self): 78 if not context.executing_eagerly(): 79 self.skipTest('Run in eager mode only.') 80 a = [np.ones(10), np.ones(20)] 81 model_inputs = training_utils_v1.ModelInputs(a) 82 self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names()) 83 vals = model_inputs.get_symbolic_inputs() 84 self.assertIsInstance(vals[0], keras_tensor.KerasTensor) 85 self.assertIsInstance(vals[1], keras_tensor.KerasTensor) 86 87 def test_dict(self): 88 a = {'b': np.ones(10), 'a': np.ones(20)} 89 model_inputs = training_utils_v1.ModelInputs(a) 90 self.assertEqual(['a', 'b'], model_inputs.get_input_names()) 91 vals = model_inputs.get_symbolic_inputs() 92 self.assertTrue(tensor_util.is_tf_type(vals['a'])) 93 self.assertTrue(tensor_util.is_tf_type(vals['b'])) 94 95 def test_dict_eager(self): 96 if not context.executing_eagerly(): 97 self.skipTest('Run in eager mode only.') 98 a = {'b': np.ones(10), 'a': np.ones(20)} 99 model_inputs = training_utils_v1.ModelInputs(a) 100 self.assertEqual(['a', 'b'], model_inputs.get_input_names()) 101 vals = model_inputs.get_symbolic_inputs() 102 self.assertIsInstance(vals['a'], keras_tensor.KerasTensor) 103 self.assertIsInstance(vals['b'], keras_tensor.KerasTensor) 104 105 106class DatasetUtilsTest(test.TestCase, parameterized.TestCase): 107 108 @parameterized.named_parameters( 109 # pylint: disable=g-long-lambda 110 ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)), 111 ('Cache', lambda: dataset_ops.Dataset.range(5).cache()), 112 ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate( 113 dataset_ops.Dataset.range(5))), 114 ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map( 115 lambda _: dataset_ops.Dataset.from_tensors(0))), 116 ('FlatMap_Shuffle', lambda: dataset_ops.Dataset.range(5).flat_map( 117 lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1)), True), 118 ('Filter', lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)), 119 ('FixedLengthRecordDatasetV2', 120 lambda: readers.FixedLengthRecordDatasetV2([], 42)), 121 ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)), 122 ('FromTensorSlices', 123 lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])), 124 ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave( 125 lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1)), 126 ('Interleave_Shuffle', lambda: dataset_ops.Dataset.range(5).interleave( 127 lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1), 128 cycle_length=1), True), 129 ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)), 130 ('Options', 131 lambda: dataset_ops.Dataset.range(5).with_options(options_lib.Options()) 132 ), 133 ('PaddedBatch', lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])), 134 ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave( 135 lambda _: dataset_ops.Dataset.from_tensors(0), 136 cycle_length=1, 137 num_parallel_calls=1)), 138 ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map( 139 lambda x: x, num_parallel_calls=1)), 140 ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)), 141 ('Range', lambda: dataset_ops.Dataset.range(0)), 142 ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)), 143 ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), True), 144 ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)), 145 ('Take', lambda: dataset_ops.Dataset.range(5).take(2)), 146 ('TextLineDataset', lambda: readers.TextLineDatasetV2([])), 147 ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])), 148 ('Window', lambda: dataset_ops.Dataset.range(5).window(2)), 149 ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))), 150 # pylint: enable=g-long-lambda 151 ) 152 def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False): 153 dataset = dataset_fn() 154 155 if not expect_shuffled: 156 with test.mock.patch.object(logging, 'warning') as mock_log: 157 shuffled = training_utils_v1.verify_dataset_shuffled(dataset) 158 self.assertRegex( 159 str(mock_log.call_args), 'input dataset `x` is not shuffled.') 160 self.assertFalse(shuffled) 161 else: 162 self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset)) 163 164 165class StandardizeWeightsTest(keras_parameterized.TestCase): 166 167 def test_sample_weights(self): 168 y = np.array([0, 1, 0, 0, 2]) 169 sample_weights = np.array([0.5, 1., 1., 0., 2.]) 170 weights = training_utils_v1.standardize_weights(y, sample_weights) 171 self.assertAllClose(weights, sample_weights) 172 173 def test_class_weights(self): 174 y = np.array([0, 1, 0, 0, 2]) 175 class_weights = {0: 0.5, 1: 1., 2: 1.5} 176 weights = training_utils_v1.standardize_weights( 177 y, class_weight=class_weights) 178 self.assertAllClose(weights, np.array([0.5, 1., 0.5, 0.5, 1.5])) 179 180 def test_sample_weights_and_class_weights(self): 181 y = np.array([0, 1, 0, 0, 2]) 182 sample_weights = np.array([0.5, 1., 1., 0., 2.]) 183 class_weights = {0: 0.5, 1: 1., 2: 1.5} 184 weights = training_utils_v1.standardize_weights(y, sample_weights, 185 class_weights) 186 expected = sample_weights * np.array([0.5, 1., 0.5, 0.5, 1.5]) 187 self.assertAllClose(weights, expected) 188 189 def test_dataset_with_class_weight(self): 190 model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) 191 model.compile('rmsprop', 'mse') 192 193 inputs = np.zeros((10, 3), np.float32) 194 targets = np.zeros((10, 4), np.float32) 195 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 196 dataset = dataset.repeat(100) 197 dataset = dataset.batch(10) 198 class_weight_np = np.array([0.25, 0.25, 0.25, 0.25]) 199 class_weight = dict(enumerate(class_weight_np)) 200 201 model.fit( 202 dataset, 203 epochs=1, 204 steps_per_epoch=2, 205 verbose=1, 206 class_weight=class_weight) 207 208 209class MonitoredPool(multiprocessing.pool.ThreadPool): 210 211 def __init__(self, *args, **kwargs): 212 self._apply_counter = 0 213 self._func_wrapper = None 214 super(MonitoredPool, self).__init__(*args, **kwargs) 215 216 def apply_async(self, func, *args, **kwargs): 217 self._apply_counter += 1 218 if self._func_wrapper: 219 func = self._func_wrapper(func) # pylint: disable=not-callable 220 return super(MonitoredPool, self).apply_async(func, *args, **kwargs) 221 222 223def add_sleep(f): 224 @functools.wraps(f) 225 def wrapped(*args, **kwargs): 226 time.sleep(1.) 227 return f(*args, **kwargs) 228 return wrapped 229 230 231def cause_error(f): 232 @functools.wraps(f) 233 def wrapped(batch_element, batch_start, batch_end, is_finished): # pylint: disable=unused-argument 234 # Induce a TypeError during assignment. 235 return f(None, None, None, is_finished) 236 return wrapped 237 238 239_TEST_DATA = np.array(( 240 (3, 1, 3, 1, 2, 0, 3, 3, 1, 2), 241 (0, 1, 2, 1, 3, 0, 0, 1, 3, 0), 242 (3, 2, 1, 1, 1, 1, 1, 3, 2, 3), 243 (2, 2, 0, 1, 0, 3, 3, 2, 1, 1), 244 (3, 0, 3, 3, 3, 2, 1, 0, 0, 1), 245 (1, 0, 3, 3, 3, 2, 1, 2, 3, 1),)) 246 247 248class AggregationTest(keras_parameterized.TestCase): 249 250 def setUp(self): 251 super(AggregationTest, self).setUp() 252 self._old_pool = training_utils_v1._COPY_POOL 253 self._old_threshold = ( 254 training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD) 255 self._old_timeout = training_utils_v1.SliceAggregator._MAX_COPY_SECONDS 256 training_utils_v1._COPY_POOL = MonitoredPool( 257 training_utils_v1._COPY_THREADS) 258 259 def tearDown(self): 260 super(AggregationTest, self).tearDown() 261 training_utils_v1._COPY_POOL = self._old_pool 262 training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = ( 263 self._old_threshold) 264 training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout 265 266 def _run_with_steps(self): 267 aggregator = training_utils_v1.OutputsAggregator(use_steps=True) 268 for i, batch in enumerate(np.array_split(_TEST_DATA, 4)): 269 if i == 0: 270 aggregator.create(batch) 271 aggregator.aggregate(batch) 272 273 assert len(aggregator.results) == 1 274 assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator) 275 276 aggregator.finalize() 277 return aggregator.results 278 279 def _run_without_steps(self): 280 aggregator = training_utils_v1.OutputsAggregator( 281 use_steps=False, num_samples=6) 282 283 batch_start = 0 284 for i, batch in enumerate(np.array_split(_TEST_DATA, 4)): 285 if i == 0: 286 aggregator.create(batch) 287 288 batch_end = batch_start + batch.shape[0] 289 aggregator.aggregate(batch, batch_start, batch_end) 290 batch_start = batch_end 291 292 assert len(aggregator.results) == 1 293 assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator) 294 295 aggregator.finalize() 296 return aggregator.results 297 298 def test_with_steps(self): 299 self.assertAllEqual(self._run_with_steps(), _TEST_DATA) 300 301 def test_without_steps(self): 302 self.assertAllEqual(self._run_without_steps(), _TEST_DATA) 303 304 def test_nested_aggregation(self): 305 aggregator = training_utils_v1.OutputsAggregator( 306 use_steps=False, num_samples=6) 307 308 batches = np.array_split(_TEST_DATA, 4) 309 batch_start = 0 310 for i, batch in enumerate(zip(batches, batches)): 311 if i == 0: 312 aggregator.create(batch) 313 314 batch_end = batch_start + batch[0].shape[0] 315 aggregator.aggregate(batch, batch_start, batch_end) 316 batch_start = batch_end 317 318 assert len(aggregator.results) == 2 319 aggregator.finalize() 320 self.assertAllEqual(aggregator.results, (_TEST_DATA, _TEST_DATA)) 321 322 def test_concat_single_batch(self): 323 aggregator = training_utils_v1.OutputsAggregator(use_steps=True) 324 data = _TEST_DATA.copy() 325 aggregator.create(data) 326 assert len(aggregator.results) == 1 327 assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator) 328 329 aggregator.aggregate(data) 330 aggregator.finalize() 331 assert aggregator.results is data # No copy. 332 333 def test_slice_single_batch(self): 334 aggregator = training_utils_v1.OutputsAggregator( 335 use_steps=False, num_samples=6) 336 data = _TEST_DATA.copy() 337 aggregator.create(data) 338 assert len(aggregator.results) == 1 339 assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator) 340 341 aggregator.aggregate(data, 0, 6) 342 aggregator.finalize() 343 assert aggregator.results is data # No copy. 344 345 def test_async_copy(self): 346 training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 347 self.assertAllEqual(self._run_without_steps(), _TEST_DATA) 348 349 # Two of the four batches will have 20 elements and two will have 10. 350 self.assertEqual(training_utils_v1._COPY_POOL._apply_counter, 2) 351 352 def test_async_copy_timeout(self): 353 training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 354 training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 0.1 355 training_utils_v1._COPY_POOL._func_wrapper = add_sleep 356 with self.assertRaisesRegex(ValueError, 'Timed out waiting for copy'): 357 self._run_without_steps() 358 359 def test_async_copy_reraise(self): 360 training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 361 training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 1. 362 training_utils_v1._COPY_POOL._func_wrapper = cause_error 363 with self.assertRaisesRegex(TypeError, 'NoneType'): 364 self._run_without_steps() 365 366 367class CompositeTensorTestUtils(keras_parameterized.TestCase): 368 369 def test_is_composite(self): 370 # Validate that all composite tensor and value types return true. 371 self.assertTrue( 372 training_utils_v1.is_composite_or_composite_value( 373 sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1]))) 374 self.assertTrue( 375 training_utils_v1.is_composite_or_composite_value( 376 sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1]))) 377 self.assertTrue( 378 training_utils_v1.is_composite_or_composite_value( 379 ragged_tensor.RaggedTensor.from_row_splits( 380 np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))) 381 self.assertTrue( 382 training_utils_v1.is_composite_or_composite_value( 383 ragged_tensor_value.RaggedTensorValue( 384 np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))) 385 386 # Test that numpy arrays and tensors return false. 387 self.assertFalse( 388 training_utils_v1.is_composite_or_composite_value(np.ndarray([0, 1]))) 389 self.assertFalse( 390 training_utils_v1.is_composite_or_composite_value( 391 ops.convert_to_tensor_v2_with_dispatch([3, 1]))) 392 393 def test_sparse_concatenation(self): 394 tensor_1 = sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1]) 395 tensor_2 = sparse_tensor.SparseTensor([[0, 0]], [2], [1, 1]) 396 concatenated_tensor = training_utils_v1._append_composite_tensor( 397 tensor_1, tensor_2) 398 evaluated_tensor = self.evaluate(concatenated_tensor) 399 self.assertAllEqual(evaluated_tensor.indices, [[0, 0], [1, 0]]) 400 self.assertAllEqual(evaluated_tensor.values, [1, 2]) 401 self.assertAllEqual(evaluated_tensor.dense_shape, [2, 1]) 402 403 def test_sparse_value_concatenation(self): 404 tensor_1 = sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1]) 405 tensor_2 = sparse_tensor.SparseTensorValue([[0, 0]], [2], [1, 1]) 406 concatenated_tensor = training_utils_v1._append_composite_tensor( 407 tensor_1, tensor_2) 408 self.assertAllEqual(concatenated_tensor.indices, [[0, 0], [1, 0]]) 409 self.assertAllEqual(concatenated_tensor.values, [1, 2]) 410 self.assertAllEqual(concatenated_tensor.dense_shape, [2, 1]) 411 412 def test_ragged_concatenation(self): 413 tensor_1 = ragged_tensor.RaggedTensor.from_row_splits( 414 np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)) 415 tensor_2 = ragged_tensor.RaggedTensor.from_row_splits( 416 np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64)) 417 concatenated_tensor = training_utils_v1._append_composite_tensor( 418 tensor_1, tensor_2) 419 evaluated_tensor = self.evaluate(concatenated_tensor) 420 421 self.assertAllEqual(evaluated_tensor.values, [0, 1, 2, 3, 4, 5]) 422 self.assertAllEqual(evaluated_tensor.row_splits, [0, 1, 3, 5, 6]) 423 424 def test_ragged_value_concatenation(self): 425 tensor_1 = ragged_tensor_value.RaggedTensorValue( 426 np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)) 427 tensor_2 = ragged_tensor_value.RaggedTensorValue( 428 np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64)) 429 concatenated_tensor = training_utils_v1._append_composite_tensor( 430 tensor_1, tensor_2) 431 432 self.assertAllEqual(concatenated_tensor.values, [0, 1, 2, 3, 4, 5]) 433 self.assertAllEqual(concatenated_tensor.row_splits, [0, 1, 3, 5, 6]) 434 435 436if __name__ == '__main__': 437 test.main() 438