• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for `DatasetCreator` with `Model.fit` across usages and strategies."""
17
18import os
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python import keras
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import constant_op
26from tensorflow.python.keras import callbacks as callbacks_lib
27from tensorflow.python.keras.engine import sequential
28from tensorflow.python.keras.layers import core as core_layers
29from tensorflow.python.keras.layers.preprocessing import string_lookup
30from tensorflow.python.keras.optimizer_v2 import gradient_descent
31from tensorflow.python.keras.utils import dataset_creator
32from tensorflow.python.ops import random_ops
33from tensorflow.python.platform import test
34from tensorflow.python.platform import tf_logging as logging
35
36
37class DatasetCreatorModelFitTestBase(test.TestCase, parameterized.TestCase):
38  """The base class for DatasetCreator with Model.fit tests."""
39
40  def _get_dataset_fn(self, use_lookup_layer):
41
42    if use_lookup_layer:
43
44      filepath = os.path.join(self.get_temp_dir(), "vocab")
45      with open(filepath, "w") as f:
46        f.write("\n".join(["earth", "wind", "and", "fire"]))
47
48      def dataset_fn(input_context):
49        del input_context
50        lookup_layer = string_lookup.StringLookup(
51            num_oov_indices=1, vocabulary=filepath)
52        x = np.array([["earth", "wind", "and", "fire"],
53                      ["fire", "and", "earth", "michigan"]])
54        y = np.array([0, 1])
55        map_fn = lambda x, y: (lookup_layer(x), y)
56        return dataset_ops.DatasetV2.from_tensor_slices(
57            (x, y)).shuffle(10).repeat().batch(2).map(map_fn)
58
59    else:
60
61      def dataset_fn(input_context):
62        del input_context
63        x = random_ops.random_uniform((10, 10))
64        y = random_ops.random_uniform((10,))
65        return dataset_ops.DatasetV2.from_tensor_slices(
66            (x, y)).shuffle(10).repeat().batch(2)
67
68    return dataset_fn
69
70  def _model_compile(self,
71                     strategy,
72                     steps_per_execution=1,
73                     run_eagerly=False,
74                     with_normalization_layer=False,
75                     use_lookup_layer=False):
76
77    class ResultAssertingCallback(callbacks_lib.Callback):
78      """A callback that asserts the result of the tests."""
79
80      def __init__(self):
81        self._prev_epoch = -1
82
83      def on_epoch_end(self, epoch, logs=None):
84        logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs)
85        if epoch <= self._prev_epoch:
86          raise RuntimeError("Epoch is supposed to be larger than previous.")
87        self._prev_epoch = epoch
88        is_loss_float = (
89            logs.get("loss", None) is not None and
90            isinstance(logs["loss"], (float, np.floating)))
91        if not is_loss_float:
92          raise RuntimeError("loss is supposed to be in the logs and float.")
93
94    with strategy.scope():
95      model = sequential.Sequential([core_layers.Dense(10)])
96      if with_normalization_layer:
97        norm = keras.layers.BatchNormalization(
98            axis=-1, input_shape=(4, 4, 3), momentum=0.8)
99        model.add(norm)
100      model.add(core_layers.Dense(1, activation="sigmoid"))
101      self._accuracy_metric = keras.metrics.Accuracy()
102
103    model.compile(
104        gradient_descent.SGD(),
105        loss="binary_crossentropy",
106        metrics=[self._accuracy_metric],
107        steps_per_execution=steps_per_execution,
108        run_eagerly=run_eagerly)
109    return model, [ResultAssertingCallback()]
110
111  def _model_fit(self,
112                 strategy,
113                 steps_per_execution=1,
114                 validation_data=None,
115                 x=None,
116                 y=None,
117                 shuffle=True,
118                 batch_size=None,
119                 steps_per_epoch=10,
120                 run_eagerly=False,
121                 with_normalization_layer=False,
122                 callbacks=None,
123                 use_lookup_layer=False):
124    if callbacks is None:
125      callbacks = []
126
127    model, default_callbacks = self._model_compile(strategy,
128                                                   steps_per_execution,
129                                                   run_eagerly,
130                                                   with_normalization_layer,
131                                                   use_lookup_layer)
132    callbacks += default_callbacks
133
134    if x is None:
135      x = dataset_creator.DatasetCreator(self._get_dataset_fn(use_lookup_layer))
136
137    if validation_data is None:
138      validation_data = dataset_creator.DatasetCreator(
139          self._get_dataset_fn(use_lookup_layer))
140
141    model.fit(
142        x,
143        y,
144        shuffle=shuffle,
145        batch_size=batch_size,
146        epochs=10,
147        steps_per_epoch=steps_per_epoch,
148        callbacks=callbacks,
149        validation_data=validation_data,
150        validation_steps=steps_per_epoch)
151    return model
152
153  def _model_evaluate(self,
154                      strategy,
155                      steps_per_execution=1,
156                      x=None,
157                      y=None,
158                      batch_size=None,
159                      steps=10,
160                      run_eagerly=False,
161                      with_normalization_layer=False,
162                      callbacks=None):
163    if callbacks is None:
164      callbacks = []
165
166    model, default_callbacks = self._model_compile(
167        strategy,
168        steps_per_execution,
169        run_eagerly,
170        with_normalization_layer,
171    )
172    callbacks += default_callbacks
173
174    def dataset_fn(input_context):
175      del input_context
176      x = random_ops.random_uniform((10, 10))
177      y = random_ops.random_uniform((10, 1))
178      return dataset_ops.DatasetV2.from_tensor_slices(
179          (x, y)).shuffle(10).repeat().batch(8)
180
181    if x is None:
182      x = dataset_creator.DatasetCreator(dataset_fn)
183
184    model.evaluate(
185        x=x, y=y, steps=steps, callbacks=callbacks, batch_size=batch_size)
186    return model
187
188  def _model_predict(
189      self,
190      strategy,
191      model=None,
192      steps_per_execution=1,
193      test_data=None,
194      steps=10,
195      with_normalization_layer=False,
196  ):
197    callbacks = []
198
199    if model is None:
200      model, default_callbacks = self._model_compile(
201          strategy,
202          steps_per_execution,
203          with_normalization_layer=with_normalization_layer,
204      )
205      callbacks += default_callbacks
206
207    def create_test_data():
208      x = constant_op.constant([1., 2., 3., 1., 5., 1.])
209      return dataset_ops.DatasetV2.from_tensor_slices(x).repeat().batch(2)
210
211    if test_data is None:
212      test_data = create_test_data()
213
214    predictions = model.predict(x=test_data, steps=steps, callbacks=callbacks)
215    predictions = np.around(predictions, 4)
216    return model, predictions
217