1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lite.py functionality related to select TF op usage.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.lite.python import lite 28from tensorflow.lite.python import test_util as tflite_test_util 29from tensorflow.lite.python.convert import register_custom_opdefs 30from tensorflow.lite.python.interpreter import Interpreter 31from tensorflow.lite.python.testdata import double_op 32from tensorflow.python.client import session 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import test_util 38from tensorflow.python.framework.importer import import_graph_def 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import nn_ops 41from tensorflow.python.ops import variables 42from tensorflow.python.platform import test 43from tensorflow.python.saved_model import saved_model 44from tensorflow.python.training.tracking import tracking 45 46 47class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase): 48 49 @parameterized.named_parameters( 50 ('EnableMlirConverter', True), # enable mlir 51 ('DisableMlirConverter', False)) # disable mlir 52 def testFlexMode(self, enable_mlir): 53 with ops.Graph().as_default(): 54 in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32) 55 out_tensor = in_tensor + in_tensor 56 sess = session.Session() 57 58 # Convert model and ensure model is not None. 59 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 60 [out_tensor]) 61 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 62 converter.experimental_new_converter = enable_mlir 63 tflite_model = converter.convert() 64 self.assertTrue(tflite_model) 65 66 # Check the model works with TensorFlow ops. 67 interpreter = Interpreter(model_content=tflite_model) 68 interpreter.allocate_tensors() 69 input_details = interpreter.get_input_details() 70 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) 71 interpreter.set_tensor(input_details[0]['index'], test_input) 72 interpreter.invoke() 73 74 output_details = interpreter.get_output_details() 75 expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32) 76 output_data = interpreter.get_tensor(output_details[0]['index']) 77 self.assertTrue((expected_output == output_data).all()) 78 79 def testFlexWithAutomaticPassThrough(self): 80 # Create a graph that has one L2Loss op. 81 with ops.Graph().as_default(): 82 with session.Session() as sess: 83 in_tensor = array_ops.placeholder( 84 shape=[4], dtype=dtypes.float32, name='input') 85 out_tensor = nn_ops.l2_loss(in_tensor) 86 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 87 [out_tensor]) 88 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 89 converter._experimental_allow_all_select_tf_ops = True 90 tflite_model = converter.convert() 91 self.assertTrue(tflite_model) 92 self.assertIn('FlexL2Loss', tflite_test_util.get_ops_list(tflite_model)) 93 94 def testDeprecatedFlags(self): 95 with ops.Graph().as_default(): 96 in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32) 97 out_tensor = in_tensor + in_tensor 98 sess = session.Session() 99 100 # Convert model and ensure model is not None. 101 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 102 [out_tensor]) 103 converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS]) 104 105 # Ensure `target_ops` is set to the correct value after flag deprecation. 106 self.assertEqual(converter.target_ops, set([lite.OpsSet.SELECT_TF_OPS])) 107 self.assertEqual(converter.target_spec.supported_ops, 108 set([lite.OpsSet.SELECT_TF_OPS])) 109 110 tflite_model = converter.convert() 111 self.assertTrue(tflite_model) 112 113 # Check the model works with TensorFlow ops. 114 interpreter = Interpreter(model_content=tflite_model) 115 interpreter.allocate_tensors() 116 input_details = interpreter.get_input_details() 117 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) 118 interpreter.set_tensor(input_details[0]['index'], test_input) 119 interpreter.invoke() 120 121 output_details = interpreter.get_output_details() 122 expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32) 123 output_data = interpreter.get_tensor(output_details[0]['index']) 124 self.assertTrue((expected_output == output_data).all()) 125 126 127class FromConcreteFunctionTest(test_util.TensorFlowTestCase, 128 parameterized.TestCase): 129 130 @parameterized.named_parameters( 131 ('EnableMlirConverter', True), # enable mlir 132 ('DisableMlirConverter', False)) # disable mlir 133 @test_util.run_v2_only 134 def testFloat(self, enable_mlir): 135 input_data = constant_op.constant(1., shape=[1]) 136 root = tracking.AutoTrackable() 137 root.v1 = variables.Variable(3.) 138 root.v2 = variables.Variable(2.) 139 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 140 concrete_func = root.f.get_concrete_function(input_data) 141 142 # Convert model. 143 converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func], 144 root) 145 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 146 converter.experimental_new_converter = enable_mlir 147 tflite_model = converter.convert() 148 149 # Check the model works with TensorFlow ops. 150 interpreter = Interpreter(model_content=tflite_model) 151 interpreter.allocate_tensors() 152 input_details = interpreter.get_input_details() 153 test_input = np.array([4.0], dtype=np.float32) 154 interpreter.set_tensor(input_details[0]['index'], test_input) 155 interpreter.invoke() 156 157 output_details = interpreter.get_output_details() 158 expected_output = np.array([24.0], dtype=np.float32) 159 output_data = interpreter.get_tensor(output_details[0]['index']) 160 self.assertTrue((expected_output == output_data).all()) 161 162 163class WithCustomOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): 164 165 def _createGraphWithCustomOp(self, opname='CustomAdd'): 166 custom_opdefs_str = ( 167 'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} ' 168 'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: ' 169 '\'Output\' type: DT_FLOAT}') 170 171 # Create a graph that has one add op. 172 new_graph = graph_pb2.GraphDef() 173 with ops.Graph().as_default(): 174 with session.Session() as sess: 175 in_tensor = array_ops.placeholder( 176 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input') 177 out_tensor = in_tensor + in_tensor 178 inputs = {'x': in_tensor} 179 outputs = {'z': out_tensor} 180 181 new_graph.CopyFrom(sess.graph_def) 182 183 # Rename Add op name to opname. 184 for node in new_graph.node: 185 if node.op.startswith('Add'): 186 node.op = opname 187 del node.attr['T'] 188 189 # Register custom op defs to import modified graph def. 190 register_custom_opdefs([custom_opdefs_str]) 191 192 return (new_graph, inputs, outputs) 193 194 def testFlexWithCustomOp(self): 195 new_graph, inputs, outputs = self._createGraphWithCustomOp( 196 opname='CustomAdd4') 197 198 # Import to load the custom opdef. 199 saved_model_dir = os.path.join(self.get_temp_dir(), 'model') 200 with ops.Graph().as_default(): 201 with session.Session() as sess: 202 import_graph_def(new_graph, name='') 203 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 204 205 converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir) 206 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 207 converter.target_spec.experimental_select_user_tf_ops = ['CustomAdd4'] 208 tflite_model = converter.convert() 209 210 self.assertIn('FlexCustomAdd4', tflite_test_util.get_ops_list(tflite_model)) 211 212 def testFlexWithDoubleOp(self): 213 # Create a graph that has one double op. 214 saved_model_dir = os.path.join(self.get_temp_dir(), 'model2') 215 with ops.Graph().as_default(): 216 with session.Session() as sess: 217 in_tensor = array_ops.placeholder( 218 shape=[1, 4], dtype=dtypes.int32, name='input') 219 out_tensor = double_op.double(in_tensor) 220 inputs = {'x': in_tensor} 221 outputs = {'z': out_tensor} 222 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 223 224 converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir) 225 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 226 converter.target_spec.experimental_select_user_tf_ops = ['Double'] 227 tflite_model = converter.convert() 228 self.assertTrue(tflite_model) 229 self.assertIn('FlexDouble', tflite_test_util.get_ops_list(tflite_model)) 230 231 # Check the model works with TensorFlow ops. 232 interpreter = Interpreter(model_content=tflite_model) 233 interpreter.allocate_tensors() 234 input_details = interpreter.get_input_details() 235 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.int32) 236 interpreter.set_tensor(input_details[0]['index'], test_input) 237 interpreter.invoke() 238 239 output_details = interpreter.get_output_details() 240 expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.int32) 241 output_data = interpreter.get_tensor(output_details[0]['index']) 242 self.assertTrue((expected_output == output_data).all()) 243 244 245if __name__ == '__main__': 246 test.main() 247