• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A simple MNIST model for testing multi-worker distribution strategies with Keras."""
16
17import tensorflow as tf
18
19
20def mnist_synthetic_dataset(batch_size, steps_per_epoch):
21  """Generate synthetic MNIST dataset for testing."""
22  # train dataset
23  x_train = tf.ones([batch_size * steps_per_epoch, 28, 28, 1],
24                    dtype=tf.dtypes.float32)
25  y_train = tf.ones([batch_size * steps_per_epoch, 1], dtype=tf.dtypes.int32)
26  train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
27  train_ds = train_ds.repeat()
28  # train_ds = train_ds.shuffle(100)
29  train_ds = train_ds.batch(64, drop_remainder=True)
30
31  # eval dataset
32  x_test = tf.random.uniform([10000, 28, 28, 1], dtype=tf.dtypes.float32)
33  y_test = tf.random.uniform([10000, 1],
34                             minval=0,
35                             maxval=9,
36                             dtype=tf.dtypes.int32)
37  eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
38  eval_ds = eval_ds.batch(64, drop_remainder=True)
39
40  return train_ds, eval_ds
41
42
43def get_mnist_model(input_shape):
44  """Define a deterministically-initialized CNN model for MNIST testing."""
45  inputs = tf.keras.Input(shape=input_shape)
46  x = tf.keras.layers.Conv2D(
47      32,
48      kernel_size=(3, 3),
49      activation="relu",
50      kernel_initializer=tf.keras.initializers.TruncatedNormal(seed=99))(
51          inputs)
52  x = tf.keras.layers.BatchNormalization()(x)
53  x = tf.keras.layers.Flatten()(x) + tf.keras.layers.Flatten()(x)
54  x = tf.keras.layers.Dense(
55      10,
56      activation="softmax",
57      kernel_initializer=tf.keras.initializers.TruncatedNormal(seed=99))(
58          x)
59  model = tf.keras.Model(inputs=inputs, outputs=x)
60
61  # TODO(yuefengz): optimizer with slot variables doesn't work because of
62  # optimizer's bug.
63  # TODO(yuefengz): we should not allow non-v2 optimizer.
64  model.compile(
65      loss=tf.keras.losses.sparse_categorical_crossentropy,
66      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
67      metrics=["accuracy"])
68  return model
69