• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An example training a Keras Model using MirroredStrategy and native APIs."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow as tf
21
22
23from tensorflow.python.distribute import mirrored_strategy
24from tensorflow.python.keras.optimizer_v2 import rmsprop
25
26
27NUM_CLASSES = 10
28
29
30def get_input_datasets(use_bfloat16=False):
31  """Downloads the MNIST dataset and creates train and eval dataset objects.
32
33  Args:
34    use_bfloat16: Boolean to determine if input should be cast to bfloat16
35
36  Returns:
37    Train dataset, eval dataset and input shape.
38
39  """
40  # input image dimensions
41  img_rows, img_cols = 28, 28
42  cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32
43
44  # the data, split between train and test sets
45  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
46
47  if tf.keras.backend.image_data_format() == 'channels_first':
48    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
49    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
50    input_shape = (1, img_rows, img_cols)
51  else:
52    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
53    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
54    input_shape = (img_rows, img_cols, 1)
55
56  x_train = x_train.astype('float32')
57  x_test = x_test.astype('float32')
58  x_train /= 255
59  x_test /= 255
60
61  # convert class vectors to binary class matrices
62  y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
63  y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)
64
65  # train dataset
66  train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
67  train_ds = train_ds.repeat()
68  train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
69  train_ds = train_ds.batch(64, drop_remainder=True)
70
71  # eval dataset
72  eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
73  eval_ds = eval_ds.repeat()
74  eval_ds = eval_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
75  eval_ds = eval_ds.batch(64, drop_remainder=True)
76
77  return train_ds, eval_ds, input_shape
78
79
80def get_model(input_shape):
81  """Builds a Sequential CNN model to recognize MNIST digits.
82
83  Args:
84    input_shape: Shape of the input depending on the `image_data_format`.
85
86  Returns:
87    a Keras model
88
89  """
90  # Define a CNN model to recognize MNIST digits.
91  model = tf.keras.models.Sequential()
92  model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3),
93                                   activation='relu',
94                                   input_shape=input_shape))
95  model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
96  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
97  model.add(tf.keras.layers.Dropout(0.25))
98  model.add(tf.keras.layers.Flatten())
99  model.add(tf.keras.layers.Dense(128, activation='relu'))
100  model.add(tf.keras.layers.Dropout(0.5))
101  model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
102  return model
103
104
105def main(_):
106  # Build the train and eval datasets from the MNIST data. Also return the
107  # input shape which is constructed based on the `image_data_format`
108  # i.e channels_first or channels_last.
109  tf.enable_eager_execution()
110
111  train_ds, eval_ds, input_shape = get_input_datasets()
112
113  # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or
114  # the `devices` argument then all the GPUs available on the machine are used.
115  # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available.
116  strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0'])
117
118  # Create and compile the model under Distribution strategy scope.
119  # `fit`, `evaluate` and `predict` will be distributed based on the strategy
120  # model was compiled with.
121  with strategy.scope():
122    model = get_model(input_shape)
123    optimizer = rmsprop.RMSProp(learning_rate=0.001)
124    model.compile(loss=tf.keras.losses.categorical_crossentropy,
125                  optimizer=optimizer,
126                  metrics=['accuracy'])
127
128  # Train the model with the train dataset.
129  model.fit(x=train_ds, epochs=20, steps_per_epoch=468)
130
131  # Evaluate the model with the eval dataset.
132  score = model.evaluate(eval_ds, steps=10, verbose=0)
133  print('Test loss:', score[0])
134  print('Test accuracy:', score[1])
135
136
137if __name__ == '__main__':
138  tf.app.run()
139