1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An example training a Keras Model using MirroredStrategy and native APIs.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import tensorflow as tf 21 22 23from tensorflow.python.distribute import mirrored_strategy 24from tensorflow.python.keras.optimizer_v2 import rmsprop 25 26 27NUM_CLASSES = 10 28 29 30def get_input_datasets(use_bfloat16=False): 31 """Downloads the MNIST dataset and creates train and eval dataset objects. 32 33 Args: 34 use_bfloat16: Boolean to determine if input should be cast to bfloat16 35 36 Returns: 37 Train dataset, eval dataset and input shape. 38 39 """ 40 # input image dimensions 41 img_rows, img_cols = 28, 28 42 cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32 43 44 # the data, split between train and test sets 45 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() 46 47 if tf.keras.backend.image_data_format() == 'channels_first': 48 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 49 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 50 input_shape = (1, img_rows, img_cols) 51 else: 52 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 53 x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 54 input_shape = (img_rows, img_cols, 1) 55 56 x_train = x_train.astype('float32') 57 x_test = x_test.astype('float32') 58 x_train /= 255 59 x_test /= 255 60 61 # convert class vectors to binary class matrices 62 y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES) 63 y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES) 64 65 # train dataset 66 train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 67 train_ds = train_ds.repeat() 68 train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) 69 train_ds = train_ds.batch(64, drop_remainder=True) 70 71 # eval dataset 72 eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) 73 eval_ds = eval_ds.repeat() 74 eval_ds = eval_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) 75 eval_ds = eval_ds.batch(64, drop_remainder=True) 76 77 return train_ds, eval_ds, input_shape 78 79 80def get_model(input_shape): 81 """Builds a Sequential CNN model to recognize MNIST digits. 82 83 Args: 84 input_shape: Shape of the input depending on the `image_data_format`. 85 86 Returns: 87 a Keras model 88 89 """ 90 # Define a CNN model to recognize MNIST digits. 91 model = tf.keras.models.Sequential() 92 model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3), 93 activation='relu', 94 input_shape=input_shape)) 95 model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu')) 96 model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2))) 97 model.add(tf.keras.layers.Dropout(0.25)) 98 model.add(tf.keras.layers.Flatten()) 99 model.add(tf.keras.layers.Dense(128, activation='relu')) 100 model.add(tf.keras.layers.Dropout(0.5)) 101 model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')) 102 return model 103 104 105def main(_): 106 # Build the train and eval datasets from the MNIST data. Also return the 107 # input shape which is constructed based on the `image_data_format` 108 # i.e channels_first or channels_last. 109 tf.enable_eager_execution() 110 111 train_ds, eval_ds, input_shape = get_input_datasets() 112 113 # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or 114 # the `devices` argument then all the GPUs available on the machine are used. 115 # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. 116 strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) 117 118 # Create and compile the model under Distribution strategy scope. 119 # `fit`, `evaluate` and `predict` will be distributed based on the strategy 120 # model was compiled with. 121 with strategy.scope(): 122 model = get_model(input_shape) 123 optimizer = rmsprop.RMSProp(learning_rate=0.001) 124 model.compile(loss=tf.keras.losses.categorical_crossentropy, 125 optimizer=optimizer, 126 metrics=['accuracy']) 127 128 # Train the model with the train dataset. 129 model.fit(x=train_ds, epochs=20, steps_per_epoch=468) 130 131 # Evaluate the model with the eval dataset. 132 score = model.evaluate(eval_ds, steps=10, verbose=0) 133 print('Test loss:', score[0]) 134 print('Test accuracy:', score[1]) 135 136 137if __name__ == '__main__': 138 tf.app.run() 139