1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-classes-have-attributes 16"""Input dataset creator for `model.fit`.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.python.data.ops import dataset_ops 21 22 23class DatasetCreator(object): 24 """Object that returns a `tf.data.Dataset` upon invoking. 25 26 `DatasetCreator` is designated as a supported type for `x`, or the input, in 27 `tf.keras.Model.fit`. Pass an instance of this class to `fit` when using a 28 callable (with a `input_context` argument) that returns a `tf.data.Dataset`. 29 30 ```python 31 model = tf.keras.Sequential([tf.keras.layers.Dense(10)]) 32 model.compile(tf.keras.optimizers.SGD(), loss="mse") 33 34 def dataset_fn(input_context): 35 global_batch_size = 64 36 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 37 dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat() 38 dataset = dataset.shard( 39 input_context.num_input_pipelines, input_context.input_pipeline_id) 40 dataset = dataset.batch(batch_size) 41 dataset = dataset.prefetch(2) 42 return dataset 43 44 model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10) 45 ``` 46 47 Args: 48 dataset_fn: A callable that takes a single argument of type 49 `tf.distribute.InputContext`, which is used for batch size calculation and 50 cross-worker input pipeline sharding (if neither is needed, the 51 `InputContext` parameter can be ignored in the `dataset_fn`), and returns 52 a `tf.data.Dataset`. 53 """ 54 55 def __init__(self, dataset_fn): 56 if not callable(dataset_fn): 57 raise TypeError('`dataset_fn` for `DatasetCreator` must be a `callable`.') 58 self.dataset_fn = dataset_fn 59 60 def __call__(self, *args, **kwargs): 61 # When a `DatasetCreator` is invoked, it forwards args/kwargs straight to 62 # the callable. 63 dataset = self.dataset_fn(*args, **kwargs) 64 if not isinstance(dataset, dataset_ops.DatasetV2): 65 raise TypeError('The `callable` provided to `DatasetCreator` must return ' 66 'a Dataset.') 67 return dataset 68