• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for dataset_creator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from tensorflow.python.compat import v2_compat
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.distribute import multi_worker_test_base
23from tensorflow.python.distribute import parameter_server_strategy_v2
24from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
25from tensorflow.python.keras.engine import sequential
26from tensorflow.python.keras.layers import core as core_layers
27from tensorflow.python.keras.optimizer_v2 import gradient_descent
28from tensorflow.python.keras.utils import dataset_creator
29from tensorflow.python.platform import test
30from tensorflow.python.training.server_lib import ClusterSpec
31
32
33class DatasetCreatorTest(test.TestCase):
34
35  def test_dataset_creator(self):
36    with self.assertRaisesRegex(
37        TypeError, "`dataset_fn` for `DatasetCreator` must be a `callable`."):
38      dataset_creator.DatasetCreator(2)
39
40    dataset_fn = lambda: 3
41    with self.assertRaisesRegex(
42        TypeError, "The `callable` provided to `DatasetCreator` must return "
43        "a Dataset."):
44      dataset_creator.DatasetCreator(dataset_fn)()
45
46    dataset_fn = lambda: dataset_ops.DatasetV2.from_tensor_slices([1, 1])
47    got = dataset_creator.DatasetCreator(dataset_fn)()
48    self.assertEqual(
49        next(iter(got)),
50        next(iter(dataset_ops.DatasetV2.from_tensor_slices([1, 1]))))
51
52  def test_dataset_creator_usage_in_parameter_server_model_fit(self):
53    cluster_def = multi_worker_test_base.create_in_process_cluster(
54        num_workers=2, num_ps=1, rpc_layer="grpc")
55    cluster_def["chief"] = [
56        "localhost:%d" % multi_worker_test_base.pick_unused_port()
57    ]
58    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
59        SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))
60    with strategy.scope():
61      model = sequential.Sequential([core_layers.Dense(10)])
62    model.compile(gradient_descent.SGD(), loss="mse")
63
64    def dataset_fn(input_context):
65      global_batch_size = 64
66      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
67      dataset = dataset_ops.DatasetV2.from_tensors(([1.], [1.])).repeat()
68      dataset = dataset.shard(input_context.num_input_pipelines,
69                              input_context.input_pipeline_id)
70      dataset = dataset.batch(batch_size)
71      dataset = dataset.prefetch(2)
72      return dataset
73
74    history = model.fit(
75        dataset_creator.DatasetCreator(dataset_fn),
76        epochs=10,
77        steps_per_epoch=10,
78        verbose=0)
79    self.assertLen(history.history["loss"], 10)
80
81
82if __name__ == "__main__":
83  v2_compat.enable_v2_behavior()
84  test.main()
85