1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A simple MNIST model for testing multi-worker distribution strategies with Keras.""" 16 17import tensorflow as tf 18 19 20def mnist_synthetic_dataset(batch_size, steps_per_epoch): 21 """Generate synthetic MNIST dataset for testing.""" 22 # train dataset 23 x_train = tf.ones([batch_size * steps_per_epoch, 28, 28, 1], 24 dtype=tf.dtypes.float32) 25 y_train = tf.ones([batch_size * steps_per_epoch, 1], dtype=tf.dtypes.int32) 26 train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 27 train_ds = train_ds.repeat() 28 # train_ds = train_ds.shuffle(100) 29 train_ds = train_ds.batch(64, drop_remainder=True) 30 31 # eval dataset 32 x_test = tf.random.uniform([10000, 28, 28, 1], dtype=tf.dtypes.float32) 33 y_test = tf.random.uniform([10000, 1], 34 minval=0, 35 maxval=9, 36 dtype=tf.dtypes.int32) 37 eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) 38 eval_ds = eval_ds.batch(64, drop_remainder=True) 39 40 return train_ds, eval_ds 41 42 43def get_mnist_model(input_shape): 44 """Define a deterministically-initialized CNN model for MNIST testing.""" 45 inputs = tf.keras.Input(shape=input_shape) 46 x = tf.keras.layers.Conv2D( 47 32, 48 kernel_size=(3, 3), 49 activation="relu", 50 kernel_initializer=tf.keras.initializers.TruncatedNormal(seed=99))( 51 inputs) 52 x = tf.keras.layers.BatchNormalization()(x) 53 x = tf.keras.layers.Flatten()(x) + tf.keras.layers.Flatten()(x) 54 x = tf.keras.layers.Dense( 55 10, 56 activation="softmax", 57 kernel_initializer=tf.keras.initializers.TruncatedNormal(seed=99))( 58 x) 59 model = tf.keras.Model(inputs=inputs, outputs=x) 60 61 # TODO(yuefengz): optimizer with slot variables doesn't work because of 62 # optimizer's bug. 63 # TODO(yuefengz): we should not allow non-v2 optimizer. 64 model.compile( 65 loss=tf.keras.losses.sparse_categorical_crossentropy, 66 optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 67 metrics=["accuracy"]) 68 return model 69