• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""DataAdapter tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python import keras
27from tensorflow.python.data.experimental.ops import cardinality
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.keras import keras_parameterized
33from tensorflow.python.keras import testing_utils
34from tensorflow.python.keras.engine import data_adapter
35from tensorflow.python.keras.utils import data_utils
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import sparse_ops
38from tensorflow.python.platform import test
39from tensorflow.python.util import nest
40
41
42class DummyArrayLike(object):
43  """Dummy array-like object."""
44
45  def __init__(self, data):
46    self.data = data
47
48  def __len__(self):
49    return len(self.data)
50
51  def __getitem__(self, key):
52    return self.data[key]
53
54  @property
55  def shape(self):
56    return self.data.shape
57
58  @property
59  def dtype(self):
60    return self.data.dtype
61
62
63def fail_on_convert(x, **kwargs):
64  _ = x
65  _ = kwargs
66  raise TypeError('Cannot convert DummyArrayLike to a tensor')
67ops.register_tensor_conversion_function(DummyArrayLike, fail_on_convert)
68
69
70class DataAdapterTestBase(keras_parameterized.TestCase):
71
72  def setUp(self):
73    super(DataAdapterTestBase, self).setUp()
74    self.batch_size = 5
75    self.numpy_input = np.zeros((50, 10))
76    self.numpy_target = np.ones(50)
77    self.tensor_input = constant_op.constant(2.0, shape=(50, 10))
78    self.tensor_target = array_ops.ones((50,))
79    self.arraylike_input = DummyArrayLike(self.numpy_input)
80    self.arraylike_target = DummyArrayLike(self.numpy_target)
81    self.dataset_input = dataset_ops.DatasetV2.from_tensor_slices(
82        (self.numpy_input, self.numpy_target)).shuffle(50).batch(
83            self.batch_size)
84
85    def generator():
86      while True:
87        yield (np.zeros((self.batch_size, 10)), np.ones(self.batch_size))
88    self.generator_input = generator()
89    self.iterator_input = data_utils.threadsafe_generator(generator)()
90    self.sequence_input = TestSequence(batch_size=self.batch_size,
91                                       feature_shape=10)
92    self.text_input = [['abc']]
93    self.bytes_input = [[b'abc']]
94    self.model = keras.models.Sequential(
95        [keras.layers.Dense(8, input_shape=(10,), activation='softmax')])
96
97
98class TestSequence(data_utils.Sequence):
99
100  def __init__(self, batch_size, feature_shape):
101    self.batch_size = batch_size
102    self.feature_shape = feature_shape
103
104  def __getitem__(self, item):
105    return (np.zeros((self.batch_size, self.feature_shape)),
106            np.ones((self.batch_size,)))
107
108  def __len__(self):
109    return 10
110
111
112class TensorLikeDataAdapterTest(DataAdapterTestBase):
113
114  def setUp(self):
115    super(TensorLikeDataAdapterTest, self).setUp()
116    self.adapter_cls = data_adapter.TensorLikeDataAdapter
117
118  def test_can_handle_numpy(self):
119    self.assertTrue(self.adapter_cls.can_handle(self.numpy_input))
120    self.assertTrue(
121        self.adapter_cls.can_handle(self.numpy_input, self.numpy_target))
122
123    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
124    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
125    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
126    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
127    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
128
129  def test_size_numpy(self):
130    adapter = self.adapter_cls(
131        self.numpy_input, self.numpy_target, batch_size=5)
132    self.assertEqual(adapter.get_size(), 10)
133    self.assertFalse(adapter.has_partial_batch())
134
135  def test_batch_size_numpy(self):
136    adapter = self.adapter_cls(
137        self.numpy_input, self.numpy_target, batch_size=5)
138    self.assertEqual(adapter.batch_size(), 5)
139
140  def test_partial_batch_numpy(self):
141    adapter = self.adapter_cls(
142        self.numpy_input, self.numpy_target, batch_size=4)
143    self.assertEqual(adapter.get_size(), 13)   # 50/4
144    self.assertTrue(adapter.has_partial_batch())
145    self.assertEqual(adapter.partial_batch_size(), 2)
146
147  def test_epochs(self):
148    num_epochs = 3
149    adapter = self.adapter_cls(
150        self.numpy_input, self.numpy_target, batch_size=5, epochs=num_epochs)
151    ds_iter = iter(adapter.get_dataset())
152    num_batches_per_epoch = self.numpy_input.shape[0] // 5
153    for _ in range(num_batches_per_epoch * num_epochs):
154      next(ds_iter)
155    with self.assertRaises(StopIteration):
156      next(ds_iter)
157
158  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
159  def test_training_numpy(self):
160    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
161                       run_eagerly=testing_utils.should_run_eagerly())
162    self.model.fit(self.numpy_input, self.numpy_target, batch_size=5)
163
164  def test_can_handle_pandas(self):
165    try:
166      import pandas as pd  # pylint: disable=g-import-not-at-top
167    except ImportError:
168      self.skipTest('Skipping test because pandas is not installed.')
169    self.assertTrue(self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input)))
170    self.assertTrue(
171        self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input)[0]))
172    self.assertTrue(
173        self.adapter_cls.can_handle(
174            pd.DataFrame(self.numpy_input),
175            pd.DataFrame(self.numpy_input)[0]))
176
177  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
178  def test_training_pandas(self):
179    try:
180      import pandas as pd  # pylint: disable=g-import-not-at-top
181    except ImportError:
182      self.skipTest('Skipping test because pandas is not installed.')
183    input_a = keras.Input(shape=(3,), name='input_a')
184    input_b = keras.Input(shape=(3,), name='input_b')
185    input_c = keras.Input(shape=(1,), name='input_b')
186
187    x = keras.layers.Dense(4, name='dense_1')(input_a)
188    y = keras.layers.Dense(3, name='dense_2')(input_b)
189    z = keras.layers.Dense(1, name='dense_3')(input_c)
190
191    model_1 = keras.Model(inputs=input_a, outputs=x)
192    model_2 = keras.Model(inputs=[input_a, input_b], outputs=[x, y])
193    model_3 = keras.Model(inputs=input_c, outputs=z)
194
195    model_1.compile(optimizer='rmsprop', loss='mse')
196    model_2.compile(optimizer='rmsprop', loss='mse')
197
198    input_a_np = np.random.random((10, 3))
199    input_b_np = np.random.random((10, 3))
200    input_a_df = pd.DataFrame(input_a_np)
201    input_b_df = pd.DataFrame(input_b_np)
202
203    output_a_df = pd.DataFrame(np.random.random((10, 4)))
204    output_b_df = pd.DataFrame(np.random.random((10, 3)))
205
206    model_1.fit(input_a_df,
207                output_a_df)
208    model_2.fit([input_a_df, input_b_df],
209                [output_a_df, output_b_df])
210    model_1.fit([input_a_df],
211                [output_a_df])
212    model_1.fit({'input_a': input_a_df},
213                output_a_df)
214    model_2.fit({'input_a': input_a_df, 'input_b': input_b_df},
215                [output_a_df, output_b_df])
216
217    model_1.evaluate(input_a_df,
218                     output_a_df)
219    model_2.evaluate([input_a_df, input_b_df],
220                     [output_a_df, output_b_df])
221    model_1.evaluate([input_a_df],
222                     [output_a_df])
223    model_1.evaluate({'input_a': input_a_df},
224                     output_a_df)
225    model_2.evaluate({'input_a': input_a_df, 'input_b': input_b_df},
226                     [output_a_df, output_b_df])
227
228    # Verify predicting on pandas vs numpy returns the same result
229    predict_1_pandas = model_1.predict(input_a_df)
230    predict_2_pandas = model_2.predict([input_a_df, input_b_df])
231    predict_3_pandas = model_3.predict(input_a_df[0])
232
233    predict_1_numpy = model_1.predict(input_a_np)
234    predict_2_numpy = model_2.predict([input_a_np, input_b_np])
235    predict_3_numpy = model_3.predict(np.asarray(input_a_df[0]))
236
237    self.assertAllClose(predict_1_numpy, predict_1_pandas)
238    self.assertAllClose(predict_2_numpy, predict_2_pandas)
239    self.assertAllClose(predict_3_numpy, predict_3_pandas)
240
241    # Extra ways to pass in dataframes
242    model_1.predict([input_a_df])
243    model_1.predict({'input_a': input_a_df})
244    model_2.predict({'input_a': input_a_df, 'input_b': input_b_df})
245
246  def test_can_handle(self):
247    self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
248    self.assertTrue(
249        self.adapter_cls.can_handle(self.tensor_input, self.tensor_target))
250
251    self.assertFalse(self.adapter_cls.can_handle(self.arraylike_input))
252    self.assertFalse(
253        self.adapter_cls.can_handle(self.arraylike_input,
254                                    self.arraylike_target))
255    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
256    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
257    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
258    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
259    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
260
261  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
262  def test_training(self):
263    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
264                       run_eagerly=testing_utils.should_run_eagerly())
265    self.model.fit(self.tensor_input, self.tensor_target, batch_size=5)
266
267  def test_size(self):
268    adapter = self.adapter_cls(
269        self.tensor_input, self.tensor_target, batch_size=5)
270    self.assertEqual(adapter.get_size(), 10)
271    self.assertFalse(adapter.has_partial_batch())
272
273  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
274  def test_shuffle_correctness(self):
275    num_samples = 100
276    batch_size = 32
277    x = np.arange(num_samples)
278    np.random.seed(99)
279    adapter = self.adapter_cls(
280        x, y=None, batch_size=batch_size, shuffle=True, epochs=2)
281
282    def _get_epoch(ds_iter):
283      ds_data = []
284      for _ in range(int(math.ceil(num_samples / batch_size))):
285        ds_data.append(next(ds_iter).numpy())
286      return np.concatenate(ds_data)
287
288    ds_iter = iter(adapter.get_dataset())
289
290    # First epoch.
291    epoch_data = _get_epoch(ds_iter)
292    # Check that shuffling occurred.
293    self.assertNotAllClose(x, epoch_data)
294    # Check that each elements appears, and only once.
295    self.assertAllClose(x, np.sort(epoch_data))
296
297    # Second epoch.
298    second_epoch_data = _get_epoch(ds_iter)
299    # Check that shuffling occurred.
300    self.assertNotAllClose(x, second_epoch_data)
301    # Check that shuffling is different across epochs.
302    self.assertNotAllClose(epoch_data, second_epoch_data)
303    # Check that each elements appears, and only once.
304    self.assertAllClose(x, np.sort(second_epoch_data))
305
306  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
307  def test_batch_shuffle_correctness(self):
308    num_samples = 100
309    batch_size = 6
310    x = np.arange(num_samples)
311    np.random.seed(99)
312    adapter = self.adapter_cls(
313        x, y=None, batch_size=batch_size, shuffle='batch', epochs=2)
314
315    def _get_epoch_batches(ds_iter):
316      ds_data = []
317      for _ in range(int(math.ceil(num_samples / batch_size))):
318        ds_data.append(next(ds_iter)[0].numpy())
319      return ds_data
320
321    ds_iter = iter(adapter.get_dataset())
322
323    # First epoch.
324    epoch_batch_data = _get_epoch_batches(ds_iter)
325    epoch_data = np.concatenate(epoch_batch_data)
326
327    def _verify_batch(batch):
328      # Verify that a batch contains only contiguous data, and that it has
329      # been shuffled.
330      shuffled_batch = np.sort(batch)
331      self.assertNotAllClose(batch, shuffled_batch)
332      for i in range(1, len(batch)):
333        self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i])
334
335    # Assert that the data within each batch remains contiguous
336    for batch in epoch_batch_data:
337      _verify_batch(batch)
338
339    # Check that individual batches are unshuffled
340    # Check that shuffling occurred.
341    self.assertNotAllClose(x, epoch_data)
342    # Check that each elements appears, and only once.
343    self.assertAllClose(x, np.sort(epoch_data))
344
345    # Second epoch.
346    second_epoch_batch_data = _get_epoch_batches(ds_iter)
347    second_epoch_data = np.concatenate(second_epoch_batch_data)
348
349    # Assert that the data within each batch remains contiguous
350    for batch in second_epoch_batch_data:
351      _verify_batch(batch)
352
353    # Check that shuffling occurred.
354    self.assertNotAllClose(x, second_epoch_data)
355    # Check that shuffling is different across epochs.
356    self.assertNotAllClose(epoch_data, second_epoch_data)
357    # Check that each elements appears, and only once.
358    self.assertAllClose(x, np.sort(second_epoch_data))
359
360  @parameterized.named_parameters(
361      ('batch_size_5', 5, None, 5),
362      ('batch_size_50', 50, 4, 50),  # Sanity check: batch_size takes precedence
363      ('steps_1', None, 1, 50),
364      ('steps_4', None, 4, 13),
365      )
366  def test_batch_size(self, batch_size_in, steps, batch_size_out):
367    adapter = self.adapter_cls(
368        self.tensor_input, self.tensor_target, batch_size=batch_size_in,
369        steps=steps)
370    self.assertEqual(adapter.batch_size(), batch_size_out)
371
372  @parameterized.named_parameters(
373      ('batch_size_5', 5, None, 10, 0),
374      ('batch_size_4', 4, None, 13, 2),
375      ('steps_1', None, 1, 1, 0),
376      ('steps_5', None, 5, 5, 0),
377      ('steps_4', None, 4, 4, 11),
378      )
379  def test_partial_batch(
380      self, batch_size_in, steps, size, partial_batch_size):
381    adapter = self.adapter_cls(
382        self.tensor_input, self.tensor_target, batch_size=batch_size_in,
383        steps=steps)
384    self.assertEqual(adapter.get_size(), size)   # 50/steps
385    self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size))
386    self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None)
387
388
389class GenericArrayLikeDataAdapterTest(DataAdapterTestBase):
390
391  def setUp(self):
392    super(GenericArrayLikeDataAdapterTest, self).setUp()
393    self.adapter_cls = data_adapter.GenericArrayLikeDataAdapter
394
395  def test_can_handle_some_numpy(self):
396    self.assertTrue(self.adapter_cls.can_handle(
397        self.arraylike_input))
398    self.assertTrue(
399        self.adapter_cls.can_handle(self.arraylike_input,
400                                    self.arraylike_target))
401
402    # Because adapters are mutually exclusive, don't handle cases
403    # where all the data is numpy or an eagertensor
404    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
405    self.assertFalse(
406        self.adapter_cls.can_handle(self.numpy_input,
407                                    self.numpy_target))
408    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
409    self.assertFalse(
410        self.adapter_cls.can_handle(self.tensor_input, self.tensor_target))
411
412    # But do handle mixes that include generic arraylike data
413    self.assertTrue(
414        self.adapter_cls.can_handle(self.numpy_input,
415                                    self.arraylike_target))
416    self.assertTrue(
417        self.adapter_cls.can_handle(self.arraylike_input,
418                                    self.numpy_target))
419    self.assertTrue(
420        self.adapter_cls.can_handle(self.arraylike_input,
421                                    self.tensor_target))
422    self.assertTrue(
423        self.adapter_cls.can_handle(self.tensor_input,
424                                    self.arraylike_target))
425
426    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
427    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
428    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
429    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
430    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
431
432  def test_size(self):
433    adapter = self.adapter_cls(
434        self.arraylike_input,
435        self.arraylike_target, batch_size=5)
436    self.assertEqual(adapter.get_size(), 10)
437    self.assertFalse(adapter.has_partial_batch())
438
439  def test_epochs(self):
440    num_epochs = 3
441    adapter = self.adapter_cls(
442        self.arraylike_input,
443        self.numpy_target, batch_size=5, epochs=num_epochs)
444    ds_iter = iter(adapter.get_dataset())
445    num_batches_per_epoch = self.numpy_input.shape[0] // 5
446    for _ in range(num_batches_per_epoch * num_epochs):
447      next(ds_iter)
448    with self.assertRaises(StopIteration):
449      next(ds_iter)
450
451  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
452  def test_training(self):
453    # First verify that DummyArrayLike can't be converted to a Tensor
454    with self.assertRaises(TypeError):
455      ops.convert_to_tensor_v2_with_dispatch(self.arraylike_input)
456
457    # Then train on the array like.
458    # It should not be converted to a tensor directly (which would force it into
459    # memory), only the sliced data should be converted.
460    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
461                       run_eagerly=testing_utils.should_run_eagerly())
462    self.model.fit(self.arraylike_input,
463                   self.arraylike_target, batch_size=5)
464    self.model.fit(self.arraylike_input,
465                   self.arraylike_target,
466                   shuffle=True, batch_size=5)
467    self.model.fit(self.arraylike_input,
468                   self.arraylike_target,
469                   shuffle='batch', batch_size=5)
470    self.model.evaluate(self.arraylike_input,
471                        self.arraylike_target, batch_size=5)
472    self.model.predict(self.arraylike_input, batch_size=5)
473
474  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
475  def test_training_numpy_target(self):
476    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
477                       run_eagerly=testing_utils.should_run_eagerly())
478    self.model.fit(self.arraylike_input,
479                   self.numpy_target, batch_size=5)
480    self.model.fit(self.arraylike_input,
481                   self.numpy_target, shuffle=True,
482                   batch_size=5)
483    self.model.fit(self.arraylike_input,
484                   self.numpy_target, shuffle='batch',
485                   batch_size=5)
486    self.model.evaluate(self.arraylike_input,
487                        self.numpy_target, batch_size=5)
488
489  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
490  def test_training_tensor_target(self):
491    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
492                       run_eagerly=testing_utils.should_run_eagerly())
493    self.model.fit(self.arraylike_input,
494                   self.tensor_target, batch_size=5)
495    self.model.fit(self.arraylike_input,
496                   self.tensor_target, shuffle=True,
497                   batch_size=5)
498    self.model.fit(self.arraylike_input,
499                   self.tensor_target, shuffle='batch',
500                   batch_size=5)
501    self.model.evaluate(self.arraylike_input,
502                        self.tensor_target, batch_size=5)
503
504  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
505  def test_shuffle_correctness(self):
506    num_samples = 100
507    batch_size = 32
508    x = DummyArrayLike(np.arange(num_samples))
509    np.random.seed(99)
510    adapter = self.adapter_cls(
511        x, y=None, batch_size=batch_size, shuffle=True, epochs=2)
512
513    def _get_epoch(ds_iter):
514      ds_data = []
515      for _ in range(int(math.ceil(num_samples / batch_size))):
516        ds_data.append(next(ds_iter).numpy())
517      return np.concatenate(ds_data)
518
519    ds_iter = iter(adapter.get_dataset())
520
521    # First epoch.
522    epoch_data = _get_epoch(ds_iter)
523    # Check that shuffling occurred.
524    self.assertNotAllClose(x, epoch_data)
525    # Check that each elements appears, and only once.
526    self.assertAllClose(x, np.sort(epoch_data))
527
528    # Second epoch.
529    second_epoch_data = _get_epoch(ds_iter)
530    # Check that shuffling occurred.
531    self.assertNotAllClose(x, second_epoch_data)
532    # Check that shuffling is different across epochs.
533    self.assertNotAllClose(epoch_data, second_epoch_data)
534    # Check that each elements appears, and only once.
535    self.assertAllClose(x, np.sort(second_epoch_data))
536
537  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
538  def test_batch_shuffle_correctness(self):
539    num_samples = 100
540    batch_size = 6
541    x = DummyArrayLike(np.arange(num_samples))
542    np.random.seed(99)
543    adapter = self.adapter_cls(
544        x, y=None, batch_size=batch_size, shuffle='batch', epochs=2)
545
546    def _get_epoch_batches(ds_iter):
547      ds_data = []
548      for _ in range(int(math.ceil(num_samples / batch_size))):
549        ds_data.append(next(ds_iter)[0].numpy())
550      return ds_data
551
552    ds_iter = iter(adapter.get_dataset())
553
554    # First epoch.
555    epoch_batch_data = _get_epoch_batches(ds_iter)
556    epoch_data = np.concatenate(epoch_batch_data)
557
558    def _verify_batch(batch):
559      # Verify that a batch contains only contiguous data, but that it has
560      # been shuffled.
561      shuffled_batch = np.sort(batch)
562      self.assertNotAllClose(batch, shuffled_batch)
563      for i in range(1, len(batch)):
564        self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i])
565
566    # Assert that the data within each batch is shuffled contiguous data
567    for batch in epoch_batch_data:
568      _verify_batch(batch)
569
570    # Check that individual batches are unshuffled
571    # Check that shuffling occurred.
572    self.assertNotAllClose(x, epoch_data)
573    # Check that each elements appears, and only once.
574    self.assertAllClose(x, np.sort(epoch_data))
575
576    # Second epoch.
577    second_epoch_batch_data = _get_epoch_batches(ds_iter)
578    second_epoch_data = np.concatenate(second_epoch_batch_data)
579
580    # Assert that the data within each batch remains contiguous
581    for batch in second_epoch_batch_data:
582      _verify_batch(batch)
583
584    # Check that shuffling occurred.
585    self.assertNotAllClose(x, second_epoch_data)
586    # Check that shuffling is different across epochs.
587    self.assertNotAllClose(epoch_data, second_epoch_data)
588    # Check that each elements appears, and only once.
589    self.assertAllClose(x, np.sort(second_epoch_data))
590
591  @parameterized.named_parameters(
592      ('batch_size_5', 5, None, 5),
593      ('batch_size_50', 50, 4, 50),  # Sanity check: batch_size takes precedence
594      ('steps_1', None, 1, 50),
595      ('steps_4', None, 4, 13),
596  )
597  def test_batch_size(self, batch_size_in, steps, batch_size_out):
598    adapter = self.adapter_cls(
599        self.arraylike_input,
600        self.arraylike_target, batch_size=batch_size_in,
601        steps=steps)
602    self.assertEqual(adapter.batch_size(), batch_size_out)
603
604  @parameterized.named_parameters(
605      ('batch_size_5', 5, None, 10, 0),
606      ('batch_size_4', 4, None, 13, 2),
607      ('steps_1', None, 1, 1, 0),
608      ('steps_5', None, 5, 5, 0),
609      ('steps_4', None, 4, 4, 11),
610  )
611  def test_partial_batch(
612      self, batch_size_in, steps, size, partial_batch_size):
613    adapter = self.adapter_cls(
614        self.arraylike_input, self.arraylike_target,
615        batch_size=batch_size_in,
616        steps=steps)
617    self.assertEqual(adapter.get_size(), size)   # 50/steps
618    self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size))
619    self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None)
620
621
622class DatasetAdapterTest(DataAdapterTestBase):
623
624  def setUp(self):
625    super(DatasetAdapterTest, self).setUp()
626    self.adapter_cls = data_adapter.DatasetAdapter
627
628  def test_can_handle(self):
629    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
630    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
631    self.assertTrue(self.adapter_cls.can_handle(self.dataset_input))
632    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
633    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
634
635  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
636  def test_training(self):
637    dataset = self.adapter_cls(self.dataset_input).get_dataset()
638    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
639                       run_eagerly=testing_utils.should_run_eagerly())
640    self.model.fit(dataset)
641
642  def test_size(self):
643    adapter = self.adapter_cls(self.dataset_input)
644    self.assertIsNone(adapter.get_size())
645
646  def test_batch_size(self):
647    adapter = self.adapter_cls(self.dataset_input)
648    self.assertIsNone(adapter.batch_size())
649
650  def test_partial_batch(self):
651    adapter = self.adapter_cls(self.dataset_input)
652    self.assertFalse(adapter.has_partial_batch())
653    self.assertIsNone(adapter.partial_batch_size())
654
655  def test_invalid_targets_argument(self):
656    with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'):
657      self.adapter_cls(self.dataset_input, y=self.dataset_input)
658
659  def test_invalid_sample_weights_argument(self):
660    with self.assertRaisesRegex(ValueError,
661                                r'`sample_weight` argument is not supported'):
662      self.adapter_cls(self.dataset_input, sample_weights=self.dataset_input)
663
664
665class GeneratorDataAdapterTest(DataAdapterTestBase):
666
667  def setUp(self):
668    super(GeneratorDataAdapterTest, self).setUp()
669    self.adapter_cls = data_adapter.GeneratorDataAdapter
670
671  def test_can_handle(self):
672    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
673    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
674    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
675    self.assertTrue(self.adapter_cls.can_handle(self.generator_input))
676    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
677    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
678    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
679
680  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
681  def test_training(self):
682    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
683                       run_eagerly=testing_utils.should_run_eagerly())
684    self.model.fit(self.generator_input, steps_per_epoch=10)
685
686  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
687  @testing_utils.run_v2_only
688  @data_utils.dont_use_multiprocessing_pool
689  def test_with_multiprocessing_training(self):
690    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
691                       run_eagerly=testing_utils.should_run_eagerly())
692    self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True,
693                   max_queue_size=10, steps_per_epoch=10)
694    # Fit twice to ensure there isn't any duplication that prevent the worker
695    # from starting.
696    self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True,
697                   max_queue_size=10, steps_per_epoch=10)
698
699  def test_size(self):
700    adapter = self.adapter_cls(self.generator_input)
701    self.assertIsNone(adapter.get_size())
702
703  def test_batch_size(self):
704    adapter = self.adapter_cls(self.generator_input)
705    self.assertEqual(adapter.batch_size(), None)
706    self.assertEqual(adapter.representative_batch_size(), 5)
707
708  def test_partial_batch(self):
709    adapter = self.adapter_cls(self.generator_input)
710    self.assertFalse(adapter.has_partial_batch())
711    self.assertIsNone(adapter.partial_batch_size())
712
713  def test_invalid_targets_argument(self):
714    with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'):
715      self.adapter_cls(self.generator_input, y=self.generator_input)
716
717  def test_invalid_sample_weights_argument(self):
718    with self.assertRaisesRegex(ValueError,
719                                r'`sample_weight` argument is not supported'):
720      self.adapter_cls(
721          self.generator_input, sample_weights=self.generator_input)
722
723  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
724  def test_not_shuffled(self):
725    def generator():
726      for i in range(10):
727        yield np.ones((1, 1)) * i
728
729    adapter = self.adapter_cls(generator(), shuffle=True)
730    for i, data in enumerate(adapter.get_dataset()):
731      self.assertEqual(i, data[0].numpy().flatten())
732
733
734class KerasSequenceAdapterTest(DataAdapterTestBase):
735
736  def setUp(self):
737    super(KerasSequenceAdapterTest, self).setUp()
738    self.adapter_cls = data_adapter.KerasSequenceAdapter
739
740  def test_can_handle(self):
741    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
742    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
743    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
744    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
745    self.assertTrue(self.adapter_cls.can_handle(self.sequence_input))
746    self.assertFalse(self.adapter_cls.can_handle(self.text_input))
747    self.assertFalse(self.adapter_cls.can_handle(self.bytes_input))
748
749  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
750  def test_training(self):
751    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
752                       run_eagerly=testing_utils.should_run_eagerly())
753    self.model.fit(self.sequence_input)
754
755  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
756  @testing_utils.run_v2_only
757  @data_utils.dont_use_multiprocessing_pool
758  def test_with_multiprocessing_training(self):
759    self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd',
760                       run_eagerly=testing_utils.should_run_eagerly())
761    self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True,
762                   max_queue_size=10, steps_per_epoch=10)
763    # Fit twice to ensure there isn't any duplication that prevent the worker
764    # from starting.
765    self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True,
766                   max_queue_size=10, steps_per_epoch=10)
767
768  def test_size(self):
769    adapter = self.adapter_cls(self.sequence_input)
770    self.assertEqual(adapter.get_size(), 10)
771
772  def test_batch_size(self):
773    adapter = self.adapter_cls(self.sequence_input)
774    self.assertEqual(adapter.batch_size(), None)
775    self.assertEqual(adapter.representative_batch_size(), 5)
776
777  def test_partial_batch(self):
778    adapter = self.adapter_cls(self.sequence_input)
779    self.assertFalse(adapter.has_partial_batch())
780    self.assertIsNone(adapter.partial_batch_size())
781
782  def test_invalid_targets_argument(self):
783    with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'):
784      self.adapter_cls(self.sequence_input, y=self.sequence_input)
785
786  def test_invalid_sample_weights_argument(self):
787    with self.assertRaisesRegex(ValueError,
788                                r'`sample_weight` argument is not supported'):
789      self.adapter_cls(self.sequence_input, sample_weights=self.sequence_input)
790
791
792class DataHandlerTest(keras_parameterized.TestCase):
793
794  def test_finite_dataset_with_steps_per_epoch(self):
795    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1)
796    # User can choose to only partially consume `Dataset`.
797    data_handler = data_adapter.DataHandler(
798        data, initial_epoch=0, epochs=2, steps_per_epoch=2)
799    self.assertEqual(data_handler.inferred_steps, 2)
800    self.assertFalse(data_handler._adapter.should_recreate_iterator())
801    returned_data = []
802    for _, iterator in data_handler.enumerate_epochs():
803      epoch_data = []
804      for _ in data_handler.steps():
805        epoch_data.append(next(iterator).numpy())
806      returned_data.append(epoch_data)
807    self.assertEqual(returned_data, [[0, 1], [2, 3]])
808
809  def test_finite_dataset_without_steps_per_epoch(self):
810    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1)
811    data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2)
812    self.assertEqual(data_handler.inferred_steps, 3)
813    returned_data = []
814    for _, iterator in data_handler.enumerate_epochs():
815      epoch_data = []
816      for _ in data_handler.steps():
817        epoch_data.append(next(iterator).numpy())
818      returned_data.append(epoch_data)
819    self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]])
820
821  def test_finite_dataset_with_steps_per_epoch_exact_size(self):
822    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1)
823    # If user specifies exact size of `Dataset` as `steps_per_epoch`,
824    # create a new iterator each epoch.
825    data_handler = data_adapter.DataHandler(
826        data, initial_epoch=0, epochs=2, steps_per_epoch=4)
827    self.assertTrue(data_handler._adapter.should_recreate_iterator())
828    returned_data = []
829    for _, iterator in data_handler.enumerate_epochs():
830      epoch_data = []
831      for _ in data_handler.steps():
832        epoch_data.append(next(iterator).numpy())
833      returned_data.append(epoch_data)
834    self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
835
836  def test_infinite_dataset_with_steps_per_epoch(self):
837    data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1).repeat()
838    data_handler = data_adapter.DataHandler(
839        data, initial_epoch=0, epochs=2, steps_per_epoch=3)
840    returned_data = []
841    for _, iterator in data_handler.enumerate_epochs():
842      epoch_data = []
843      for _ in data_handler.steps():
844        epoch_data.append(next(iterator).numpy())
845      returned_data.append(epoch_data)
846    self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]])
847
848  def test_unknown_cardinality_dataset_with_steps_per_epoch(self):
849    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
850    filtered_ds = ds.filter(lambda x: x < 4)
851    self.assertEqual(
852        cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN)
853
854    # User can choose to only partially consume `Dataset`.
855    data_handler = data_adapter.DataHandler(
856        filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2)
857    self.assertFalse(data_handler._adapter.should_recreate_iterator())
858    returned_data = []
859    for _, iterator in data_handler.enumerate_epochs():
860      epoch_data = []
861      for _ in data_handler.steps():
862        epoch_data.append(next(iterator))
863      returned_data.append(epoch_data)
864    returned_data = self.evaluate(returned_data)
865    self.assertEqual(returned_data, [[0, 1], [2, 3]])
866    self.assertEqual(data_handler.inferred_steps, 2)
867
868  def test_unknown_cardinality_dataset_without_steps_per_epoch(self):
869    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
870    filtered_ds = ds.filter(lambda x: x < 4)
871    self.assertEqual(
872        cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN)
873
874    data_handler = data_adapter.DataHandler(
875        filtered_ds, initial_epoch=0, epochs=2)
876    self.assertEqual(data_handler.inferred_steps, None)
877    self.assertTrue(data_handler._adapter.should_recreate_iterator())
878    returned_data = []
879    for _, iterator in data_handler.enumerate_epochs():
880      epoch_data = []
881      with data_handler.catch_stop_iteration():
882        for _ in data_handler.steps():
883          epoch_data.append(next(iterator))
884      returned_data.append(epoch_data)
885    returned_data = self.evaluate(returned_data)
886    self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
887    self.assertEqual(data_handler.inferred_steps, 4)
888
889  def test_insufficient_data(self):
890    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1])
891    ds = ds.filter(lambda *args, **kwargs: True)
892    data_handler = data_adapter.DataHandler(
893        ds, initial_epoch=0, epochs=2, steps_per_epoch=3)
894    returned_data = []
895    for _, iterator in data_handler.enumerate_epochs():
896      epoch_data = []
897      for _ in data_handler.steps():
898        with data_handler.catch_stop_iteration():
899          epoch_data.append(next(iterator))
900      returned_data.append(epoch_data)
901    returned_data = self.evaluate(returned_data)
902    self.assertTrue(data_handler._insufficient_data)
903    self.assertEqual(returned_data, [[0, 1]])
904
905  def test_numpy(self):
906    x = np.array([0, 1, 2])
907    y = np.array([0, 2, 4])
908    sw = np.array([0, 4, 8])
909    data_handler = data_adapter.DataHandler(
910        x=x, y=y, sample_weight=sw, batch_size=1, epochs=2)
911    returned_data = []
912    for _, iterator in data_handler.enumerate_epochs():
913      epoch_data = []
914      for _ in data_handler.steps():
915        epoch_data.append(next(iterator))
916      returned_data.append(epoch_data)
917    returned_data = self.evaluate(returned_data)
918    self.assertEqual(returned_data,
919                     [[(0, 0, 0), (1, 2, 4),
920                       (2, 4, 8)], [(0, 0, 0), (1, 2, 4), (2, 4, 8)]])
921
922  def test_generator(self):
923
924    def generator():
925      for _ in range(2):
926        for step in range(3):
927          yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
928
929    data_handler = data_adapter.DataHandler(
930        generator(), epochs=2, steps_per_epoch=3)
931    returned_data = []
932    for _, iterator in data_handler.enumerate_epochs():
933      epoch_data = []
934      for _ in data_handler.steps():
935        epoch_data.append(next(iterator))
936      returned_data.append(epoch_data)
937    returned_data = self.evaluate(returned_data)
938    self.assertEqual(returned_data, [[([0],), ([1],),
939                                      ([2],)], [([0],), ([1],), ([2],)]])
940
941  def test_composite_tensor(self):
942    st = sparse_tensor.SparseTensor(
943        indices=[[0, 0], [1, 0], [2, 0]], values=[0, 1, 2], dense_shape=[3, 1])
944    data_handler = data_adapter.DataHandler(st, epochs=2, steps_per_epoch=3)
945    returned_data = []
946    for _, iterator in data_handler.enumerate_epochs():
947      epoch_data = []
948      for _ in data_handler.steps():
949        epoch_data.append(next(iterator))
950      returned_data.append(epoch_data)
951    returned_data = self.evaluate(
952        nest.map_structure(sparse_ops.sparse_tensor_to_dense, returned_data))
953    self.assertEqual(returned_data, [[([0],), ([1],),
954                                      ([2],)], [([0],), ([1],), ([2],)]])
955
956  def test_iterator(self):
957    def generator():
958      for _ in range(2):
959        for step in range(3):
960          yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
961
962    it = iter(dataset_ops.Dataset.from_generator(
963        generator, output_types=('float32',)))
964    data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3)
965    returned_data = []
966    for _, iterator in data_handler.enumerate_epochs():
967      epoch_data = []
968      for _ in data_handler.steps():
969        epoch_data.append(next(iterator))
970      returned_data.append(epoch_data)
971    returned_data = self.evaluate(returned_data)
972    self.assertEqual(returned_data, [[([0],), ([1],), ([2],)],
973                                     [([0],), ([1],), ([2],)]])
974
975  def test_list_of_scalars(self):
976    data_handler = data_adapter.DataHandler([[0], [1], [2]],
977                                            epochs=2,
978                                            steps_per_epoch=3)
979    returned_data = []
980    for _, iterator in data_handler.enumerate_epochs():
981      epoch_data = []
982      for _ in data_handler.steps():
983        epoch_data.append(next(iterator))
984      returned_data.append(epoch_data)
985    returned_data = self.evaluate(returned_data)
986    self.assertEqual(returned_data, [[([0],), ([1],),
987                                      ([2],)], [([0],), ([1],), ([2],)]])
988
989  def test_class_weight_user_errors(self):
990    with self.assertRaisesRegex(ValueError, 'to be a dict with keys'):
991      data_adapter.DataHandler(
992          x=[[0], [1], [2]],
993          y=[[2], [1], [0]],
994          batch_size=1,
995          sample_weight=[[1.], [2.], [4.]],
996          class_weight={
997              0: 0.5,
998              1: 1.,
999              3: 1.5  # Skips class `2`.
1000          })
1001
1002    with self.assertRaisesRegex(ValueError, 'with a single output'):
1003      data_adapter.DataHandler(
1004          x=np.ones((10, 1)),
1005          y=[np.ones((10, 1)), np.zeros((10, 1))],
1006          batch_size=2,
1007          class_weight={
1008              0: 0.5,
1009              1: 1.,
1010              2: 1.5
1011          })
1012
1013  @parameterized.named_parameters(('numpy', True), ('dataset', False))
1014  def test_single_x_input_no_tuple_wrapping(self, use_numpy):
1015    x = np.ones((10, 1))
1016
1017    if use_numpy:
1018      batch_size = 2
1019    else:
1020      x = dataset_ops.Dataset.from_tensor_slices(x).batch(2)
1021      batch_size = None
1022
1023    data_handler = data_adapter.DataHandler(x, batch_size=batch_size)
1024    for _, iterator in data_handler.enumerate_epochs():
1025      for _ in data_handler.steps():
1026        # Check that single x input is not wrapped in a tuple.
1027        self.assertIsInstance(next(iterator), ops.Tensor)
1028
1029
1030class TestValidationSplit(keras_parameterized.TestCase):
1031
1032  @parameterized.named_parameters(('numpy_arrays', True), ('tensors', False))
1033  def test_validation_split_unshuffled(self, use_numpy):
1034    if use_numpy:
1035      x = np.array([0, 1, 2, 3, 4])
1036      y = np.array([0, 2, 4, 6, 8])
1037      sw = np.array([0, 4, 8, 12, 16])
1038    else:
1039      x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4])
1040      y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8])
1041      sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16])
1042
1043    (train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
1044        data_adapter.train_validation_split((x, y, sw), validation_split=0.2))
1045
1046    if use_numpy:
1047      train_x = ops.convert_to_tensor_v2_with_dispatch(train_x)
1048      train_y = ops.convert_to_tensor_v2_with_dispatch(train_y)
1049      train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw)
1050      val_x = ops.convert_to_tensor_v2_with_dispatch(val_x)
1051      val_y = ops.convert_to_tensor_v2_with_dispatch(val_y)
1052      val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw)
1053
1054    self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3])
1055    self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6])
1056    self.assertEqual(train_sw.numpy().tolist(), [0, 4, 8, 12])
1057
1058    self.assertEqual(val_x.numpy().tolist(), [4])
1059    self.assertEqual(val_y.numpy().tolist(), [8])
1060    self.assertEqual(val_sw.numpy().tolist(), [16])
1061
1062  def test_validation_split_user_error(self):
1063    with self.assertRaisesRegex(ValueError, 'is only supported for Tensors'):
1064      data_adapter.train_validation_split(
1065          lambda: np.ones((10, 1)), validation_split=0.2)
1066
1067  def test_validation_split_examples_too_few(self):
1068    with self.assertRaisesRegex(ValueError, 'not sufficient to split it'):
1069      data_adapter.train_validation_split(
1070          np.ones((1, 10)), validation_split=0.2)
1071
1072  def test_validation_split_none(self):
1073    train_sw, val_sw = data_adapter.train_validation_split(
1074        None, validation_split=0.2)
1075    self.assertIsNone(train_sw)
1076    self.assertIsNone(val_sw)
1077
1078    (_, train_sw), (_, val_sw) = data_adapter.train_validation_split(
1079        (np.ones((10, 1)), None), validation_split=0.2)
1080    self.assertIsNone(train_sw)
1081    self.assertIsNone(val_sw)
1082
1083
1084class ListsOfScalarsDataAdapterTest(DataAdapterTestBase):
1085
1086  def setUp(self):
1087    super(ListsOfScalarsDataAdapterTest, self).setUp()
1088    self.adapter_cls = data_adapter.ListsOfScalarsDataAdapter
1089
1090  def test_can_list_inputs(self):
1091    self.assertTrue(self.adapter_cls.can_handle(self.text_input))
1092    self.assertTrue(self.adapter_cls.can_handle(self.bytes_input))
1093
1094    self.assertFalse(self.adapter_cls.can_handle(self.numpy_input))
1095    self.assertFalse(self.adapter_cls.can_handle(self.tensor_input))
1096    self.assertFalse(self.adapter_cls.can_handle(self.dataset_input))
1097    self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
1098    self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
1099    self.assertFalse(self.adapter_cls.can_handle([]))
1100
1101
1102class TestUtils(keras_parameterized.TestCase):
1103
1104  def test_expand_1d_sparse_tensors_untouched(self):
1105    st = sparse_tensor.SparseTensor(
1106        indices=[[0], [10]], values=[1, 2], dense_shape=[10])
1107    st = data_adapter.expand_1d(st)
1108    self.assertEqual(st.shape.rank, 1)
1109
1110
1111if __name__ == '__main__':
1112  ops.enable_eager_execution()
1113  test.main()
1114