1# TensorFlow Lite LSTM ops API 2 3TensorFlow Lite LSTM ops help developers deploy LSTM models to TensorFlow Lite. 4This is currently an experimental API, it's likely to change in future. 5 6## Introduction 7 8LSTM ops in TensorFlow Lite realm are expressed as "fused ops" (e.g., 9UnidirectionalSequenceRNN, BidirectionalSequenceLSTM, etc.). However, in 10TensorFlow, LSTM ops are expressed as a "cell" (e.g., `tf.nn.rnn_cell.LSTMCell`, 11`tf.nn.rnn_cell.BasicRNNCell`, etc., and they all contain multiple TensorFlow 12ops) and a "rnn" ( e.g., `tf.nn.static_rnn`, 13`tf.nn.bidirectional_dynamic_rnn`). 14 15The ops breakdown in TensorFlow gives us flexibility while the "fused op" in 16TensorFlow Lite gives us performance boost. 17 18See the difference between TensorFlow LSTM and TensorFlow Lite LSTM. 19 20##### TensorFlow LSTM op ("cell") 21 22![TensorFlow LSTM op](./images/tf_lstm.png) 23 24##### TensorFlow Lite LSTM op ("fused ops") 25 26![TensorFlow Lite LSTM op](./images/tflite_lstm.png) 27 28The TensorFlow LSTM figure is credited to this 29[blog](https://colah.github.io/posts/2015-08-Understanding-LSTMs/). 30 31## How to use 32 33To use TensorFlow Lite LSTM ops is actually pretty simple. 34 35### 1) Training & Evaluation. 36 37First step is replacing `tf.nn.rnn_cell.LSTMCell` with 38`tf.lite.experimental.nn.TFLiteLSTMCell` in training phase, and replacing 39`tf.nn.rnn.dynamic_rnn` with `tf.lite.experimental.nn.dynamic_rnn`, if you are 40using dynamic_rnn. Note you don't need to change if you're using static_rnn. 41 42Both `tf.lite.experimental.nn.TFLiteLSTMCell` & 43`tf.lite.experimental.nn.dynamic_rnn` are just normal `tf.nn.rnn_cell.LSTMCell` 44and `tf.nn.rnn.dynamic_rnn` with OpHinted nodes in it to help the graph 45transformation. 46 47Then you can train and export the model as usual. 48 49### 2) Export for TensorFlow Lite inference. 50 51When you want to convert to TensorFlow Lite model, here's one simple step you 52need to do for your frozen graph: 53 54```python 55with tf.Session() as sess: 56 ophinted_graph = tf.lite.experimental.convert_op_hints_to_stubs(session=sess) 57``` 58 59Then you can convert the model to TensorFlow Lite model as usual. 60 61```python 62converter = tf.lite.TFLiteConverter(ophinted_graph, [INPUTS], [OUTPUTS]) 63converter.post_training_quantize = True # If post training quantize is desired. 64tflite_model = converter.convert() # You got a tflite model! 65``` 66 67#### Simple example diff for using original TF code VS. TensorFlow Lite code: 68 69```python 70@@ -56,7 +56,7 @@ class MnistLstmModel(object): 71 for _ in range(self.num_lstm_layer): 72 lstm_layers.append( 73 # Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell`. 74- tf.nn.rnn_cell.LSTMCell( 75+ tf.lite.experimental.nn.TFLiteLSTMCell( 76 self.num_lstm_units, forget_bias=0)) 77 # Weights and biases for output softmax layer. 78 out_weights = tf.Variable(tf.random_normal([self.units, self.num_class])) 79@@ -67,7 +67,7 @@ class MnistLstmModel(object): 80 lstm_cells = tf.nn.rnn_cell.MultiRNNCell(lstm_layers) 81 # Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major` 82 # is set to True. 83- outputs, _ = tf.nn.dynamic_rnn( 84+ outputs, _ = tf.lite.experimental.nn.dynamic_rnn( 85 lstm_cells, lstm_inputs, dtype='float32', time_major=True) 86 87 # Transpose the outputs back to [batch, time, output] 88@@ -154,7 +154,9 @@ def export(model, model_dir, tflite_model_file, 89 sess, sess.graph_def, [output_class.op.name]) 90 91 # Convert ophinted lstm ops to tflite UnidirectionalSequenceLstm ops. 92- converted_graph = tf.graph_util.remove_training_nodes(frozen_graph) 93+ converted_graph = tf.lite.experimental.convert_op_hints_to_stubs( 94+ graph_def=frozen_graph) 95+ converted_graph = tf.graph_util.remove_training_nodes(converted_graph) 96 converter = tf.lite.TFLiteConverter(converted_graph, [x], [output_class]) 97 converter.post_training_quantize = use_post_training_quantize 98 tflite = converter.convert() 99``` 100 101## Why introduce another set of LSTM APIs? 102 103Bridging TensorFlow LSTM and TensorFlow Lite is not easy, and the use of 104`dynamic_rnn` adds additional complexity (as the while loop is introduced). 105With the help of 106[OpHint](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/op_hint.py) 107(also see the next section), we create special wrappers around `rnn_cell` and 108`rnn` to help us identify the inputs and outputs of the LSTM ops, and these 109ops are converted to a single fused LSTM op when converting TensorFlow models 110to TensorFlow Lite format. 111 112### What's OpHint 113 114`OpHint` is essentially `Identity` op that is inserted after input tensors and 115output tensors to "hint" the customized op boundary, see the following figure. 116 117##### Ophinted Customized Graph 118 119Let's say we have a "customized conv" op which is a normal conv2d op with a bias 120add op followed by an activation op (graph on the letf), we use `OpHint` to 121track down all the inputs and output, during the graph transformation phase 122(done by `tf.lite.experimental.convert_op_hints_to_stubs`), the conv2d op, bias 123add op and the activation op will become a "my customized conv" op (see the 124graph on the right), and all the "OpHinted" tensors will become the 125inputs/outputs of the "my customized conv" op. 126 127![Ophinted Customized Graph](./images/op_hint.png) 128 129 130## Simple Tutorial 131 132The following tutorial uses MNIST dataset to build a simple two-layer LSTM model 133and convert to quantized TensorFlow Lite model. 134 135Note since we will be using dynamic_rnn, we need to turn on `control_flow_v2`. 136 137### 0. Turn on `control_flow_v2`. 138 139```python 140# Note this needs to happen before import tensorflow. 141import os 142os.environ['TF_ENABLE_CONTROL_FLOW_V2'] = '1' 143``` 144 145### 1. Build the model. 146 147```python 148class MnistLstmModel(object): 149 """Build a simple LSTM based MNIST model. 150 151 Attributes: 152 time_steps: The maximum length of the time_steps, but since we're just using 153 the 'width' dimension as time_steps, it's actually a fixed number. 154 input_size: The LSTM layer input size. 155 num_lstm_layer: Number of LSTM layers for the stacked LSTM cell case. 156 num_lstm_units: Number of units in the LSTM cell. 157 units: The units for the last layer. 158 num_class: Number of classes to predict. 159 """ 160 161 def __init__(self, time_steps, input_size, num_lstm_layer, num_lstm_units, 162 units, num_class): 163 self.time_steps = time_steps 164 self.input_size = input_size 165 self.num_lstm_layer = num_lstm_layer 166 self.num_lstm_units = num_lstm_units 167 self.units = units 168 self.num_class = num_class 169 170 def build_model(self): 171 """Build the model using the given configs. 172 173 Returns: 174 x: The input placehoder tensor. 175 logits: The logits of the output. 176 output_class: The prediction. 177 """ 178 x = tf.placeholder( 179 'float32', [None, self.time_steps, self.input_size], name='INPUT') 180 lstm_layers = [] 181 for _ in range(self.num_lstm_layer): 182 lstm_layers.append( 183 # Important: 184 # 185 # Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell` 186 # (OpHinted LSTMCell). 187 tf.lite.experimental.nn.TFLiteLSTMCell( 188 self.num_lstm_units, forget_bias=0)) 189 # Weights and biases for output softmax layer. 190 out_weights = tf.Variable(tf.random_normal([self.units, self.num_class])) 191 out_bias = tf.Variable(tf.zeros([self.num_class])) 192 193 # Transpose input x to make it time major. 194 lstm_inputs = tf.transpose(x, perm=[1, 0, 2]) 195 lstm_cells = tf.keras.layers.StackedRNNCells(lstm_layers) 196 # Important: 197 # 198 # Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major` 199 # is set to True. 200 outputs, _ = tf.lite.experimental.nn.dynamic_rnn( 201 lstm_cells, lstm_inputs, dtype='float32', time_major=True) 202 203 # Transpose the outputs back to [batch, time, output] 204 outputs = tf.transpose(outputs, perm=[1, 0, 2]) 205 outputs = tf.unstack(outputs, axis=1) 206 logits = tf.matmul(outputs[-1], out_weights) + out_bias 207 output_class = tf.nn.softmax(logits, name='OUTPUT_CLASS') 208 209 return x, logits, output_class 210``` 211 212### 2. Let's define the train & eval function. 213 214```python 215def train(model, 216 model_dir, 217 batch_size=20, 218 learning_rate=0.001, 219 train_steps=2000, 220 eval_steps=500, 221 save_every_n_steps=1000): 222 """Train & save the MNIST recognition model.""" 223 # Train & test dataset. 224 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() 225 train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 226 train_iterator = train_dataset.shuffle( 227 buffer_size=1000).batch(batch_size).repeat().make_one_shot_iterator() 228 x, logits, output_class = model.build_model() 229 test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) 230 test_iterator = test_dataset.batch( 231 batch_size).repeat().make_one_shot_iterator() 232 # input label placeholder 233 y = tf.placeholder(tf.int32, [ 234 None, 235 ]) 236 one_hot_labels = tf.one_hot(y, depth=model.num_class) 237 # Loss function 238 loss = tf.reduce_mean( 239 tf.nn.softmax_cross_entropy_with_logits( 240 logits=logits, labels=one_hot_labels)) 241 correct = tf.nn.in_top_k(output_class, y, 1) 242 accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) 243 # Optimization 244 opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) 245 246 # Initialize variables 247 init = tf.global_variables_initializer() 248 saver = tf.train.Saver() 249 batch_x, batch_y = train_iterator.get_next() 250 batch_test_x, batch_test_y = test_iterator.get_next() 251 with tf.Session() as sess: 252 sess.run([init]) 253 for i in range(train_steps): 254 batch_x_value, batch_y_value = sess.run([batch_x, batch_y]) 255 _, loss_value = sess.run([opt, loss], 256 feed_dict={ 257 x: batch_x_value, 258 y: batch_y_value 259 }) 260 if i % 100 == 0: 261 tf.logging.info('Training step %d, loss is %f' % (i, loss_value)) 262 if i > 0 and i % save_every_n_steps == 0: 263 accuracy_sum = 0.0 264 for _ in range(eval_steps): 265 test_x_value, test_y_value = sess.run([batch_test_x, batch_test_y]) 266 accuracy_value = sess.run( 267 accuracy, feed_dict={ 268 x: test_x_value, 269 y: test_y_value 270 }) 271 accuracy_sum += accuracy_value 272 tf.logging.info('Training step %d, accuracy is %f' % 273 (i, accuracy_sum / (eval_steps * 1.0))) 274 saver.save(sess, model_dir) 275``` 276 277### 3. Let's define the export to TensorFlow Lite model function. 278 279```python 280def export(model, model_dir, tflite_model_file, 281 use_post_training_quantize=True): 282 """Export trained model to tflite model.""" 283 tf.reset_default_graph() 284 x, _, output_class = model.build_model() 285 saver = tf.train.Saver() 286 sess = tf.Session() 287 saver.restore(sess, model_dir) 288 # Freeze the graph. 289 frozen_graph = tf.graph_util.convert_variables_to_constants( 290 sess, sess.graph_def, [output_class.op.name]) 291 292 # Important: 293 # 294 # Convert ophinted lstm ops to tflite UnidirectionalSequenceLstm ops. 295 converted_graph = 296 tf.lite.experimental.convert_op_hints_to_stubs(graph_def=frozen_graph) 297 converted_graph = tf.graph_util.remove_training_nodes(converted_graph) 298 converter = tf.lite.TFLiteConverter(converted_graph, [x], [output_class]) 299 converter.post_training_quantize = use_post_training_quantize 300 tflite = converter.convert() 301 with open(tflite_model_file, 'w') as f: 302 f.write(tflite) 303``` 304 305### 4. Hook everything together. 306 307```python 308def train_and_export(parsed_flags): 309 """Train the MNIST LSTM model and export to TfLite.""" 310 model = MnistLstmModel( 311 time_steps=28, 312 input_size=28, 313 num_lstm_layer=2, 314 num_lstm_units=64, 315 units=64, 316 num_class=10) 317 tf.logging.info('Starts training...') 318 train(model, parsed_flags.model_dir) 319 tf.logging.info('Finished training, starts exporting to tflite to %s ...' % 320 parsed_flags.tflite_model_file) 321 export(model, parsed_flags.model_dir, parsed_flags.tflite_model_file, 322 parsed_flags.use_post_training_quantize) 323 tf.logging.info( 324 'Finished exporting, model is %s' % parsed_flags.tflite_model_file) 325 326 327def run_main(_): 328 """Main in the TfLite LSTM tutorial.""" 329 parser = argparse.ArgumentParser( 330 description=('Train a MNIST recognition model then export to TfLite.')) 331 parser.add_argument( 332 '--model_dir', 333 type=str, 334 help='Directory where the models will store.', 335 required=True) 336 parser.add_argument( 337 '--tflite_model_file', 338 type=str, 339 help='Full filepath to the exported tflite model file.', 340 required=True) 341 parser.add_argument( 342 '--use_post_training_quantize', 343 action='store_true', 344 default=True, 345 help='Whether or not to use post_training_quatize.') 346 parsed_flags, _ = parser.parse_known_args() 347 train_and_export(parsed_flags) 348 349 350def main(): 351 app.run(main=run_main, argv=sys.argv[:1]) 352 353 354if __name__ == '__main__': 355 main() 356 357``` 358 359### 5. Visualize the exported TensorFlow Lite model. 360 361Let's go to where the TensorFlow Lite model is exported and use 362[Netron](https://github.com/lutzroeder/netron) to visualize the graph. 363 364See below. 365 366##### Exported TensorFlow Lite Model. 367 368![Exported TensorFlow Lite Model](./images/exported_tflite_model.png) 369 370## Caveat 371 372* Currently, `tf.lite.experimental.nn.dynamic_rnn` & 373 `tf.lite.experimental.nn.bidirectional_dynamic_rnn` only supports 374 `control_flow_v2`, you can this on by setting the environment variable 375 `TF_ENABLE_CONTROL_FLOW_V2=1`, see in the tutorial. 376* Currently, `sequence_length` is not supported, prefer to set it to None. 377* `num_unit_shards` & `num_proj_shards` in LSTMCell are not supported as 378 well. 379* Currently, `tf.lite.experimental.nn.dynamic_rnn` & 380 `tf.lite.experimental.nn.bidirectional_dynamic_rnn` only takes 381 `time_major=True`. 382* The behavior of `tf.lite.experimental.nn.bidirectional_dynamic_rnn` is a 383 wrapper around `tf.nn.bidirectional_dynamic_rnn`, not 384 `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`. 385* For bidirectional_rnn cases, make sure you include all the op_hinted nodes 386 before freeze the graph. See below: 387 388```python 389all_output_nodes = [OUTPUT_NODES] 390with tf.Session() as sess 391 all_output_nodes += tf.lite.find_all_hinted_output_nodes(sess) 392 frozen_graph = tf.graph_util.convert_variables_to_constants( 393 sess, sess.graph_def, all_output_nodes) 394``` 395