• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""DataAdapter tests."""
16
17import math
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python import keras
23from tensorflow.python.data.experimental.ops import cardinality
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.keras import keras_parameterized
29from tensorflow.python.keras import testing_utils
30from tensorflow.python.keras.engine import data_adapter
31from tensorflow.python.keras.utils import data_utils
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import sparse_ops
34from tensorflow.python.platform import test
35from tensorflow.python.util import nest
36
37
38class DummyArrayLike(object):
39  """Dummy array-like object."""
40
41  def __init__(self, data):
42    self.data = data
43
44  def __len__(self):
45    return len(self.data)
46
47  def __getitem__(self, key):
48    return self.data[key]
49
50  @property
51  def shape(self):
52    return self.data.shape
53
54  @property
55  def dtype(self):
56    return self.data.dtype
57
58
59def fail_on_convert(x, **kwargs):
60  _ = x
61  _ = kwargs
62  raise TypeError('Cannot convert DummyArrayLike to a tensor')
63ops.register_tensor_conversion_function(DummyArrayLike, fail_on_convert)
64
65
66class DataAdapterTestBase(keras_parameterized.TestCase):
67
68  def setUp(self):
69    super(DataAdapterTestBase, self).setUp()
70    self.batch_size = 5
71    self.numpy_input = np.zeros((50, 10))
72    self.numpy_target = np.ones(50)
73    self.tensor_input = constant_op.constant(2.0, shape=(50, 10))
74    self.tensor_target = array_ops.ones((50,))
75    self.arraylike_input = DummyArrayLike(self.numpy_input)
76    self.arraylike_target = DummyArrayLike(self.numpy_target)
77    self.dataset_input = dataset_ops.DatasetV2.from_tensor_slices(
78        (self.numpy_input, self.numpy_target)).shuffle(50).batch(
79            self.batch_size)
80
81    def generator():
82      while True:
83        yield (np.zeros((self.batch_size, 10)), np.ones(self.batch_size))
84    self.generator_input = generator()
85    self.iterator_input = data_utils.threadsafe_generator(generator)()
86    self.sequence_input = TestSequence(batch_size=self.batch_size,
87                                       feature_shape=10)
88    self.text_input = [['abc']]
89    self.bytes_input = [[b'abc']]
90    self.model = keras.models.Sequential(
91        [keras.layers.Dense(8, input_shape=(10,), activation='softmax')])
92
93
94class TestSequence(data_utils.Sequence):
95
96  def __init__(self, batch_size, feature_shape):
97    self.batch_size = batch_size
98    self.feature_shape = feature_shape
99
100  def __getitem__(self, item):
101    return (np.zeros((self.batch_size, self.feature_shape)),
102            np.ones((self.batch_size,)))
103
104  def __len__(self):
105    return 10
106
107
108class TensorLikeDataAdapterTest(DataAdapterTestBase):
109
110  def setUp(self):
111    super(TensorLikeDataAdapterTest, self).setUp()
112    self.adapter_cls = data_adapter.TensorLikeDataAdapter
113
114  def test_can_handle_numpy(self):
115    self.assertTrue(self.adapter_cls.can_handle(self.numpy_input))
116    self.assertTrue(
117        self.adapter_cls.can_handle(self.numpy_input, self.numpy_target))
118
119    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
120    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
121    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
122    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
123    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
124
125  def test_size_numpy(self):
126    adapter = self.adapter_cls(
127        self.numpy_input, self.numpy_target, batch_size=5)
128    self.assertEqual(adapter.get_size(), 10)
129    self.assertFalse(adapter.has_partial_batch())
130
131  def test_batch_size_numpy(self):
132    adapter = self.adapter_cls(
133        self.numpy_input, self.numpy_target, batch_size=5)
134    self.assertEqual(adapter.batch_size(), 5)
135
136  def test_partial_batch_numpy(self):
137    adapter = self.adapter_cls(
138        self.numpy_input, self.numpy_target, batch_size=4)
139    self.assertEqual(adapter.get_size(), 13)   # 50/4
140    self.assertTrue(adapter.has_partial_batch())
141    self.assertEqual(adapter.partial_batch_size(), 2)
142
143  def test_epochs(self):
144    num_epochs = 3
145    adapter = self.adapter_cls(
146        self.numpy_input, self.numpy_target, batch_size=5, epochs=num_epochs)
147    ds_iter = iter(adapter.get_dataset())
148    num_batches_per_epoch = self.numpy_input.shape[0] // 5
149    for _ in range(num_batches_per_epoch * num_epochs):
150      next(ds_iter)
151    with self.assertRaises(StopIteration):
152      next(ds_iter)
153
154  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
155  def test_training_numpy(self):
156    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
157                       run_eagerly=testing_utils.should_run_eagerly())
158    self.model.fit(self.numpy_input, self.numpy_target, batch_size=5)
159
160  def test_can_handle_pandas(self):
161    try:
162      import pandas as pd  # pylint: disable=g-import-not-at-top
163    except ImportError:
164      self.skipTest('Skipping test because pandas is not installed.')
165    self.assertTrue(self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input)))
166    self.assertTrue(
167        self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input)[0]))
168    self.assertTrue(
169        self.adapter_cls.can_handle(
170            pd.DataFrame(self.numpy_input),
171            pd.DataFrame(self.numpy_input)[0]))
172
173  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
174  def test_training_pandas(self):
175    try:
176      import pandas as pd  # pylint: disable=g-import-not-at-top
177    except ImportError:
178      self.skipTest('Skipping test because pandas is not installed.')
179    input_a = keras.Input(shape=(3,), name='input_a')
180    input_b = keras.Input(shape=(3,), name='input_b')
181    input_c = keras.Input(shape=(1,), name='input_b')
182
183    x = keras.layers.Dense(4, name='dense_1')(input_a)
184    y = keras.layers.Dense(3, name='dense_2')(input_b)
185    z = keras.layers.Dense(1, name='dense_3')(input_c)
186
187    model_1 = keras.Model(inputs=input_a, outputs=x)
188    model_2 = keras.Model(inputs=[input_a, input_b], outputs=[x, y])
189    model_3 = keras.Model(inputs=input_c, outputs=z)
190
191    model_1.compile(optimizer='rmsprop', loss='mse')
192    model_2.compile(optimizer='rmsprop', loss='mse')
193
194    input_a_np = np.random.random((10, 3))
195    input_b_np = np.random.random((10, 3))
196    input_a_df = pd.DataFrame(input_a_np)
197    input_b_df = pd.DataFrame(input_b_np)
198
199    output_a_df = pd.DataFrame(np.random.random((10, 4)))
200    output_b_df = pd.DataFrame(np.random.random((10, 3)))
201
202    model_1.fit(input_a_df,
203                output_a_df)
204    model_2.fit([input_a_df, input_b_df],
205                [output_a_df, output_b_df])
206    model_1.fit([input_a_df],
207                [output_a_df])
208    model_1.fit({'input_a': input_a_df},
209                output_a_df)
210    model_2.fit({'input_a': input_a_df, 'input_b': input_b_df},
211                [output_a_df, output_b_df])
212
213    model_1.evaluate(input_a_df,
214                     output_a_df)
215    model_2.evaluate([input_a_df, input_b_df],
216                     [output_a_df, output_b_df])
217    model_1.evaluate([input_a_df],
218                     [output_a_df])
219    model_1.evaluate({'input_a': input_a_df},
220                     output_a_df)
221    model_2.evaluate({'input_a': input_a_df, 'input_b': input_b_df},
222                     [output_a_df, output_b_df])
223
224    # Verify predicting on pandas vs numpy returns the same result
225    predict_1_pandas = model_1.predict(input_a_df)
226    predict_2_pandas = model_2.predict([input_a_df, input_b_df])
227    predict_3_pandas = model_3.predict(input_a_df[0])
228
229    predict_1_numpy = model_1.predict(input_a_np)
230    predict_2_numpy = model_2.predict([input_a_np, input_b_np])
231    predict_3_numpy = model_3.predict(np.asarray(input_a_df[0]))
232
233    self.assertAllClose(predict_1_numpy, predict_1_pandas)
234    self.assertAllClose(predict_2_numpy, predict_2_pandas)
235    self.assertAllClose(predict_3_numpy, predict_3_pandas)
236
237    # Extra ways to pass in dataframes
238    model_1.predict([input_a_df])
239    model_1.predict({'input_a': input_a_df})
240    model_2.predict({'input_a': input_a_df, 'input_b': input_b_df})
241
242  def test_can_handle(self):
243    self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
244    self.assertTrue(
245        self.adapter_cls.can_handle(self.tensor_input, self.tensor_target))
246
247    self.assertFalse(self.adapter_cls.can_handle(self.arraylike_input))
248    self.assertFalse(
249        self.adapter_cls.can_handle(self.arraylike_input,
250                                    self.arraylike_target))
251    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
252    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
253    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
254    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
255    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
256
257  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
258  def test_training(self):
259    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
260                       run_eagerly=testing_utils.should_run_eagerly())
261    self.model.fit(self.tensor_input, self.tensor_target, batch_size=5)
262
263  def test_size(self):
264    adapter = self.adapter_cls(
265        self.tensor_input, self.tensor_target, batch_size=5)
266    self.assertEqual(adapter.get_size(), 10)
267    self.assertFalse(adapter.has_partial_batch())
268
269  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
270  def test_shuffle_correctness(self):
271    num_samples = 100
272    batch_size = 32
273    x = np.arange(num_samples)
274    np.random.seed(99)
275    adapter = self.adapter_cls(
276        x, y=None, batch_size=batch_size, shuffle=True, epochs=2)
277
278    def _get_epoch(ds_iter):
279      ds_data = []
280      for _ in range(int(math.ceil(num_samples / batch_size))):
281        ds_data.append(next(ds_iter).numpy())
282      return np.concatenate(ds_data)
283
284    ds_iter = iter(adapter.get_dataset())
285
286    # First epoch.
287    epoch_data = _get_epoch(ds_iter)
288    # Check that shuffling occurred.
289    self.assertNotAllClose(x, epoch_data)
290    # Check that each elements appears, and only once.
291    self.assertAllClose(x, np.sort(epoch_data))
292
293    # Second epoch.
294    second_epoch_data = _get_epoch(ds_iter)
295    # Check that shuffling occurred.
296    self.assertNotAllClose(x, second_epoch_data)
297    # Check that shuffling is different across epochs.
298    self.assertNotAllClose(epoch_data, second_epoch_data)
299    # Check that each elements appears, and only once.
300    self.assertAllClose(x, np.sort(second_epoch_data))
301
302  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
303  def test_batch_shuffle_correctness(self):
304    num_samples = 100
305    batch_size = 6
306    x = np.arange(num_samples)
307    np.random.seed(99)
308    adapter = self.adapter_cls(
309        x, y=None, batch_size=batch_size, shuffle='batch', epochs=2)
310
311    def _get_epoch_batches(ds_iter):
312      ds_data = []
313      for _ in range(int(math.ceil(num_samples / batch_size))):
314        ds_data.append(next(ds_iter)[0].numpy())
315      return ds_data
316
317    ds_iter = iter(adapter.get_dataset())
318
319    # First epoch.
320    epoch_batch_data = _get_epoch_batches(ds_iter)
321    epoch_data = np.concatenate(epoch_batch_data)
322
323    def _verify_batch(batch):
324      # Verify that a batch contains only contiguous data, and that it has
325      # been shuffled.
326      shuffled_batch = np.sort(batch)
327      self.assertNotAllClose(batch, shuffled_batch)
328      for i in range(1, len(batch)):
329        self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i])
330
331    # Assert that the data within each batch remains contiguous
332    for batch in epoch_batch_data:
333      _verify_batch(batch)
334
335    # Check that individual batches are unshuffled
336    # Check that shuffling occurred.
337    self.assertNotAllClose(x, epoch_data)
338    # Check that each elements appears, and only once.
339    self.assertAllClose(x, np.sort(epoch_data))
340
341    # Second epoch.
342    second_epoch_batch_data = _get_epoch_batches(ds_iter)
343    second_epoch_data = np.concatenate(second_epoch_batch_data)
344
345    # Assert that the data within each batch remains contiguous
346    for batch in second_epoch_batch_data:
347      _verify_batch(batch)
348
349    # Check that shuffling occurred.
350    self.assertNotAllClose(x, second_epoch_data)
351    # Check that shuffling is different across epochs.
352    self.assertNotAllClose(epoch_data, second_epoch_data)
353    # Check that each elements appears, and only once.
354    self.assertAllClose(x, np.sort(second_epoch_data))
355
356  @parameterized.named_parameters(
357      ('batch_size_5', 5, None, 5),
358      ('batch_size_50', 50, 4, 50),  # Sanity check: batch_size takes precedence
359      ('steps_1', None, 1, 50),
360      ('steps_4', None, 4, 13),
361      )
362  def test_batch_size(self, batch_size_in, steps, batch_size_out):
363    adapter = self.adapter_cls(
364        self.tensor_input, self.tensor_target, batch_size=batch_size_in,
365        steps=steps)
366    self.assertEqual(adapter.batch_size(), batch_size_out)
367
368  @parameterized.named_parameters(
369      ('batch_size_5', 5, None, 10, 0),
370      ('batch_size_4', 4, None, 13, 2),
371      ('steps_1', None, 1, 1, 0),
372      ('steps_5', None, 5, 5, 0),
373      ('steps_4', None, 4, 4, 11),
374      )
375  def test_partial_batch(
376      self, batch_size_in, steps, size, partial_batch_size):
377    adapter = self.adapter_cls(
378        self.tensor_input, self.tensor_target, batch_size=batch_size_in,
379        steps=steps)
380    self.assertEqual(adapter.get_size(), size)   # 50/steps
381    self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size))
382    self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None)
383
384
385class GenericArrayLikeDataAdapterTest(DataAdapterTestBase):
386
387  def setUp(self):
388    super(GenericArrayLikeDataAdapterTest, self).setUp()
389    self.adapter_cls = data_adapter.GenericArrayLikeDataAdapter
390
391  def test_can_handle_some_numpy(self):
392    self.assertTrue(self.adapter_cls.can_handle(
393        self.arraylike_input))
394    self.assertTrue(
395        self.adapter_cls.can_handle(self.arraylike_input,
396                                    self.arraylike_target))
397
398    # Because adapters are mutually exclusive, don't handle cases
399    # where all the data is numpy or an eagertensor
400    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
401    self.assertFalse(
402        self.adapter_cls.can_handle(self.numpy_input,
403                                    self.numpy_target))
404    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
405    self.assertFalse(
406        self.adapter_cls.can_handle(self.tensor_input, self.tensor_target))
407
408    # But do handle mixes that include generic arraylike data
409    self.assertTrue(
410        self.adapter_cls.can_handle(self.numpy_input,
411                                    self.arraylike_target))
412    self.assertTrue(
413        self.adapter_cls.can_handle(self.arraylike_input,
414                                    self.numpy_target))
415    self.assertTrue(
416        self.adapter_cls.can_handle(self.arraylike_input,
417                                    self.tensor_target))
418    self.assertTrue(
419        self.adapter_cls.can_handle(self.tensor_input,
420                                    self.arraylike_target))
421
422    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
423    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
424    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
425    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
426    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
427
428  def test_size(self):
429    adapter = self.adapter_cls(
430        self.arraylike_input,
431        self.arraylike_target, batch_size=5)
432    self.assertEqual(adapter.get_size(), 10)
433    self.assertFalse(adapter.has_partial_batch())
434
435  def test_epochs(self):
436    num_epochs = 3
437    adapter = self.adapter_cls(
438        self.arraylike_input,
439        self.numpy_target, batch_size=5, epochs=num_epochs)
440    ds_iter = iter(adapter.get_dataset())
441    num_batches_per_epoch = self.numpy_input.shape[0] // 5
442    for _ in range(num_batches_per_epoch * num_epochs):
443      next(ds_iter)
444    with self.assertRaises(StopIteration):
445      next(ds_iter)
446
447  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
448  def test_training(self):
449    # First verify that DummyArrayLike can't be converted to a Tensor
450    with self.assertRaises(TypeError):
451      ops.convert_to_tensor_v2_with_dispatch(self.arraylike_input)
452
453    # Then train on the array like.
454    # It should not be converted to a tensor directly (which would force it into
455    # memory), only the sliced data should be converted.
456    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
457                       run_eagerly=testing_utils.should_run_eagerly())
458    self.model.fit(self.arraylike_input,
459                   self.arraylike_target, batch_size=5)
460    self.model.fit(self.arraylike_input,
461                   self.arraylike_target,
462                   shuffle=True, batch_size=5)
463    self.model.fit(self.arraylike_input,
464                   self.arraylike_target,
465                   shuffle='batch', batch_size=5)
466    self.model.evaluate(self.arraylike_input,
467                        self.arraylike_target, batch_size=5)
468    self.model.predict(self.arraylike_input, batch_size=5)
469
470  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
471  def test_training_numpy_target(self):
472    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
473                       run_eagerly=testing_utils.should_run_eagerly())
474    self.model.fit(self.arraylike_input,
475                   self.numpy_target, batch_size=5)
476    self.model.fit(self.arraylike_input,
477                   self.numpy_target, shuffle=True,
478                   batch_size=5)
479    self.model.fit(self.arraylike_input,
480                   self.numpy_target, shuffle='batch',
481                   batch_size=5)
482    self.model.evaluate(self.arraylike_input,
483                        self.numpy_target, batch_size=5)
484
485  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
486  def test_training_tensor_target(self):
487    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
488                       run_eagerly=testing_utils.should_run_eagerly())
489    self.model.fit(self.arraylike_input,
490                   self.tensor_target, batch_size=5)
491    self.model.fit(self.arraylike_input,
492                   self.tensor_target, shuffle=True,
493                   batch_size=5)
494    self.model.fit(self.arraylike_input,
495                   self.tensor_target, shuffle='batch',
496                   batch_size=5)
497    self.model.evaluate(self.arraylike_input,
498                        self.tensor_target, batch_size=5)
499
500  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
501  def test_shuffle_correctness(self):
502    num_samples = 100
503    batch_size = 32
504    x = DummyArrayLike(np.arange(num_samples))
505    np.random.seed(99)
506    adapter = self.adapter_cls(
507        x, y=None, batch_size=batch_size, shuffle=True, epochs=2)
508
509    def _get_epoch(ds_iter):
510      ds_data = []
511      for _ in range(int(math.ceil(num_samples / batch_size))):
512        ds_data.append(next(ds_iter).numpy())
513      return np.concatenate(ds_data)
514
515    ds_iter = iter(adapter.get_dataset())
516
517    # First epoch.
518    epoch_data = _get_epoch(ds_iter)
519    # Check that shuffling occurred.
520    self.assertNotAllClose(x, epoch_data)
521    # Check that each elements appears, and only once.
522    self.assertAllClose(x, np.sort(epoch_data))
523
524    # Second epoch.
525    second_epoch_data = _get_epoch(ds_iter)
526    # Check that shuffling occurred.
527    self.assertNotAllClose(x, second_epoch_data)
528    # Check that shuffling is different across epochs.
529    self.assertNotAllClose(epoch_data, second_epoch_data)
530    # Check that each elements appears, and only once.
531    self.assertAllClose(x, np.sort(second_epoch_data))
532
533  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
534  def test_batch_shuffle_correctness(self):
535    num_samples = 100
536    batch_size = 6
537    x = DummyArrayLike(np.arange(num_samples))
538    np.random.seed(99)
539    adapter = self.adapter_cls(
540        x, y=None, batch_size=batch_size, shuffle='batch', epochs=2)
541
542    def _get_epoch_batches(ds_iter):
543      ds_data = []
544      for _ in range(int(math.ceil(num_samples / batch_size))):
545        ds_data.append(next(ds_iter)[0].numpy())
546      return ds_data
547
548    ds_iter = iter(adapter.get_dataset())
549
550    # First epoch.
551    epoch_batch_data = _get_epoch_batches(ds_iter)
552    epoch_data = np.concatenate(epoch_batch_data)
553
554    def _verify_batch(batch):
555      # Verify that a batch contains only contiguous data, but that it has
556      # been shuffled.
557      shuffled_batch = np.sort(batch)
558      self.assertNotAllClose(batch, shuffled_batch)
559      for i in range(1, len(batch)):
560        self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i])
561
562    # Assert that the data within each batch is shuffled contiguous data
563    for batch in epoch_batch_data:
564      _verify_batch(batch)
565
566    # Check that individual batches are unshuffled
567    # Check that shuffling occurred.
568    self.assertNotAllClose(x, epoch_data)
569    # Check that each elements appears, and only once.
570    self.assertAllClose(x, np.sort(epoch_data))
571
572    # Second epoch.
573    second_epoch_batch_data = _get_epoch_batches(ds_iter)
574    second_epoch_data = np.concatenate(second_epoch_batch_data)
575
576    # Assert that the data within each batch remains contiguous
577    for batch in second_epoch_batch_data:
578      _verify_batch(batch)
579
580    # Check that shuffling occurred.
581    self.assertNotAllClose(x, second_epoch_data)
582    # Check that shuffling is different across epochs.
583    self.assertNotAllClose(epoch_data, second_epoch_data)
584    # Check that each elements appears, and only once.
585    self.assertAllClose(x, np.sort(second_epoch_data))
586
587  @parameterized.named_parameters(
588      ('batch_size_5', 5, None, 5),
589      ('batch_size_50', 50, 4, 50),  # Sanity check: batch_size takes precedence
590      ('steps_1', None, 1, 50),
591      ('steps_4', None, 4, 13),
592  )
593  def test_batch_size(self, batch_size_in, steps, batch_size_out):
594    adapter = self.adapter_cls(
595        self.arraylike_input,
596        self.arraylike_target, batch_size=batch_size_in,
597        steps=steps)
598    self.assertEqual(adapter.batch_size(), batch_size_out)
599
600  @parameterized.named_parameters(
601      ('batch_size_5', 5, None, 10, 0),
602      ('batch_size_4', 4, None, 13, 2),
603      ('steps_1', None, 1, 1, 0),
604      ('steps_5', None, 5, 5, 0),
605      ('steps_4', None, 4, 4, 11),
606  )
607  def test_partial_batch(
608      self, batch_size_in, steps, size, partial_batch_size):
609    adapter = self.adapter_cls(
610        self.arraylike_input, self.arraylike_target,
611        batch_size=batch_size_in,
612        steps=steps)
613    self.assertEqual(adapter.get_size(), size)   # 50/steps
614    self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size))
615    self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None)
616
617
618class DatasetAdapterTest(DataAdapterTestBase):
619
620  def setUp(self):
621    super(DatasetAdapterTest, self).setUp()
622    self.adapter_cls = data_adapter.DatasetAdapter
623
624  def test_can_handle(self):
625    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
626    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
627    self.assertTrue(self.adapter_cls.can_handle(self.dataset_input))
628    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
629    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
630
631  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
632  def test_training(self):
633    dataset = self.adapter_cls(self.dataset_input).get_dataset()
634    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
635                       run_eagerly=testing_utils.should_run_eagerly())
636    self.model.fit(dataset)
637
638  def test_size(self):
639    adapter = self.adapter_cls(self.dataset_input)
640    self.assertIsNone(adapter.get_size())
641
642  def test_batch_size(self):
643    adapter = self.adapter_cls(self.dataset_input)
644    self.assertIsNone(adapter.batch_size())
645
646  def test_partial_batch(self):
647    adapter = self.adapter_cls(self.dataset_input)
648    self.assertFalse(adapter.has_partial_batch())
649    self.assertIsNone(adapter.partial_batch_size())
650
651  def test_invalid_targets_argument(self):
652    with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'):
653      self.adapter_cls(self.dataset_input, y=self.dataset_input)
654
655  def test_invalid_sample_weights_argument(self):
656    with self.assertRaisesRegex(ValueError,
657                                r'`sample_weight` argument is not supported'):
658      self.adapter_cls(self.dataset_input, sample_weights=self.dataset_input)
659
660
661class GeneratorDataAdapterTest(DataAdapterTestBase):
662
663  def setUp(self):
664    super(GeneratorDataAdapterTest, self).setUp()
665    self.adapter_cls = data_adapter.GeneratorDataAdapter
666
667  def test_can_handle(self):
668    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
669    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
670    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
671    self.assertTrue(self.adapter_cls.can_handle(self.generator_input))
672    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
673    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
674    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
675
676  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
677  def test_training(self):
678    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
679                       run_eagerly=testing_utils.should_run_eagerly())
680    self.model.fit(self.generator_input, steps_per_epoch=10)
681
682  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
683  @testing_utils.run_v2_only
684  @data_utils.dont_use_multiprocessing_pool
685  def test_with_multiprocessing_training(self):
686    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
687                       run_eagerly=testing_utils.should_run_eagerly())
688    self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True,
689                   max_queue_size=10, steps_per_epoch=10)
690    # Fit twice to ensure there isn't any duplication that prevent the worker
691    # from starting.
692    self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True,
693                   max_queue_size=10, steps_per_epoch=10)
694
695  def test_size(self):
696    adapter = self.adapter_cls(self.generator_input)
697    self.assertIsNone(adapter.get_size())
698
699  def test_batch_size(self):
700    adapter = self.adapter_cls(self.generator_input)
701    self.assertEqual(adapter.batch_size(), None)
702    self.assertEqual(adapter.representative_batch_size(), 5)
703
704  def test_partial_batch(self):
705    adapter = self.adapter_cls(self.generator_input)
706    self.assertFalse(adapter.has_partial_batch())
707    self.assertIsNone(adapter.partial_batch_size())
708
709  def test_invalid_targets_argument(self):
710    with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'):
711      self.adapter_cls(self.generator_input, y=self.generator_input)
712
713  def test_invalid_sample_weights_argument(self):
714    with self.assertRaisesRegex(ValueError,
715                                r'`sample_weight` argument is not supported'):
716      self.adapter_cls(
717          self.generator_input, sample_weights=self.generator_input)
718
719  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
720  def test_not_shuffled(self):
721    def generator():
722      for i in range(10):
723        yield np.ones((1, 1)) * i
724
725    adapter = self.adapter_cls(generator(), shuffle=True)
726    for i, data in enumerate(adapter.get_dataset()):
727      self.assertEqual(i, data[0].numpy().flatten())
728
729
730class KerasSequenceAdapterTest(DataAdapterTestBase):
731
732  def setUp(self):
733    super(KerasSequenceAdapterTest, self).setUp()
734    self.adapter_cls = data_adapter.KerasSequenceAdapter
735
736  def test_can_handle(self):
737    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
738    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
739    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
740    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
741    self.assertTrue(self.adapter_cls.can_handle(self.sequence_input))
742    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
743    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
744
745  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
746  def test_training(self):
747    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
748                       run_eagerly=testing_utils.should_run_eagerly())
749    self.model.fit(self.sequence_input)
750
751  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
752  @testing_utils.run_v2_only
753  @data_utils.dont_use_multiprocessing_pool
754  def test_with_multiprocessing_training(self):
755    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
756                       run_eagerly=testing_utils.should_run_eagerly())
757    self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True,
758                   max_queue_size=10, steps_per_epoch=10)
759    # Fit twice to ensure there isn't any duplication that prevent the worker
760    # from starting.
761    self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True,
762                   max_queue_size=10, steps_per_epoch=10)
763
764  def test_size(self):
765    adapter = self.adapter_cls(self.sequence_input)
766    self.assertEqual(adapter.get_size(), 10)
767
768  def test_batch_size(self):
769    adapter = self.adapter_cls(self.sequence_input)
770    self.assertEqual(adapter.batch_size(), None)
771    self.assertEqual(adapter.representative_batch_size(), 5)
772
773  def test_partial_batch(self):
774    adapter = self.adapter_cls(self.sequence_input)
775    self.assertFalse(adapter.has_partial_batch())
776    self.assertIsNone(adapter.partial_batch_size())
777
778  def test_invalid_targets_argument(self):
779    with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'):
780      self.adapter_cls(self.sequence_input, y=self.sequence_input)
781
782  def test_invalid_sample_weights_argument(self):
783    with self.assertRaisesRegex(ValueError,
784                                r'`sample_weight` argument is not supported'):
785      self.adapter_cls(self.sequence_input, sample_weights=self.sequence_input)
786
787
788class DataHandlerTest(keras_parameterized.TestCase):
789
790  def test_finite_dataset_with_steps_per_epoch(self):
791    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1)
792    # User can choose to only partially consume `Dataset`.
793    data_handler = data_adapter.DataHandler(
794        data, initial_epoch=0, epochs=2, steps_per_epoch=2)
795    self.assertEqual(data_handler.inferred_steps, 2)
796    self.assertFalse(data_handler._adapter.should_recreate_iterator())
797    returned_data = []
798    for _, iterator in data_handler.enumerate_epochs():
799      epoch_data = []
800      for _ in data_handler.steps():
801        epoch_data.append(next(iterator).numpy())
802      returned_data.append(epoch_data)
803    self.assertEqual(returned_data, [[0, 1], [2, 3]])
804
805  def test_finite_dataset_without_steps_per_epoch(self):
806    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1)
807    data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2)
808    self.assertEqual(data_handler.inferred_steps, 3)
809    returned_data = []
810    for _, iterator in data_handler.enumerate_epochs():
811      epoch_data = []
812      for _ in data_handler.steps():
813        epoch_data.append(next(iterator).numpy())
814      returned_data.append(epoch_data)
815    self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]])
816
817  def test_finite_dataset_with_steps_per_epoch_exact_size(self):
818    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1)
819    # If user specifies exact size of `Dataset` as `steps_per_epoch`,
820    # create a new iterator each epoch.
821    data_handler = data_adapter.DataHandler(
822        data, initial_epoch=0, epochs=2, steps_per_epoch=4)
823    self.assertTrue(data_handler._adapter.should_recreate_iterator())
824    returned_data = []
825    for _, iterator in data_handler.enumerate_epochs():
826      epoch_data = []
827      for _ in data_handler.steps():
828        epoch_data.append(next(iterator).numpy())
829      returned_data.append(epoch_data)
830    self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
831
832  def test_infinite_dataset_with_steps_per_epoch(self):
833    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1).repeat()
834    data_handler = data_adapter.DataHandler(
835        data, initial_epoch=0, epochs=2, steps_per_epoch=3)
836    returned_data = []
837    for _, iterator in data_handler.enumerate_epochs():
838      epoch_data = []
839      for _ in data_handler.steps():
840        epoch_data.append(next(iterator).numpy())
841      returned_data.append(epoch_data)
842    self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]])
843
844  def test_unknown_cardinality_dataset_with_steps_per_epoch(self):
845    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
846    filtered_ds = ds.filter(lambda x: x < 4)
847    self.assertEqual(
848        cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN)
849
850    # User can choose to only partially consume `Dataset`.
851    data_handler = data_adapter.DataHandler(
852        filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2)
853    self.assertFalse(data_handler._adapter.should_recreate_iterator())
854    returned_data = []
855    for _, iterator in data_handler.enumerate_epochs():
856      epoch_data = []
857      for _ in data_handler.steps():
858        epoch_data.append(next(iterator))
859      returned_data.append(epoch_data)
860    returned_data = self.evaluate(returned_data)
861    self.assertEqual(returned_data, [[0, 1], [2, 3]])
862    self.assertEqual(data_handler.inferred_steps, 2)
863
864  def test_unknown_cardinality_dataset_without_steps_per_epoch(self):
865    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
866    filtered_ds = ds.filter(lambda x: x < 4)
867    self.assertEqual(
868        cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN)
869
870    data_handler = data_adapter.DataHandler(
871        filtered_ds, initial_epoch=0, epochs=2)
872    self.assertEqual(data_handler.inferred_steps, None)
873    self.assertTrue(data_handler._adapter.should_recreate_iterator())
874    returned_data = []
875    for _, iterator in data_handler.enumerate_epochs():
876      epoch_data = []
877      with data_handler.catch_stop_iteration():
878        for _ in data_handler.steps():
879          epoch_data.append(next(iterator))
880      returned_data.append(epoch_data)
881    returned_data = self.evaluate(returned_data)
882    self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
883    self.assertEqual(data_handler.inferred_steps, 4)
884
885  def test_insufficient_data(self):
886    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1])
887    ds = ds.filter(lambda *args, **kwargs: True)
888    data_handler = data_adapter.DataHandler(
889        ds, initial_epoch=0, epochs=2, steps_per_epoch=3)
890    returned_data = []
891    for _, iterator in data_handler.enumerate_epochs():
892      epoch_data = []
893      for _ in data_handler.steps():
894        with data_handler.catch_stop_iteration():
895          epoch_data.append(next(iterator))
896      returned_data.append(epoch_data)
897    returned_data = self.evaluate(returned_data)
898    self.assertTrue(data_handler._insufficient_data)
899    self.assertEqual(returned_data, [[0, 1]])
900
901  def test_numpy(self):
902    x = np.array([0, 1, 2])
903    y = np.array([0, 2, 4])
904    sw = np.array([0, 4, 8])
905    data_handler = data_adapter.DataHandler(
906        x=x, y=y, sample_weight=sw, batch_size=1, epochs=2)
907    returned_data = []
908    for _, iterator in data_handler.enumerate_epochs():
909      epoch_data = []
910      for _ in data_handler.steps():
911        epoch_data.append(next(iterator))
912      returned_data.append(epoch_data)
913    returned_data = self.evaluate(returned_data)
914    self.assertEqual(returned_data,
915                     [[(0, 0, 0), (1, 2, 4),
916                       (2, 4, 8)], [(0, 0, 0), (1, 2, 4), (2, 4, 8)]])
917
918  def test_generator(self):
919
920    def generator():
921      for _ in range(2):
922        for step in range(3):
923          yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
924
925    data_handler = data_adapter.DataHandler(
926        generator(), epochs=2, steps_per_epoch=3)
927    returned_data = []
928    for _, iterator in data_handler.enumerate_epochs():
929      epoch_data = []
930      for _ in data_handler.steps():
931        epoch_data.append(next(iterator))
932      returned_data.append(epoch_data)
933    returned_data = self.evaluate(returned_data)
934    self.assertEqual(returned_data, [[([0],), ([1],),
935                                      ([2],)], [([0],), ([1],), ([2],)]])
936
937  def test_composite_tensor(self):
938    st = sparse_tensor.SparseTensor(
939        indices=[[0, 0], [1, 0], [2, 0]], values=[0, 1, 2], dense_shape=[3, 1])
940    data_handler = data_adapter.DataHandler(st, epochs=2, steps_per_epoch=3)
941    returned_data = []
942    for _, iterator in data_handler.enumerate_epochs():
943      epoch_data = []
944      for _ in data_handler.steps():
945        epoch_data.append(next(iterator))
946      returned_data.append(epoch_data)
947    returned_data = self.evaluate(
948        nest.map_structure(sparse_ops.sparse_tensor_to_dense, returned_data))
949    self.assertEqual(returned_data, [[([0],), ([1],),
950                                      ([2],)], [([0],), ([1],), ([2],)]])
951
952  def test_iterator(self):
953    def generator():
954      for _ in range(2):
955        for step in range(3):
956          yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
957
958    it = iter(dataset_ops.Dataset.from_generator(
959        generator, output_types=('float32',)))
960    data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3)
961    returned_data = []
962    for _, iterator in data_handler.enumerate_epochs():
963      epoch_data = []
964      for _ in data_handler.steps():
965        epoch_data.append(next(iterator))
966      returned_data.append(epoch_data)
967    returned_data = self.evaluate(returned_data)
968    self.assertEqual(returned_data, [[([0],), ([1],), ([2],)],
969                                     [([0],), ([1],), ([2],)]])
970
971  def test_list_of_scalars(self):
972    data_handler = data_adapter.DataHandler([[0], [1], [2]],
973                                            epochs=2,
974                                            steps_per_epoch=3)
975    returned_data = []
976    for _, iterator in data_handler.enumerate_epochs():
977      epoch_data = []
978      for _ in data_handler.steps():
979        epoch_data.append(next(iterator))
980      returned_data.append(epoch_data)
981    returned_data = self.evaluate(returned_data)
982    self.assertEqual(returned_data, [[([0],), ([1],),
983                                      ([2],)], [([0],), ([1],), ([2],)]])
984
985  def test_class_weight_user_errors(self):
986    with self.assertRaisesRegex(ValueError, 'to be a dict with keys'):
987      data_adapter.DataHandler(
988          x=[[0], [1], [2]],
989          y=[[2], [1], [0]],
990          batch_size=1,
991          sample_weight=[[1.], [2.], [4.]],
992          class_weight={
993              0: 0.5,
994              1: 1.,
995              3: 1.5  # Skips class `2`.
996          })
997
998    with self.assertRaisesRegex(ValueError, 'with a single output'):
999      data_adapter.DataHandler(
1000          x=np.ones((10, 1)),
1001          y=[np.ones((10, 1)), np.zeros((10, 1))],
1002          batch_size=2,
1003          class_weight={
1004              0: 0.5,
1005              1: 1.,
1006              2: 1.5
1007          })
1008
1009  @parameterized.named_parameters(('numpy', True), ('dataset', False))
1010  def test_single_x_input_no_tuple_wrapping(self, use_numpy):
1011    x = np.ones((10, 1))
1012
1013    if use_numpy:
1014      batch_size = 2
1015    else:
1016      x = dataset_ops.Dataset.from_tensor_slices(x).batch(2)
1017      batch_size = None
1018
1019    data_handler = data_adapter.DataHandler(x, batch_size=batch_size)
1020    for _, iterator in data_handler.enumerate_epochs():
1021      for _ in data_handler.steps():
1022        # Check that single x input is not wrapped in a tuple.
1023        self.assertIsInstance(next(iterator), ops.Tensor)
1024
1025
1026class TestValidationSplit(keras_parameterized.TestCase):
1027
1028  @parameterized.named_parameters(('numpy_arrays', True), ('tensors', False))
1029  def test_validation_split_unshuffled(self, use_numpy):
1030    if use_numpy:
1031      x = np.array([0, 1, 2, 3, 4])
1032      y = np.array([0, 2, 4, 6, 8])
1033      sw = np.array([0, 4, 8, 12, 16])
1034    else:
1035      x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4])
1036      y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8])
1037      sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16])
1038
1039    (train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
1040        data_adapter.train_validation_split((x, y, sw), validation_split=0.2))
1041
1042    if use_numpy:
1043      train_x = ops.convert_to_tensor_v2_with_dispatch(train_x)
1044      train_y = ops.convert_to_tensor_v2_with_dispatch(train_y)
1045      train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw)
1046      val_x = ops.convert_to_tensor_v2_with_dispatch(val_x)
1047      val_y = ops.convert_to_tensor_v2_with_dispatch(val_y)
1048      val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw)
1049
1050    self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3])
1051    self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6])
1052    self.assertEqual(train_sw.numpy().tolist(), [0, 4, 8, 12])
1053
1054    self.assertEqual(val_x.numpy().tolist(), [4])
1055    self.assertEqual(val_y.numpy().tolist(), [8])
1056    self.assertEqual(val_sw.numpy().tolist(), [16])
1057
1058  def test_validation_split_user_error(self):
1059    with self.assertRaisesRegex(ValueError, 'is only supported for Tensors'):
1060      data_adapter.train_validation_split(
1061          lambda: np.ones((10, 1)), validation_split=0.2)
1062
1063  def test_validation_split_examples_too_few(self):
1064    with self.assertRaisesRegex(ValueError, 'not sufficient to split it'):
1065      data_adapter.train_validation_split(
1066          np.ones((1, 10)), validation_split=0.2)
1067
1068  def test_validation_split_none(self):
1069    train_sw, val_sw = data_adapter.train_validation_split(
1070        None, validation_split=0.2)
1071    self.assertIsNone(train_sw)
1072    self.assertIsNone(val_sw)
1073
1074    (_, train_sw), (_, val_sw) = data_adapter.train_validation_split(
1075        (np.ones((10, 1)), None), validation_split=0.2)
1076    self.assertIsNone(train_sw)
1077    self.assertIsNone(val_sw)
1078
1079
1080class ListsOfScalarsDataAdapterTest(DataAdapterTestBase):
1081
1082  def setUp(self):
1083    super(ListsOfScalarsDataAdapterTest, self).setUp()
1084    self.adapter_cls = data_adapter.ListsOfScalarsDataAdapter
1085
1086  def test_can_list_inputs(self):
1087    self.assertTrue(self.adapter_cls.can_handle(self.text_input))
1088    self.assertTrue(self.adapter_cls.can_handle(self.bytes_input))
1089
1090    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
1091    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
1092    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
1093    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
1094    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
1095    self.assertFalse(self.adapter_cls.can_handle([]))
1096
1097
1098class TestUtils(keras_parameterized.TestCase):
1099
1100  def test_expand_1d_sparse_tensors_untouched(self):
1101    st = sparse_tensor.SparseTensor(
1102        indices=[[0], [10]], values=[1, 2], dense_shape=[10])
1103    st = data_adapter.expand_1d(st)
1104    self.assertEqual(st.shape.rank, 1)
1105
1106
1107if __name__ == '__main__':
1108  ops.enable_eager_execution()
1109  test.main()
1110