1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""DataAdapter tests.""" 16 17import math 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.python import keras 23from tensorflow.python.data.experimental.ops import cardinality 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.keras import keras_parameterized 29from tensorflow.python.keras import testing_utils 30from tensorflow.python.keras.engine import data_adapter 31from tensorflow.python.keras.utils import data_utils 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import sparse_ops 34from tensorflow.python.platform import test 35from tensorflow.python.util import nest 36 37 38class DummyArrayLike(object): 39 """Dummy array-like object.""" 40 41 def __init__(self, data): 42 self.data = data 43 44 def __len__(self): 45 return len(self.data) 46 47 def __getitem__(self, key): 48 return self.data[key] 49 50 @property 51 def shape(self): 52 return self.data.shape 53 54 @property 55 def dtype(self): 56 return self.data.dtype 57 58 59def fail_on_convert(x, **kwargs): 60 _ = x 61 _ = kwargs 62 raise TypeError('Cannot convert DummyArrayLike to a tensor') 63ops.register_tensor_conversion_function(DummyArrayLike, fail_on_convert) 64 65 66class DataAdapterTestBase(keras_parameterized.TestCase): 67 68 def setUp(self): 69 super(DataAdapterTestBase, self).setUp() 70 self.batch_size = 5 71 self.numpy_input = np.zeros((50, 10)) 72 self.numpy_target = np.ones(50) 73 self.tensor_input = constant_op.constant(2.0, shape=(50, 10)) 74 self.tensor_target = array_ops.ones((50,)) 75 self.arraylike_input = DummyArrayLike(self.numpy_input) 76 self.arraylike_target = DummyArrayLike(self.numpy_target) 77 self.dataset_input = dataset_ops.DatasetV2.from_tensor_slices( 78 (self.numpy_input, self.numpy_target)).shuffle(50).batch( 79 self.batch_size) 80 81 def generator(): 82 while True: 83 yield (np.zeros((self.batch_size, 10)), np.ones(self.batch_size)) 84 self.generator_input = generator() 85 self.iterator_input = data_utils.threadsafe_generator(generator)() 86 self.sequence_input = TestSequence(batch_size=self.batch_size, 87 feature_shape=10) 88 self.text_input = [['abc']] 89 self.bytes_input = [[b'abc']] 90 self.model = keras.models.Sequential( 91 [keras.layers.Dense(8, input_shape=(10,), activation='softmax')]) 92 93 94class TestSequence(data_utils.Sequence): 95 96 def __init__(self, batch_size, feature_shape): 97 self.batch_size = batch_size 98 self.feature_shape = feature_shape 99 100 def __getitem__(self, item): 101 return (np.zeros((self.batch_size, self.feature_shape)), 102 np.ones((self.batch_size,))) 103 104 def __len__(self): 105 return 10 106 107 108class TensorLikeDataAdapterTest(DataAdapterTestBase): 109 110 def setUp(self): 111 super(TensorLikeDataAdapterTest, self).setUp() 112 self.adapter_cls = data_adapter.TensorLikeDataAdapter 113 114 def test_can_handle_numpy(self): 115 self.assertTrue(self.adapter_cls.can_handle(self.numpy_input)) 116 self.assertTrue( 117 self.adapter_cls.can_handle(self.numpy_input, self.numpy_target)) 118 119 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 120 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 121 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 122 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 123 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 124 125 def test_size_numpy(self): 126 adapter = self.adapter_cls( 127 self.numpy_input, self.numpy_target, batch_size=5) 128 self.assertEqual(adapter.get_size(), 10) 129 self.assertFalse(adapter.has_partial_batch()) 130 131 def test_batch_size_numpy(self): 132 adapter = self.adapter_cls( 133 self.numpy_input, self.numpy_target, batch_size=5) 134 self.assertEqual(adapter.batch_size(), 5) 135 136 def test_partial_batch_numpy(self): 137 adapter = self.adapter_cls( 138 self.numpy_input, self.numpy_target, batch_size=4) 139 self.assertEqual(adapter.get_size(), 13) # 50/4 140 self.assertTrue(adapter.has_partial_batch()) 141 self.assertEqual(adapter.partial_batch_size(), 2) 142 143 def test_epochs(self): 144 num_epochs = 3 145 adapter = self.adapter_cls( 146 self.numpy_input, self.numpy_target, batch_size=5, epochs=num_epochs) 147 ds_iter = iter(adapter.get_dataset()) 148 num_batches_per_epoch = self.numpy_input.shape[0] // 5 149 for _ in range(num_batches_per_epoch * num_epochs): 150 next(ds_iter) 151 with self.assertRaises(StopIteration): 152 next(ds_iter) 153 154 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 155 def test_training_numpy(self): 156 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 157 run_eagerly=testing_utils.should_run_eagerly()) 158 self.model.fit(self.numpy_input, self.numpy_target, batch_size=5) 159 160 def test_can_handle_pandas(self): 161 try: 162 import pandas as pd # pylint: disable=g-import-not-at-top 163 except ImportError: 164 self.skipTest('Skipping test because pandas is not installed.') 165 self.assertTrue(self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input))) 166 self.assertTrue( 167 self.adapter_cls.can_handle(pd.DataFrame(self.numpy_input)[0])) 168 self.assertTrue( 169 self.adapter_cls.can_handle( 170 pd.DataFrame(self.numpy_input), 171 pd.DataFrame(self.numpy_input)[0])) 172 173 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 174 def test_training_pandas(self): 175 try: 176 import pandas as pd # pylint: disable=g-import-not-at-top 177 except ImportError: 178 self.skipTest('Skipping test because pandas is not installed.') 179 input_a = keras.Input(shape=(3,), name='input_a') 180 input_b = keras.Input(shape=(3,), name='input_b') 181 input_c = keras.Input(shape=(1,), name='input_b') 182 183 x = keras.layers.Dense(4, name='dense_1')(input_a) 184 y = keras.layers.Dense(3, name='dense_2')(input_b) 185 z = keras.layers.Dense(1, name='dense_3')(input_c) 186 187 model_1 = keras.Model(inputs=input_a, outputs=x) 188 model_2 = keras.Model(inputs=[input_a, input_b], outputs=[x, y]) 189 model_3 = keras.Model(inputs=input_c, outputs=z) 190 191 model_1.compile(optimizer='rmsprop', loss='mse') 192 model_2.compile(optimizer='rmsprop', loss='mse') 193 194 input_a_np = np.random.random((10, 3)) 195 input_b_np = np.random.random((10, 3)) 196 input_a_df = pd.DataFrame(input_a_np) 197 input_b_df = pd.DataFrame(input_b_np) 198 199 output_a_df = pd.DataFrame(np.random.random((10, 4))) 200 output_b_df = pd.DataFrame(np.random.random((10, 3))) 201 202 model_1.fit(input_a_df, 203 output_a_df) 204 model_2.fit([input_a_df, input_b_df], 205 [output_a_df, output_b_df]) 206 model_1.fit([input_a_df], 207 [output_a_df]) 208 model_1.fit({'input_a': input_a_df}, 209 output_a_df) 210 model_2.fit({'input_a': input_a_df, 'input_b': input_b_df}, 211 [output_a_df, output_b_df]) 212 213 model_1.evaluate(input_a_df, 214 output_a_df) 215 model_2.evaluate([input_a_df, input_b_df], 216 [output_a_df, output_b_df]) 217 model_1.evaluate([input_a_df], 218 [output_a_df]) 219 model_1.evaluate({'input_a': input_a_df}, 220 output_a_df) 221 model_2.evaluate({'input_a': input_a_df, 'input_b': input_b_df}, 222 [output_a_df, output_b_df]) 223 224 # Verify predicting on pandas vs numpy returns the same result 225 predict_1_pandas = model_1.predict(input_a_df) 226 predict_2_pandas = model_2.predict([input_a_df, input_b_df]) 227 predict_3_pandas = model_3.predict(input_a_df[0]) 228 229 predict_1_numpy = model_1.predict(input_a_np) 230 predict_2_numpy = model_2.predict([input_a_np, input_b_np]) 231 predict_3_numpy = model_3.predict(np.asarray(input_a_df[0])) 232 233 self.assertAllClose(predict_1_numpy, predict_1_pandas) 234 self.assertAllClose(predict_2_numpy, predict_2_pandas) 235 self.assertAllClose(predict_3_numpy, predict_3_pandas) 236 237 # Extra ways to pass in dataframes 238 model_1.predict([input_a_df]) 239 model_1.predict({'input_a': input_a_df}) 240 model_2.predict({'input_a': input_a_df, 'input_b': input_b_df}) 241 242 def test_can_handle(self): 243 self.assertTrue(self.adapter_cls.can_handle(self.tensor_input)) 244 self.assertTrue( 245 self.adapter_cls.can_handle(self.tensor_input, self.tensor_target)) 246 247 self.assertFalse(self.adapter_cls.can_handle(self.arraylike_input)) 248 self.assertFalse( 249 self.adapter_cls.can_handle(self.arraylike_input, 250 self.arraylike_target)) 251 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 252 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 253 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 254 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 255 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 256 257 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 258 def test_training(self): 259 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 260 run_eagerly=testing_utils.should_run_eagerly()) 261 self.model.fit(self.tensor_input, self.tensor_target, batch_size=5) 262 263 def test_size(self): 264 adapter = self.adapter_cls( 265 self.tensor_input, self.tensor_target, batch_size=5) 266 self.assertEqual(adapter.get_size(), 10) 267 self.assertFalse(adapter.has_partial_batch()) 268 269 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 270 def test_shuffle_correctness(self): 271 num_samples = 100 272 batch_size = 32 273 x = np.arange(num_samples) 274 np.random.seed(99) 275 adapter = self.adapter_cls( 276 x, y=None, batch_size=batch_size, shuffle=True, epochs=2) 277 278 def _get_epoch(ds_iter): 279 ds_data = [] 280 for _ in range(int(math.ceil(num_samples / batch_size))): 281 ds_data.append(next(ds_iter).numpy()) 282 return np.concatenate(ds_data) 283 284 ds_iter = iter(adapter.get_dataset()) 285 286 # First epoch. 287 epoch_data = _get_epoch(ds_iter) 288 # Check that shuffling occurred. 289 self.assertNotAllClose(x, epoch_data) 290 # Check that each elements appears, and only once. 291 self.assertAllClose(x, np.sort(epoch_data)) 292 293 # Second epoch. 294 second_epoch_data = _get_epoch(ds_iter) 295 # Check that shuffling occurred. 296 self.assertNotAllClose(x, second_epoch_data) 297 # Check that shuffling is different across epochs. 298 self.assertNotAllClose(epoch_data, second_epoch_data) 299 # Check that each elements appears, and only once. 300 self.assertAllClose(x, np.sort(second_epoch_data)) 301 302 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 303 def test_batch_shuffle_correctness(self): 304 num_samples = 100 305 batch_size = 6 306 x = np.arange(num_samples) 307 np.random.seed(99) 308 adapter = self.adapter_cls( 309 x, y=None, batch_size=batch_size, shuffle='batch', epochs=2) 310 311 def _get_epoch_batches(ds_iter): 312 ds_data = [] 313 for _ in range(int(math.ceil(num_samples / batch_size))): 314 ds_data.append(next(ds_iter)[0].numpy()) 315 return ds_data 316 317 ds_iter = iter(adapter.get_dataset()) 318 319 # First epoch. 320 epoch_batch_data = _get_epoch_batches(ds_iter) 321 epoch_data = np.concatenate(epoch_batch_data) 322 323 def _verify_batch(batch): 324 # Verify that a batch contains only contiguous data, and that it has 325 # been shuffled. 326 shuffled_batch = np.sort(batch) 327 self.assertNotAllClose(batch, shuffled_batch) 328 for i in range(1, len(batch)): 329 self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i]) 330 331 # Assert that the data within each batch remains contiguous 332 for batch in epoch_batch_data: 333 _verify_batch(batch) 334 335 # Check that individual batches are unshuffled 336 # Check that shuffling occurred. 337 self.assertNotAllClose(x, epoch_data) 338 # Check that each elements appears, and only once. 339 self.assertAllClose(x, np.sort(epoch_data)) 340 341 # Second epoch. 342 second_epoch_batch_data = _get_epoch_batches(ds_iter) 343 second_epoch_data = np.concatenate(second_epoch_batch_data) 344 345 # Assert that the data within each batch remains contiguous 346 for batch in second_epoch_batch_data: 347 _verify_batch(batch) 348 349 # Check that shuffling occurred. 350 self.assertNotAllClose(x, second_epoch_data) 351 # Check that shuffling is different across epochs. 352 self.assertNotAllClose(epoch_data, second_epoch_data) 353 # Check that each elements appears, and only once. 354 self.assertAllClose(x, np.sort(second_epoch_data)) 355 356 @parameterized.named_parameters( 357 ('batch_size_5', 5, None, 5), 358 ('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence 359 ('steps_1', None, 1, 50), 360 ('steps_4', None, 4, 13), 361 ) 362 def test_batch_size(self, batch_size_in, steps, batch_size_out): 363 adapter = self.adapter_cls( 364 self.tensor_input, self.tensor_target, batch_size=batch_size_in, 365 steps=steps) 366 self.assertEqual(adapter.batch_size(), batch_size_out) 367 368 @parameterized.named_parameters( 369 ('batch_size_5', 5, None, 10, 0), 370 ('batch_size_4', 4, None, 13, 2), 371 ('steps_1', None, 1, 1, 0), 372 ('steps_5', None, 5, 5, 0), 373 ('steps_4', None, 4, 4, 11), 374 ) 375 def test_partial_batch( 376 self, batch_size_in, steps, size, partial_batch_size): 377 adapter = self.adapter_cls( 378 self.tensor_input, self.tensor_target, batch_size=batch_size_in, 379 steps=steps) 380 self.assertEqual(adapter.get_size(), size) # 50/steps 381 self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size)) 382 self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None) 383 384 385class GenericArrayLikeDataAdapterTest(DataAdapterTestBase): 386 387 def setUp(self): 388 super(GenericArrayLikeDataAdapterTest, self).setUp() 389 self.adapter_cls = data_adapter.GenericArrayLikeDataAdapter 390 391 def test_can_handle_some_numpy(self): 392 self.assertTrue(self.adapter_cls.can_handle( 393 self.arraylike_input)) 394 self.assertTrue( 395 self.adapter_cls.can_handle(self.arraylike_input, 396 self.arraylike_target)) 397 398 # Because adapters are mutually exclusive, don't handle cases 399 # where all the data is numpy or an eagertensor 400 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 401 self.assertFalse( 402 self.adapter_cls.can_handle(self.numpy_input, 403 self.numpy_target)) 404 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 405 self.assertFalse( 406 self.adapter_cls.can_handle(self.tensor_input, self.tensor_target)) 407 408 # But do handle mixes that include generic arraylike data 409 self.assertTrue( 410 self.adapter_cls.can_handle(self.numpy_input, 411 self.arraylike_target)) 412 self.assertTrue( 413 self.adapter_cls.can_handle(self.arraylike_input, 414 self.numpy_target)) 415 self.assertTrue( 416 self.adapter_cls.can_handle(self.arraylike_input, 417 self.tensor_target)) 418 self.assertTrue( 419 self.adapter_cls.can_handle(self.tensor_input, 420 self.arraylike_target)) 421 422 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 423 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 424 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 425 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 426 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 427 428 def test_size(self): 429 adapter = self.adapter_cls( 430 self.arraylike_input, 431 self.arraylike_target, batch_size=5) 432 self.assertEqual(adapter.get_size(), 10) 433 self.assertFalse(adapter.has_partial_batch()) 434 435 def test_epochs(self): 436 num_epochs = 3 437 adapter = self.adapter_cls( 438 self.arraylike_input, 439 self.numpy_target, batch_size=5, epochs=num_epochs) 440 ds_iter = iter(adapter.get_dataset()) 441 num_batches_per_epoch = self.numpy_input.shape[0] // 5 442 for _ in range(num_batches_per_epoch * num_epochs): 443 next(ds_iter) 444 with self.assertRaises(StopIteration): 445 next(ds_iter) 446 447 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 448 def test_training(self): 449 # First verify that DummyArrayLike can't be converted to a Tensor 450 with self.assertRaises(TypeError): 451 ops.convert_to_tensor_v2_with_dispatch(self.arraylike_input) 452 453 # Then train on the array like. 454 # It should not be converted to a tensor directly (which would force it into 455 # memory), only the sliced data should be converted. 456 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 457 run_eagerly=testing_utils.should_run_eagerly()) 458 self.model.fit(self.arraylike_input, 459 self.arraylike_target, batch_size=5) 460 self.model.fit(self.arraylike_input, 461 self.arraylike_target, 462 shuffle=True, batch_size=5) 463 self.model.fit(self.arraylike_input, 464 self.arraylike_target, 465 shuffle='batch', batch_size=5) 466 self.model.evaluate(self.arraylike_input, 467 self.arraylike_target, batch_size=5) 468 self.model.predict(self.arraylike_input, batch_size=5) 469 470 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 471 def test_training_numpy_target(self): 472 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 473 run_eagerly=testing_utils.should_run_eagerly()) 474 self.model.fit(self.arraylike_input, 475 self.numpy_target, batch_size=5) 476 self.model.fit(self.arraylike_input, 477 self.numpy_target, shuffle=True, 478 batch_size=5) 479 self.model.fit(self.arraylike_input, 480 self.numpy_target, shuffle='batch', 481 batch_size=5) 482 self.model.evaluate(self.arraylike_input, 483 self.numpy_target, batch_size=5) 484 485 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 486 def test_training_tensor_target(self): 487 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 488 run_eagerly=testing_utils.should_run_eagerly()) 489 self.model.fit(self.arraylike_input, 490 self.tensor_target, batch_size=5) 491 self.model.fit(self.arraylike_input, 492 self.tensor_target, shuffle=True, 493 batch_size=5) 494 self.model.fit(self.arraylike_input, 495 self.tensor_target, shuffle='batch', 496 batch_size=5) 497 self.model.evaluate(self.arraylike_input, 498 self.tensor_target, batch_size=5) 499 500 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 501 def test_shuffle_correctness(self): 502 num_samples = 100 503 batch_size = 32 504 x = DummyArrayLike(np.arange(num_samples)) 505 np.random.seed(99) 506 adapter = self.adapter_cls( 507 x, y=None, batch_size=batch_size, shuffle=True, epochs=2) 508 509 def _get_epoch(ds_iter): 510 ds_data = [] 511 for _ in range(int(math.ceil(num_samples / batch_size))): 512 ds_data.append(next(ds_iter).numpy()) 513 return np.concatenate(ds_data) 514 515 ds_iter = iter(adapter.get_dataset()) 516 517 # First epoch. 518 epoch_data = _get_epoch(ds_iter) 519 # Check that shuffling occurred. 520 self.assertNotAllClose(x, epoch_data) 521 # Check that each elements appears, and only once. 522 self.assertAllClose(x, np.sort(epoch_data)) 523 524 # Second epoch. 525 second_epoch_data = _get_epoch(ds_iter) 526 # Check that shuffling occurred. 527 self.assertNotAllClose(x, second_epoch_data) 528 # Check that shuffling is different across epochs. 529 self.assertNotAllClose(epoch_data, second_epoch_data) 530 # Check that each elements appears, and only once. 531 self.assertAllClose(x, np.sort(second_epoch_data)) 532 533 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 534 def test_batch_shuffle_correctness(self): 535 num_samples = 100 536 batch_size = 6 537 x = DummyArrayLike(np.arange(num_samples)) 538 np.random.seed(99) 539 adapter = self.adapter_cls( 540 x, y=None, batch_size=batch_size, shuffle='batch', epochs=2) 541 542 def _get_epoch_batches(ds_iter): 543 ds_data = [] 544 for _ in range(int(math.ceil(num_samples / batch_size))): 545 ds_data.append(next(ds_iter)[0].numpy()) 546 return ds_data 547 548 ds_iter = iter(adapter.get_dataset()) 549 550 # First epoch. 551 epoch_batch_data = _get_epoch_batches(ds_iter) 552 epoch_data = np.concatenate(epoch_batch_data) 553 554 def _verify_batch(batch): 555 # Verify that a batch contains only contiguous data, but that it has 556 # been shuffled. 557 shuffled_batch = np.sort(batch) 558 self.assertNotAllClose(batch, shuffled_batch) 559 for i in range(1, len(batch)): 560 self.assertEqual(shuffled_batch[i-1] + 1, shuffled_batch[i]) 561 562 # Assert that the data within each batch is shuffled contiguous data 563 for batch in epoch_batch_data: 564 _verify_batch(batch) 565 566 # Check that individual batches are unshuffled 567 # Check that shuffling occurred. 568 self.assertNotAllClose(x, epoch_data) 569 # Check that each elements appears, and only once. 570 self.assertAllClose(x, np.sort(epoch_data)) 571 572 # Second epoch. 573 second_epoch_batch_data = _get_epoch_batches(ds_iter) 574 second_epoch_data = np.concatenate(second_epoch_batch_data) 575 576 # Assert that the data within each batch remains contiguous 577 for batch in second_epoch_batch_data: 578 _verify_batch(batch) 579 580 # Check that shuffling occurred. 581 self.assertNotAllClose(x, second_epoch_data) 582 # Check that shuffling is different across epochs. 583 self.assertNotAllClose(epoch_data, second_epoch_data) 584 # Check that each elements appears, and only once. 585 self.assertAllClose(x, np.sort(second_epoch_data)) 586 587 @parameterized.named_parameters( 588 ('batch_size_5', 5, None, 5), 589 ('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence 590 ('steps_1', None, 1, 50), 591 ('steps_4', None, 4, 13), 592 ) 593 def test_batch_size(self, batch_size_in, steps, batch_size_out): 594 adapter = self.adapter_cls( 595 self.arraylike_input, 596 self.arraylike_target, batch_size=batch_size_in, 597 steps=steps) 598 self.assertEqual(adapter.batch_size(), batch_size_out) 599 600 @parameterized.named_parameters( 601 ('batch_size_5', 5, None, 10, 0), 602 ('batch_size_4', 4, None, 13, 2), 603 ('steps_1', None, 1, 1, 0), 604 ('steps_5', None, 5, 5, 0), 605 ('steps_4', None, 4, 4, 11), 606 ) 607 def test_partial_batch( 608 self, batch_size_in, steps, size, partial_batch_size): 609 adapter = self.adapter_cls( 610 self.arraylike_input, self.arraylike_target, 611 batch_size=batch_size_in, 612 steps=steps) 613 self.assertEqual(adapter.get_size(), size) # 50/steps 614 self.assertEqual(adapter.has_partial_batch(), bool(partial_batch_size)) 615 self.assertEqual(adapter.partial_batch_size(), partial_batch_size or None) 616 617 618class DatasetAdapterTest(DataAdapterTestBase): 619 620 def setUp(self): 621 super(DatasetAdapterTest, self).setUp() 622 self.adapter_cls = data_adapter.DatasetAdapter 623 624 def test_can_handle(self): 625 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 626 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 627 self.assertTrue(self.adapter_cls.can_handle(self.dataset_input)) 628 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 629 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 630 631 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 632 def test_training(self): 633 dataset = self.adapter_cls(self.dataset_input).get_dataset() 634 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 635 run_eagerly=testing_utils.should_run_eagerly()) 636 self.model.fit(dataset) 637 638 def test_size(self): 639 adapter = self.adapter_cls(self.dataset_input) 640 self.assertIsNone(adapter.get_size()) 641 642 def test_batch_size(self): 643 adapter = self.adapter_cls(self.dataset_input) 644 self.assertIsNone(adapter.batch_size()) 645 646 def test_partial_batch(self): 647 adapter = self.adapter_cls(self.dataset_input) 648 self.assertFalse(adapter.has_partial_batch()) 649 self.assertIsNone(adapter.partial_batch_size()) 650 651 def test_invalid_targets_argument(self): 652 with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'): 653 self.adapter_cls(self.dataset_input, y=self.dataset_input) 654 655 def test_invalid_sample_weights_argument(self): 656 with self.assertRaisesRegex(ValueError, 657 r'`sample_weight` argument is not supported'): 658 self.adapter_cls(self.dataset_input, sample_weights=self.dataset_input) 659 660 661class GeneratorDataAdapterTest(DataAdapterTestBase): 662 663 def setUp(self): 664 super(GeneratorDataAdapterTest, self).setUp() 665 self.adapter_cls = data_adapter.GeneratorDataAdapter 666 667 def test_can_handle(self): 668 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 669 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 670 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 671 self.assertTrue(self.adapter_cls.can_handle(self.generator_input)) 672 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 673 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 674 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 675 676 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 677 def test_training(self): 678 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 679 run_eagerly=testing_utils.should_run_eagerly()) 680 self.model.fit(self.generator_input, steps_per_epoch=10) 681 682 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 683 @testing_utils.run_v2_only 684 @data_utils.dont_use_multiprocessing_pool 685 def test_with_multiprocessing_training(self): 686 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 687 run_eagerly=testing_utils.should_run_eagerly()) 688 self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True, 689 max_queue_size=10, steps_per_epoch=10) 690 # Fit twice to ensure there isn't any duplication that prevent the worker 691 # from starting. 692 self.model.fit(self.iterator_input, workers=1, use_multiprocessing=True, 693 max_queue_size=10, steps_per_epoch=10) 694 695 def test_size(self): 696 adapter = self.adapter_cls(self.generator_input) 697 self.assertIsNone(adapter.get_size()) 698 699 def test_batch_size(self): 700 adapter = self.adapter_cls(self.generator_input) 701 self.assertEqual(adapter.batch_size(), None) 702 self.assertEqual(adapter.representative_batch_size(), 5) 703 704 def test_partial_batch(self): 705 adapter = self.adapter_cls(self.generator_input) 706 self.assertFalse(adapter.has_partial_batch()) 707 self.assertIsNone(adapter.partial_batch_size()) 708 709 def test_invalid_targets_argument(self): 710 with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'): 711 self.adapter_cls(self.generator_input, y=self.generator_input) 712 713 def test_invalid_sample_weights_argument(self): 714 with self.assertRaisesRegex(ValueError, 715 r'`sample_weight` argument is not supported'): 716 self.adapter_cls( 717 self.generator_input, sample_weights=self.generator_input) 718 719 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 720 def test_not_shuffled(self): 721 def generator(): 722 for i in range(10): 723 yield np.ones((1, 1)) * i 724 725 adapter = self.adapter_cls(generator(), shuffle=True) 726 for i, data in enumerate(adapter.get_dataset()): 727 self.assertEqual(i, data[0].numpy().flatten()) 728 729 730class KerasSequenceAdapterTest(DataAdapterTestBase): 731 732 def setUp(self): 733 super(KerasSequenceAdapterTest, self).setUp() 734 self.adapter_cls = data_adapter.KerasSequenceAdapter 735 736 def test_can_handle(self): 737 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 738 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 739 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 740 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 741 self.assertTrue(self.adapter_cls.can_handle(self.sequence_input)) 742 self.assertFalse(self.adapter_cls.can_handle(self.text_input)) 743 self.assertFalse(self.adapter_cls.can_handle(self.bytes_input)) 744 745 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 746 def test_training(self): 747 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 748 run_eagerly=testing_utils.should_run_eagerly()) 749 self.model.fit(self.sequence_input) 750 751 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 752 @testing_utils.run_v2_only 753 @data_utils.dont_use_multiprocessing_pool 754 def test_with_multiprocessing_training(self): 755 self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', 756 run_eagerly=testing_utils.should_run_eagerly()) 757 self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True, 758 max_queue_size=10, steps_per_epoch=10) 759 # Fit twice to ensure there isn't any duplication that prevent the worker 760 # from starting. 761 self.model.fit(self.sequence_input, workers=1, use_multiprocessing=True, 762 max_queue_size=10, steps_per_epoch=10) 763 764 def test_size(self): 765 adapter = self.adapter_cls(self.sequence_input) 766 self.assertEqual(adapter.get_size(), 10) 767 768 def test_batch_size(self): 769 adapter = self.adapter_cls(self.sequence_input) 770 self.assertEqual(adapter.batch_size(), None) 771 self.assertEqual(adapter.representative_batch_size(), 5) 772 773 def test_partial_batch(self): 774 adapter = self.adapter_cls(self.sequence_input) 775 self.assertFalse(adapter.has_partial_batch()) 776 self.assertIsNone(adapter.partial_batch_size()) 777 778 def test_invalid_targets_argument(self): 779 with self.assertRaisesRegex(ValueError, r'`y` argument is not supported'): 780 self.adapter_cls(self.sequence_input, y=self.sequence_input) 781 782 def test_invalid_sample_weights_argument(self): 783 with self.assertRaisesRegex(ValueError, 784 r'`sample_weight` argument is not supported'): 785 self.adapter_cls(self.sequence_input, sample_weights=self.sequence_input) 786 787 788class DataHandlerTest(keras_parameterized.TestCase): 789 790 def test_finite_dataset_with_steps_per_epoch(self): 791 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1) 792 # User can choose to only partially consume `Dataset`. 793 data_handler = data_adapter.DataHandler( 794 data, initial_epoch=0, epochs=2, steps_per_epoch=2) 795 self.assertEqual(data_handler.inferred_steps, 2) 796 self.assertFalse(data_handler._adapter.should_recreate_iterator()) 797 returned_data = [] 798 for _, iterator in data_handler.enumerate_epochs(): 799 epoch_data = [] 800 for _ in data_handler.steps(): 801 epoch_data.append(next(iterator).numpy()) 802 returned_data.append(epoch_data) 803 self.assertEqual(returned_data, [[0, 1], [2, 3]]) 804 805 def test_finite_dataset_without_steps_per_epoch(self): 806 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1) 807 data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2) 808 self.assertEqual(data_handler.inferred_steps, 3) 809 returned_data = [] 810 for _, iterator in data_handler.enumerate_epochs(): 811 epoch_data = [] 812 for _ in data_handler.steps(): 813 epoch_data.append(next(iterator).numpy()) 814 returned_data.append(epoch_data) 815 self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]]) 816 817 def test_finite_dataset_with_steps_per_epoch_exact_size(self): 818 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1) 819 # If user specifies exact size of `Dataset` as `steps_per_epoch`, 820 # create a new iterator each epoch. 821 data_handler = data_adapter.DataHandler( 822 data, initial_epoch=0, epochs=2, steps_per_epoch=4) 823 self.assertTrue(data_handler._adapter.should_recreate_iterator()) 824 returned_data = [] 825 for _, iterator in data_handler.enumerate_epochs(): 826 epoch_data = [] 827 for _ in data_handler.steps(): 828 epoch_data.append(next(iterator).numpy()) 829 returned_data.append(epoch_data) 830 self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]]) 831 832 def test_infinite_dataset_with_steps_per_epoch(self): 833 data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1).repeat() 834 data_handler = data_adapter.DataHandler( 835 data, initial_epoch=0, epochs=2, steps_per_epoch=3) 836 returned_data = [] 837 for _, iterator in data_handler.enumerate_epochs(): 838 epoch_data = [] 839 for _ in data_handler.steps(): 840 epoch_data.append(next(iterator).numpy()) 841 returned_data.append(epoch_data) 842 self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]]) 843 844 def test_unknown_cardinality_dataset_with_steps_per_epoch(self): 845 ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6]) 846 filtered_ds = ds.filter(lambda x: x < 4) 847 self.assertEqual( 848 cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN) 849 850 # User can choose to only partially consume `Dataset`. 851 data_handler = data_adapter.DataHandler( 852 filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2) 853 self.assertFalse(data_handler._adapter.should_recreate_iterator()) 854 returned_data = [] 855 for _, iterator in data_handler.enumerate_epochs(): 856 epoch_data = [] 857 for _ in data_handler.steps(): 858 epoch_data.append(next(iterator)) 859 returned_data.append(epoch_data) 860 returned_data = self.evaluate(returned_data) 861 self.assertEqual(returned_data, [[0, 1], [2, 3]]) 862 self.assertEqual(data_handler.inferred_steps, 2) 863 864 def test_unknown_cardinality_dataset_without_steps_per_epoch(self): 865 ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6]) 866 filtered_ds = ds.filter(lambda x: x < 4) 867 self.assertEqual( 868 cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN) 869 870 data_handler = data_adapter.DataHandler( 871 filtered_ds, initial_epoch=0, epochs=2) 872 self.assertEqual(data_handler.inferred_steps, None) 873 self.assertTrue(data_handler._adapter.should_recreate_iterator()) 874 returned_data = [] 875 for _, iterator in data_handler.enumerate_epochs(): 876 epoch_data = [] 877 with data_handler.catch_stop_iteration(): 878 for _ in data_handler.steps(): 879 epoch_data.append(next(iterator)) 880 returned_data.append(epoch_data) 881 returned_data = self.evaluate(returned_data) 882 self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]]) 883 self.assertEqual(data_handler.inferred_steps, 4) 884 885 def test_insufficient_data(self): 886 ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1]) 887 ds = ds.filter(lambda *args, **kwargs: True) 888 data_handler = data_adapter.DataHandler( 889 ds, initial_epoch=0, epochs=2, steps_per_epoch=3) 890 returned_data = [] 891 for _, iterator in data_handler.enumerate_epochs(): 892 epoch_data = [] 893 for _ in data_handler.steps(): 894 with data_handler.catch_stop_iteration(): 895 epoch_data.append(next(iterator)) 896 returned_data.append(epoch_data) 897 returned_data = self.evaluate(returned_data) 898 self.assertTrue(data_handler._insufficient_data) 899 self.assertEqual(returned_data, [[0, 1]]) 900 901 def test_numpy(self): 902 x = np.array([0, 1, 2]) 903 y = np.array([0, 2, 4]) 904 sw = np.array([0, 4, 8]) 905 data_handler = data_adapter.DataHandler( 906 x=x, y=y, sample_weight=sw, batch_size=1, epochs=2) 907 returned_data = [] 908 for _, iterator in data_handler.enumerate_epochs(): 909 epoch_data = [] 910 for _ in data_handler.steps(): 911 epoch_data.append(next(iterator)) 912 returned_data.append(epoch_data) 913 returned_data = self.evaluate(returned_data) 914 self.assertEqual(returned_data, 915 [[(0, 0, 0), (1, 2, 4), 916 (2, 4, 8)], [(0, 0, 0), (1, 2, 4), (2, 4, 8)]]) 917 918 def test_generator(self): 919 920 def generator(): 921 for _ in range(2): 922 for step in range(3): 923 yield (ops.convert_to_tensor_v2_with_dispatch([step]),) 924 925 data_handler = data_adapter.DataHandler( 926 generator(), epochs=2, steps_per_epoch=3) 927 returned_data = [] 928 for _, iterator in data_handler.enumerate_epochs(): 929 epoch_data = [] 930 for _ in data_handler.steps(): 931 epoch_data.append(next(iterator)) 932 returned_data.append(epoch_data) 933 returned_data = self.evaluate(returned_data) 934 self.assertEqual(returned_data, [[([0],), ([1],), 935 ([2],)], [([0],), ([1],), ([2],)]]) 936 937 def test_composite_tensor(self): 938 st = sparse_tensor.SparseTensor( 939 indices=[[0, 0], [1, 0], [2, 0]], values=[0, 1, 2], dense_shape=[3, 1]) 940 data_handler = data_adapter.DataHandler(st, epochs=2, steps_per_epoch=3) 941 returned_data = [] 942 for _, iterator in data_handler.enumerate_epochs(): 943 epoch_data = [] 944 for _ in data_handler.steps(): 945 epoch_data.append(next(iterator)) 946 returned_data.append(epoch_data) 947 returned_data = self.evaluate( 948 nest.map_structure(sparse_ops.sparse_tensor_to_dense, returned_data)) 949 self.assertEqual(returned_data, [[([0],), ([1],), 950 ([2],)], [([0],), ([1],), ([2],)]]) 951 952 def test_iterator(self): 953 def generator(): 954 for _ in range(2): 955 for step in range(3): 956 yield (ops.convert_to_tensor_v2_with_dispatch([step]),) 957 958 it = iter(dataset_ops.Dataset.from_generator( 959 generator, output_types=('float32',))) 960 data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3) 961 returned_data = [] 962 for _, iterator in data_handler.enumerate_epochs(): 963 epoch_data = [] 964 for _ in data_handler.steps(): 965 epoch_data.append(next(iterator)) 966 returned_data.append(epoch_data) 967 returned_data = self.evaluate(returned_data) 968 self.assertEqual(returned_data, [[([0],), ([1],), ([2],)], 969 [([0],), ([1],), ([2],)]]) 970 971 def test_list_of_scalars(self): 972 data_handler = data_adapter.DataHandler([[0], [1], [2]], 973 epochs=2, 974 steps_per_epoch=3) 975 returned_data = [] 976 for _, iterator in data_handler.enumerate_epochs(): 977 epoch_data = [] 978 for _ in data_handler.steps(): 979 epoch_data.append(next(iterator)) 980 returned_data.append(epoch_data) 981 returned_data = self.evaluate(returned_data) 982 self.assertEqual(returned_data, [[([0],), ([1],), 983 ([2],)], [([0],), ([1],), ([2],)]]) 984 985 def test_class_weight_user_errors(self): 986 with self.assertRaisesRegex(ValueError, 'to be a dict with keys'): 987 data_adapter.DataHandler( 988 x=[[0], [1], [2]], 989 y=[[2], [1], [0]], 990 batch_size=1, 991 sample_weight=[[1.], [2.], [4.]], 992 class_weight={ 993 0: 0.5, 994 1: 1., 995 3: 1.5 # Skips class `2`. 996 }) 997 998 with self.assertRaisesRegex(ValueError, 'with a single output'): 999 data_adapter.DataHandler( 1000 x=np.ones((10, 1)), 1001 y=[np.ones((10, 1)), np.zeros((10, 1))], 1002 batch_size=2, 1003 class_weight={ 1004 0: 0.5, 1005 1: 1., 1006 2: 1.5 1007 }) 1008 1009 @parameterized.named_parameters(('numpy', True), ('dataset', False)) 1010 def test_single_x_input_no_tuple_wrapping(self, use_numpy): 1011 x = np.ones((10, 1)) 1012 1013 if use_numpy: 1014 batch_size = 2 1015 else: 1016 x = dataset_ops.Dataset.from_tensor_slices(x).batch(2) 1017 batch_size = None 1018 1019 data_handler = data_adapter.DataHandler(x, batch_size=batch_size) 1020 for _, iterator in data_handler.enumerate_epochs(): 1021 for _ in data_handler.steps(): 1022 # Check that single x input is not wrapped in a tuple. 1023 self.assertIsInstance(next(iterator), ops.Tensor) 1024 1025 1026class TestValidationSplit(keras_parameterized.TestCase): 1027 1028 @parameterized.named_parameters(('numpy_arrays', True), ('tensors', False)) 1029 def test_validation_split_unshuffled(self, use_numpy): 1030 if use_numpy: 1031 x = np.array([0, 1, 2, 3, 4]) 1032 y = np.array([0, 2, 4, 6, 8]) 1033 sw = np.array([0, 4, 8, 12, 16]) 1034 else: 1035 x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4]) 1036 y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8]) 1037 sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16]) 1038 1039 (train_x, train_y, train_sw), (val_x, val_y, val_sw) = ( 1040 data_adapter.train_validation_split((x, y, sw), validation_split=0.2)) 1041 1042 if use_numpy: 1043 train_x = ops.convert_to_tensor_v2_with_dispatch(train_x) 1044 train_y = ops.convert_to_tensor_v2_with_dispatch(train_y) 1045 train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw) 1046 val_x = ops.convert_to_tensor_v2_with_dispatch(val_x) 1047 val_y = ops.convert_to_tensor_v2_with_dispatch(val_y) 1048 val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw) 1049 1050 self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3]) 1051 self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6]) 1052 self.assertEqual(train_sw.numpy().tolist(), [0, 4, 8, 12]) 1053 1054 self.assertEqual(val_x.numpy().tolist(), [4]) 1055 self.assertEqual(val_y.numpy().tolist(), [8]) 1056 self.assertEqual(val_sw.numpy().tolist(), [16]) 1057 1058 def test_validation_split_user_error(self): 1059 with self.assertRaisesRegex(ValueError, 'is only supported for Tensors'): 1060 data_adapter.train_validation_split( 1061 lambda: np.ones((10, 1)), validation_split=0.2) 1062 1063 def test_validation_split_examples_too_few(self): 1064 with self.assertRaisesRegex(ValueError, 'not sufficient to split it'): 1065 data_adapter.train_validation_split( 1066 np.ones((1, 10)), validation_split=0.2) 1067 1068 def test_validation_split_none(self): 1069 train_sw, val_sw = data_adapter.train_validation_split( 1070 None, validation_split=0.2) 1071 self.assertIsNone(train_sw) 1072 self.assertIsNone(val_sw) 1073 1074 (_, train_sw), (_, val_sw) = data_adapter.train_validation_split( 1075 (np.ones((10, 1)), None), validation_split=0.2) 1076 self.assertIsNone(train_sw) 1077 self.assertIsNone(val_sw) 1078 1079 1080class ListsOfScalarsDataAdapterTest(DataAdapterTestBase): 1081 1082 def setUp(self): 1083 super(ListsOfScalarsDataAdapterTest, self).setUp() 1084 self.adapter_cls = data_adapter.ListsOfScalarsDataAdapter 1085 1086 def test_can_list_inputs(self): 1087 self.assertTrue(self.adapter_cls.can_handle(self.text_input)) 1088 self.assertTrue(self.adapter_cls.can_handle(self.bytes_input)) 1089 1090 self.assertFalse(self.adapter_cls.can_handle(self.numpy_input)) 1091 self.assertFalse(self.adapter_cls.can_handle(self.tensor_input)) 1092 self.assertFalse(self.adapter_cls.can_handle(self.dataset_input)) 1093 self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) 1094 self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) 1095 self.assertFalse(self.adapter_cls.can_handle([])) 1096 1097 1098class TestUtils(keras_parameterized.TestCase): 1099 1100 def test_expand_1d_sparse_tensors_untouched(self): 1101 st = sparse_tensor.SparseTensor( 1102 indices=[[0], [10]], values=[1, 2], dense_shape=[10]) 1103 st = data_adapter.expand_1d(st) 1104 self.assertEqual(st.shape.rank, 1) 1105 1106 1107if __name__ == '__main__': 1108 ops.enable_eager_execution() 1109 test.main() 1110