• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19import tempfile
20
21import numpy as np
22from six.moves import range
23import tensorflow.compat.v1 as tf
24
25from tensorflow.lite.experimental.examples.lstm import input_data
26from tensorflow.python.framework import test_util
27from tensorflow.python.platform import test
28
29
30# Number of steps to train model.
31# Dial to 0 means no training at all, all the weights will be just using their
32# initial values. This can help make the test smaller.
33TRAIN_STEPS = 0
34
35
36class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
37
38  def setUp(self):
39    tf.compat.v1.reset_default_graph()
40    # Import MNIST dataset
41    self.mnist = input_data.read_data_sets(
42        "/tmp/data/", fake_data=True, one_hot=True)
43
44    # Define constants
45    # Unrolled through 28 time steps
46    self.time_steps = 28
47    # Rows of 28 pixels
48    self.n_input = 28
49    # Learning rate for Adam optimizer
50    self.learning_rate = 0.001
51    # MNIST is meant to be classified in 10 classes(0-9).
52    self.n_classes = 10
53    # Batch size
54    self.batch_size = 16
55    # Lstm Units.
56    self.num_units = 16
57
58  def buildLstmLayer(self):
59    return tf.keras.layers.StackedRNNCells([
60        tf.compat.v1.lite.experimental.nn.TFLiteLSTMCell(
61            self.num_units, use_peepholes=True, forget_bias=1.0, name="rnn1"),
62        tf.compat.v1.lite.experimental.nn.TFLiteLSTMCell(
63            self.num_units, num_proj=8, forget_bias=1.0, name="rnn2"),
64        tf.compat.v1.lite.experimental.nn.TFLiteLSTMCell(
65            self.num_units // 2,
66            use_peepholes=True,
67            num_proj=8,
68            forget_bias=0,
69            name="rnn3"),
70        tf.compat.v1.lite.experimental.nn.TFLiteLSTMCell(
71            self.num_units, forget_bias=1.0, name="rnn4")
72    ])
73
74  def buildModel(self, lstm_layer, is_dynamic_rnn):
75    """Build Mnist recognition model.
76
77    Args:
78      lstm_layer: The lstm layer either a single lstm cell or a multi lstm cell.
79      is_dynamic_rnn: Use dynamic_rnn or not.
80
81    Returns:
82     A tuple containing:
83
84     - Input tensor of the model.
85     - Prediction tensor of the model.
86     - Output class tensor of the model.
87    """
88    # Weights and biases for output softmax layer.
89    out_weights = tf.Variable(
90        tf.random.normal([self.num_units, self.n_classes]))
91    out_bias = tf.Variable(tf.random.normal([self.n_classes]))
92
93    # input image placeholder
94    x = tf.compat.v1.placeholder(
95        "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE")
96
97    # x is shaped [batch_size,time_steps,num_inputs]
98    if is_dynamic_rnn:
99      lstm_input = tf.transpose(x, perm=[1, 0, 2])
100      outputs, _ = tf.compat.v1.lite.experimental.nn.dynamic_rnn(
101          lstm_layer, lstm_input, dtype="float32")
102      outputs = tf.unstack(outputs, axis=0)
103    else:
104      lstm_input = tf.unstack(x, self.time_steps, 1)
105      outputs, _ = tf.compat.v1.nn.static_rnn(
106          lstm_layer, lstm_input, dtype="float32")
107
108    # Compute logits by multiplying outputs[-1] of shape [batch_size,num_units]
109    # by the softmax layer's out_weight of shape [num_units,n_classes]
110    # plus out_bias
111    prediction = tf.matmul(outputs[-1], out_weights) + out_bias
112    output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS")
113
114    return x, prediction, output_class
115
116  def trainModel(self, x, prediction, output_class, sess):
117    """Train the model.
118
119    Args:
120      x: The input tensor.
121      prediction: The prediction class tensor.
122      output_class: The output tensor.
123      sess: The graph session.
124    """
125    # input label placeholder
126    y = tf.compat.v1.placeholder("float", [None, self.n_classes])
127    # Loss function
128    loss = tf.reduce_mean(
129        tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
130    # Optimization
131    opt = tf.compat.v1.train.AdamOptimizer(
132        learning_rate=self.learning_rate).minimize(loss)
133
134    # Initialize variables
135    init = tf.compat.v1.global_variables_initializer()
136    sess.run(init)
137    for _ in range(TRAIN_STEPS):
138      batch_x, batch_y = self.mnist.train.next_batch(
139          batch_size=self.batch_size, fake_data=True)
140
141      batch_x = np.array(batch_x)
142      batch_y = np.array(batch_y)
143      batch_x = batch_x.reshape((self.batch_size, self.time_steps,
144                                 self.n_input))
145      sess.run(opt, feed_dict={x: batch_x, y: batch_y})
146
147  def saveAndRestoreModel(self, lstm_layer, sess, saver, is_dynamic_rnn):
148    """Saves and restores the model to mimic the most common use case.
149
150    Args:
151      lstm_layer: The lstm layer either a single lstm cell or a multi lstm cell.
152      sess: Old session.
153      saver: Saver created by tf.compat.v1.train.Saver()
154      is_dynamic_rnn: Use dynamic_rnn or not.
155
156    Returns:
157      A tuple containing:
158
159      - Input tensor of the restored model.
160      - Prediction tensor of the restored model.
161      - Output tensor, which is the softwmax result of the prediction tensor.
162      - new session of the restored model.
163
164    """
165    model_dir = tempfile.mkdtemp()
166    saver.save(sess, model_dir)
167
168    # Reset the graph.
169    tf.compat.v1.reset_default_graph()
170    x, prediction, output_class = self.buildModel(lstm_layer, is_dynamic_rnn)
171
172    new_sess = tf.compat.v1.Session()
173    saver = tf.compat.v1.train.Saver()
174    saver.restore(new_sess, model_dir)
175    return x, prediction, output_class, new_sess
176
177  def getInferenceResult(self, x, output_class, sess):
178    """Get inference result given input tensor and output tensor.
179
180    Args:
181      x: The input tensor.
182      output_class: The output tensor.
183      sess: Current session.
184
185    Returns:
186     A tuple containing:
187
188      - Input of the next batch, batch size is 1.
189      - Expected output.
190
191    """
192    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
193    b1 = np.array(b1, dtype=np.dtype("float32"))
194    sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
195
196    expected_output = sess.run(output_class, feed_dict={x: sample_input})
197    return sample_input, expected_output
198
199  def tfliteInvoke(self,
200                   sess,
201                   test_inputs,
202                   input_tensor,
203                   output_tensor,
204                   use_mlir_converter=False):
205    """Get tflite inference result.
206
207    This method will convert tensorflow from session to tflite model then based
208    on the inputs, run tflite inference and return the results.
209
210    Args:
211      sess: Current tensorflow session.
212      test_inputs: The test inputs for tflite.
213      input_tensor: The input tensor of tensorflow graph.
214      output_tensor: The output tensor of tensorflow graph.
215      use_mlir_converter: Whether or not to use MLIRConverter to convert the
216        model.
217
218    Returns:
219      The tflite inference result.
220    """
221    converter = tf.compat.v1.lite.TFLiteConverter.from_session(
222        sess, [input_tensor], [output_tensor])
223    converter.experimental_new_converter = use_mlir_converter
224    tflite = converter.convert()
225
226    interpreter = tf.lite.Interpreter(model_content=tflite)
227
228    try:
229      interpreter.allocate_tensors()
230    except ValueError:
231      assert False
232
233    input_index = (interpreter.get_input_details()[0]["index"])
234    interpreter.set_tensor(input_index, test_inputs)
235    interpreter.invoke()
236    output_index = (interpreter.get_output_details()[0]["index"])
237    result = interpreter.get_tensor(output_index)
238    # Reset all variables so it will not pollute other inferences.
239    interpreter.reset_all_variables()
240    return result
241
242  def testStaticRnnMultiRnnCell(self):
243    sess = tf.compat.v1.Session()
244
245    x, prediction, output_class = self.buildModel(
246        self.buildLstmLayer(), is_dynamic_rnn=False)
247    self.trainModel(x, prediction, output_class, sess)
248
249    saver = tf.compat.v1.train.Saver()
250    x, prediction, output_class, new_sess = self.saveAndRestoreModel(
251        self.buildLstmLayer(), sess, saver, is_dynamic_rnn=False)
252
253    test_inputs, expected_output = self.getInferenceResult(
254        x, output_class, new_sess)
255
256    # Test Toco-converted model.
257    result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
258    self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
259
260    # Test MLIR-Converted model.
261    result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
262    self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
263
264  @test_util.enable_control_flow_v2
265  def testDynamicRnnMultiRnnCell(self):
266    sess = tf.compat.v1.Session()
267
268    x, prediction, output_class = self.buildModel(
269        self.buildLstmLayer(), is_dynamic_rnn=True)
270    self.trainModel(x, prediction, output_class, sess)
271
272    saver = tf.compat.v1.train.Saver()
273
274    x, prediction, output_class, new_sess = self.saveAndRestoreModel(
275        self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True)
276
277    test_inputs, expected_output = self.getInferenceResult(
278        x, output_class, new_sess)
279
280    # Test Toco-converted model.
281    result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
282    self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
283
284    # Test MLIR-converted model.
285    result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
286    self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
287
288
289if __name__ == "__main__":
290  tf.disable_v2_behavior()
291  test.main()
292