1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""DataAdapter tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.python import keras 27from tensorflow.python.data.experimental.ops import cardinality 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.keras import keras_parameterized 33from tensorflow.python.keras import testing_utils 34from tensorflow.python.keras.engine import data_adapter 35from tensorflow.python.keras.utils import data_utils 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import sparse_ops 38from tensorflow.python.platform import test 39from tensorflow.python.util import nest 40 41 42class DummyArrayLike(object): 43 """Dummy array-like object.""" 44 45 def __init__(self, data): 46 self.data = data 47 48 def __len__(self): 49 return len(self.data) 50 51 def __getitem__(self, key): 52 return self.data[key] 53 54 @property 55 def shape(self): 56 return self.data.shape 57 58 @property 59 def dtype(self): 60 return self.data.dtype 61 62 63def fail_on_convert(x, **kwargs): 64 _ = x 65 _ = kwargs 66 raise TypeError('Cannot convert DummyArrayLike to a tensor') 67ops.register_tensor_conversion_function(DummyArrayLike, fail_on_convert) 68 69 70class DataAdapterTestBase(keras_parameterized.TestCase): 71 72 def setUp(self): 73 super(DataAdapterTestBase, self).setUp() 74 self.batch_size = 5 75 self.numpy_input = np.zeros((50, 10)) 76 self.numpy_target = np.ones(50) 77 self.tensor_input = constant_op.constant(2.0, shape=(50, 10)) 78 self.tensor_target = array_ops.ones((50,)) 79 self.arraylike_input = DummyArrayLike(self.numpy_input) 80 self.arraylike_target = DummyArrayLike(self.numpy_target) 81 self.dataset_input = dataset_ops.DatasetV2.from_tensor_slices( 82 (self.numpy_input, self.numpy_target)).shuffle(50).batch( 83 self.batch_size) 84 85 def generator(): 86 while True: 87 yield (np.zeros((self.batch_size, 10)), np.ones(self.batch_size)) 88 self.generator_input = generator() 89 self.iterator_input = data_utils.threadsafe_generator(generator)() 90 self.sequence_input = TestSequence(batch_size=self.batch_size, 91 feature_shape=10) 92 self.text_input = [['abc']] 93 self.bytes_input = [[b'abc']] 94 self.model = keras.models.Sequential( 95 [keras.layers.Dense(8, input_shape=(10,), activation='softmax')]) 96 97 98class TestSequence(data_utils.Sequence): 99 100 def __init__(self, batch_size, feature_shape): 101 self.batch_size = batch_size 102 self.feature_shape = feature_shape 103 104 def __getitem__(self, item): 105 return (np.zeros((self.batch_size, self.feature_shape)), 106 np.ones((self.batch_size,))) 107 108 def __len__(self): 109 return 10 110 111 112class TensorLikeDataAdapterTest(DataAdapterTestBase): 113 114 def setUp(self): 115 super(TensorLikeDataAdapterTest, self).setUp() 116 self.adapter_cls = data_adapter.TensorLikeDataAdapter 117 118 def test_can_handle_numpy(self): 119 self.assertTrue(self.adapter_cls.can_handle(self.numpy_input)) 120 self.assertTrue( 121 self.adapter_cls.can_handle(self.numpy_input, self.numpy_target)) 122 123 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 124 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 125 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 126 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 127 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 128 129 def test_size_numpy(self): 130 adapter = self.adapter_cls( 131 self.numpy_input, self.numpy_target, batch_size=5) 132 self.assertEqual(adapter.get_size(), 10) 133 self.assertFalse(adapter.has_partial_batch()) 134 135 def test_batch_size_numpy(self): 136 adapter = self.adapter_cls( 137 self.numpy_input, self.numpy_target, batch_size=5) 138 self.assertEqual(adapter.batch_size(), 5) 139 140 def test_partial_batch_numpy(self): 141 adapter = self.adapter_cls( 142 self.numpy_input, self.numpy_target, batch_size=4) 143 self.assertEqual(adapter.get_size(), 13) # 50/4 144 self.assertTrue(adapter.has_partial_batch()) 145 self.assertEqual(adapter.partial_batch_size(), 2) 146 147 def test_epochs(self): 148 num_epochs = 3 149 adapter = self.adapter_cls( 150 self.numpy_input, self.numpy_target, batch_size=5, epochs=num_epochs) 151 ds_iter = iter(adapter.get_dataset()) 152 num_batches_per_epoch = self.numpy_input.shape[0] // 5 153 for _ in range(num_batches_per_epoch * num_epochs): 154 next(ds_iter) 155 with self.assertRaises(StopIteration): 156 next(ds_iter) 157 158 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 159 def test_training_numpy(self): 160 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 161 run_eagerly=testing_utils.should_run_eagerly()) 162 self.model.fit(self.numpy_input, self.numpy_target, batch_size=5) 163 164 def test_can_handle_pandas(self): 165 try: 166 import pandas as pd # pylint: disable=g-import-not-at-top 167 except ImportError: 168 self.skipTest('Skipping test because pandas is not installed.') 169 self.assertTrue(self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input))) 170 self.assertTrue( 171 self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input)[0])) 172 self.assertTrue( 173 self.adapter_cls.can_handle( 174 pd.DataFrame(self.numpy_input), 175 pd.DataFrame(self.numpy_input)[0])) 176 177 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 178 def test_training_pandas(self): 179 try: 180 import pandas as pd # pylint: disable=g-import-not-at-top 181 except ImportError: 182 self.skipTest('Skipping test because pandas is not installed.') 183 input_a = keras.Input(shape=(3,), name='input_a') 184 input_b = keras.Input(shape=(3,), name='input_b') 185 input_c = keras.Input(shape=(1,), name='input_b') 186 187 x = keras.layers.Dense(4, name='dense_1')(input_a) 188 y = keras.layers.Dense(3, name='dense_2')(input_b) 189 z = keras.layers.Dense(1, name='dense_3')(input_c) 190 191 model_1 = keras.Model(inputs=input_a, outputs=x) 192 model_2 = keras.Model(inputs=[input_a, input_b], outputs=[x, y]) 193 model_3 = keras.Model(inputs=input_c, outputs=z) 194 195 model_1.compile(optimizer='rmsprop', loss='mse') 196 model_2.compile(optimizer='rmsprop', loss='mse') 197 198 input_a_np = np.random.random((10, 3)) 199 input_b_np = np.random.random((10, 3)) 200 input_a_df = pd.DataFrame(input_a_np) 201 input_b_df = pd.DataFrame(input_b_np) 202 203 output_a_df = pd.DataFrame(np.random.random((10, 4))) 204 output_b_df = pd.DataFrame(np.random.random((10, 3))) 205 206 model_1.fit(input_a_df, 207 output_a_df) 208 model_2.fit([input_a_df, input_b_df], 209 [output_a_df, output_b_df]) 210 model_1.fit([input_a_df], 211 [output_a_df]) 212 model_1.fit({'input_a': input_a_df}, 213 output_a_df) 214 model_2.fit({'input_a': input_a_df, 'input_b': input_b_df}, 215 [output_a_df, output_b_df]) 216 217 model_1.evaluate(input_a_df, 218 output_a_df) 219 model_2.evaluate([input_a_df, input_b_df], 220 [output_a_df, output_b_df]) 221 model_1.evaluate([input_a_df], 222 [output_a_df]) 223 model_1.evaluate({'input_a': input_a_df}, 224 output_a_df) 225 model_2.evaluate({'input_a': input_a_df, 'input_b': input_b_df}, 226 [output_a_df, output_b_df]) 227 228 # Verify predicting on pandas vs numpy returns the same result 229 predict_1_pandas = model_1.predict(input_a_df) 230 predict_2_pandas = model_2.predict([input_a_df, input_b_df]) 231 predict_3_pandas = model_3.predict(input_a_df[0]) 232 233 predict_1_numpy = model_1.predict(input_a_np) 234 predict_2_numpy = model_2.predict([input_a_np, input_b_np]) 235 predict_3_numpy = model_3.predict(np.asarray(input_a_df[0])) 236 237 self.assertAllClose(predict_1_numpy, predict_1_pandas) 238 self.assertAllClose(predict_2_numpy, predict_2_pandas) 239 self.assertAllClose(predict_3_numpy, predict_3_pandas) 240 241 # Extra ways to pass in dataframes 242 model_1.predict([input_a_df]) 243 model_1.predict({'input_a': input_a_df}) 244 model_2.predict({'input_a': input_a_df, 'input_b': input_b_df}) 245 246 def test_can_handle(self): 247 self.assertTrue(self.adapter_cls.can_handle(self.tensor_input)) 248 self.assertTrue( 249 self.adapter_cls.can_handle(self.tensor_input, self.tensor_target)) 250 251 self.assertFalse(self.adapter_cls.can_handle(self.arraylike_input)) 252 self.assertFalse( 253 self.adapter_cls.can_handle(self.arraylike_input, 254 self.arraylike_target)) 255 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 256 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 257 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 258 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 259 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 260 261 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 262 def test_training(self): 263 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 264 run_eagerly=testing_utils.should_run_eagerly()) 265 self.model.fit(self.tensor_input, self.tensor_target, batch_size=5) 266 267 def test_size(self): 268 adapter = self.adapter_cls( 269 self.tensor_input, self.tensor_target, batch_size=5) 270 self.assertEqual(adapter.get_size(), 10) 271 self.assertFalse(adapter.has_partial_batch()) 272 273 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 274 def test_shuffle_correctness(self): 275 num_samples = 100 276 batch_size = 32 277 x = np.arange(num_samples) 278 np.random.seed(99) 279 adapter = self.adapter_cls( 280 x, y=None, batch_size=batch_size, shuffle=True, epochs=2) 281 282 def _get_epoch(ds_iter): 283 ds_data = [] 284 for _ in range(int(math.ceil(num_samples / batch_size))): 285 ds_data.append(next(ds_iter).numpy()) 286 return np.concatenate(ds_data) 287 288 ds_iter = iter(adapter.get_dataset()) 289 290 # First epoch. 291 epoch_data = _get_epoch(ds_iter) 292 # Check that shuffling occurred. 293 self.assertNotAllClose(x, epoch_data) 294 # Check that each elements appears, and only once. 295 self.assertAllClose(x, np.sort(epoch_data)) 296 297 # Second epoch. 298 second_epoch_data = _get_epoch(ds_iter) 299 # Check that shuffling occurred. 300 self.assertNotAllClose(x, second_epoch_data) 301 # Check that shuffling is different across epochs. 302 self.assertNotAllClose(epoch_data, second_epoch_data) 303 # Check that each elements appears, and only once. 304 self.assertAllClose(x, np.sort(second_epoch_data)) 305 306 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 307 def test_batch_shuffle_correctness(self): 308 num_samples = 100 309 batch_size = 6 310 x = np.arange(num_samples) 311 np.random.seed(99) 312 adapter = self.adapter_cls( 313 x, y=None, batch_size=batch_size, shuffle='batch', epochs=2) 314 315 def _get_epoch_batches(ds_iter): 316 ds_data = [] 317 for _ in range(int(math.ceil(num_samples / batch_size))): 318 ds_data.append(next(ds_iter)[0].numpy()) 319 return ds_data 320 321 ds_iter = iter(adapter.get_dataset()) 322 323 # First epoch. 324 epoch_batch_data = _get_epoch_batches(ds_iter) 325 epoch_data = np.concatenate(epoch_batch_data) 326 327 def _verify_batch(batch): 328 # Verify that a batch contains only contiguous data, and that it has 329 # been shuffled. 330 shuffled_batch = np.sort(batch) 331 self.assertNotAllClose(batch, shuffled_batch) 332 for i in range(1, len(batch)): 333 self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i]) 334 335 # Assert that the data within each batch remains contiguous 336 for batch in epoch_batch_data: 337 _verify_batch(batch) 338 339 # Check that individual batches are unshuffled 340 # Check that shuffling occurred. 341 self.assertNotAllClose(x, epoch_data) 342 # Check that each elements appears, and only once. 343 self.assertAllClose(x, np.sort(epoch_data)) 344 345 # Second epoch. 346 second_epoch_batch_data = _get_epoch_batches(ds_iter) 347 second_epoch_data = np.concatenate(second_epoch_batch_data) 348 349 # Assert that the data within each batch remains contiguous 350 for batch in second_epoch_batch_data: 351 _verify_batch(batch) 352 353 # Check that shuffling occurred. 354 self.assertNotAllClose(x, second_epoch_data) 355 # Check that shuffling is different across epochs. 356 self.assertNotAllClose(epoch_data, second_epoch_data) 357 # Check that each elements appears, and only once. 358 self.assertAllClose(x, np.sort(second_epoch_data)) 359 360 @parameterized.named_parameters( 361 ('batch_size_5', 5, None, 5), 362 ('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence 363 ('steps_1', None, 1, 50), 364 ('steps_4', None, 4, 13), 365 ) 366 def test_batch_size(self, batch_size_in, steps, batch_size_out): 367 adapter = self.adapter_cls( 368 self.tensor_input, self.tensor_target, batch_size=batch_size_in, 369 steps=steps) 370 self.assertEqual(adapter.batch_size(), batch_size_out) 371 372 @parameterized.named_parameters( 373 ('batch_size_5', 5, None, 10, 0), 374 ('batch_size_4', 4, None, 13, 2), 375 ('steps_1', None, 1, 1, 0), 376 ('steps_5', None, 5, 5, 0), 377 ('steps_4', None, 4, 4, 11), 378 ) 379 def test_partial_batch( 380 self, batch_size_in, steps, size, partial_batch_size): 381 adapter = self.adapter_cls( 382 self.tensor_input, self.tensor_target, batch_size=batch_size_in, 383 steps=steps) 384 self.assertEqual(adapter.get_size(), size) # 50/steps 385 self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size)) 386 self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None) 387 388 389class GenericArrayLikeDataAdapterTest(DataAdapterTestBase): 390 391 def setUp(self): 392 super(GenericArrayLikeDataAdapterTest, self).setUp() 393 self.adapter_cls = data_adapter.GenericArrayLikeDataAdapter 394 395 def test_can_handle_some_numpy(self): 396 self.assertTrue(self.adapter_cls.can_handle( 397 self.arraylike_input)) 398 self.assertTrue( 399 self.adapter_cls.can_handle(self.arraylike_input, 400 self.arraylike_target)) 401 402 # Because adapters are mutually exclusive, don't handle cases 403 # where all the data is numpy or an eagertensor 404 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 405 self.assertFalse( 406 self.adapter_cls.can_handle(self.numpy_input, 407 self.numpy_target)) 408 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 409 self.assertFalse( 410 self.adapter_cls.can_handle(self.tensor_input, self.tensor_target)) 411 412 # But do handle mixes that include generic arraylike data 413 self.assertTrue( 414 self.adapter_cls.can_handle(self.numpy_input, 415 self.arraylike_target)) 416 self.assertTrue( 417 self.adapter_cls.can_handle(self.arraylike_input, 418 self.numpy_target)) 419 self.assertTrue( 420 self.adapter_cls.can_handle(self.arraylike_input, 421 self.tensor_target)) 422 self.assertTrue( 423 self.adapter_cls.can_handle(self.tensor_input, 424 self.arraylike_target)) 425 426 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 427 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 428 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 429 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 430 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 431 432 def test_size(self): 433 adapter = self.adapter_cls( 434 self.arraylike_input, 435 self.arraylike_target, batch_size=5) 436 self.assertEqual(adapter.get_size(), 10) 437 self.assertFalse(adapter.has_partial_batch()) 438 439 def test_epochs(self): 440 num_epochs = 3 441 adapter = self.adapter_cls( 442 self.arraylike_input, 443 self.numpy_target, batch_size=5, epochs=num_epochs) 444 ds_iter = iter(adapter.get_dataset()) 445 num_batches_per_epoch = self.numpy_input.shape[0] // 5 446 for _ in range(num_batches_per_epoch * num_epochs): 447 next(ds_iter) 448 with self.assertRaises(StopIteration): 449 next(ds_iter) 450 451 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 452 def test_training(self): 453 # First verify that DummyArrayLike can't be converted to a Tensor 454 with self.assertRaises(TypeError): 455 ops.convert_to_tensor_v2_with_dispatch(self.arraylike_input) 456 457 # Then train on the array like. 458 # It should not be converted to a tensor directly (which would force it into 459 # memory), only the sliced data should be converted. 460 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 461 run_eagerly=testing_utils.should_run_eagerly()) 462 self.model.fit(self.arraylike_input, 463 self.arraylike_target, batch_size=5) 464 self.model.fit(self.arraylike_input, 465 self.arraylike_target, 466 shuffle=True, batch_size=5) 467 self.model.fit(self.arraylike_input, 468 self.arraylike_target, 469 shuffle='batch', batch_size=5) 470 self.model.evaluate(self.arraylike_input, 471 self.arraylike_target, batch_size=5) 472 self.model.predict(self.arraylike_input, batch_size=5) 473 474 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 475 def test_training_numpy_target(self): 476 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 477 run_eagerly=testing_utils.should_run_eagerly()) 478 self.model.fit(self.arraylike_input, 479 self.numpy_target, batch_size=5) 480 self.model.fit(self.arraylike_input, 481 self.numpy_target, shuffle=True, 482 batch_size=5) 483 self.model.fit(self.arraylike_input, 484 self.numpy_target, shuffle='batch', 485 batch_size=5) 486 self.model.evaluate(self.arraylike_input, 487 self.numpy_target, batch_size=5) 488 489 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 490 def test_training_tensor_target(self): 491 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 492 run_eagerly=testing_utils.should_run_eagerly()) 493 self.model.fit(self.arraylike_input, 494 self.tensor_target, batch_size=5) 495 self.model.fit(self.arraylike_input, 496 self.tensor_target, shuffle=True, 497 batch_size=5) 498 self.model.fit(self.arraylike_input, 499 self.tensor_target, shuffle='batch', 500 batch_size=5) 501 self.model.evaluate(self.arraylike_input, 502 self.tensor_target, batch_size=5) 503 504 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 505 def test_shuffle_correctness(self): 506 num_samples = 100 507 batch_size = 32 508 x = DummyArrayLike(np.arange(num_samples)) 509 np.random.seed(99) 510 adapter = self.adapter_cls( 511 x, y=None, batch_size=batch_size, shuffle=True, epochs=2) 512 513 def _get_epoch(ds_iter): 514 ds_data = [] 515 for _ in range(int(math.ceil(num_samples / batch_size))): 516 ds_data.append(next(ds_iter).numpy()) 517 return np.concatenate(ds_data) 518 519 ds_iter = iter(adapter.get_dataset()) 520 521 # First epoch. 522 epoch_data = _get_epoch(ds_iter) 523 # Check that shuffling occurred. 524 self.assertNotAllClose(x, epoch_data) 525 # Check that each elements appears, and only once. 526 self.assertAllClose(x, np.sort(epoch_data)) 527 528 # Second epoch. 529 second_epoch_data = _get_epoch(ds_iter) 530 # Check that shuffling occurred. 531 self.assertNotAllClose(x, second_epoch_data) 532 # Check that shuffling is different across epochs. 533 self.assertNotAllClose(epoch_data, second_epoch_data) 534 # Check that each elements appears, and only once. 535 self.assertAllClose(x, np.sort(second_epoch_data)) 536 537 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 538 def test_batch_shuffle_correctness(self): 539 num_samples = 100 540 batch_size = 6 541 x = DummyArrayLike(np.arange(num_samples)) 542 np.random.seed(99) 543 adapter = self.adapter_cls( 544 x, y=None, batch_size=batch_size, shuffle='batch', epochs=2) 545 546 def _get_epoch_batches(ds_iter): 547 ds_data = [] 548 for _ in range(int(math.ceil(num_samples / batch_size))): 549 ds_data.append(next(ds_iter)[0].numpy()) 550 return ds_data 551 552 ds_iter = iter(adapter.get_dataset()) 553 554 # First epoch. 555 epoch_batch_data = _get_epoch_batches(ds_iter) 556 epoch_data = np.concatenate(epoch_batch_data) 557 558 def _verify_batch(batch): 559 # Verify that a batch contains only contiguous data, but that it has 560 # been shuffled. 561 shuffled_batch = np.sort(batch) 562 self.assertNotAllClose(batch, shuffled_batch) 563 for i in range(1, len(batch)): 564 self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i]) 565 566 # Assert that the data within each batch is shuffled contiguous data 567 for batch in epoch_batch_data: 568 _verify_batch(batch) 569 570 # Check that individual batches are unshuffled 571 # Check that shuffling occurred. 572 self.assertNotAllClose(x, epoch_data) 573 # Check that each elements appears, and only once. 574 self.assertAllClose(x, np.sort(epoch_data)) 575 576 # Second epoch. 577 second_epoch_batch_data = _get_epoch_batches(ds_iter) 578 second_epoch_data = np.concatenate(second_epoch_batch_data) 579 580 # Assert that the data within each batch remains contiguous 581 for batch in second_epoch_batch_data: 582 _verify_batch(batch) 583 584 # Check that shuffling occurred. 585 self.assertNotAllClose(x, second_epoch_data) 586 # Check that shuffling is different across epochs. 587 self.assertNotAllClose(epoch_data, second_epoch_data) 588 # Check that each elements appears, and only once. 589 self.assertAllClose(x, np.sort(second_epoch_data)) 590 591 @parameterized.named_parameters( 592 ('batch_size_5', 5, None, 5), 593 ('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence 594 ('steps_1', None, 1, 50), 595 ('steps_4', None, 4, 13), 596 ) 597 def test_batch_size(self, batch_size_in, steps, batch_size_out): 598 adapter = self.adapter_cls( 599 self.arraylike_input, 600 self.arraylike_target, batch_size=batch_size_in, 601 steps=steps) 602 self.assertEqual(adapter.batch_size(), batch_size_out) 603 604 @parameterized.named_parameters( 605 ('batch_size_5', 5, None, 10, 0), 606 ('batch_size_4', 4, None, 13, 2), 607 ('steps_1', None, 1, 1, 0), 608 ('steps_5', None, 5, 5, 0), 609 ('steps_4', None, 4, 4, 11), 610 ) 611 def test_partial_batch( 612 self, batch_size_in, steps, size, partial_batch_size): 613 adapter = self.adapter_cls( 614 self.arraylike_input, self.arraylike_target, 615 batch_size=batch_size_in, 616 steps=steps) 617 self.assertEqual(adapter.get_size(), size) # 50/steps 618 self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size)) 619 self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None) 620 621 622class DatasetAdapterTest(DataAdapterTestBase): 623 624 def setUp(self): 625 super(DatasetAdapterTest, self).setUp() 626 self.adapter_cls = data_adapter.DatasetAdapter 627 628 def test_can_handle(self): 629 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 630 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 631 self.assertTrue(self.adapter_cls.can_handle(self.dataset_input)) 632 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 633 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 634 635 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 636 def test_training(self): 637 dataset = self.adapter_cls(self.dataset_input).get_dataset() 638 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 639 run_eagerly=testing_utils.should_run_eagerly()) 640 self.model.fit(dataset) 641 642 def test_size(self): 643 adapter = self.adapter_cls(self.dataset_input) 644 self.assertIsNone(adapter.get_size()) 645 646 def test_batch_size(self): 647 adapter = self.adapter_cls(self.dataset_input) 648 self.assertIsNone(adapter.batch_size()) 649 650 def test_partial_batch(self): 651 adapter = self.adapter_cls(self.dataset_input) 652 self.assertFalse(adapter.has_partial_batch()) 653 self.assertIsNone(adapter.partial_batch_size()) 654 655 def test_invalid_targets_argument(self): 656 with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'): 657 self.adapter_cls(self.dataset_input, y=self.dataset_input) 658 659 def test_invalid_sample_weights_argument(self): 660 with self.assertRaisesRegex(ValueError, 661 r'`sample_weight` argument is not supported'): 662 self.adapter_cls(self.dataset_input, sample_weights=self.dataset_input) 663 664 665class GeneratorDataAdapterTest(DataAdapterTestBase): 666 667 def setUp(self): 668 super(GeneratorDataAdapterTest, self).setUp() 669 self.adapter_cls = data_adapter.GeneratorDataAdapter 670 671 def test_can_handle(self): 672 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 673 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 674 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 675 self.assertTrue(self.adapter_cls.can_handle(self.generator_input)) 676 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 677 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 678 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 679 680 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 681 def test_training(self): 682 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 683 run_eagerly=testing_utils.should_run_eagerly()) 684 self.model.fit(self.generator_input, steps_per_epoch=10) 685 686 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 687 @testing_utils.run_v2_only 688 @data_utils.dont_use_multiprocessing_pool 689 def test_with_multiprocessing_training(self): 690 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 691 run_eagerly=testing_utils.should_run_eagerly()) 692 self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True, 693 max_queue_size=10, steps_per_epoch=10) 694 # Fit twice to ensure there isn't any duplication that prevent the worker 695 # from starting. 696 self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True, 697 max_queue_size=10, steps_per_epoch=10) 698 699 def test_size(self): 700 adapter = self.adapter_cls(self.generator_input) 701 self.assertIsNone(adapter.get_size()) 702 703 def test_batch_size(self): 704 adapter = self.adapter_cls(self.generator_input) 705 self.assertEqual(adapter.batch_size(), None) 706 self.assertEqual(adapter.representative_batch_size(), 5) 707 708 def test_partial_batch(self): 709 adapter = self.adapter_cls(self.generator_input) 710 self.assertFalse(adapter.has_partial_batch()) 711 self.assertIsNone(adapter.partial_batch_size()) 712 713 def test_invalid_targets_argument(self): 714 with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'): 715 self.adapter_cls(self.generator_input, y=self.generator_input) 716 717 def test_invalid_sample_weights_argument(self): 718 with self.assertRaisesRegex(ValueError, 719 r'`sample_weight` argument is not supported'): 720 self.adapter_cls( 721 self.generator_input, sample_weights=self.generator_input) 722 723 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 724 def test_not_shuffled(self): 725 def generator(): 726 for i in range(10): 727 yield np.ones((1, 1)) * i 728 729 adapter = self.adapter_cls(generator(), shuffle=True) 730 for i, data in enumerate(adapter.get_dataset()): 731 self.assertEqual(i, data[0].numpy().flatten()) 732 733 734class KerasSequenceAdapterTest(DataAdapterTestBase): 735 736 def setUp(self): 737 super(KerasSequenceAdapterTest, self).setUp() 738 self.adapter_cls = data_adapter.KerasSequenceAdapter 739 740 def test_can_handle(self): 741 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 742 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 743 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 744 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 745 self.assertTrue(self.adapter_cls.can_handle(self.sequence_input)) 746 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 747 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 748 749 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 750 def test_training(self): 751 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 752 run_eagerly=testing_utils.should_run_eagerly()) 753 self.model.fit(self.sequence_input) 754 755 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 756 @testing_utils.run_v2_only 757 @data_utils.dont_use_multiprocessing_pool 758 def test_with_multiprocessing_training(self): 759 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 760 run_eagerly=testing_utils.should_run_eagerly()) 761 self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True, 762 max_queue_size=10, steps_per_epoch=10) 763 # Fit twice to ensure there isn't any duplication that prevent the worker 764 # from starting. 765 self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True, 766 max_queue_size=10, steps_per_epoch=10) 767 768 def test_size(self): 769 adapter = self.adapter_cls(self.sequence_input) 770 self.assertEqual(adapter.get_size(), 10) 771 772 def test_batch_size(self): 773 adapter = self.adapter_cls(self.sequence_input) 774 self.assertEqual(adapter.batch_size(), None) 775 self.assertEqual(adapter.representative_batch_size(), 5) 776 777 def test_partial_batch(self): 778 adapter = self.adapter_cls(self.sequence_input) 779 self.assertFalse(adapter.has_partial_batch()) 780 self.assertIsNone(adapter.partial_batch_size()) 781 782 def test_invalid_targets_argument(self): 783 with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'): 784 self.adapter_cls(self.sequence_input, y=self.sequence_input) 785 786 def test_invalid_sample_weights_argument(self): 787 with self.assertRaisesRegex(ValueError, 788 r'`sample_weight` argument is not supported'): 789 self.adapter_cls(self.sequence_input, sample_weights=self.sequence_input) 790 791 792class DataHandlerTest(keras_parameterized.TestCase): 793 794 def test_finite_dataset_with_steps_per_epoch(self): 795 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1) 796 # User can choose to only partially consume `Dataset`. 797 data_handler = data_adapter.DataHandler( 798 data, initial_epoch=0, epochs=2, steps_per_epoch=2) 799 self.assertEqual(data_handler.inferred_steps, 2) 800 self.assertFalse(data_handler._adapter.should_recreate_iterator()) 801 returned_data = [] 802 for _, iterator in data_handler.enumerate_epochs(): 803 epoch_data = [] 804 for _ in data_handler.steps(): 805 epoch_data.append(next(iterator).numpy()) 806 returned_data.append(epoch_data) 807 self.assertEqual(returned_data, [[0, 1], [2, 3]]) 808 809 def test_finite_dataset_without_steps_per_epoch(self): 810 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1) 811 data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2) 812 self.assertEqual(data_handler.inferred_steps, 3) 813 returned_data = [] 814 for _, iterator in data_handler.enumerate_epochs(): 815 epoch_data = [] 816 for _ in data_handler.steps(): 817 epoch_data.append(next(iterator).numpy()) 818 returned_data.append(epoch_data) 819 self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]]) 820 821 def test_finite_dataset_with_steps_per_epoch_exact_size(self): 822 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1) 823 # If user specifies exact size of `Dataset` as `steps_per_epoch`, 824 # create a new iterator each epoch. 825 data_handler = data_adapter.DataHandler( 826 data, initial_epoch=0, epochs=2, steps_per_epoch=4) 827 self.assertTrue(data_handler._adapter.should_recreate_iterator()) 828 returned_data = [] 829 for _, iterator in data_handler.enumerate_epochs(): 830 epoch_data = [] 831 for _ in data_handler.steps(): 832 epoch_data.append(next(iterator).numpy()) 833 returned_data.append(epoch_data) 834 self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]]) 835 836 def test_infinite_dataset_with_steps_per_epoch(self): 837 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1).repeat() 838 data_handler = data_adapter.DataHandler( 839 data, initial_epoch=0, epochs=2, steps_per_epoch=3) 840 returned_data = [] 841 for _, iterator in data_handler.enumerate_epochs(): 842 epoch_data = [] 843 for _ in data_handler.steps(): 844 epoch_data.append(next(iterator).numpy()) 845 returned_data.append(epoch_data) 846 self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]]) 847 848 def test_unknown_cardinality_dataset_with_steps_per_epoch(self): 849 ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6]) 850 filtered_ds = ds.filter(lambda x: x < 4) 851 self.assertEqual( 852 cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN) 853 854 # User can choose to only partially consume `Dataset`. 855 data_handler = data_adapter.DataHandler( 856 filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2) 857 self.assertFalse(data_handler._adapter.should_recreate_iterator()) 858 returned_data = [] 859 for _, iterator in data_handler.enumerate_epochs(): 860 epoch_data = [] 861 for _ in data_handler.steps(): 862 epoch_data.append(next(iterator)) 863 returned_data.append(epoch_data) 864 returned_data = self.evaluate(returned_data) 865 self.assertEqual(returned_data, [[0, 1], [2, 3]]) 866 self.assertEqual(data_handler.inferred_steps, 2) 867 868 def test_unknown_cardinality_dataset_without_steps_per_epoch(self): 869 ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6]) 870 filtered_ds = ds.filter(lambda x: x < 4) 871 self.assertEqual( 872 cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN) 873 874 data_handler = data_adapter.DataHandler( 875 filtered_ds, initial_epoch=0, epochs=2) 876 self.assertEqual(data_handler.inferred_steps, None) 877 self.assertTrue(data_handler._adapter.should_recreate_iterator()) 878 returned_data = [] 879 for _, iterator in data_handler.enumerate_epochs(): 880 epoch_data = [] 881 with data_handler.catch_stop_iteration(): 882 for _ in data_handler.steps(): 883 epoch_data.append(next(iterator)) 884 returned_data.append(epoch_data) 885 returned_data = self.evaluate(returned_data) 886 self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]]) 887 self.assertEqual(data_handler.inferred_steps, 4) 888 889 def test_insufficient_data(self): 890 ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1]) 891 ds = ds.filter(lambda *args, **kwargs: True) 892 data_handler = data_adapter.DataHandler( 893 ds, initial_epoch=0, epochs=2, steps_per_epoch=3) 894 returned_data = [] 895 for _, iterator in data_handler.enumerate_epochs(): 896 epoch_data = [] 897 for _ in data_handler.steps(): 898 with data_handler.catch_stop_iteration(): 899 epoch_data.append(next(iterator)) 900 returned_data.append(epoch_data) 901 returned_data = self.evaluate(returned_data) 902 self.assertTrue(data_handler._insufficient_data) 903 self.assertEqual(returned_data, [[0, 1]]) 904 905 def test_numpy(self): 906 x = np.array([0, 1, 2]) 907 y = np.array([0, 2, 4]) 908 sw = np.array([0, 4, 8]) 909 data_handler = data_adapter.DataHandler( 910 x=x, y=y, sample_weight=sw, batch_size=1, epochs=2) 911 returned_data = [] 912 for _, iterator in data_handler.enumerate_epochs(): 913 epoch_data = [] 914 for _ in data_handler.steps(): 915 epoch_data.append(next(iterator)) 916 returned_data.append(epoch_data) 917 returned_data = self.evaluate(returned_data) 918 self.assertEqual(returned_data, 919 [[(0, 0, 0), (1, 2, 4), 920 (2, 4, 8)], [(0, 0, 0), (1, 2, 4), (2, 4, 8)]]) 921 922 def test_generator(self): 923 924 def generator(): 925 for _ in range(2): 926 for step in range(3): 927 yield (ops.convert_to_tensor_v2_with_dispatch([step]),) 928 929 data_handler = data_adapter.DataHandler( 930 generator(), epochs=2, steps_per_epoch=3) 931 returned_data = [] 932 for _, iterator in data_handler.enumerate_epochs(): 933 epoch_data = [] 934 for _ in data_handler.steps(): 935 epoch_data.append(next(iterator)) 936 returned_data.append(epoch_data) 937 returned_data = self.evaluate(returned_data) 938 self.assertEqual(returned_data, [[([0],), ([1],), 939 ([2],)], [([0],), ([1],), ([2],)]]) 940 941 def test_composite_tensor(self): 942 st = sparse_tensor.SparseTensor( 943 indices=[[0, 0], [1, 0], [2, 0]], values=[0, 1, 2], dense_shape=[3, 1]) 944 data_handler = data_adapter.DataHandler(st, epochs=2, steps_per_epoch=3) 945 returned_data = [] 946 for _, iterator in data_handler.enumerate_epochs(): 947 epoch_data = [] 948 for _ in data_handler.steps(): 949 epoch_data.append(next(iterator)) 950 returned_data.append(epoch_data) 951 returned_data = self.evaluate( 952 nest.map_structure(sparse_ops.sparse_tensor_to_dense, returned_data)) 953 self.assertEqual(returned_data, [[([0],), ([1],), 954 ([2],)], [([0],), ([1],), ([2],)]]) 955 956 def test_iterator(self): 957 def generator(): 958 for _ in range(2): 959 for step in range(3): 960 yield (ops.convert_to_tensor_v2_with_dispatch([step]),) 961 962 it = iter(dataset_ops.Dataset.from_generator( 963 generator, output_types=('float32',))) 964 data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3) 965 returned_data = [] 966 for _, iterator in data_handler.enumerate_epochs(): 967 epoch_data = [] 968 for _ in data_handler.steps(): 969 epoch_data.append(next(iterator)) 970 returned_data.append(epoch_data) 971 returned_data = self.evaluate(returned_data) 972 self.assertEqual(returned_data, [[([0],), ([1],), ([2],)], 973 [([0],), ([1],), ([2],)]]) 974 975 def test_list_of_scalars(self): 976 data_handler = data_adapter.DataHandler([[0], [1], [2]], 977 epochs=2, 978 steps_per_epoch=3) 979 returned_data = [] 980 for _, iterator in data_handler.enumerate_epochs(): 981 epoch_data = [] 982 for _ in data_handler.steps(): 983 epoch_data.append(next(iterator)) 984 returned_data.append(epoch_data) 985 returned_data = self.evaluate(returned_data) 986 self.assertEqual(returned_data, [[([0],), ([1],), 987 ([2],)], [([0],), ([1],), ([2],)]]) 988 989 def test_class_weight_user_errors(self): 990 with self.assertRaisesRegex(ValueError, 'to be a dict with keys'): 991 data_adapter.DataHandler( 992 x=[[0], [1], [2]], 993 y=[[2], [1], [0]], 994 batch_size=1, 995 sample_weight=[[1.], [2.], [4.]], 996 class_weight={ 997 0: 0.5, 998 1: 1., 999 3: 1.5 # Skips class `2`. 1000 }) 1001 1002 with self.assertRaisesRegex(ValueError, 'with a single output'): 1003 data_adapter.DataHandler( 1004 x=np.ones((10, 1)), 1005 y=[np.ones((10, 1)), np.zeros((10, 1))], 1006 batch_size=2, 1007 class_weight={ 1008 0: 0.5, 1009 1: 1., 1010 2: 1.5 1011 }) 1012 1013 @parameterized.named_parameters(('numpy', True), ('dataset', False)) 1014 def test_single_x_input_no_tuple_wrapping(self, use_numpy): 1015 x = np.ones((10, 1)) 1016 1017 if use_numpy: 1018 batch_size = 2 1019 else: 1020 x = dataset_ops.Dataset.from_tensor_slices(x).batch(2) 1021 batch_size = None 1022 1023 data_handler = data_adapter.DataHandler(x, batch_size=batch_size) 1024 for _, iterator in data_handler.enumerate_epochs(): 1025 for _ in data_handler.steps(): 1026 # Check that single x input is not wrapped in a tuple. 1027 self.assertIsInstance(next(iterator), ops.Tensor) 1028 1029 1030class TestValidationSplit(keras_parameterized.TestCase): 1031 1032 @parameterized.named_parameters(('numpy_arrays', True), ('tensors', False)) 1033 def test_validation_split_unshuffled(self, use_numpy): 1034 if use_numpy: 1035 x = np.array([0, 1, 2, 3, 4]) 1036 y = np.array([0, 2, 4, 6, 8]) 1037 sw = np.array([0, 4, 8, 12, 16]) 1038 else: 1039 x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4]) 1040 y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8]) 1041 sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16]) 1042 1043 (train_x, train_y, train_sw), (val_x, val_y, val_sw) = ( 1044 data_adapter.train_validation_split((x, y, sw), validation_split=0.2)) 1045 1046 if use_numpy: 1047 train_x = ops.convert_to_tensor_v2_with_dispatch(train_x) 1048 train_y = ops.convert_to_tensor_v2_with_dispatch(train_y) 1049 train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw) 1050 val_x = ops.convert_to_tensor_v2_with_dispatch(val_x) 1051 val_y = ops.convert_to_tensor_v2_with_dispatch(val_y) 1052 val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw) 1053 1054 self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3]) 1055 self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6]) 1056 self.assertEqual(train_sw.numpy().tolist(), [0, 4, 8, 12]) 1057 1058 self.assertEqual(val_x.numpy().tolist(), [4]) 1059 self.assertEqual(val_y.numpy().tolist(), [8]) 1060 self.assertEqual(val_sw.numpy().tolist(), [16]) 1061 1062 def test_validation_split_user_error(self): 1063 with self.assertRaisesRegex(ValueError, 'is only supported for Tensors'): 1064 data_adapter.train_validation_split( 1065 lambda: np.ones((10, 1)), validation_split=0.2) 1066 1067 def test_validation_split_examples_too_few(self): 1068 with self.assertRaisesRegex(ValueError, 'not sufficient to split it'): 1069 data_adapter.train_validation_split( 1070 np.ones((1, 10)), validation_split=0.2) 1071 1072 def test_validation_split_none(self): 1073 train_sw, val_sw = data_adapter.train_validation_split( 1074 None, validation_split=0.2) 1075 self.assertIsNone(train_sw) 1076 self.assertIsNone(val_sw) 1077 1078 (_, train_sw), (_, val_sw) = data_adapter.train_validation_split( 1079 (np.ones((10, 1)), None), validation_split=0.2) 1080 self.assertIsNone(train_sw) 1081 self.assertIsNone(val_sw) 1082 1083 1084class ListsOfScalarsDataAdapterTest(DataAdapterTestBase): 1085 1086 def setUp(self): 1087 super(ListsOfScalarsDataAdapterTest, self).setUp() 1088 self.adapter_cls = data_adapter.ListsOfScalarsDataAdapter 1089 1090 def test_can_list_inputs(self): 1091 self.assertTrue(self.adapter_cls.can_handle(self.text_input)) 1092 self.assertTrue(self.adapter_cls.can_handle(self.bytes_input)) 1093 1094 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 1095 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 1096 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 1097 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 1098 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 1099 self.assertFalse(self.adapter_cls.can_handle([])) 1100 1101 1102class TestUtils(keras_parameterized.TestCase): 1103 1104 def test_expand_1d_sparse_tensors_untouched(self): 1105 st = sparse_tensor.SparseTensor( 1106 indices=[[0], [10]], values=[1, 2], dense_shape=[10]) 1107 st = data_adapter.expand_1d(st) 1108 self.assertEqual(st.shape.rank, 1) 1109 1110 1111if __name__ == '__main__': 1112 ops.enable_eager_execution() 1113 test.main() 1114