1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for dataset_creator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.python.compat import v2_compat 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.distribute import multi_worker_test_base 23from tensorflow.python.distribute import parameter_server_strategy_v2 24from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 25from tensorflow.python.keras.engine import sequential 26from tensorflow.python.keras.layers import core as core_layers 27from tensorflow.python.keras.optimizer_v2 import gradient_descent 28from tensorflow.python.keras.utils import dataset_creator 29from tensorflow.python.platform import test 30from tensorflow.python.training.server_lib import ClusterSpec 31 32 33class DatasetCreatorTest(test.TestCase): 34 35 def test_dataset_creator(self): 36 with self.assertRaisesRegex( 37 TypeError, "`dataset_fn` for `DatasetCreator` must be a `callable`."): 38 dataset_creator.DatasetCreator(2) 39 40 dataset_fn = lambda: 3 41 with self.assertRaisesRegex( 42 TypeError, "The `callable` provided to `DatasetCreator` must return " 43 "a Dataset."): 44 dataset_creator.DatasetCreator(dataset_fn)() 45 46 dataset_fn = lambda: dataset_ops.DatasetV2.from_tensor_slices([1, 1]) 47 got = dataset_creator.DatasetCreator(dataset_fn)() 48 self.assertEqual( 49 next(iter(got)), 50 next(iter(dataset_ops.DatasetV2.from_tensor_slices([1, 1])))) 51 52 def test_dataset_creator_usage_in_parameter_server_model_fit(self): 53 cluster_def = multi_worker_test_base.create_in_process_cluster( 54 num_workers=2, num_ps=1, rpc_layer="grpc") 55 cluster_def["chief"] = [ 56 "localhost:%d" % multi_worker_test_base.pick_unused_port() 57 ] 58 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 59 SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")) 60 with strategy.scope(): 61 model = sequential.Sequential([core_layers.Dense(10)]) 62 model.compile(gradient_descent.SGD(), loss="mse") 63 64 def dataset_fn(input_context): 65 global_batch_size = 64 66 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 67 dataset = dataset_ops.DatasetV2.from_tensors(([1.], [1.])).repeat() 68 dataset = dataset.shard(input_context.num_input_pipelines, 69 input_context.input_pipeline_id) 70 dataset = dataset.batch(batch_size) 71 dataset = dataset.prefetch(2) 72 return dataset 73 74 history = model.fit( 75 dataset_creator.DatasetCreator(dataset_fn), 76 epochs=10, 77 steps_per_epoch=10, 78 verbose=0) 79 self.assertLen(history.history["loss"], 10) 80 81 82if __name__ == "__main__": 83 v2_compat.enable_v2_behavior() 84 test.main() 85