1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""MNIST model float training script with TensorFlow graph execution.""" 15 16import os 17from absl import flags 18 19import tensorflow as tf 20import tensorflow_datasets as tfds 21from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops 22from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs # pylint: disable=unused-import 23from tensorflow.python.framework import load_library 24 25flags.DEFINE_integer('train_steps', 20, 'Number of steps in training.') 26 27_lib_dir = os.path.dirname(gen_mnist_ops.__file__) 28_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') 29load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) 30 31# MNIST dataset parameters. 32num_classes = 10 # total classes (0-9 digits). 33num_features = 784 # data features (img shape: 28*28). 34num_channels = 1 35 36# Training parameters. 37learning_rate = 0.001 38display_step = 10 39batch_size = 32 40 41# Network parameters. 42n_hidden_1 = 32 # 1st conv layer number of neurons. 43n_hidden_2 = 64 # 2nd conv layer number of neurons. 44n_hidden_3 = 64 # 1st fully connected layer of neurons. 45flatten_size = num_features // 16 * n_hidden_2 46 47seed = 66478 48 49 50class FloatModel(tf.Module): 51 """Float inference for mnist model.""" 52 53 def __init__(self): 54 self.weights = { 55 'f1': 56 tf.Variable( 57 tf.random.truncated_normal([5, 5, num_channels, n_hidden_1], 58 stddev=0.1, 59 seed=seed)), 60 'f2': 61 tf.Variable( 62 tf.random.truncated_normal([5, 5, n_hidden_1, n_hidden_2], 63 stddev=0.1, 64 seed=seed)), 65 'f3': 66 tf.Variable( 67 tf.random.truncated_normal([n_hidden_3, flatten_size], 68 stddev=0.1, 69 seed=seed)), 70 'f4': 71 tf.Variable( 72 tf.random.truncated_normal([num_classes, n_hidden_3], 73 stddev=0.1, 74 seed=seed)), 75 } 76 77 self.biases = { 78 'b1': tf.Variable(tf.zeros([n_hidden_1])), 79 'b2': tf.Variable(tf.zeros([n_hidden_2])), 80 'b3': tf.Variable(tf.zeros([n_hidden_3])), 81 'b4': tf.Variable(tf.zeros([num_classes])), 82 } 83 84 @tf.function 85 def __call__(self, data): 86 """The Model definition.""" 87 x = tf.reshape(data, [-1, 28, 28, 1]) 88 89 # 2D convolution, with 'SAME' padding (i.e. the output feature map has 90 # the same size as the input). 91 92 # NOTE: The data/x/input is always specified in floating point precision. 93 # output shape: [-1, 28, 28, 32] 94 conv1 = gen_mnist_ops.new_conv2d(x, self.weights['f1'], self.biases['b1'], 95 1, 1, 1, 1, 'SAME', 'RELU') 96 97 # Max pooling. The kernel size spec {ksize} also follows the layout of 98 # the data. Here we have a pooling window of 2, and a stride of 2. 99 # output shape: [-1, 14, 14, 32] 100 max_pool1 = gen_mnist_ops.new_max_pool(conv1, 2, 2, 2, 2, 'SAME') 101 102 # output shape: [-1, 14, 14, 64] 103 conv2 = gen_mnist_ops.new_conv2d(max_pool1, self.weights['f2'], 104 self.biases['b2'], 1, 1, 1, 1, 'SAME', 105 'RELU') 106 107 # output shape: [-1, 7, 7, 64] 108 max_pool2 = gen_mnist_ops.new_max_pool(conv2, 2, 2, 2, 2, 'SAME') 109 110 # Reshape the feature map cuboid into a 2D matrix to feed it to the 111 # fully connected layers. 112 # output shape: [-1, 7*7*64] 113 reshape = tf.reshape(max_pool2, [-1, flatten_size]) 114 115 # output shape: [-1, 1024] 116 fc1 = gen_mnist_ops.new_fully_connected(reshape, self.weights['f3'], 117 self.biases['b3'], 'RELU') 118 # output shape: [-1, 10] 119 return gen_mnist_ops.new_fully_connected(fc1, self.weights['f4'], 120 self.biases['b4']) 121 122 123def main(strategy): 124 """Trains an MNIST model using the given tf.distribute.Strategy.""" 125 # TODO(fengliuai): put this in some automatically generated code. 126 os.environ[ 127 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' 128 129 ds_train = tfds.load('mnist', split='train', shuffle_files=True) 130 ds_train = ds_train.shuffle(1024).batch(batch_size).prefetch(64) 131 ds_train = strategy.experimental_distribute_dataset(ds_train) 132 133 with strategy.scope(): 134 # Create an mnist float model with the specified float state. 135 model = FloatModel() 136 optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 137 138 def train_step(features): 139 inputs = tf.image.convert_image_dtype( 140 features['image'], dtype=tf.float32, saturate=False) 141 labels = tf.one_hot(features['label'], num_classes) 142 143 with tf.GradientTape() as tape: 144 logits = model(inputs) 145 loss_value = tf.reduce_mean( 146 tf.nn.softmax_cross_entropy_with_logits(labels, logits)) 147 148 grads = tape.gradient(loss_value, model.trainable_variables) 149 correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) 150 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 151 optimizer.apply_gradients(zip(grads, model.trainable_variables)) 152 return accuracy, loss_value 153 154 @tf.function 155 def distributed_train_step(dist_inputs): 156 per_replica_accuracy, per_replica_losses = strategy.run( 157 train_step, args=(dist_inputs,)) 158 accuracy = strategy.reduce( 159 tf.distribute.ReduceOp.MEAN, per_replica_accuracy, axis=None) 160 loss_value = strategy.reduce( 161 tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) 162 return accuracy, loss_value 163 164 iterator = iter(ds_train) 165 accuracy = 0.0 166 for step in range(flags.FLAGS.train_steps): 167 accuracy, loss_value = distributed_train_step(next(iterator)) 168 if step % display_step == 0: 169 tf.print('Step %d:' % step) 170 tf.print(' Loss = %f' % loss_value) 171 tf.print(' Batch accuracy = %f' % accuracy) 172 173 return accuracy 174