• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for training utility functions."""
16
17import functools
18import multiprocessing.pool
19import time
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.data.ops import options as options_lib
26from tensorflow.python.data.ops import readers
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.keras import backend
33from tensorflow.python.keras import keras_parameterized
34from tensorflow.python.keras import testing_utils
35from tensorflow.python.keras.engine import keras_tensor
36from tensorflow.python.keras.engine import training_utils_v1
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.ops.ragged import ragged_tensor_value
39from tensorflow.python.platform import test
40from tensorflow.python.platform import tf_logging as logging
41
42
43class ModelInputsTest(test.TestCase):
44
45  def test_single_thing(self):
46    a = np.ones(10)
47    model_inputs = training_utils_v1.ModelInputs(a)
48    self.assertEqual(['input_1'], model_inputs.get_input_names())
49    vals = model_inputs.get_symbolic_inputs()
50    self.assertTrue(tensor_util.is_tf_type(vals))
51    vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
52    self.assertEqual(1, len(vals))
53    self.assertTrue(tensor_util.is_tf_type(vals[0]))
54    self.assertEqual(backend.floatx(), vals[0].dtype)
55
56  def test_single_thing_eager(self):
57    if not context.executing_eagerly():
58      self.skipTest('Run in eager mode only.')
59    a = np.ones(10, dtype=np.int32)
60    model_inputs = training_utils_v1.ModelInputs(a)
61    self.assertEqual(['input_1'], model_inputs.get_input_names())
62    val = model_inputs.get_symbolic_inputs()
63    self.assertIsInstance(val, keras_tensor.KerasTensor)
64    vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
65    self.assertEqual(1, len(vals))
66    self.assertIsInstance(vals[0], keras_tensor.KerasTensor)
67    self.assertEqual(dtypes.int32, vals[0].dtype)
68
69  def test_list(self):
70    a = [np.ones(10), np.ones(20)]
71    model_inputs = training_utils_v1.ModelInputs(a)
72    self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
73    vals = model_inputs.get_symbolic_inputs()
74    self.assertTrue(tensor_util.is_tf_type(vals[0]))
75    self.assertTrue(tensor_util.is_tf_type(vals[1]))
76
77  def test_list_eager(self):
78    if not context.executing_eagerly():
79      self.skipTest('Run in eager mode only.')
80    a = [np.ones(10), np.ones(20)]
81    model_inputs = training_utils_v1.ModelInputs(a)
82    self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
83    vals = model_inputs.get_symbolic_inputs()
84    self.assertIsInstance(vals[0], keras_tensor.KerasTensor)
85    self.assertIsInstance(vals[1], keras_tensor.KerasTensor)
86
87  def test_dict(self):
88    a = {'b': np.ones(10), 'a': np.ones(20)}
89    model_inputs = training_utils_v1.ModelInputs(a)
90    self.assertEqual(['a', 'b'], model_inputs.get_input_names())
91    vals = model_inputs.get_symbolic_inputs()
92    self.assertTrue(tensor_util.is_tf_type(vals['a']))
93    self.assertTrue(tensor_util.is_tf_type(vals['b']))
94
95  def test_dict_eager(self):
96    if not context.executing_eagerly():
97      self.skipTest('Run in eager mode only.')
98    a = {'b': np.ones(10), 'a': np.ones(20)}
99    model_inputs = training_utils_v1.ModelInputs(a)
100    self.assertEqual(['a', 'b'], model_inputs.get_input_names())
101    vals = model_inputs.get_symbolic_inputs()
102    self.assertIsInstance(vals['a'], keras_tensor.KerasTensor)
103    self.assertIsInstance(vals['b'], keras_tensor.KerasTensor)
104
105
106class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
107
108  @parameterized.named_parameters(
109      # pylint: disable=g-long-lambda
110      ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)),
111      ('Cache', lambda: dataset_ops.Dataset.range(5).cache()),
112      ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
113          dataset_ops.Dataset.range(5))),
114      ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
115          lambda _: dataset_ops.Dataset.from_tensors(0))),
116      ('FlatMap_Shuffle', lambda: dataset_ops.Dataset.range(5).flat_map(
117          lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1)), True),
118      ('Filter', lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
119      ('FixedLengthRecordDatasetV2',
120       lambda: readers.FixedLengthRecordDatasetV2([], 42)),
121      ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)),
122      ('FromTensorSlices',
123       lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
124      ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
125          lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1)),
126      ('Interleave_Shuffle', lambda: dataset_ops.Dataset.range(5).interleave(
127          lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1),
128          cycle_length=1), True),
129      ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
130      ('Options',
131       lambda: dataset_ops.Dataset.range(5).with_options(options_lib.Options())
132      ),
133      ('PaddedBatch', lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])),
134      ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
135          lambda _: dataset_ops.Dataset.from_tensors(0),
136          cycle_length=1,
137          num_parallel_calls=1)),
138      ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
139          lambda x: x, num_parallel_calls=1)),
140      ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
141      ('Range', lambda: dataset_ops.Dataset.range(0)),
142      ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
143      ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), True),
144      ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
145      ('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
146      ('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
147      ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])),
148      ('Window', lambda: dataset_ops.Dataset.range(5).window(2)),
149      ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
150      # pylint: enable=g-long-lambda
151  )
152  def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False):
153    dataset = dataset_fn()
154
155    if not expect_shuffled:
156      with test.mock.patch.object(logging, 'warning') as mock_log:
157        shuffled = training_utils_v1.verify_dataset_shuffled(dataset)
158        self.assertRegex(
159            str(mock_log.call_args), 'input dataset `x` is not shuffled.')
160        self.assertFalse(shuffled)
161    else:
162      self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset))
163
164
165class StandardizeWeightsTest(keras_parameterized.TestCase):
166
167  def test_sample_weights(self):
168    y = np.array([0, 1, 0, 0, 2])
169    sample_weights = np.array([0.5, 1., 1., 0., 2.])
170    weights = training_utils_v1.standardize_weights(y, sample_weights)
171    self.assertAllClose(weights, sample_weights)
172
173  def test_class_weights(self):
174    y = np.array([0, 1, 0, 0, 2])
175    class_weights = {0: 0.5, 1: 1., 2: 1.5}
176    weights = training_utils_v1.standardize_weights(
177        y, class_weight=class_weights)
178    self.assertAllClose(weights, np.array([0.5, 1., 0.5, 0.5, 1.5]))
179
180  def test_sample_weights_and_class_weights(self):
181    y = np.array([0, 1, 0, 0, 2])
182    sample_weights = np.array([0.5, 1., 1., 0., 2.])
183    class_weights = {0: 0.5, 1: 1., 2: 1.5}
184    weights = training_utils_v1.standardize_weights(y, sample_weights,
185                                                    class_weights)
186    expected = sample_weights * np.array([0.5, 1., 0.5, 0.5, 1.5])
187    self.assertAllClose(weights, expected)
188
189  def test_dataset_with_class_weight(self):
190    model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
191    model.compile('rmsprop', 'mse')
192
193    inputs = np.zeros((10, 3), np.float32)
194    targets = np.zeros((10, 4), np.float32)
195    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
196    dataset = dataset.repeat(100)
197    dataset = dataset.batch(10)
198    class_weight_np = np.array([0.25, 0.25, 0.25, 0.25])
199    class_weight = dict(enumerate(class_weight_np))
200
201    model.fit(
202        dataset,
203        epochs=1,
204        steps_per_epoch=2,
205        verbose=1,
206        class_weight=class_weight)
207
208
209class MonitoredPool(multiprocessing.pool.ThreadPool):
210
211  def __init__(self, *args, **kwargs):
212    self._apply_counter = 0
213    self._func_wrapper = None
214    super(MonitoredPool, self).__init__(*args, **kwargs)
215
216  def apply_async(self, func, *args, **kwargs):
217    self._apply_counter += 1
218    if self._func_wrapper:
219      func = self._func_wrapper(func)  # pylint: disable=not-callable
220    return super(MonitoredPool, self).apply_async(func, *args, **kwargs)
221
222
223def add_sleep(f):
224  @functools.wraps(f)
225  def wrapped(*args, **kwargs):
226    time.sleep(1.)
227    return f(*args, **kwargs)
228  return wrapped
229
230
231def cause_error(f):
232  @functools.wraps(f)
233  def wrapped(batch_element, batch_start, batch_end, is_finished):  # pylint: disable=unused-argument
234    # Induce a TypeError during assignment.
235    return f(None, None, None, is_finished)
236  return wrapped
237
238
239_TEST_DATA = np.array((
240    (3, 1, 3, 1, 2, 0, 3, 3, 1, 2),
241    (0, 1, 2, 1, 3, 0, 0, 1, 3, 0),
242    (3, 2, 1, 1, 1, 1, 1, 3, 2, 3),
243    (2, 2, 0, 1, 0, 3, 3, 2, 1, 1),
244    (3, 0, 3, 3, 3, 2, 1, 0, 0, 1),
245    (1, 0, 3, 3, 3, 2, 1, 2, 3, 1),))
246
247
248class AggregationTest(keras_parameterized.TestCase):
249
250  def setUp(self):
251    super(AggregationTest, self).setUp()
252    self._old_pool = training_utils_v1._COPY_POOL
253    self._old_threshold = (
254        training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD)
255    self._old_timeout = training_utils_v1.SliceAggregator._MAX_COPY_SECONDS
256    training_utils_v1._COPY_POOL = MonitoredPool(
257        training_utils_v1._COPY_THREADS)
258
259  def tearDown(self):
260    super(AggregationTest, self).tearDown()
261    training_utils_v1._COPY_POOL = self._old_pool
262    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = (
263        self._old_threshold)
264    training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout
265
266  def _run_with_steps(self):
267    aggregator = training_utils_v1.OutputsAggregator(use_steps=True)
268    for i, batch in enumerate(np.array_split(_TEST_DATA, 4)):
269      if i == 0:
270        aggregator.create(batch)
271      aggregator.aggregate(batch)
272
273    assert len(aggregator.results) == 1
274    assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator)
275
276    aggregator.finalize()
277    return aggregator.results
278
279  def _run_without_steps(self):
280    aggregator = training_utils_v1.OutputsAggregator(
281        use_steps=False, num_samples=6)
282
283    batch_start = 0
284    for i, batch in enumerate(np.array_split(_TEST_DATA, 4)):
285      if i == 0:
286        aggregator.create(batch)
287
288      batch_end = batch_start + batch.shape[0]
289      aggregator.aggregate(batch, batch_start, batch_end)
290      batch_start = batch_end
291
292    assert len(aggregator.results) == 1
293    assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator)
294
295    aggregator.finalize()
296    return aggregator.results
297
298  def test_with_steps(self):
299    self.assertAllEqual(self._run_with_steps(), _TEST_DATA)
300
301  def test_without_steps(self):
302    self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
303
304  def test_nested_aggregation(self):
305    aggregator = training_utils_v1.OutputsAggregator(
306        use_steps=False, num_samples=6)
307
308    batches = np.array_split(_TEST_DATA, 4)
309    batch_start = 0
310    for i, batch in enumerate(zip(batches, batches)):
311      if i == 0:
312        aggregator.create(batch)
313
314      batch_end = batch_start + batch[0].shape[0]
315      aggregator.aggregate(batch, batch_start, batch_end)
316      batch_start = batch_end
317
318    assert len(aggregator.results) == 2
319    aggregator.finalize()
320    self.assertAllEqual(aggregator.results, (_TEST_DATA, _TEST_DATA))
321
322  def test_concat_single_batch(self):
323    aggregator = training_utils_v1.OutputsAggregator(use_steps=True)
324    data = _TEST_DATA.copy()
325    aggregator.create(data)
326    assert len(aggregator.results) == 1
327    assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator)
328
329    aggregator.aggregate(data)
330    aggregator.finalize()
331    assert aggregator.results is data  # No copy.
332
333  def test_slice_single_batch(self):
334    aggregator = training_utils_v1.OutputsAggregator(
335        use_steps=False, num_samples=6)
336    data = _TEST_DATA.copy()
337    aggregator.create(data)
338    assert len(aggregator.results) == 1
339    assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator)
340
341    aggregator.aggregate(data, 0, 6)
342    aggregator.finalize()
343    assert aggregator.results is data  # No copy.
344
345  def test_async_copy(self):
346    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
347    self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
348
349    # Two of the four batches will have 20 elements and two will have 10.
350    self.assertEqual(training_utils_v1._COPY_POOL._apply_counter, 2)
351
352  def test_async_copy_timeout(self):
353    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
354    training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 0.1
355    training_utils_v1._COPY_POOL._func_wrapper = add_sleep
356    with self.assertRaisesRegex(ValueError, 'Timed out waiting for copy'):
357      self._run_without_steps()
358
359  def test_async_copy_reraise(self):
360    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
361    training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 1.
362    training_utils_v1._COPY_POOL._func_wrapper = cause_error
363    with self.assertRaisesRegex(TypeError, 'NoneType'):
364      self._run_without_steps()
365
366
367class CompositeTensorTestUtils(keras_parameterized.TestCase):
368
369  def test_is_composite(self):
370    # Validate that all composite tensor and value types return true.
371    self.assertTrue(
372        training_utils_v1.is_composite_or_composite_value(
373            sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1])))
374    self.assertTrue(
375        training_utils_v1.is_composite_or_composite_value(
376            sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1])))
377    self.assertTrue(
378        training_utils_v1.is_composite_or_composite_value(
379            ragged_tensor.RaggedTensor.from_row_splits(
380                np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))))
381    self.assertTrue(
382        training_utils_v1.is_composite_or_composite_value(
383            ragged_tensor_value.RaggedTensorValue(
384                np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))))
385
386    # Test that numpy arrays and tensors return false.
387    self.assertFalse(
388        training_utils_v1.is_composite_or_composite_value(np.ndarray([0, 1])))
389    self.assertFalse(
390        training_utils_v1.is_composite_or_composite_value(
391            ops.convert_to_tensor_v2_with_dispatch([3, 1])))
392
393  def test_sparse_concatenation(self):
394    tensor_1 = sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1])
395    tensor_2 = sparse_tensor.SparseTensor([[0, 0]], [2], [1, 1])
396    concatenated_tensor = training_utils_v1._append_composite_tensor(
397        tensor_1, tensor_2)
398    evaluated_tensor = self.evaluate(concatenated_tensor)
399    self.assertAllEqual(evaluated_tensor.indices, [[0, 0], [1, 0]])
400    self.assertAllEqual(evaluated_tensor.values, [1, 2])
401    self.assertAllEqual(evaluated_tensor.dense_shape, [2, 1])
402
403  def test_sparse_value_concatenation(self):
404    tensor_1 = sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1])
405    tensor_2 = sparse_tensor.SparseTensorValue([[0, 0]], [2], [1, 1])
406    concatenated_tensor = training_utils_v1._append_composite_tensor(
407        tensor_1, tensor_2)
408    self.assertAllEqual(concatenated_tensor.indices, [[0, 0], [1, 0]])
409    self.assertAllEqual(concatenated_tensor.values, [1, 2])
410    self.assertAllEqual(concatenated_tensor.dense_shape, [2, 1])
411
412  def test_ragged_concatenation(self):
413    tensor_1 = ragged_tensor.RaggedTensor.from_row_splits(
414        np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))
415    tensor_2 = ragged_tensor.RaggedTensor.from_row_splits(
416        np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64))
417    concatenated_tensor = training_utils_v1._append_composite_tensor(
418        tensor_1, tensor_2)
419    evaluated_tensor = self.evaluate(concatenated_tensor)
420
421    self.assertAllEqual(evaluated_tensor.values, [0, 1, 2, 3, 4, 5])
422    self.assertAllEqual(evaluated_tensor.row_splits, [0, 1, 3, 5, 6])
423
424  def test_ragged_value_concatenation(self):
425    tensor_1 = ragged_tensor_value.RaggedTensorValue(
426        np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))
427    tensor_2 = ragged_tensor_value.RaggedTensorValue(
428        np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64))
429    concatenated_tensor = training_utils_v1._append_composite_tensor(
430        tensor_1, tensor_2)
431
432    self.assertAllEqual(concatenated_tensor.values, [0, 1, 2, 3, 4, 5])
433    self.assertAllEqual(concatenated_tensor.row_splits, [0, 1, 3, 5, 6])
434
435
436if __name__ == '__main__':
437  test.main()
438