• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""MNIST model float training script with TensorFlow graph execution."""
15
16import os
17from absl import flags
18
19import tensorflow as tf
20import tensorflow_datasets as tfds
21from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops
22from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs  # pylint: disable=unused-import
23from tensorflow.python.framework import load_library
24
25flags.DEFINE_integer('train_steps', 20, 'Number of steps in training.')
26
27_lib_dir = os.path.dirname(gen_mnist_ops.__file__)
28_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so')
29load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
30
31# MNIST dataset parameters.
32num_classes = 10  # total classes (0-9 digits).
33num_features = 784  # data features (img shape: 28*28).
34num_channels = 1
35
36# Training parameters.
37learning_rate = 0.001
38display_step = 10
39batch_size = 32
40
41# Network parameters.
42n_hidden_1 = 32  # 1st conv layer number of neurons.
43n_hidden_2 = 64  # 2nd conv layer number of neurons.
44n_hidden_3 = 64  # 1st fully connected layer of neurons.
45flatten_size = num_features // 16 * n_hidden_2
46
47seed = 66478
48
49
50class FloatModel(tf.Module):
51  """Float inference for mnist model."""
52
53  def __init__(self):
54    self.weights = {
55        'f1':
56            tf.Variable(
57                tf.random.truncated_normal([5, 5, num_channels, n_hidden_1],
58                                           stddev=0.1,
59                                           seed=seed)),
60        'f2':
61            tf.Variable(
62                tf.random.truncated_normal([5, 5, n_hidden_1, n_hidden_2],
63                                           stddev=0.1,
64                                           seed=seed)),
65        'f3':
66            tf.Variable(
67                tf.random.truncated_normal([n_hidden_3, flatten_size],
68                                           stddev=0.1,
69                                           seed=seed)),
70        'f4':
71            tf.Variable(
72                tf.random.truncated_normal([num_classes, n_hidden_3],
73                                           stddev=0.1,
74                                           seed=seed)),
75    }
76
77    self.biases = {
78        'b1': tf.Variable(tf.zeros([n_hidden_1])),
79        'b2': tf.Variable(tf.zeros([n_hidden_2])),
80        'b3': tf.Variable(tf.zeros([n_hidden_3])),
81        'b4': tf.Variable(tf.zeros([num_classes])),
82    }
83
84  @tf.function
85  def __call__(self, data):
86    """The Model definition."""
87    x = tf.reshape(data, [-1, 28, 28, 1])
88
89    # 2D convolution, with 'SAME' padding (i.e. the output feature map has
90    # the same size as the input).
91
92    # NOTE: The data/x/input is always specified in floating point precision.
93    # output shape: [-1, 28, 28, 32]
94    conv1 = gen_mnist_ops.new_conv2d(x, self.weights['f1'], self.biases['b1'],
95                                     1, 1, 1, 1, 'SAME', 'RELU')
96
97    # Max pooling. The kernel size spec {ksize} also follows the layout of
98    # the data. Here we have a pooling window of 2, and a stride of 2.
99    # output shape: [-1, 14, 14, 32]
100    max_pool1 = gen_mnist_ops.new_max_pool(conv1, 2, 2, 2, 2, 'SAME')
101
102    # output shape: [-1, 14, 14, 64]
103    conv2 = gen_mnist_ops.new_conv2d(max_pool1, self.weights['f2'],
104                                     self.biases['b2'], 1, 1, 1, 1, 'SAME',
105                                     'RELU')
106
107    # output shape: [-1, 7, 7, 64]
108    max_pool2 = gen_mnist_ops.new_max_pool(conv2, 2, 2, 2, 2, 'SAME')
109
110    # Reshape the feature map cuboid into a 2D matrix to feed it to the
111    # fully connected layers.
112    # output shape: [-1, 7*7*64]
113    reshape = tf.reshape(max_pool2, [-1, flatten_size])
114
115    # output shape: [-1, 1024]
116    fc1 = gen_mnist_ops.new_fully_connected(reshape, self.weights['f3'],
117                                            self.biases['b3'], 'RELU')
118    # output shape: [-1, 10]
119    return gen_mnist_ops.new_fully_connected(fc1, self.weights['f4'],
120                                             self.biases['b4'])
121
122
123def main(strategy):
124  """Trains an MNIST model using the given tf.distribute.Strategy."""
125  # TODO(fengliuai): put this in some automatically generated code.
126  os.environ[
127      'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist'
128
129  ds_train = tfds.load('mnist', split='train', shuffle_files=True)
130  ds_train = ds_train.shuffle(1024).batch(batch_size).prefetch(64)
131  ds_train = strategy.experimental_distribute_dataset(ds_train)
132
133  with strategy.scope():
134    # Create an mnist float model with the specified float state.
135    model = FloatModel()
136    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
137
138  def train_step(features):
139    inputs = tf.image.convert_image_dtype(
140        features['image'], dtype=tf.float32, saturate=False)
141    labels = tf.one_hot(features['label'], num_classes)
142
143    with tf.GradientTape() as tape:
144      logits = model(inputs)
145      loss_value = tf.reduce_mean(
146          tf.nn.softmax_cross_entropy_with_logits(labels, logits))
147
148    grads = tape.gradient(loss_value, model.trainable_variables)
149    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
150    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
151    optimizer.apply_gradients(zip(grads, model.trainable_variables))
152    return accuracy, loss_value
153
154  @tf.function
155  def distributed_train_step(dist_inputs):
156    per_replica_accuracy, per_replica_losses = strategy.run(
157        train_step, args=(dist_inputs,))
158    accuracy = strategy.reduce(
159        tf.distribute.ReduceOp.MEAN, per_replica_accuracy, axis=None)
160    loss_value = strategy.reduce(
161        tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
162    return accuracy, loss_value
163
164  iterator = iter(ds_train)
165  accuracy = 0.0
166  for step in range(flags.FLAGS.train_steps):
167    accuracy, loss_value = distributed_train_step(next(iterator))
168    if step % display_step == 0:
169      tf.print('Step %d:' % step)
170      tf.print('    Loss = %f' % loss_value)
171      tf.print('    Batch accuracy = %f' % accuracy)
172
173  return accuracy
174