• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# TensorFlow Lite LSTM ops API
2
3TensorFlow Lite LSTM ops help developers deploy LSTM models to TensorFlow Lite.
4This is currently an experimental API, it's likely to change in future.
5
6## Introduction
7
8LSTM ops in TensorFlow Lite realm are expressed as "fused ops" (e.g.,
9UnidirectionalSequenceRNN, BidirectionalSequenceLSTM, etc.). However, in
10TensorFlow, LSTM ops are expressed as a "cell" (e.g., `tf.nn.rnn_cell.LSTMCell`,
11`tf.nn.rnn_cell.BasicRNNCell`, etc., and they all contain multiple TensorFlow
12ops) and a "rnn" ( e.g., `tf.nn.static_rnn`,
13`tf.nn.bidirectional_dynamic_rnn`).
14
15The ops breakdown in TensorFlow gives us flexibility while the "fused op" in
16TensorFlow Lite gives us performance boost.
17
18See the difference between TensorFlow LSTM and TensorFlow Lite LSTM.
19
20##### TensorFlow LSTM op ("cell")
21
22![TensorFlow LSTM op](./images/tf_lstm.png)
23
24##### TensorFlow Lite LSTM op ("fused ops")
25
26![TensorFlow Lite LSTM op](./images/tflite_lstm.png)
27
28The TensorFlow LSTM figure is credited to this
29[blog](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).
30
31## How to use
32
33To use TensorFlow Lite LSTM ops is actually pretty simple.
34
35### 1) Training & Evaluation.
36
37First step is replacing `tf.nn.rnn_cell.LSTMCell` with
38`tf.lite.experimental.nn.TFLiteLSTMCell` in training phase, and replacing
39`tf.nn.rnn.dynamic_rnn` with `tf.lite.experimental.nn.dynamic_rnn`, if you are
40using dynamic_rnn. Note you don't need to change if you're using static_rnn.
41
42Both `tf.lite.experimental.nn.TFLiteLSTMCell` &
43`tf.lite.experimental.nn.dynamic_rnn` are just normal `tf.nn.rnn_cell.LSTMCell`
44and `tf.nn.rnn.dynamic_rnn` with OpHinted nodes in it to help the graph
45transformation.
46
47Then you can train and export the model as usual.
48
49### 2) Export for TensorFlow Lite inference.
50
51When you want to convert to TensorFlow Lite model, here's one simple step you
52need to do for your frozen graph:
53
54```python
55with tf.Session() as sess:
56  ophinted_graph = tf.lite.experimental.convert_op_hints_to_stubs(session=sess)
57```
58
59Then you can convert the model to TensorFlow Lite model as usual.
60
61```python
62converter = tf.lite.TFLiteConverter(ophinted_graph, [INPUTS], [OUTPUTS])
63converter.post_training_quantize = True  # If post training quantize is desired.
64tflite_model = converter.convert()  # You got a tflite model!
65```
66
67#### Simple example diff for using original TF code VS. TensorFlow Lite code:
68
69```python
70@@ -56,7 +56,7 @@ class MnistLstmModel(object):
71     for _ in range(self.num_lstm_layer):
72       lstm_layers.append(
73           # Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell`.
74-          tf.nn.rnn_cell.LSTMCell(
75+          tf.lite.experimental.nn.TFLiteLSTMCell(
76               self.num_lstm_units, forget_bias=0))
77     # Weights and biases for output softmax layer.
78     out_weights = tf.Variable(tf.random_normal([self.units, self.num_class]))
79@@ -67,7 +67,7 @@ class MnistLstmModel(object):
80     lstm_cells = tf.nn.rnn_cell.MultiRNNCell(lstm_layers)
81     # Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major`
82     # is set to True.
83-    outputs, _ = tf.nn.dynamic_rnn(
84+    outputs, _ = tf.lite.experimental.nn.dynamic_rnn(
85         lstm_cells, lstm_inputs, dtype='float32', time_major=True)
86
87     # Transpose the outputs back to [batch, time, output]
88@@ -154,7 +154,9 @@ def export(model, model_dir, tflite_model_file,
89       sess, sess.graph_def, [output_class.op.name])
90
91   # Convert ophinted lstm ops to tflite UnidirectionalSequenceLstm ops.
92-  converted_graph = tf.graph_util.remove_training_nodes(frozen_graph)
93+  converted_graph = tf.lite.experimental.convert_op_hints_to_stubs(
94+      graph_def=frozen_graph)
95+  converted_graph = tf.graph_util.remove_training_nodes(converted_graph)
96   converter = tf.lite.TFLiteConverter(converted_graph, [x], [output_class])
97   converter.post_training_quantize = use_post_training_quantize
98   tflite = converter.convert()
99```
100
101## Why introduce another set of LSTM APIs?
102
103Bridging TensorFlow LSTM and TensorFlow Lite is not easy, and the use of
104`dynamic_rnn` adds additional complexity (as the while loop is introduced).
105With the help of
106[OpHint](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/op_hint.py)
107(also see the next section), we create special wrappers around `rnn_cell` and
108`rnn` to help us identify the inputs and outputs of the LSTM ops, and these
109ops are converted to a single fused LSTM op when converting TensorFlow models
110to TensorFlow Lite format.
111
112### What's OpHint
113
114`OpHint` is essentially `Identity` op that is inserted after input tensors and
115output tensors to "hint" the customized op boundary, see the following figure.
116
117##### Ophinted Customized Graph
118
119Let's say we have a "customized conv" op which is a normal conv2d op with a bias
120add op followed by an activation op (graph on the letf), we use `OpHint` to
121track down all the inputs and output, during the graph transformation phase
122(done by `tf.lite.experimental.convert_op_hints_to_stubs`), the conv2d op, bias
123add op and the activation op will become a "my customized conv" op (see the
124graph on the right), and all the "OpHinted" tensors will become the
125inputs/outputs of the "my customized conv" op.
126
127![Ophinted Customized Graph](./images/op_hint.png)
128
129
130## Simple Tutorial
131
132The following tutorial uses MNIST dataset to build a simple two-layer LSTM model
133and convert to quantized TensorFlow Lite model.
134
135Note since we will be using dynamic_rnn, we need to turn on `control_flow_v2`.
136
137### 0. Turn on `control_flow_v2`.
138
139```python
140# Note this needs to happen before import tensorflow.
141import os
142os.environ['TF_ENABLE_CONTROL_FLOW_V2'] = '1'
143```
144
145### 1. Build the model.
146
147```python
148class MnistLstmModel(object):
149  """Build a simple LSTM based MNIST model.
150
151  Attributes:
152    time_steps: The maximum length of the time_steps, but since we're just using
153      the 'width' dimension as time_steps, it's actually a fixed number.
154    input_size: The LSTM layer input size.
155    num_lstm_layer: Number of LSTM layers for the stacked LSTM cell case.
156    num_lstm_units: Number of units in the LSTM cell.
157    units: The units for the last layer.
158    num_class: Number of classes to predict.
159  """
160
161  def __init__(self, time_steps, input_size, num_lstm_layer, num_lstm_units,
162               units, num_class):
163    self.time_steps = time_steps
164    self.input_size = input_size
165    self.num_lstm_layer = num_lstm_layer
166    self.num_lstm_units = num_lstm_units
167    self.units = units
168    self.num_class = num_class
169
170  def build_model(self):
171    """Build the model using the given configs.
172
173    Returns:
174      x: The input placehoder tensor.
175      logits: The logits of the output.
176      output_class: The prediction.
177    """
178    x = tf.placeholder(
179        'float32', [None, self.time_steps, self.input_size], name='INPUT')
180    lstm_layers = []
181    for _ in range(self.num_lstm_layer):
182      lstm_layers.append(
183          # Important:
184          #
185          # Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell`
186          # (OpHinted LSTMCell).
187          tf.lite.experimental.nn.TFLiteLSTMCell(
188              self.num_lstm_units, forget_bias=0))
189    # Weights and biases for output softmax layer.
190    out_weights = tf.Variable(tf.random_normal([self.units, self.num_class]))
191    out_bias = tf.Variable(tf.zeros([self.num_class]))
192
193    # Transpose input x to make it time major.
194    lstm_inputs = tf.transpose(x, perm=[1, 0, 2])
195    lstm_cells = tf.keras.layers.StackedRNNCells(lstm_layers)
196    # Important:
197    #
198    # Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major`
199    # is set to True.
200    outputs, _ = tf.lite.experimental.nn.dynamic_rnn(
201        lstm_cells, lstm_inputs, dtype='float32', time_major=True)
202
203    # Transpose the outputs back to [batch, time, output]
204    outputs = tf.transpose(outputs, perm=[1, 0, 2])
205    outputs = tf.unstack(outputs, axis=1)
206    logits = tf.matmul(outputs[-1], out_weights) + out_bias
207    output_class = tf.nn.softmax(logits, name='OUTPUT_CLASS')
208
209    return x, logits, output_class
210```
211
212### 2. Let's define the train & eval function.
213
214```python
215def train(model,
216          model_dir,
217          batch_size=20,
218          learning_rate=0.001,
219          train_steps=2000,
220          eval_steps=500,
221          save_every_n_steps=1000):
222  """Train & save the MNIST recognition model."""
223  # Train & test dataset.
224  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
225  train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
226  train_iterator = train_dataset.shuffle(
227      buffer_size=1000).batch(batch_size).repeat().make_one_shot_iterator()
228  x, logits, output_class = model.build_model()
229  test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
230  test_iterator = test_dataset.batch(
231      batch_size).repeat().make_one_shot_iterator()
232  # input label placeholder
233  y = tf.placeholder(tf.int32, [
234      None,
235  ])
236  one_hot_labels = tf.one_hot(y, depth=model.num_class)
237  # Loss function
238  loss = tf.reduce_mean(
239      tf.nn.softmax_cross_entropy_with_logits(
240          logits=logits, labels=one_hot_labels))
241  correct = tf.nn.in_top_k(output_class, y, 1)
242  accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
243  # Optimization
244  opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
245
246  # Initialize variables
247  init = tf.global_variables_initializer()
248  saver = tf.train.Saver()
249  batch_x, batch_y = train_iterator.get_next()
250  batch_test_x, batch_test_y = test_iterator.get_next()
251  with tf.Session() as sess:
252    sess.run([init])
253    for i in range(train_steps):
254      batch_x_value, batch_y_value = sess.run([batch_x, batch_y])
255      _, loss_value = sess.run([opt, loss],
256                               feed_dict={
257                                   x: batch_x_value,
258                                   y: batch_y_value
259                               })
260      if i % 100 == 0:
261        tf.logging.info('Training step %d, loss is %f' % (i, loss_value))
262      if i > 0 and i % save_every_n_steps == 0:
263        accuracy_sum = 0.0
264        for _ in range(eval_steps):
265          test_x_value, test_y_value = sess.run([batch_test_x, batch_test_y])
266          accuracy_value = sess.run(
267              accuracy, feed_dict={
268                  x: test_x_value,
269                  y: test_y_value
270              })
271          accuracy_sum += accuracy_value
272        tf.logging.info('Training step %d, accuracy is %f' %
273                        (i, accuracy_sum / (eval_steps * 1.0)))
274        saver.save(sess, model_dir)
275```
276
277### 3. Let's define the export to TensorFlow Lite model function.
278
279```python
280def export(model, model_dir, tflite_model_file,
281           use_post_training_quantize=True):
282  """Export trained model to tflite model."""
283  tf.reset_default_graph()
284  x, _, output_class = model.build_model()
285  saver = tf.train.Saver()
286  sess = tf.Session()
287  saver.restore(sess, model_dir)
288  # Freeze the graph.
289  frozen_graph = tf.graph_util.convert_variables_to_constants(
290      sess, sess.graph_def, [output_class.op.name])
291
292  # Important:
293  #
294  # Convert ophinted lstm ops to tflite UnidirectionalSequenceLstm ops.
295  converted_graph =
296      tf.lite.experimental.convert_op_hints_to_stubs(graph_def=frozen_graph)
297  converted_graph = tf.graph_util.remove_training_nodes(converted_graph)
298  converter = tf.lite.TFLiteConverter(converted_graph, [x], [output_class])
299  converter.post_training_quantize = use_post_training_quantize
300  tflite = converter.convert()
301  with open(tflite_model_file, 'w') as f:
302    f.write(tflite)
303```
304
305### 4. Hook everything together.
306
307```python
308def train_and_export(parsed_flags):
309  """Train the MNIST LSTM model and export to TfLite."""
310  model = MnistLstmModel(
311      time_steps=28,
312      input_size=28,
313      num_lstm_layer=2,
314      num_lstm_units=64,
315      units=64,
316      num_class=10)
317  tf.logging.info('Starts training...')
318  train(model, parsed_flags.model_dir)
319  tf.logging.info('Finished training, starts exporting to tflite to %s ...' %
320                  parsed_flags.tflite_model_file)
321  export(model, parsed_flags.model_dir, parsed_flags.tflite_model_file,
322         parsed_flags.use_post_training_quantize)
323  tf.logging.info(
324      'Finished exporting, model is %s' % parsed_flags.tflite_model_file)
325
326
327def run_main(_):
328  """Main in the TfLite LSTM tutorial."""
329  parser = argparse.ArgumentParser(
330      description=('Train a MNIST recognition model then export to TfLite.'))
331  parser.add_argument(
332      '--model_dir',
333      type=str,
334      help='Directory where the models will store.',
335      required=True)
336  parser.add_argument(
337      '--tflite_model_file',
338      type=str,
339      help='Full filepath to the exported tflite model file.',
340      required=True)
341  parser.add_argument(
342      '--use_post_training_quantize',
343      action='store_true',
344      default=True,
345      help='Whether or not to use post_training_quatize.')
346  parsed_flags, _ = parser.parse_known_args()
347  train_and_export(parsed_flags)
348
349
350def main():
351  app.run(main=run_main, argv=sys.argv[:1])
352
353
354if __name__ == '__main__':
355  main()
356
357```
358
359### 5. Visualize the exported TensorFlow Lite model.
360
361Let's go to where the TensorFlow Lite model is exported and use
362[Netron](https://github.com/lutzroeder/netron) to visualize the graph.
363
364See below.
365
366##### Exported TensorFlow Lite Model.
367
368![Exported TensorFlow Lite Model](./images/exported_tflite_model.png)
369
370## Caveat
371
372*   Currently, `tf.lite.experimental.nn.dynamic_rnn` &
373    `tf.lite.experimental.nn.bidirectional_dynamic_rnn` only supports
374    `control_flow_v2`, you can this on by setting the environment variable
375    `TF_ENABLE_CONTROL_FLOW_V2=1`, see in the tutorial.
376*   Currently, `sequence_length` is not supported, prefer to set it to None.
377*   `num_unit_shards` & `num_proj_shards` in LSTMCell are not supported as
378    well.
379*   Currently, `tf.lite.experimental.nn.dynamic_rnn` &
380    `tf.lite.experimental.nn.bidirectional_dynamic_rnn` only takes
381    `time_major=True`.
382*   The behavior of `tf.lite.experimental.nn.bidirectional_dynamic_rnn` is a
383    wrapper around `tf.nn.bidirectional_dynamic_rnn`, not
384    `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`.
385*   For bidirectional_rnn cases, make sure you include all the op_hinted nodes
386    before freeze the graph. See below:
387
388```python
389all_output_nodes = [OUTPUT_NODES]
390with tf.Session() as sess
391  all_output_nodes += tf.lite.find_all_hinted_output_nodes(sess)
392  frozen_graph = tf.graph_util.convert_variables_to_constants(
393        sess, sess.graph_def, all_output_nodes)
394```
395