• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lite.py functionality related to select TF op usage."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.lite.python import lite
28from tensorflow.lite.python import test_util as tflite_test_util
29from tensorflow.lite.python.convert import register_custom_opdefs
30from tensorflow.lite.python.interpreter import Interpreter
31from tensorflow.lite.python.testdata import double_op
32from tensorflow.python.client import session
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import test_util
38from tensorflow.python.framework.importer import import_graph_def
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import nn_ops
41from tensorflow.python.ops import variables
42from tensorflow.python.platform import test
43from tensorflow.python.saved_model import saved_model
44from tensorflow.python.training.tracking import tracking
45
46
47class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
48
49  @parameterized.named_parameters(
50      ('EnableMlirConverter', True),  # enable mlir
51      ('DisableMlirConverter', False))  # disable mlir
52  def testFlexMode(self, enable_mlir):
53    with ops.Graph().as_default():
54      in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
55      out_tensor = in_tensor + in_tensor
56      sess = session.Session()
57
58    # Convert model and ensure model is not None.
59    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
60                                                  [out_tensor])
61    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
62    converter.experimental_new_converter = enable_mlir
63    tflite_model = converter.convert()
64    self.assertTrue(tflite_model)
65
66    # Check the model works with TensorFlow ops.
67    interpreter = Interpreter(model_content=tflite_model)
68    interpreter.allocate_tensors()
69    input_details = interpreter.get_input_details()
70    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
71    interpreter.set_tensor(input_details[0]['index'], test_input)
72    interpreter.invoke()
73
74    output_details = interpreter.get_output_details()
75    expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
76    output_data = interpreter.get_tensor(output_details[0]['index'])
77    self.assertTrue((expected_output == output_data).all())
78
79  def testFlexWithAutomaticPassThrough(self):
80    # Create a graph that has one L2Loss op.
81    with ops.Graph().as_default():
82      with session.Session() as sess:
83        in_tensor = array_ops.placeholder(
84            shape=[4], dtype=dtypes.float32, name='input')
85        out_tensor = nn_ops.l2_loss(in_tensor)
86        converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
87                                                      [out_tensor])
88        converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
89        converter._experimental_allow_all_select_tf_ops = True
90        tflite_model = converter.convert()
91    self.assertTrue(tflite_model)
92    self.assertIn('FlexL2Loss', tflite_test_util.get_ops_list(tflite_model))
93
94  def testDeprecatedFlags(self):
95    with ops.Graph().as_default():
96      in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
97      out_tensor = in_tensor + in_tensor
98      sess = session.Session()
99
100    # Convert model and ensure model is not None.
101    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
102                                                  [out_tensor])
103    converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS])
104
105    # Ensure `target_ops` is set to the correct value after flag deprecation.
106    self.assertEqual(converter.target_ops, set([lite.OpsSet.SELECT_TF_OPS]))
107    self.assertEqual(converter.target_spec.supported_ops,
108                     set([lite.OpsSet.SELECT_TF_OPS]))
109
110    tflite_model = converter.convert()
111    self.assertTrue(tflite_model)
112
113    # Check the model works with TensorFlow ops.
114    interpreter = Interpreter(model_content=tflite_model)
115    interpreter.allocate_tensors()
116    input_details = interpreter.get_input_details()
117    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
118    interpreter.set_tensor(input_details[0]['index'], test_input)
119    interpreter.invoke()
120
121    output_details = interpreter.get_output_details()
122    expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
123    output_data = interpreter.get_tensor(output_details[0]['index'])
124    self.assertTrue((expected_output == output_data).all())
125
126
127class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
128                               parameterized.TestCase):
129
130  @parameterized.named_parameters(
131      ('EnableMlirConverter', True),  # enable mlir
132      ('DisableMlirConverter', False))  # disable mlir
133  @test_util.run_v2_only
134  def testFloat(self, enable_mlir):
135    input_data = constant_op.constant(1., shape=[1])
136    root = tracking.AutoTrackable()
137    root.v1 = variables.Variable(3.)
138    root.v2 = variables.Variable(2.)
139    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
140    concrete_func = root.f.get_concrete_function(input_data)
141
142    # Convert model.
143    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func],
144                                                               root)
145    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
146    converter.experimental_new_converter = enable_mlir
147    tflite_model = converter.convert()
148
149    # Check the model works with TensorFlow ops.
150    interpreter = Interpreter(model_content=tflite_model)
151    interpreter.allocate_tensors()
152    input_details = interpreter.get_input_details()
153    test_input = np.array([4.0], dtype=np.float32)
154    interpreter.set_tensor(input_details[0]['index'], test_input)
155    interpreter.invoke()
156
157    output_details = interpreter.get_output_details()
158    expected_output = np.array([24.0], dtype=np.float32)
159    output_data = interpreter.get_tensor(output_details[0]['index'])
160    self.assertTrue((expected_output == output_data).all())
161
162
163class WithCustomOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
164
165  def _createGraphWithCustomOp(self, opname='CustomAdd'):
166    custom_opdefs_str = (
167        'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
168        'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
169        '\'Output\' type: DT_FLOAT}')
170
171    # Create a graph that has one add op.
172    new_graph = graph_pb2.GraphDef()
173    with ops.Graph().as_default():
174      with session.Session() as sess:
175        in_tensor = array_ops.placeholder(
176            shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
177        out_tensor = in_tensor + in_tensor
178        inputs = {'x': in_tensor}
179        outputs = {'z': out_tensor}
180
181        new_graph.CopyFrom(sess.graph_def)
182
183    # Rename Add op name to opname.
184    for node in new_graph.node:
185      if node.op.startswith('Add'):
186        node.op = opname
187        del node.attr['T']
188
189    # Register custom op defs to import modified graph def.
190    register_custom_opdefs([custom_opdefs_str])
191
192    return (new_graph, inputs, outputs)
193
194  def testFlexWithCustomOp(self):
195    new_graph, inputs, outputs = self._createGraphWithCustomOp(
196        opname='CustomAdd4')
197
198    # Import to load the custom opdef.
199    saved_model_dir = os.path.join(self.get_temp_dir(), 'model')
200    with ops.Graph().as_default():
201      with session.Session() as sess:
202        import_graph_def(new_graph, name='')
203        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
204
205    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
206    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
207    converter.target_spec.experimental_select_user_tf_ops = ['CustomAdd4']
208    tflite_model = converter.convert()
209
210    self.assertIn('FlexCustomAdd4', tflite_test_util.get_ops_list(tflite_model))
211
212  def testFlexWithDoubleOp(self):
213    # Create a graph that has one double op.
214    saved_model_dir = os.path.join(self.get_temp_dir(), 'model2')
215    with ops.Graph().as_default():
216      with session.Session() as sess:
217        in_tensor = array_ops.placeholder(
218            shape=[1, 4], dtype=dtypes.int32, name='input')
219        out_tensor = double_op.double(in_tensor)
220        inputs = {'x': in_tensor}
221        outputs = {'z': out_tensor}
222        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
223
224    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
225    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
226    converter.target_spec.experimental_select_user_tf_ops = ['Double']
227    tflite_model = converter.convert()
228    self.assertTrue(tflite_model)
229    self.assertIn('FlexDouble', tflite_test_util.get_ops_list(tflite_model))
230
231    # Check the model works with TensorFlow ops.
232    interpreter = Interpreter(model_content=tflite_model)
233    interpreter.allocate_tensors()
234    input_details = interpreter.get_input_details()
235    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.int32)
236    interpreter.set_tensor(input_details[0]['index'], test_input)
237    interpreter.invoke()
238
239    output_details = interpreter.get_output_details()
240    expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.int32)
241    output_data = interpreter.get_tensor(output_details[0]['index'])
242    self.assertTrue((expected_output == output_data).all())
243
244
245if __name__ == '__main__':
246  test.main()
247