1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for lite.py.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import io 23import logging 24import os 25import tempfile 26 27from absl.testing import parameterized 28import numpy as np 29import six 30from six.moves import range 31from tensorflow import keras 32 33from tensorflow.lite.python import lite 34from tensorflow.lite.python import lite_constants 35from tensorflow.lite.python import schema_py_generated as schema_fb 36from tensorflow.lite.python import util 37from tensorflow.lite.python.convert import ConverterError 38from tensorflow.lite.python.convert import mlir_quantize 39from tensorflow.lite.python.interpreter import Interpreter 40from tensorflow.python.client import session 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import convert_to_constants 45from tensorflow.python.framework import dtypes 46from tensorflow.python.framework import ops 47from tensorflow.python.framework import test_util 48from tensorflow.python.ops import array_ops 49from tensorflow.python.ops import gen_array_ops 50from tensorflow.python.ops import logging_ops 51from tensorflow.python.ops import math_ops 52from tensorflow.python.ops import nn_ops 53from tensorflow.python.ops import random_ops 54from tensorflow.python.ops import variable_scope 55from tensorflow.python.ops import variables 56from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer 57from tensorflow.python.platform import gfile 58from tensorflow.python.platform import resource_loader 59from tensorflow.python.platform import test 60from tensorflow.python.saved_model import saved_model 61from tensorflow.python.training.training_util import write_graph 62 63 64class LiteTest(test_util.TensorFlowTestCase): 65 """Base class of all the tests in this module.""" 66 67 68class TestModels(LiteTest): 69 70 def assertValidDebugInfo(self, debug_info): 71 """Verify the DebugInfo is valid.""" 72 file_names = set() 73 for file_path in debug_info.files: 74 file_names.add(os.path.basename(file_path)) 75 # To make the test independent on how the nodes are created, we only assert 76 # the name of this test file. 77 self.assertIn('lite_test.py', file_names) 78 self.assertNotIn('lite_v2_test.py', file_names) 79 80 81class FromConstructor(TestModels): 82 83 # Tests invalid constructors using a dummy value for the GraphDef. 84 def testInvalidConstructor(self): 85 message = ( 86 'If input_tensors and output_tensors are None, both ' 87 'input_arrays_with_shape and output_arrays|control_output_arrays must ' 88 'be defined.') 89 90 # `output_arrays` is not defined. 91 with self.assertRaises(ValueError) as error: 92 lite.TFLiteConverter( 93 None, None, [], input_arrays_with_shape=[('input', [3, 94 9])]).convert() 95 self.assertEqual(message, str(error.exception)) 96 97 # `input_arrays_with_shape` is not defined. 98 with self.assertRaises(ValueError) as error: 99 lite.TFLiteConverter(None, [], None, output_arrays=['output']).convert() 100 self.assertEqual(message, str(error.exception)) 101 102 # Tests valid constructors using a dummy value for the GraphDef. 103 def testValidConstructor(self): 104 converter = lite.TFLiteConverter( 105 None, 106 None, 107 None, 108 input_arrays_with_shape=[('input', [3, 9])], 109 output_arrays=['output']) 110 self.assertFalse(converter._has_valid_tensors()) 111 self.assertEqual(converter.get_input_arrays(), ['input']) 112 113 with self.assertRaises(ValueError) as error: 114 converter._set_batch_size(1) 115 self.assertEqual( 116 'The batch size cannot be set for this model. Please use ' 117 'input_shapes parameter.', str(error.exception)) 118 119 converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor']) 120 self.assertTrue(converter._has_valid_tensors()) 121 122 def testRedundantArgumentsWarning(self): 123 """Test if the warning message when there are redundant arguments.""" 124 with ops.Graph().as_default(): 125 in_tensor = array_ops.placeholder( 126 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 127 out_tensor = math_ops.add(in_tensor, in_tensor, name='add') 128 sess = session.Session() 129 130 frozen_graph_def = ( 131 convert_to_constants.convert_variables_to_constants_from_session_graph( 132 sess, sess.graph_def, ['add'])) 133 134 # Convert model and ensure model is not None. 135 log = io.BytesIO() if six.PY2 else io.StringIO() 136 handler = logging.StreamHandler(log) 137 logging.root.addHandler(handler) 138 converter = lite.TFLiteConverter(frozen_graph_def, [in_tensor], 139 [out_tensor], 140 [('in_tensor', [2, 16, 16, 3])], ['add']) 141 142 input_warning_message = 'input_arrays_with_shape will be ignored' 143 output_warning_message = 'output_arrays will be ignored' 144 145 # Convert model and ensure model is not None. 146 tflite_model = converter.convert() 147 self.assertIsNotNone(tflite_model) 148 self.assertIn(input_warning_message, log.getvalue()) 149 self.assertIn(output_warning_message, log.getvalue()) 150 logging.root.removeHandler(handler) 151 152 def testShapeOverriding(self): 153 """Test a shape overriding case via the constructor.""" 154 with ops.Graph().as_default(): 155 in_tensor = array_ops.placeholder( 156 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 157 math_ops.add(in_tensor, in_tensor, name='add') 158 sess = session.Session() 159 160 frozen_graph_def = ( 161 convert_to_constants.convert_variables_to_constants_from_session_graph( 162 sess, sess.graph_def, ['add'])) 163 164 # Convert model and ensure model is not None. 165 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 166 [('in_tensor', [2, 16, 16, 3])], ['add']) 167 tflite_model = converter.convert() 168 self.assertIsNotNone(tflite_model) 169 170 # Check values from converted model. 171 interpreter = Interpreter(model_content=tflite_model) 172 interpreter.allocate_tensors() 173 174 input_details = interpreter.get_input_details() 175 self.assertLen(input_details, 1) 176 self.assertEqual('in_tensor', input_details[0]['name']) 177 self.assertEqual(np.float32, input_details[0]['dtype']) 178 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 179 self.assertEqual((0., 0.), input_details[0]['quantization']) 180 181 output_details = interpreter.get_output_details() 182 self.assertLen(output_details, 1) 183 self.assertEqual('add', output_details[0]['name']) 184 self.assertEqual(np.float32, output_details[0]['dtype']) 185 self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape']) 186 self.assertEqual((0., 0.), output_details[0]['quantization']) 187 188 def testPartialShapeOverriding(self): 189 """Test a partial shape overriding case via the constructor.""" 190 with ops.Graph().as_default(): 191 in_tensor_a = array_ops.placeholder( 192 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_a') 193 in_tensor_b = array_ops.placeholder( 194 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_b') 195 math_ops.add(in_tensor_a, in_tensor_b, name='add') 196 sess = session.Session() 197 198 frozen_graph_def = ( 199 convert_to_constants.convert_variables_to_constants_from_session_graph( 200 sess, sess.graph_def, ['add'])) 201 202 # Convert model and ensure model is not None. 203 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 204 [('in_tensor_a', [2, 16, 16, 3])], ['add']) 205 # There is an unhandled Placeholder op. 206 with self.assertRaises(ConverterError): 207 converter.convert() 208 209 def testInvalidShapeOverriding(self): 210 """Test an invalid shape overriding case via the constructor.""" 211 with ops.Graph().as_default(): 212 in_tensor = array_ops.placeholder( 213 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 214 math_ops.add(in_tensor, in_tensor, name='add') 215 sess = session.Session() 216 217 frozen_graph_def = ( 218 convert_to_constants.convert_variables_to_constants_from_session_graph( 219 sess, sess.graph_def, ['add'])) 220 221 # Convert model and ensure model is not None. 222 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 223 [('wrong_tensor', [2, 16, 16, 3])], 224 ['add']) 225 with self.assertRaises(ConverterError): 226 converter.convert() 227 228 229class FromSessionTest(TestModels, parameterized.TestCase): 230 231 def testFloatModel(self): 232 with ops.Graph().as_default(): 233 in_tensor = array_ops.placeholder( 234 shape=[1, 16, 16, 3], dtype=dtypes.float32) 235 out_tensor = in_tensor + in_tensor 236 sess = session.Session() 237 238 # Convert model and ensure model is not None. 239 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 240 [out_tensor]) 241 tflite_model = converter.convert() 242 self.assertIsNotNone(tflite_model) 243 244 # Check values from converted model. 245 interpreter = Interpreter(model_content=tflite_model) 246 interpreter.allocate_tensors() 247 248 input_details = interpreter.get_input_details() 249 self.assertLen(input_details, 1) 250 self.assertEqual('Placeholder', input_details[0]['name']) 251 self.assertEqual(np.float32, input_details[0]['dtype']) 252 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 253 self.assertEqual((0., 0.), input_details[0]['quantization']) 254 255 output_details = interpreter.get_output_details() 256 self.assertLen(output_details, 1) 257 self.assertEqual('add', output_details[0]['name']) 258 self.assertEqual(np.float32, output_details[0]['dtype']) 259 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 260 self.assertEqual((0., 0.), output_details[0]['quantization']) 261 262 def testFloatModelQuantizedInput(self): 263 with ops.Graph().as_default(): 264 in_tensor = array_ops.placeholder( 265 shape=[1, 16, 16, 3], dtype=dtypes.float32) 266 out_tensor = in_tensor + in_tensor 267 sess = session.Session() 268 269 # Convert model and ensure model is not None. 270 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 271 [out_tensor]) 272 converter.inference_input_type = dtypes.uint8 273 converter.inference_type = dtypes.float32 274 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 275 tflite_model = converter.convert() 276 self.assertIsNotNone(tflite_model) 277 278 # Check values from converted model. 279 interpreter = Interpreter(model_content=tflite_model) 280 interpreter.allocate_tensors() 281 282 input_details = interpreter.get_input_details() 283 self.assertLen(input_details, 1) 284 self.assertEqual('Placeholder', input_details[0]['name']) 285 self.assertEqual(np.uint8, input_details[0]['dtype']) 286 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 287 self.assertEqual((1., 0.), input_details[0]['quantization']) 288 289 output_details = interpreter.get_output_details() 290 self.assertLen(output_details, 1) 291 self.assertEqual('add', output_details[0]['name']) 292 self.assertEqual(np.float32, output_details[0]['dtype']) 293 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 294 self.assertEqual((0., 0.), output_details[0]['quantization']) # float 295 296 def testForgottenCallToAllocateTensors(self): 297 with ops.Graph().as_default(): 298 in_tensor = array_ops.placeholder( 299 shape=[1, 16, 16, 3], dtype=dtypes.float32) 300 out_tensor = in_tensor + in_tensor 301 sess = session.Session() 302 # Convert model and ensure model is not None. 303 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 304 [out_tensor]) 305 tflite_model = converter.convert() 306 self.assertIsNotNone(tflite_model) 307 308 # Check values from converted model. 309 interpreter = Interpreter(model_content=tflite_model) 310 input_index = interpreter.get_input_details()[0]['index'] 311 dummy_tensor = np.ones(shape=[1, 16, 16, 3], dtype=np.float32) 312 with self.assertRaises(ValueError): 313 interpreter.set_tensor(input_index, dummy_tensor) 314 315 @parameterized.named_parameters( 316 ('_INT8InputOutput', False, False, dtypes.int8), 317 ('_UINT8InputOutput', False, False, dtypes.uint8), 318 ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16), 319 ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), 320 ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8), 321 ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16), 322 ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True), 323 ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True)) 324 def testIntegerQuantizationWithUnsupportedOps(self, 325 is_int_only, 326 is_int16_quantize, 327 inference_input_output_type, 328 enable_mlir_quantizer=False): 329 with ops.Graph().as_default(): 330 in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 331 in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 332 # ceil kernel does not support int8 nor int16 types neither. 333 left = math_ops.ceil(in_tensor_a) 334 out_tensor_b = math_ops.tanh(in_tensor_b) 335 add = math_ops.add(left, out_tensor_b) 336 # ceil kernel does not support int8 nor int16 types neither. 337 out_tensor_a = math_ops.ceil(add) 338 sess = session.Session() 339 340 def calibration_gen(): 341 for _ in range(5): 342 yield [ 343 np.random.uniform(-1, 1, size=(3)).astype(np.float32), 344 np.random.uniform(-1, 1, size=(3)).astype(np.float32) 345 ] 346 347 quantized_converter = lite.TFLiteConverter.from_session( 348 sess, [in_tensor_a, in_tensor_b], [out_tensor_a, out_tensor_b]) 349 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 350 quantized_converter.representative_dataset = calibration_gen 351 if is_int_only: 352 if is_int16_quantize: 353 quantized_converter.target_spec.supported_ops = [ 354 lite.OpsSet 355 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, 356 lite.OpsSet.TFLITE_BUILTINS 357 ] 358 else: 359 quantized_converter.target_spec.supported_ops = [ 360 lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS 361 ] 362 else: 363 if is_int16_quantize: 364 quantized_converter.target_spec.supported_ops = [ 365 lite.OpsSet 366 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, 367 lite.OpsSet.TFLITE_BUILTINS 368 ] 369 else: 370 quantized_converter.target_spec.supported_ops = [ 371 lite.OpsSet.TFLITE_BUILTINS 372 ] 373 374 quantized_converter.inference_input_type = inference_input_output_type 375 quantized_converter.inference_output_type = inference_input_output_type 376 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 377 quantized_tflite_model = quantized_converter.convert() 378 self.assertIsNotNone(quantized_tflite_model) 379 380 expected_dtype = inference_input_output_type.as_numpy_dtype 381 # Allow float32 for fallback on non-quantizable op. 382 expected_ceil_dtype = ( 383 expected_dtype if enable_mlir_quantizer else dtypes.float32) 384 385 interpreter = Interpreter(model_content=quantized_tflite_model) 386 interpreter.allocate_tensors() 387 input_details = interpreter.get_input_details() 388 self.assertLen(input_details, 2) 389 self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype) 390 self.assertEqual(input_details[1]['dtype'], expected_dtype) 391 output_details = interpreter.get_output_details() 392 self.assertLen(output_details, 2) 393 self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype) 394 self.assertEqual(output_details[1]['dtype'], expected_dtype) 395 396 @parameterized.named_parameters( 397 ('_PerChannelQuant', False, False), ('_PerChannelMlirQuant', False, True), 398 ('_PerTensorQuant', True, False), ('_PerTensorMlirQuant', True, True)) 399 def testDisablePerChannelQuantization(self, 400 disable_per_channel=False, 401 enable_mlir_quantizer=False): 402 k_conv_name = 'Conv2D1' 403 k_num_filters = 16 404 with ops.Graph().as_default(): 405 inp, output, calibration_gen = self._getIntegerQuantizeModel() 406 sess = session.Session() 407 408 quantized_converter = lite.TFLiteConverter.from_session( 409 sess, [inp], [output]) 410 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 411 quantized_converter.representative_dataset = calibration_gen 412 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 413 if disable_per_channel: 414 quantized_converter._experimental_disable_per_channel = ( 415 disable_per_channel) 416 quantized_tflite_model = quantized_converter.convert() 417 self.assertIsNotNone(quantized_tflite_model) 418 419 interpreter = Interpreter(model_content=quantized_tflite_model) 420 interpreter.allocate_tensors() 421 detail = next((d for d in interpreter.get_tensor_details() 422 if d['name'] == k_conv_name)) 423 quant_params = detail['quantization_parameters'] 424 expected_num_params = 1 if disable_per_channel else k_num_filters 425 self.assertLen(quant_params['scales'], expected_num_params) 426 self.assertLen(quant_params['zero_points'], expected_num_params) 427 428 @parameterized.named_parameters( 429 ('EnableMlirConverter', True), # enable mlir 430 ('DisableMlirConverter', False)) # disable mlir 431 def testString(self, enable_mlir_converter): 432 with ops.Graph().as_default(): 433 in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string) 434 out_tensor = array_ops.reshape(in_tensor, shape=[2, 2]) 435 sess = session.Session() 436 437 # Convert model and ensure model is not None. 438 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 439 [out_tensor]) 440 converter.experimental_new_converter = enable_mlir_converter 441 tflite_model = converter.convert() 442 self.assertIsNotNone(tflite_model) 443 444 # Check values from converted model. 445 interpreter = Interpreter(model_content=tflite_model) 446 interpreter.allocate_tensors() 447 448 input_details = interpreter.get_input_details() 449 self.assertLen(input_details, 1) 450 self.assertEqual('Placeholder', input_details[0]['name']) 451 self.assertEqual(np.string_, input_details[0]['dtype']) 452 self.assertAllEqual([4], input_details[0]['shape']) 453 454 output_details = interpreter.get_output_details() 455 self.assertLen(output_details, 1) 456 self.assertEqual('Reshape', output_details[0]['name']) 457 self.assertEqual(np.string_, output_details[0]['dtype']) 458 self.assertAllEqual([2, 2], output_details[0]['shape']) 459 # TODO(b/122659643): Test setting/getting string data via the python 460 # interpreter API after support has been added. 461 462 def testIntermediateInputArray(self): 463 """Convert a model from an intermediate input array.""" 464 with ops.Graph().as_default(): 465 in_tensor_init = array_ops.placeholder( 466 shape=[1, 16, 16, 3], dtype=dtypes.float32) 467 in_tensor_final = in_tensor_init + in_tensor_init 468 out_tensor = in_tensor_final + in_tensor_final 469 sess = session.Session() 470 471 # Convert model and ensure model is not None. 472 converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final], 473 [out_tensor]) 474 tflite_model = converter.convert() 475 self.assertIsNotNone(tflite_model) 476 477 # Check values from converted model. 478 interpreter = Interpreter(model_content=tflite_model) 479 interpreter.allocate_tensors() 480 481 input_details = interpreter.get_input_details() 482 self.assertLen(input_details, 1) 483 self.assertEqual('add', input_details[0]['name']) 484 self.assertEqual(np.float32, input_details[0]['dtype']) 485 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 486 self.assertEqual((0., 0.), input_details[0]['quantization']) 487 488 output_details = interpreter.get_output_details() 489 self.assertLen(output_details, 1) 490 self.assertEqual('add_1', output_details[0]['name']) 491 self.assertEqual(np.float32, output_details[0]['dtype']) 492 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 493 self.assertEqual((0., 0.), output_details[0]['quantization']) 494 495 def testSizeNoneInvalid(self): 496 with ops.Graph().as_default(): 497 in_tensor = array_ops.placeholder(dtype=dtypes.float32) 498 out_tensor = in_tensor + in_tensor 499 sess = session.Session() 500 501 # Test None as shape when dynamic shapes are disabled. Run with TOCO in 502 # order to invoke shape checking code. 503 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 504 [out_tensor]) 505 converter.experimental_new_converter = False 506 with self.assertRaises(ValueError) as error: 507 converter.convert() 508 self.assertEqual('Provide an input shape for input array \'Placeholder\'.', 509 str(error.exception)) 510 511 @parameterized.named_parameters( 512 ('EnableMlirConverter', True), # enable mlir 513 ('DisableMlirConverter', False)) # disable mlir 514 def testScalarValid(self, enable_mlir_converter): 515 # Construct a graph using a scalar (empty shape) input. 516 with ops.Graph().as_default(): 517 in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 518 out_tensor = in_tensor + in_tensor 519 sess = session.Session() 520 521 # Test conversion with the scalar input shape. 522 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 523 [out_tensor]) 524 converter.experimental_new_converter = enable_mlir_converter 525 tflite_model = converter.convert() 526 self.assertIsNotNone(tflite_model) 527 528 # Check values from converted model. 529 interpreter = Interpreter(model_content=tflite_model) 530 interpreter.allocate_tensors() 531 532 input_details = interpreter.get_input_details() 533 self.assertLen(input_details, 1) 534 self.assertEqual('Placeholder', input_details[0]['name']) 535 self.assertEqual(np.float32, input_details[0]['dtype']) 536 self.assertEmpty(input_details[0]['shape']) 537 538 output_details = interpreter.get_output_details() 539 self.assertLen(output_details, 1) 540 self.assertEqual('add', output_details[0]['name']) 541 self.assertEqual(np.float32, output_details[0]['dtype']) 542 self.assertEmpty(input_details[0]['shape']) 543 544 # Validate inference using the scalar inputs/outputs. 545 test_input = np.array(4.0, dtype=np.float32) 546 expected_output = np.array(8.0, dtype=np.float32) 547 interpreter.set_tensor(input_details[0]['index'], test_input) 548 interpreter.invoke() 549 550 output_data = interpreter.get_tensor(output_details[0]['index']) 551 self.assertEqual(expected_output, output_data) 552 553 def testSizeInvalid(self): 554 with ops.Graph().as_default(): 555 in_tensor = array_ops.placeholder( 556 shape=[1, None, 16, 3], dtype=dtypes.float32) 557 out_tensor = in_tensor + in_tensor 558 sess = session.Session() 559 560 # Test invalid shape. None after 1st dimension. Run with TOCO in order to 561 # invoke shape checking code. 562 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 563 [out_tensor]) 564 converter.experimental_new_converter = False 565 with self.assertRaises(ValueError) as error: 566 converter.convert() 567 self.assertEqual( 568 'None is only supported in the 1st dimension. Tensor ' 569 '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', 570 str(error.exception)) 571 572 def testSizeNone(self): 573 with ops.Graph().as_default(): 574 in_tensor = array_ops.placeholder( 575 shape=[1, None, 16, 3], dtype=dtypes.float32) 576 out_tensor = in_tensor + in_tensor 577 sess = session.Session() 578 579 # Test None after 1st dimension. 580 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 581 [out_tensor]) 582 tflite_model = converter.convert() 583 584 # Check values from converted model. 585 interpreter = Interpreter(model_content=tflite_model) 586 input_details = interpreter.get_input_details() 587 self.assertLen(input_details, 1) 588 self.assertEqual('Placeholder', input_details[0]['name']) 589 self.assertEqual(np.float32, input_details[0]['dtype']) 590 self.assertAllEqual([1, 1, 16, 3], input_details[0]['shape']) 591 self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature']) 592 self.assertEqual((0., 0.), input_details[0]['quantization']) 593 594 # Resize tensor with strict checking. 595 with self.assertRaises(RuntimeError) as error: 596 interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) 597 self.assertIn( 598 'ResizeInputTensorStrict only allows mutating unknown dimensions ' 599 'identified by -1.', str(error.exception)) 600 601 # Resize tensor and invoke. 602 interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) 603 interpreter.allocate_tensors() 604 interpreter.invoke() 605 606 input_details = interpreter.get_input_details() 607 self.assertLen(input_details, 1) 608 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 609 self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature']) 610 611 output_details = interpreter.get_output_details() 612 self.assertAllEqual([1, -1, 16, 3], output_details[0]['shape_signature']) 613 614 def testResizeTensorInputStrict(self): 615 # Ensures that resize_tensor_input(strict=True) works as expected. 616 with ops.Graph().as_default(): 617 in_tensor = array_ops.placeholder( 618 shape=[1, 16, 16, 3], dtype=dtypes.float32) 619 out_tensor = in_tensor + in_tensor 620 sess = session.Session() 621 622 # Convert model and ensure model is not None. 623 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 624 [out_tensor]) 625 tflite_model = converter.convert() 626 self.assertIsNotNone(tflite_model) 627 628 # Check values from converted model. 629 interpreter = Interpreter(model_content=tflite_model) 630 631 # Resize incorrect value. 632 with self.assertRaises(RuntimeError) as error: 633 interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) 634 self.assertIn( 635 'ResizeInputTensorStrict only allows mutating unknown dimensions ' 636 'identified by -1.', str(error.exception)) 637 638 # Resize correct value. 639 interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) 640 interpreter.allocate_tensors() 641 642 def testBatchSizeValid(self): 643 with ops.Graph().as_default(): 644 in_tensor = array_ops.placeholder( 645 shape=[None, 16, 16, 3], dtype=dtypes.float32) 646 out_tensor = in_tensor + in_tensor 647 sess = session.Session() 648 649 # Convert model and ensure model is not None. 650 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 651 [out_tensor]) 652 tflite_model = converter.convert() 653 self.assertIsNotNone(tflite_model) 654 655 # Check values from converted model. 656 interpreter = Interpreter(model_content=tflite_model) 657 interpreter.allocate_tensors() 658 659 input_details = interpreter.get_input_details() 660 self.assertLen(input_details, 1) 661 self.assertEqual('Placeholder', input_details[0]['name']) 662 self.assertEqual(np.float32, input_details[0]['dtype']) 663 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 664 self.assertEqual((0., 0.), input_details[0]['quantization']) 665 666 output_details = interpreter.get_output_details() 667 self.assertLen(output_details, 1) 668 self.assertEqual('add', output_details[0]['name']) 669 self.assertEqual(np.float32, output_details[0]['dtype']) 670 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 671 self.assertEqual((0., 0.), output_details[0]['quantization']) 672 673 def testBatchSizeNonZero(self): 674 with ops.Graph().as_default(): 675 in_tensor_1 = array_ops.placeholder( 676 shape=[None, 4], dtype=dtypes.float32, name='input1') 677 in_tensor_2 = array_ops.placeholder( 678 shape=[4, 10], dtype=dtypes.float32, name='input2') 679 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2) 680 sess = session.Session() 681 682 # Convert model and ensure model is not None. 683 converter = lite.TFLiteConverter.from_session(sess, 684 [in_tensor_1, in_tensor_2], 685 [out_tensor]) 686 tflite_model = converter.convert() 687 self.assertIsNotNone(tflite_model) 688 689 # Check values from converted model. 690 interpreter = Interpreter(model_content=tflite_model) 691 interpreter.allocate_tensors() 692 693 input_details = interpreter.get_input_details() 694 self.assertLen(input_details, 2) 695 self.assertEqual('input1', input_details[0]['name']) 696 self.assertAllEqual([1, 4], input_details[0]['shape']) 697 self.assertEqual('input2', input_details[1]['name']) 698 self.assertAllEqual([4, 10], input_details[1]['shape']) 699 700 def testFreezeGraph(self): 701 with ops.Graph().as_default(): 702 in_tensor = array_ops.placeholder( 703 shape=[1, 16, 16, 3], dtype=dtypes.float32) 704 var = variable_scope.get_variable( 705 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 706 # Get the second output to ensure freezing properly processes tensor names 707 # like 'X:1'. 708 out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1] 709 sess = session.Session() 710 sess.run(_global_variables_initializer()) 711 712 # Convert model and ensure model is not None. 713 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 714 [out_tensor]) 715 tflite_model = converter.convert() 716 self.assertIsNotNone(tflite_model) 717 718 # Check values from converted model. 719 interpreter = Interpreter(model_content=tflite_model) 720 interpreter.allocate_tensors() 721 722 input_details = interpreter.get_input_details() 723 self.assertLen(input_details, 1) 724 self.assertEqual('Placeholder', input_details[0]['name']) 725 self.assertEqual(np.float32, input_details[0]['dtype']) 726 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 727 self.assertEqual((0., 0.), input_details[0]['quantization']) 728 729 output_details = interpreter.get_output_details() 730 self.assertLen(output_details, 1) 731 self.assertEqual('top_k:1', output_details[0]['name']) 732 self.assertEqual(np.int32, output_details[0]['dtype']) 733 self.assertAllEqual([1, 16, 16, 1], output_details[0]['shape']) 734 self.assertEqual((0., 0.), output_details[0]['quantization']) 735 736 def testGraphviz(self): 737 with ops.Graph().as_default(): 738 in_tensor = array_ops.placeholder( 739 shape=[1, 16, 16, 3], dtype=dtypes.float32) 740 out_tensor = in_tensor + in_tensor 741 sess = session.Session() 742 743 # Convert model and ensure model is not None. 744 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 745 [out_tensor]) 746 converter.output_format = lite_constants.GRAPHVIZ_DOT 747 graphviz_output = converter.convert() 748 self.assertIsNotNone(graphviz_output) 749 750 @parameterized.named_parameters( 751 ('EnableMlirConverter', True), # enable mlir 752 ('DisableMlirConverter', False)) # disable mlir 753 def testDumpGraphviz(self, enable_mlir_converter): 754 with ops.Graph().as_default(): 755 in_tensor = array_ops.placeholder( 756 shape=[1, 16, 16, 3], dtype=dtypes.float32) 757 out_tensor = in_tensor + in_tensor 758 sess = session.Session() 759 760 # Convert model and ensure model is not None. 761 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 762 [out_tensor]) 763 converter.experimental_new_converter = enable_mlir_converter 764 graphviz_dir = self.get_temp_dir() 765 converter.dump_graphviz_dir = graphviz_dir 766 tflite_model = converter.convert() 767 self.assertIsNotNone(tflite_model) 768 769 # Ensure interpreter is able to allocate and check graphviz data. 770 interpreter = Interpreter(model_content=tflite_model) 771 interpreter.allocate_tensors() 772 773 num_items_graphviz = len(os.listdir(graphviz_dir)) 774 self.assertIsNotNone(num_items_graphviz) 775 self.assertIsNotNone( 776 os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot'))) 777 self.assertIsNotNone( 778 os.path.exists( 779 os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot'))) 780 781 # new converter doesn't support `dump_graphviz_video` flag 782 if not enable_mlir_converter: 783 # Convert model and ensure model is not None. 784 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 785 [out_tensor]) 786 converter.experimental_new_converter = enable_mlir_converter 787 graphviz_dir = self.get_temp_dir() 788 converter.dump_graphviz_dir = graphviz_dir 789 converter.dump_graphviz_video = True 790 tflite_model = converter.convert() 791 self.assertIsNotNone(tflite_model) 792 793 # Ensure graphviz folder has more data after using video flag. 794 num_items_graphviz_video = len(os.listdir(graphviz_dir)) 795 self.assertGreater(num_items_graphviz_video, num_items_graphviz) 796 797 def testDumpConversionSummary(self): 798 with ops.Graph().as_default(): 799 in_tensor = array_ops.placeholder( 800 shape=[1, 16, 16, 3], dtype=dtypes.float32) 801 out_tensor = in_tensor + in_tensor 802 sess = session.Session() 803 804 # Convert model and ensure model is not None. 805 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 806 [out_tensor]) 807 log_dir = self.get_temp_dir() 808 converter.conversion_summary_dir = log_dir 809 tflite_model = converter.convert() 810 self.assertIsNotNone(tflite_model) 811 812 self.assertNotEmpty(os.listdir(log_dir)) 813 814 def testDumpConversionSummaryWithOldConverter(self): 815 with ops.Graph().as_default(): 816 in_tensor = array_ops.placeholder( 817 shape=[1, 16, 16, 3], dtype=dtypes.float32) 818 out_tensor = in_tensor + in_tensor 819 sess = session.Session() 820 821 # Convert model and ensure model is not None. 822 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 823 [out_tensor]) 824 converter.experimental_new_converter = False 825 log_dir = self.get_temp_dir() 826 converter.conversion_summary_dir = log_dir 827 tflite_model = converter.convert() 828 self.assertIsNotNone(tflite_model) 829 # Check nothing is generated under the conversion summary path. 830 num_items_conversion_summary = len(os.listdir(log_dir)) 831 self.assertEqual(num_items_conversion_summary, 0) 832 833 @parameterized.named_parameters( 834 ('EnableMlirConverter', True), # enable mlir 835 ('DisableMlirConverter', False)) # disable mlir 836 def testQuantizeDynamicRange(self, enable_mlir_converter): 837 np.random.seed(0) 838 with ops.Graph().as_default(): 839 # We need the tensor to have more than 1024 elements for quantize_weights 840 # to kick in. Thus, the [33, 33] shape. 841 in_tensor_1 = array_ops.placeholder( 842 shape=[33, 33], dtype=dtypes.float32, name='inputA') 843 in_tensor_2 = constant_op.constant( 844 np.random.uniform(low=-10., high=10., size=(33, 33)), 845 shape=[33, 33], 846 dtype=dtypes.float32, 847 name='inputB') 848 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 849 sess = session.Session() 850 851 # Convert float model. 852 float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1], 853 [out_tensor]) 854 float_converter.experimental_new_converter = enable_mlir_converter 855 float_tflite_model = float_converter.convert() 856 self.assertIsNotNone(float_tflite_model) 857 858 # Convert quantized weights model. 859 quantized_converter = lite.TFLiteConverter.from_session( 860 sess, [in_tensor_1], [out_tensor]) 861 862 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 863 quantized_converter.experimental_new_converter = enable_mlir_converter 864 quantized_tflite_model = quantized_converter.convert() 865 self.assertIsNotNone(quantized_tflite_model) 866 867 # Ensure that the quantized weights tflite model is smaller. 868 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 869 870 @parameterized.named_parameters( 871 ('EnableMlirConverter', True), # enable mlir 872 ('DisableMlirConverter', False)) # disable mlir 873 def testQuantizeDynamicRangeDeprecatedPostTrainingQuantizeAttribute( 874 self, enable_mlir_converter): 875 with ops.Graph().as_default(): 876 in_tensor_1 = array_ops.placeholder( 877 shape=[33, 33], dtype=dtypes.float32, name='inputA') 878 in_tensor_2 = constant_op.constant( 879 np.random.uniform(low=-10., high=10., size=(33, 33)), 880 shape=[33, 33], 881 dtype=dtypes.float32, 882 name='inputB') 883 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 884 sess = session.Session() 885 886 quantized_converter = lite.TFLiteConverter.from_session( 887 sess, [in_tensor_1], [out_tensor]) 888 self.assertFalse(quantized_converter.post_training_quantize) 889 quantized_converter.experimental_new_converter = enable_mlir_converter 890 891 quantized_converter.post_training_quantize = True 892 self.assertTrue(quantized_converter.post_training_quantize) 893 self.assertEqual(quantized_converter.optimizations, [lite.Optimize.DEFAULT]) 894 895 quantized_tflite_model = quantized_converter.convert() 896 self.assertIsNotNone(quantized_tflite_model) 897 898 def _getIntegerQuantizeModel(self): 899 np.random.seed(0) 900 inp = array_ops.placeholder( 901 dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input') 902 conv = nn_ops.conv2d( 903 inp, 904 filter=array_ops.ones([3, 3, 3, 16]), 905 strides=[1, 1, 1, 1], 906 padding='SAME') 907 output = nn_ops.relu(conv, name='output') 908 909 def calibration_gen(): 910 for _ in range(5): 911 yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)] 912 913 return (inp, output, calibration_gen) 914 915 @parameterized.named_parameters( 916 ('EnableMlirConverter', True), # enable mlir 917 ('DisableMlirConverter', False)) # disable mlir 918 def testQuantizeInt8AllowFloat(self, enable_mlir_converter): 919 with ops.Graph().as_default(): 920 inp, output, calibration_gen = self._getIntegerQuantizeModel() 921 sess = session.Session() 922 923 # Convert float model. 924 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 925 float_tflite_model = float_converter.convert() 926 self.assertIsNotNone(float_tflite_model) 927 928 # Convert quantized model. 929 quantized_converter = lite.TFLiteConverter.from_session( 930 sess, [inp], [output]) 931 quantized_converter.experimental_new_converter = enable_mlir_converter 932 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 933 quantized_converter.representative_dataset = calibration_gen 934 quantized_tflite_model = quantized_converter.convert() 935 self.assertIsNotNone(quantized_tflite_model) 936 937 # The default input and output types should be float. 938 interpreter = Interpreter(model_content=quantized_tflite_model) 939 interpreter.allocate_tensors() 940 input_details = interpreter.get_input_details() 941 self.assertLen(input_details, 1) 942 self.assertEqual(np.float32, input_details[0]['dtype']) 943 output_details = interpreter.get_output_details() 944 self.assertLen(output_details, 1) 945 self.assertEqual(np.float32, output_details[0]['dtype']) 946 947 # Ensure that the quantized weights tflite model is smaller. 948 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 949 950 @parameterized.named_parameters( 951 # Quantize model to Int8: with enable mlir 952 ('UseTfliteBuiltinsIntEnableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8 953 ], True), 954 # Quantize model to Int8: with disable mlir 955 ('UseTfliteBuiltinsIntDisableMLIR', [lite.OpsSet.TFLITE_BUILTINS_INT8 956 ], False), 957 # Quantize model to Int16: with disable mlir 958 ('UseTfliteBuiltinsInt16DisableMLIR', [ 959 lite.OpsSet 960 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 961 ], False), 962 ('UseTfliteBuiltinsInt16EnableMLIR', [ 963 lite.OpsSet 964 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 965 ], True)) 966 def testQuantizeInt8And16x8(self, supported_ops, enable_mlir_converter): 967 with ops.Graph().as_default(): 968 inp, output, calibration_gen = self._getIntegerQuantizeModel() 969 sess = session.Session() 970 971 # Convert float model. 972 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 973 float_converter.experimental_new_converter = enable_mlir_converter 974 float_tflite_model = float_converter.convert() 975 self.assertIsNotNone(float_tflite_model) 976 977 # Convert model by specifying target spec (instead of optimizations), since 978 # when targeting an integer only backend, quantization is mandatory. 979 quantized_converter = lite.TFLiteConverter.from_session( 980 sess, [inp], [output]) 981 quantized_converter.experimental_new_converter = enable_mlir_converter 982 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 983 quantized_converter.target_spec.supported_ops = supported_ops 984 quantized_converter.representative_dataset = calibration_gen 985 quantized_tflite_model = quantized_converter.convert() 986 self.assertIsNotNone(quantized_tflite_model) 987 988 # The default input and output types should be float. 989 interpreter = Interpreter(model_content=quantized_tflite_model) 990 interpreter.allocate_tensors() 991 input_details = interpreter.get_input_details() 992 self.assertLen(input_details, 1) 993 self.assertEqual(np.float32, input_details[0]['dtype']) 994 output_details = interpreter.get_output_details() 995 self.assertLen(output_details, 1) 996 self.assertEqual(np.float32, output_details[0]['dtype']) 997 998 # Ensure that the quantized weights tflite model is smaller. 999 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 1000 1001 @parameterized.named_parameters( 1002 ('EnableMlirConverter', True), # enable mlir 1003 ('DisableMlirConverter', False)) # disable mlir 1004 def testQuantizeInt8InputOutput(self, enable_mlir_converter): 1005 with ops.Graph().as_default(): 1006 inp, output, calibration_gen = self._getIntegerQuantizeModel() 1007 sess = session.Session() 1008 1009 # Convert float model. 1010 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1011 float_converter.experimental_new_converter = enable_mlir_converter 1012 float_tflite_model = float_converter.convert() 1013 self.assertIsNotNone(float_tflite_model) 1014 1015 # Convert quantized weights model. 1016 quantized_converter = lite.TFLiteConverter.from_session( 1017 sess, [inp], [output]) 1018 quantized_converter.experimental_new_converter = enable_mlir_converter 1019 quantized_converter.inference_input_type = dtypes.int8 1020 quantized_converter.inference_output_type = dtypes.int8 1021 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1022 quantized_converter.representative_dataset = calibration_gen 1023 quantized_tflite_model = quantized_converter.convert() 1024 self.assertIsNotNone(quantized_tflite_model) 1025 1026 # The input and output types should be int8. 1027 interpreter = Interpreter(model_content=quantized_tflite_model) 1028 interpreter.allocate_tensors() 1029 input_details = interpreter.get_input_details() 1030 self.assertLen(input_details, 1) 1031 self.assertEqual(np.int8, input_details[0]['dtype']) 1032 output_details = interpreter.get_output_details() 1033 self.assertLen(output_details, 1) 1034 self.assertEqual(np.int8, output_details[0]['dtype']) 1035 1036 # Ensure that the quantized weights tflite model is smaller. 1037 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 1038 1039 @parameterized.named_parameters( 1040 ('EnableMlirConverter', True), # enable mlir 1041 ('DisableMlirConverter', False)) # disable mlir 1042 def testInvalidQuantizeInt8(self, enable_mlir_converter): 1043 np.random.seed(0) 1044 with ops.Graph().as_default(): 1045 # We need the tensor to have more than 1024 elements for quantize_weights 1046 # to kick in. Thus, the [33, 33] shape. 1047 in_tensor_1 = array_ops.placeholder( 1048 shape=[33, 33], dtype=dtypes.float32, name='inputA') 1049 in_tensor_2 = constant_op.constant( 1050 np.random.uniform(low=-10., high=10., size=(33, 33)), 1051 shape=[33, 33], 1052 dtype=dtypes.float32, 1053 name='inputB') 1054 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 1055 sess = session.Session() 1056 1057 # Attempt to convert to quantized weights model. 1058 quantized_converter = lite.TFLiteConverter.from_session( 1059 sess, [in_tensor_1], [out_tensor]) 1060 quantized_converter.experimental_new_converter = enable_mlir_converter 1061 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1062 # Restricting to int8 type only 1063 quantized_converter.target_spec.supported_types = [dtypes.int8] 1064 # A representative dataset is required for full fixed point quantization. 1065 with self.assertRaises(ValueError) as error: 1066 quantized_converter.convert() 1067 self.assertEqual( 1068 'representative_dataset is required when specifying ' 1069 'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception)) 1070 1071 @parameterized.named_parameters( 1072 ('EnableMlirConverter', True), # enable mlir 1073 ('DisableMlirConverter', False)) # disable mlir 1074 def testQuantizeUInt8(self, enable_mlir_converter): 1075 with ops.Graph().as_default(): 1076 in_tensor_1 = array_ops.placeholder( 1077 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 1078 in_tensor_2 = array_ops.placeholder( 1079 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 1080 out_tensor = array_ops.fake_quant_with_min_max_args( 1081 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 1082 sess = session.Session() 1083 1084 # Convert model and ensure model is not None. 1085 converter = lite.TFLiteConverter.from_session(sess, 1086 [in_tensor_1, in_tensor_2], 1087 [out_tensor]) 1088 converter.inference_type = dtypes.uint8 1089 converter.quantized_input_stats = { 1090 'inputA': (0., 1.), 1091 'inputB': (0., 1.) 1092 } # mean, std_dev 1093 converter.experimental_new_converter = enable_mlir_converter 1094 tflite_model = converter.convert() 1095 self.assertIsNotNone(tflite_model) 1096 1097 # Check values from converted model. 1098 interpreter = Interpreter(model_content=tflite_model) 1099 interpreter.allocate_tensors() 1100 1101 input_details = interpreter.get_input_details() 1102 self.assertLen(input_details, 2) 1103 self.assertEqual('inputA', input_details[0]['name']) 1104 self.assertEqual(np.uint8, input_details[0]['dtype']) 1105 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1106 self.assertEqual((1., 0.), input_details[0]['quantization']) 1107 1108 self.assertEqual('inputB', input_details[1]['name']) 1109 self.assertEqual(np.uint8, input_details[1]['dtype']) 1110 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 1111 self.assertEqual((1., 0.), input_details[1]['quantization']) 1112 1113 output_details = interpreter.get_output_details() 1114 self.assertLen(output_details, 1) 1115 self.assertEqual(np.uint8, output_details[0]['dtype']) 1116 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1117 self.assertGreater(output_details[0]['quantization'][0], 0) # scale 1118 1119 def testQuantizeUInt8UsingDefaultRangeStats(self): 1120 with ops.Graph().as_default(): 1121 in_tensor = array_ops.placeholder( 1122 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1123 out_tensor = in_tensor + in_tensor 1124 sess = session.Session() 1125 1126 # Convert model and ensure model is not None. 1127 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1128 [out_tensor]) 1129 converter.inference_type = dtypes.uint8 1130 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 1131 converter.default_ranges_stats = (0, 6) # min, max 1132 tflite_model = converter.convert() 1133 self.assertIsNotNone(tflite_model) 1134 1135 # Check values from converted model. 1136 interpreter = Interpreter(model_content=tflite_model) 1137 interpreter.allocate_tensors() 1138 1139 input_details = interpreter.get_input_details() 1140 self.assertLen(input_details, 1) 1141 self.assertEqual('Placeholder', input_details[0]['name']) 1142 self.assertEqual(np.uint8, input_details[0]['dtype']) 1143 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1144 self.assertEqual((1., 0.), input_details[0]['quantization']) 1145 1146 output_details = interpreter.get_output_details() 1147 self.assertLen(output_details, 1) 1148 self.assertEqual('add', output_details[0]['name']) 1149 self.assertEqual(np.uint8, output_details[0]['dtype']) 1150 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1151 self.assertGreater(output_details[0]['quantization'][0], 0) # scale 1152 1153 @parameterized.named_parameters( 1154 # Quantize to Float16 even if rep data provided. 1155 ('UseRepresentativeData', True, False, True, False, False, False, False, 1156 False), 1157 # Quantize to Float16 if no rep data provided. 1158 ('NoRepresentativeData', False, False, True, False, False, False, False, 1159 False), 1160 # Quantize to Float16 and set Float16Accumulation 1161 ('SpecifyFloat16Accumulation', False, False, True, True, False, False, 1162 False, False), 1163 # Post training quantization if both rep data and int8 included. 1164 ('UseSampleDataIncludeInt8', True, True, False, False, False, True, False, 1165 False), 1166 # Quantize to Float16 even if rep data provided with mlir. 1167 ('UseRepresentativeDataMlir', True, False, True, False, False, False, 1168 True, False), 1169 # Quantize to Float16 if no rep data provided with mlir. 1170 ('NoRepresentativeDataMlir', False, False, True, False, False, False, 1171 True, False), 1172 # Post training quantization if both rep data and int8 included with mlir. 1173 ('SampleDataIncludeInt8Mlir', True, True, False, False, False, True, True, 1174 False), 1175 # Same as above, but using MLIR quantizer 1176 ('SampleDataIncludeInt8MlirQuant', True, True, False, False, False, True, 1177 True, True)) 1178 def testQuantizeFloat16(self, use_rep_data, include_int8, 1179 is_float16_quantized, is_float16_accumulation, 1180 is_error, is_post_training_quantized, 1181 enable_mlir_converter, enable_mlir_quantizer): 1182 with ops.Graph().as_default(): 1183 inp, output, calibration_gen = self._getIntegerQuantizeModel() 1184 sess = session.Session() 1185 1186 bias_idx = 1 if enable_mlir_converter else 0 1187 bias_name = 'Conv2D' if enable_mlir_converter else 'Conv2D_bias' 1188 1189 # Convert float model. 1190 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1191 float_converter.experimental_new_converter = enable_mlir_converter 1192 float_tflite_model = float_converter.convert() 1193 self.assertIsNotNone(float_tflite_model) 1194 interpreter = Interpreter(model_content=float_tflite_model) 1195 interpreter.allocate_tensors() 1196 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'], 1197 bias_name) 1198 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'], 1199 dtypes.float32) 1200 1201 # MLIR quantizer has different bias index. 1202 if enable_mlir_quantizer: 1203 bias_idx = 2 1204 1205 # Convert model to quantized version 1206 quantized_converter = lite.TFLiteConverter.from_session( 1207 sess, [inp], [output]) 1208 quantized_converter.experimental_new_converter = enable_mlir_converter 1209 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 1210 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1211 quantized_converter.target_spec.supported_types = [dtypes.float16] 1212 if include_int8: 1213 quantized_converter.target_spec.supported_types.append(dtypes.int8) 1214 if use_rep_data: 1215 quantized_converter.representative_dataset = calibration_gen 1216 if is_float16_accumulation: 1217 quantized_converter.target_spec.experimental_supported_accumulation_type = dtypes.float16 # pylint: disable=line-too-long 1218 1219 if is_error: 1220 with self.assertRaises(ValueError) as error: 1221 quantized_converter.convert() 1222 self.assertEqual( 1223 'representative_dataset is required when specifying ' 1224 'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception)) 1225 1226 else: 1227 quantized_tflite_model = quantized_converter.convert() 1228 self.assertIsNotNone(quantized_tflite_model) 1229 interpreter = Interpreter(model_content=quantized_tflite_model) 1230 interpreter.allocate_tensors() 1231 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'], 1232 bias_name) 1233 1234 if is_float16_quantized: 1235 # Verify that bias constant is float16 type. 1236 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'], 1237 dtypes.float16) 1238 elif is_post_training_quantized: 1239 # Verify that bias constants is int32 type. 1240 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'], 1241 dtypes.int32) 1242 else: 1243 raise ValueError('Invalid test options.') 1244 1245 @parameterized.named_parameters( 1246 ('EnableMlirConverter', True), # enable mlir 1247 ('DisableMlirConverter', False)) # disable mlir 1248 def testInvalidQuantizeFloat16(self, enable_mlir_converter): 1249 with ops.Graph().as_default(): 1250 inp, output, _ = self._getIntegerQuantizeModel() 1251 sess = session.Session() 1252 1253 # Specify float16 quantization 1254 quantized_converter = lite.TFLiteConverter.from_session( 1255 sess, [inp], [output]) 1256 quantized_converter.experimental_new_converter = enable_mlir_converter 1257 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1258 quantized_converter.target_spec.supported_types = [dtypes.float16] 1259 # Specify only int8 builtin ops 1260 quantized_converter.target_spec.supported_ops = [ 1261 lite.OpsSet.TFLITE_BUILTINS_INT8 1262 ] 1263 with self.assertRaises(ValueError) as error: 1264 quantized_converter.convert() 1265 self.assertEqual( 1266 'TFLITE_BUILTINS_INT8 requires smallest supported type to be INT8.', 1267 str(error.exception)) 1268 1269 @parameterized.named_parameters(('InferenceType_INT8', dtypes.int8), 1270 ('InferenceType_UINT8', dtypes.uint8)) 1271 def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type): 1272 with ops.Graph().as_default(): 1273 in_tensor = array_ops.placeholder( 1274 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1275 out_tensor = array_ops.fake_quant_with_min_max_args( 1276 in_tensor + in_tensor, min=0., max=1.) 1277 sess = session.Session() 1278 1279 quantized_converter = lite.TFLiteConverter.from_session( 1280 sess, [in_tensor], [out_tensor]) 1281 1282 with self.assertRaises(ValueError) as error: 1283 quantized_converter.inference_type = quantized_type 1284 quantized_converter.convert() 1285 self.assertEqual( 1286 'The `quantized_input_stats` flag must be defined when either ' 1287 '`inference_type` flag or `inference_input_type` flag is set to ' 1288 'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and ' 1289 '`inference_input_type=None`.'.format(quantized_type.name), 1290 str(error.exception)) 1291 1292 with self.assertRaises(ValueError) as error: 1293 quantized_converter.inference_type = dtypes.float32 1294 quantized_converter.inference_input_type = quantized_type 1295 quantized_converter.convert() 1296 self.assertEqual( 1297 'The `quantized_input_stats` flag must be defined when either ' 1298 '`inference_type` flag or `inference_input_type` flag is set to ' 1299 'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and ' 1300 '`inference_input_type=tf.{}`.'.format(quantized_type.name), 1301 str(error.exception)) 1302 1303 quantized_converter.inference_type = quantized_type 1304 quantized_converter.inference_input_type = quantized_type 1305 1306 input_arrays = quantized_converter.get_input_arrays() 1307 quantized_converter.quantized_input_stats = {input_arrays[0]: (0., 1.)} 1308 quantized_converter.convert() 1309 1310 def testInvalidQuantizeQATModelMissingInputStats(self): 1311 with ops.Graph().as_default(): 1312 in_tensor_1 = array_ops.placeholder( 1313 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 1314 in_tensor_2 = array_ops.placeholder( 1315 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 1316 out_tensor = array_ops.fake_quant_with_min_max_args( 1317 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 1318 sess = session.Session() 1319 1320 # Convert model and ensure model is not None. 1321 converter = lite.TFLiteConverter.from_session(sess, 1322 [in_tensor_1, in_tensor_2], 1323 [out_tensor]) 1324 converter.inference_type = dtypes.uint8 1325 converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev 1326 with self.assertRaises(ValueError) as error: 1327 converter.convert() 1328 self.assertEqual( 1329 'Quantization input stats are not available for input tensors ' 1330 '\'inputB\'.', str(error.exception)) 1331 1332 def testTrainingTimeAndPostTrainingCalibrateAndQuantize(self): 1333 with ops.Graph().as_default(): 1334 inp, output, calibration_gen = self._getIntegerQuantizeModel() 1335 sess = session.Session() 1336 1337 # Convert float model. 1338 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1339 float_tflite_model = float_converter.convert() 1340 self.assertIsNotNone(float_tflite_model) 1341 1342 converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1343 1344 # extra flags to trigger training time quantization conversion 1345 converter.inference_type = dtypes.int8 1346 converter.inference_input_type = dtypes.float32 1347 converter.inference_output_type = dtypes.float32 1348 input_arrays = converter.get_input_arrays() 1349 converter.quantized_input_stats = {input_arrays[0]: (0., 1.)} 1350 # trigger post-training quantization 1351 converter.optimizations = [lite.Optimize.DEFAULT] 1352 converter.representative_dataset = calibration_gen 1353 converter.experimental_new_quantizer = True 1354 quantized_tflite_model = converter.convert() 1355 self.assertIsNotNone(quantized_tflite_model) 1356 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 1357 1358 # calibration only api 1359 converter._experimental_calibrate_only = True 1360 calibrated_tflite = converter.convert() 1361 quantized_tflite_model = mlir_quantize( 1362 calibrated_tflite, fully_quantize=True) 1363 interpreter = Interpreter(model_content=quantized_tflite_model) 1364 interpreter.allocate_tensors() 1365 input_details = interpreter.get_input_details() 1366 self.assertEqual(np.int8, input_details[0]['dtype']) 1367 self.assertEqual((1., 0.), input_details[0]['quantization']) 1368 1369 output_details = interpreter.get_output_details() 1370 self.assertEqual(np.int8, output_details[0]['dtype']) 1371 1372 def testFloatTocoConverter(self): 1373 """Tests deprecated test TocoConverter.""" 1374 with ops.Graph().as_default(): 1375 in_tensor = array_ops.placeholder( 1376 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1377 out_tensor = in_tensor + in_tensor 1378 sess = session.Session() 1379 1380 # Convert model and ensure model is not None. 1381 converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) 1382 tflite_model = converter.convert() 1383 self.assertIsNotNone(tflite_model) 1384 1385 # Ensure the interpreter is able to load. 1386 interpreter = Interpreter(model_content=tflite_model) 1387 interpreter.allocate_tensors() 1388 1389 def testMultipleOutputNodeNames(self): 1390 """Tests converting a graph with an op that have multiple outputs.""" 1391 with ops.Graph().as_default(): 1392 input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) 1393 out0, out1, out2, out3 = array_ops.split( 1394 input_tensor, [1, 1, 1, 1], axis=0) 1395 sess = session.Session() 1396 1397 # Convert model and ensure model is not None. 1398 converter = lite.TFLiteConverter.from_session(sess, [input_tensor], 1399 [out0, out1, out2, out3]) 1400 tflite_model = converter.convert() 1401 self.assertIsNotNone(tflite_model) 1402 1403 # Check values from converted model. 1404 interpreter = Interpreter(model_content=tflite_model) 1405 interpreter.allocate_tensors() 1406 1407 input_details = interpreter.get_input_details() 1408 self.assertLen(input_details, 1) 1409 interpreter.set_tensor(input_details[0]['index'], 1410 np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) 1411 interpreter.invoke() 1412 1413 output_details = interpreter.get_output_details() 1414 self.assertLen(output_details, 4) 1415 self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index'])) 1416 self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index'])) 1417 self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index'])) 1418 self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index'])) 1419 1420 @parameterized.named_parameters( 1421 ('EnableMlirConverter', True), # enable mlir 1422 ('DisableMlirConverter', False)) # disable mlir 1423 @test_util.run_in_graph_and_eager_modes 1424 def testFunctions(self, enable_mlir_converter): 1425 """Tests tf.function in 1.X.""" 1426 1427 @def_function.function 1428 def plus_placeholder(x, placeholder): 1429 return x + placeholder 1430 1431 with ops.Graph().as_default(): 1432 placeholder = array_ops.placeholder( 1433 dtype=dtypes.float32, shape=[1], name='input') 1434 variable_node = variables.Variable(1.0, name='variable_node') 1435 defun_node = plus_placeholder(variable_node, placeholder) 1436 output_node = math_ops.multiply(defun_node, 2.0, name='output_node') 1437 1438 # Initialize variables in the model. 1439 sess = session.Session() 1440 sess.run(variables.variables_initializer([variable_node])) 1441 1442 # Convert model and ensure model is not None. 1443 converter = lite.TFLiteConverter.from_session(sess, [placeholder], 1444 [output_node]) 1445 converter.experimental_new_converter = enable_mlir_converter 1446 tflite_model = converter.convert() 1447 self.assertIsNotNone(tflite_model) 1448 1449 # Check values from converted model. 1450 interpreter = Interpreter(model_content=tflite_model) 1451 interpreter.allocate_tensors() 1452 1453 input_details = interpreter.get_input_details() 1454 self.assertLen(input_details, 1) 1455 self.assertEqual('input', input_details[0]['name']) 1456 self.assertEqual(np.float32, input_details[0]['dtype']) 1457 self.assertAllEqual([1], input_details[0]['shape']) 1458 self.assertEqual((0., 0.), input_details[0]['quantization']) 1459 1460 output_details = interpreter.get_output_details() 1461 self.assertLen(output_details, 1) 1462 self.assertEqual('output_node', output_details[0]['name']) 1463 self.assertEqual(np.float32, output_details[0]['dtype']) 1464 self.assertAllEqual([1], output_details[0]['shape']) 1465 self.assertEqual((0., 0.), output_details[0]['quantization']) 1466 1467 def testInferenceInputOutputTypeFloatDefault(self): 1468 with ops.Graph().as_default(): 1469 in_tensor = array_ops.placeholder( 1470 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1471 out_tensor = in_tensor + in_tensor 1472 sess = session.Session() 1473 1474 # Convert model and ensure model is not None. 1475 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1476 [out_tensor]) 1477 tflite_model = converter.convert() 1478 self.assertIsNotNone(tflite_model) 1479 1480 # Check values from converted model. 1481 interpreter = Interpreter(model_content=tflite_model) 1482 interpreter.allocate_tensors() 1483 1484 input_details = interpreter.get_input_details() 1485 self.assertLen(input_details, 1) 1486 self.assertEqual('Placeholder', input_details[0]['name']) 1487 self.assertEqual(np.float32, input_details[0]['dtype']) 1488 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1489 1490 output_details = interpreter.get_output_details() 1491 self.assertLen(output_details, 1) 1492 self.assertEqual('add', output_details[0]['name']) 1493 self.assertEqual(np.float32, output_details[0]['dtype']) 1494 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1495 1496 def testInferenceInputOutputTypeQuantizedUint8Default(self): 1497 with ops.Graph().as_default(): 1498 in_tensor = array_ops.placeholder( 1499 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1500 out_tensor = array_ops.fake_quant_with_min_max_args( 1501 in_tensor + in_tensor, min=0., max=1., name='output') 1502 sess = session.Session() 1503 1504 # Convert model and ensure model is not None. 1505 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1506 [out_tensor]) 1507 converter.inference_type = dtypes.uint8 1508 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 1509 tflite_model = converter.convert() 1510 self.assertIsNotNone(tflite_model) 1511 1512 # Check values from converted model. 1513 interpreter = Interpreter(model_content=tflite_model) 1514 interpreter.allocate_tensors() 1515 1516 input_details = interpreter.get_input_details() 1517 self.assertLen(input_details, 1) 1518 self.assertEqual('Placeholder', input_details[0]['name']) 1519 self.assertEqual(np.uint8, input_details[0]['dtype']) 1520 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1521 1522 output_details = interpreter.get_output_details() 1523 self.assertLen(output_details, 1) 1524 self.assertEqual('output', output_details[0]['name']) 1525 self.assertEqual(np.uint8, output_details[0]['dtype']) 1526 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1527 1528 def testReusingConverterWithDifferentPostTrainingQuantization(self): 1529 with ops.Graph().as_default(): 1530 in_tensor = array_ops.placeholder( 1531 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1532 out_tensor = array_ops.fake_quant_with_min_max_args( 1533 in_tensor + in_tensor, min=0., max=1., name='output') 1534 sess = session.Session() 1535 1536 # Convert model and ensure model is not None. 1537 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1538 [out_tensor]) 1539 1540 converter.post_training_quantize = True 1541 tflite_model = converter.convert() 1542 self.assertIsNotNone(tflite_model) 1543 1544 converter.post_training_quantize = False 1545 tflite_model = converter.convert() 1546 self.assertIsNotNone(tflite_model) 1547 1548 def testResizeWithShape(self): 1549 with ops.Graph().as_default(): 1550 # Construct a graph with a dynamically shapped input and an internal node 1551 # that relies on the output of that input's shape. 1552 in_tensor = array_ops.placeholder( 1553 shape=[None, None], dtype=dtypes.float32) 1554 in_tensor2 = [[1, 2], [3, 4]] 1555 out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor)) 1556 sess = session.Session() 1557 1558 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1559 [out_tensor]) 1560 tflite_model = converter.convert() 1561 1562 # Check values from converted model. 1563 interpreter = Interpreter(model_content=tflite_model) 1564 input_details = interpreter.get_input_details() 1565 self.assertLen(input_details, 1) 1566 self.assertAllEqual([1, 1], input_details[0]['shape']) 1567 self.assertAllEqual([-1, -1], input_details[0]['shape_signature']) 1568 1569 # Resize tensor and invoke. 1570 interpreter.resize_tensor_input(0, [4]) 1571 interpreter.allocate_tensors() 1572 interpreter.invoke() 1573 1574 # The output should be reshaped properly according to the resized input. 1575 output_details = interpreter.get_output_details() 1576 self.assertLen(output_details, 1) 1577 self.assertEqual(np.int32, output_details[0]['dtype']) 1578 self.assertAllEqual([4], output_details[0]['shape']) 1579 output_data = interpreter.get_tensor(output_details[0]['index']) 1580 self.assertAllEqual([1, 2, 3, 4], output_data) 1581 1582 def testResizingIntermediateDynamicTensor(self): 1583 # This is a regression test for the case where shape of dynamic output 1584 # tensors changes between invocations. 1585 # See also https://github.com/tensorflow/tensorflow/issues/26549 1586 with ops.Graph().as_default(): 1587 input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32) 1588 input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32) 1589 1590 # The bug is triggered only when dynamic tensor is intermediate. Putting 1591 # some other ops around it. 1592 neg = math_ops.negative(input2_tensor) 1593 padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32) 1594 output_tensor = array_ops.pad(input_tensor, padding) + neg 1595 1596 sess = session.Session() 1597 1598 converter = lite.TFLiteConverter.from_session( 1599 sess, [input_tensor, padding, input2_tensor], [output_tensor]) 1600 tflite_model = converter.convert() 1601 1602 interpreter = Interpreter(model_content=tflite_model) 1603 interpreter.allocate_tensors() 1604 1605 input_details = interpreter.get_input_details() 1606 interpreter.set_tensor(input_details[1]['index'], 1607 np.array([[1, 1], [1, 1]], dtype=np.int32)) 1608 interpreter.invoke() 1609 1610 # Without the fix, invocation will fail when changing the shape of 1611 # intermediate dynamic tensors. 1612 interpreter.set_tensor(input_details[1]['index'], 1613 np.array([[2, 2], [2, 2]], dtype=np.int32)) 1614 interpreter.invoke() 1615 1616 def testGraphDebugInfo(self): 1617 """Test a session has debug info captured.""" 1618 1619 @def_function.function 1620 def plus_placeholder(x, placeholder): 1621 return x + placeholder 1622 1623 with ops.Graph().as_default(): 1624 placeholder = array_ops.placeholder( 1625 dtype=dtypes.float32, shape=[1], name='input') 1626 variable_node = variables.Variable(1.0, name='variable_node') 1627 defun_node = plus_placeholder(variable_node, placeholder) 1628 output_node = math_ops.multiply(defun_node, 2.0, name='output_node') 1629 1630 # Initialize variables in the model. 1631 sess = session.Session() 1632 sess.run(variables.variables_initializer([variable_node])) 1633 1634 converter = lite.TFLiteConverter.from_session(sess, [placeholder], 1635 [output_node]) 1636 converter.convert() 1637 self.assertValidDebugInfo(converter._debug_info) 1638 1639 # Check the add node in the inlined function is included. 1640 func = sess.graph.as_graph_def().library.function[0].signature.name 1641 self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces) 1642 1643 def testOutputOnlyModel(self): 1644 with ops.Graph().as_default(): 1645 out_tensor = random_ops.random_normal(shape=[3]) 1646 sess = session.Session() 1647 1648 # Convert model and ensure model is not None. 1649 converter = lite.TFLiteConverter.from_session(sess, [], [out_tensor]) 1650 converter.target_spec.supported_ops = [ 1651 lite.OpsSet.TFLITE_BUILTINS, 1652 lite.OpsSet.SELECT_TF_OPS, 1653 ] 1654 1655 # Empty input array is a valid input. 1656 self.assertTrue(converter._has_valid_tensors()) 1657 1658 tflite_model = converter.convert() 1659 self.assertIsNotNone(tflite_model) 1660 1661 1662class FromFrozenGraphFile(LiteTest): 1663 1664 def testFloat(self): 1665 with ops.Graph().as_default(): 1666 in_tensor = array_ops.placeholder( 1667 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1668 _ = in_tensor + in_tensor 1669 sess = session.Session() 1670 1671 # Write graph to file. 1672 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1673 write_graph(sess.graph_def, '', graph_def_file, False) 1674 sess.close() 1675 1676 # Convert model and ensure model is not None. 1677 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 1678 ['Placeholder'], ['add']) 1679 tflite_model = converter.convert() 1680 self.assertIsNotNone(tflite_model) 1681 1682 # Check values from converted model. 1683 interpreter = Interpreter(model_content=tflite_model) 1684 interpreter.allocate_tensors() 1685 1686 input_details = interpreter.get_input_details() 1687 self.assertLen(input_details, 1) 1688 self.assertEqual('Placeholder', input_details[0]['name']) 1689 self.assertEqual(np.float32, input_details[0]['dtype']) 1690 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1691 self.assertEqual((0., 0.), input_details[0]['quantization']) 1692 1693 output_details = interpreter.get_output_details() 1694 self.assertLen(output_details, 1) 1695 self.assertEqual('add', output_details[0]['name']) 1696 self.assertEqual(np.float32, output_details[0]['dtype']) 1697 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1698 self.assertEqual((0., 0.), output_details[0]['quantization']) 1699 1700 def testFloatWithShapesArray(self): 1701 """Test a shape overriding case.""" 1702 with ops.Graph().as_default(): 1703 in_tensor = array_ops.placeholder( 1704 shape=[None, 16, 16, 3], dtype=dtypes.float32) 1705 _ = in_tensor + in_tensor 1706 sess = session.Session() 1707 1708 # Write graph to file. 1709 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1710 write_graph(sess.graph_def, '', graph_def_file, False) 1711 sess.close() 1712 1713 # Convert model and ensure model is not None. 1714 converter = lite.TFLiteConverter.from_frozen_graph( 1715 graph_def_file, ['Placeholder'], ['add'], 1716 input_shapes={'Placeholder': [2, 16, 16, 3]}) 1717 tflite_model = converter.convert() 1718 self.assertIsNotNone(tflite_model) 1719 1720 # Check values from converted model. 1721 interpreter = Interpreter(model_content=tflite_model) 1722 interpreter.allocate_tensors() 1723 1724 input_details = interpreter.get_input_details() 1725 self.assertLen(input_details, 1) 1726 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 1727 1728 def testInvalidShapesArray(self): 1729 """Test an invalid shape overriding case, which has a wrong input name.""" 1730 with ops.Graph().as_default(): 1731 in_tensor = array_ops.placeholder( 1732 shape=[None, 16, 16, 3], dtype=dtypes.float32) 1733 _ = in_tensor + in_tensor 1734 sess = session.Session() 1735 1736 # Write graph to file. 1737 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1738 write_graph(sess.graph_def, '', graph_def_file, False) 1739 sess.close() 1740 1741 # Convert model and ensure model is not None. 1742 with self.assertRaises(ValueError): 1743 lite.TFLiteConverter.from_frozen_graph( 1744 graph_def_file, ['Placeholder'], ['add'], 1745 input_shapes={'wrong_input': [2, 16, 16, 3]}) 1746 1747 def testPartialShapesArray(self): 1748 """Test a shape overriding case, with the only one input among two.""" 1749 with ops.Graph().as_default(): 1750 a = array_ops.placeholder( 1751 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='a') 1752 b = array_ops.placeholder( 1753 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='b') 1754 _ = math_ops.add(a, b, name='add') 1755 sess = session.Session() 1756 1757 # Write graph to file. 1758 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1759 write_graph(sess.graph_def, '', graph_def_file, False) 1760 sess.close() 1761 1762 # Convert model and ensure model is not None. 1763 converter = lite.TFLiteConverter.from_frozen_graph( 1764 graph_def_file, ['a', 'b'], ['add'], input_shapes={'a': [2, 16, 16, 3]}) 1765 tflite_model = converter.convert() 1766 self.assertIsNotNone(tflite_model) 1767 1768 # Check values from converted model. 1769 interpreter = Interpreter(model_content=tflite_model) 1770 interpreter.allocate_tensors() 1771 1772 input_details = interpreter.get_input_details() 1773 self.assertLen(input_details, 2) 1774 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 1775 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 1776 1777 def testFreezeGraph(self): 1778 with ops.Graph().as_default(): 1779 in_tensor = array_ops.placeholder( 1780 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1781 var = variable_scope.get_variable( 1782 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 1783 _ = in_tensor + var 1784 sess = session.Session() 1785 1786 # Write graph to file. 1787 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1788 write_graph(sess.graph_def, '', graph_def_file, False) 1789 sess.close() 1790 1791 # Ensure the graph with variables cannot be converted. 1792 with self.assertRaises(ValueError) as error: 1793 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 1794 ['add']) 1795 self.assertEqual('Please freeze the graph using freeze_graph.py.', 1796 str(error.exception)) 1797 1798 def testPbtxt(self): 1799 with ops.Graph().as_default(): 1800 in_tensor = array_ops.placeholder( 1801 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1802 _ = in_tensor + in_tensor 1803 sess = session.Session() 1804 1805 # Write graph to file. 1806 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') 1807 write_graph(sess.graph_def, '', graph_def_file, True) 1808 sess.close() 1809 1810 # Convert model and ensure model is not None. 1811 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 1812 ['Placeholder'], ['add']) 1813 tflite_model = converter.convert() 1814 self.assertIsNotNone(tflite_model) 1815 1816 # Check values from converted model. 1817 interpreter = Interpreter(model_content=tflite_model) 1818 interpreter.allocate_tensors() 1819 1820 input_details = interpreter.get_input_details() 1821 self.assertLen(input_details, 1) 1822 self.assertEqual('Placeholder', input_details[0]['name']) 1823 self.assertEqual(np.float32, input_details[0]['dtype']) 1824 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1825 self.assertEqual((0., 0.), input_details[0]['quantization']) 1826 1827 output_details = interpreter.get_output_details() 1828 self.assertLen(output_details, 1) 1829 self.assertEqual('add', output_details[0]['name']) 1830 self.assertEqual(np.float32, output_details[0]['dtype']) 1831 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1832 self.assertEqual((0., 0.), output_details[0]['quantization']) 1833 1834 def testInvalidFileNotFound(self): 1835 with self.assertRaises(IOError) as error: 1836 lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'], 1837 ['add']) 1838 self.assertEqual('File \'invalid_file\' does not exist.', 1839 str(error.exception)) 1840 1841 def testInvalidFileBadData(self): 1842 graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') 1843 with gfile.Open(graph_def_file, 'wb') as temp_file: 1844 temp_file.write('bad data') 1845 temp_file.flush() 1846 1847 # Attempts to convert the invalid model. 1848 with self.assertRaises(IOError) as error: 1849 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 1850 ['add']) 1851 self.assertEqual( 1852 'Unable to parse input file \'{}\'.'.format(graph_def_file), 1853 str(error.exception)) 1854 1855 def testFloatTocoConverter(self): 1856 with ops.Graph().as_default(): 1857 in_tensor = array_ops.placeholder( 1858 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1859 _ = in_tensor + in_tensor 1860 sess = session.Session() 1861 1862 # Write graph to file. 1863 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1864 write_graph(sess.graph_def, '', graph_def_file, False) 1865 sess.close() 1866 1867 # Convert model and ensure model is not None. 1868 converter = lite.TocoConverter.from_frozen_graph(graph_def_file, 1869 ['Placeholder'], ['add']) 1870 tflite_model = converter.convert() 1871 self.assertIsNotNone(tflite_model) 1872 1873 # Ensure the model is able to load. 1874 interpreter = Interpreter(model_content=tflite_model) 1875 interpreter.allocate_tensors() 1876 1877 def testGraphDebugInfo(self): 1878 """Test a frozen graph doesn't have debug info captured.""" 1879 with ops.Graph().as_default(): 1880 in_tensor = array_ops.placeholder( 1881 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1882 _ = in_tensor + in_tensor 1883 sess = session.Session() 1884 1885 # Write graph to file. 1886 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1887 write_graph(sess.graph_def, '', graph_def_file, False) 1888 sess.close() 1889 1890 # Convert model and ensure model is not None. 1891 converter = lite.TocoConverter.from_frozen_graph(graph_def_file, 1892 ['Placeholder'], ['add']) 1893 converter.convert() 1894 # GraphDebugInfo should be none for frozen graph. 1895 self.assertFalse(converter._debug_info) 1896 1897 1898class FromFrozenGraphObjectDetection(LiteTest): 1899 1900 def _initObjectDetectionArgs(self): 1901 # Initializes the arguments required for the object detection model. 1902 # Looks for the model file which is saved in a different location internally 1903 # and externally. 1904 filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') 1905 if not os.path.exists(filename): 1906 filename = os.path.join( 1907 resource_loader.get_root_dir_with_all_resources(), 1908 '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb') 1909 if not os.path.exists(filename): 1910 raise IOError("File '{0}' does not exist.".format(filename)) 1911 1912 self._graph_def_file = filename 1913 self._input_arrays = ['normalized_input_image_tensor'] 1914 self._output_arrays = [ 1915 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 1916 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3' 1917 ] 1918 self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]} 1919 1920 def testTFLiteGraphDef(self): 1921 # Tests the object detection model that cannot be loaded in TensorFlow. 1922 self._initObjectDetectionArgs() 1923 1924 converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file, 1925 self._input_arrays, 1926 self._output_arrays, 1927 self._input_shapes) 1928 converter.allow_custom_ops = True 1929 tflite_model = converter.convert() 1930 self.assertIsNotNone(tflite_model) 1931 1932 # Check values from converted model. 1933 interpreter = Interpreter(model_content=tflite_model) 1934 interpreter.allocate_tensors() 1935 1936 input_details = interpreter.get_input_details() 1937 self.assertLen(input_details, 1) 1938 self.assertEqual('normalized_input_image_tensor', input_details[0]['name']) 1939 self.assertEqual(np.float32, input_details[0]['dtype']) 1940 self.assertAllEqual([1, 300, 300, 3], input_details[0]['shape']) 1941 self.assertEqual((0., 0.), input_details[0]['quantization']) 1942 1943 output_details = interpreter.get_output_details() 1944 self.assertLen(output_details, 4) 1945 self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name']) 1946 self.assertEqual(np.float32, output_details[0]['dtype']) 1947 self.assertAllEqual([1, 10, 4], output_details[0]['shape']) 1948 self.assertEqual((0., 0.), output_details[0]['quantization']) 1949 1950 self.assertEqual('TFLite_Detection_PostProcess:1', 1951 output_details[1]['name']) 1952 self.assertAllEqual([1, 10], output_details[1]['shape']) 1953 self.assertEqual('TFLite_Detection_PostProcess:2', 1954 output_details[2]['name']) 1955 self.assertAllEqual([1, 10], output_details[2]['shape']) 1956 self.assertEqual('TFLite_Detection_PostProcess:3', 1957 output_details[3]['name']) 1958 self.assertAllEqual([1], output_details[3]['shape']) 1959 1960 def testTFLiteGraphDefWithControlOutput(self): 1961 with ops.Graph().as_default(): 1962 in_tensor = array_ops.placeholder( 1963 shape=[5, 5], dtype=dtypes.float32, name='input') 1964 out_tensor = in_tensor + in_tensor 1965 logging_ops.print_v2(out_tensor) 1966 sess = session.Session() 1967 1968 converter = lite.TFLiteConverter( 1969 sess.graph_def, 1970 input_tensors=None, 1971 output_tensors=None, 1972 input_arrays_with_shape=[('input', [5, 5])], 1973 output_arrays=None, 1974 experimental_debug_info_func=None) 1975 converter._control_output_arrays = ['PrintV2'] 1976 converter.target_spec.supported_ops = [ 1977 lite.OpsSet.TFLITE_BUILTINS, 1978 lite.OpsSet.SELECT_TF_OPS, 1979 ] 1980 tflite_model = converter.convert() 1981 self.assertIsNotNone(tflite_model) 1982 1983 model = util._convert_model_from_bytearray_to_object(tflite_model) 1984 self.assertEqual(model.operatorCodes[0].builtinCode, 1985 schema_fb.BuiltinOperator.ADD) 1986 self.assertEqual(model.operatorCodes[1].builtinCode, 1987 schema_fb.BuiltinOperator.CUSTOM) 1988 self.assertEqual(model.operatorCodes[1].customCode, b'FlexStringFormat') 1989 self.assertEqual(model.operatorCodes[2].builtinCode, 1990 schema_fb.BuiltinOperator.CUSTOM) 1991 self.assertEqual(model.operatorCodes[2].customCode, b'FlexPrintV2') 1992 1993 # Check values from converted model. 1994 interpreter = Interpreter(model_content=tflite_model) 1995 interpreter.allocate_tensors() 1996 1997 input_details = interpreter.get_input_details() 1998 self.assertLen(input_details, 1) 1999 self.assertEqual('input', input_details[0]['name']) 2000 self.assertEqual(np.float32, input_details[0]['dtype']) 2001 self.assertAllEqual([5, 5], input_details[0]['shape']) 2002 self.assertEqual((0., 0.), input_details[0]['quantization']) 2003 2004 output_details = interpreter.get_output_details() 2005 self.assertLen(output_details, 0) 2006 2007 def testModifyIOToUint8(self): 2008 # Tests the object detection model that cannot be loaded in TensorFlow. 2009 self._initObjectDetectionArgs() 2010 2011 def representative_dataset_gen(): 2012 for _ in range(2): 2013 yield [ 2014 np.random.uniform(low=0, high=1, 2015 size=(1, 300, 300, 3)).astype(np.float32) 2016 ] 2017 2018 converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file, 2019 self._input_arrays, 2020 self._output_arrays, 2021 self._input_shapes) 2022 converter.representative_dataset = representative_dataset_gen 2023 converter.target_spec.supported_ops = {lite.OpsSet.TFLITE_BUILTINS_INT8} 2024 converter.inference_type = dtypes.int8 2025 converter.inference_input_type = dtypes.uint8 2026 converter.inference_output_type = dtypes.uint8 2027 converter.experimental_new_quantizer = True 2028 converter.quantized_input_stats = { 2029 'normalized_input_image_tensor': (0., 1.) 2030 } # mean, std_dev 2031 converter.allow_custom_ops = True 2032 tflite_model = converter.convert() 2033 2034 self.assertIsNotNone(tflite_model) 2035 2036 model = util._convert_model_from_bytearray_to_object(tflite_model) 2037 quant_opcode_idxs = util.get_quantize_opcode_idx(model) 2038 2039 subgraph = model.subgraphs[0] 2040 tensors = subgraph.tensors 2041 operators = subgraph.operators 2042 for op in operators: 2043 if op.opcodeIndex in quant_opcode_idxs: 2044 input_type = util._convert_tflite_enum_type_to_tf_type( 2045 tensors[op.inputs[0]].type) 2046 if op.outputs[0] in subgraph.outputs: 2047 self.assertEqual(input_type, dtypes.float32) 2048 2049 2050class FromSavedModelTest(TestModels): 2051 2052 def _createSavedModel(self, shape): 2053 """Create a simple SavedModel.""" 2054 saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') 2055 with ops.Graph().as_default(): 2056 with session.Session() as sess: 2057 in_tensor_1 = array_ops.placeholder( 2058 shape=shape, dtype=dtypes.float32, name='inputB') 2059 in_tensor_2 = array_ops.placeholder( 2060 shape=shape, dtype=dtypes.float32, name='inputA') 2061 out_tensor = in_tensor_1 + in_tensor_2 2062 inputs = {'x': in_tensor_1, 'y': in_tensor_2} 2063 outputs = {'z': out_tensor} 2064 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 2065 return saved_model_dir 2066 2067 def testSimpleModel(self): 2068 """Test a SavedModel.""" 2069 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2070 2071 # Convert model and ensure model is not None. 2072 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2073 tflite_model = converter.convert() 2074 self.assertIsNotNone(tflite_model) 2075 2076 interpreter = Interpreter(model_content=tflite_model) 2077 interpreter.allocate_tensors() 2078 2079 input_details = interpreter.get_input_details() 2080 self.assertLen(input_details, 2) 2081 self.assertStartsWith(input_details[0]['name'], 'inputA') 2082 self.assertEqual(np.float32, input_details[0]['dtype']) 2083 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2084 self.assertEqual((0., 0.), input_details[0]['quantization']) 2085 2086 self.assertStartsWith(input_details[1]['name'], 'inputB') 2087 self.assertEqual(np.float32, input_details[1]['dtype']) 2088 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2089 self.assertEqual((0., 0.), input_details[1]['quantization']) 2090 2091 output_details = interpreter.get_output_details() 2092 self.assertLen(output_details, 1) 2093 self.assertStartsWith(output_details[0]['name'], 'add') 2094 self.assertEqual(np.float32, output_details[0]['dtype']) 2095 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2096 self.assertEqual((0., 0.), output_details[0]['quantization']) 2097 2098 def testOldConverterWarning(self): 2099 """Test if the warning message when using TOCO is logged.""" 2100 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2101 log = io.BytesIO() if six.PY2 else io.StringIO() 2102 handler = logging.StreamHandler(log) 2103 logging.root.addHandler(handler) 2104 warning_message = 'Please consider switching to the new converter' 2105 # Convert model and ensure model is not None. 2106 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2107 converter.experimental_new_converter = False 2108 tflite_model = converter.convert() 2109 self.assertIsNotNone(tflite_model) 2110 self.assertIn(warning_message, log.getvalue()) 2111 logging.root.removeHandler(handler) 2112 2113 def testNewConverterOptOut(self): 2114 """Test if the opt out message when using New converter is logged.""" 2115 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2116 log = io.BytesIO() if six.PY2 else io.StringIO() 2117 handler = logging.StreamHandler(log) 2118 logging.root.addHandler(handler) 2119 optout_message = ('Using experimental converter: ' 2120 'If you encountered a problem') 2121 # Convert model and ensure model is not None. 2122 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2123 tflite_model = converter.convert() 2124 self.assertIsNotNone(tflite_model) 2125 self.assertIn(optout_message, log.getvalue()) 2126 logging.root.removeHandler(handler) 2127 2128 def testNoneBatchSize(self): 2129 """Test a SavedModel, with None in input tensor's shape.""" 2130 saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) 2131 2132 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2133 tflite_model = converter.convert() 2134 self.assertIsNotNone(tflite_model) 2135 2136 # Check values from converted model. 2137 interpreter = Interpreter(model_content=tflite_model) 2138 interpreter.allocate_tensors() 2139 2140 input_details = interpreter.get_input_details() 2141 self.assertLen(input_details, 2) 2142 self.assertStartsWith(input_details[0]['name'], 'inputA') 2143 self.assertEqual(np.float32, input_details[0]['dtype']) 2144 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2145 self.assertEqual((0., 0.), input_details[0]['quantization']) 2146 2147 self.assertStartsWith(input_details[1]['name'], 'inputB') 2148 self.assertEqual(np.float32, input_details[1]['dtype']) 2149 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2150 self.assertEqual((0., 0.), input_details[1]['quantization']) 2151 2152 output_details = interpreter.get_output_details() 2153 self.assertLen(output_details, 1) 2154 self.assertStartsWith(output_details[0]['name'], 'add') 2155 self.assertEqual(np.float32, output_details[0]['dtype']) 2156 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2157 self.assertEqual((0., 0.), output_details[0]['quantization']) 2158 2159 def testOrderInputArrays(self): 2160 """Test a SavedModel ordering of input arrays.""" 2161 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2162 2163 converter = lite.TFLiteConverter.from_saved_model( 2164 saved_model_dir, input_arrays=['inputB', 'inputA']) 2165 tflite_model = converter.convert() 2166 self.assertIsNotNone(tflite_model) 2167 2168 # Check values from converted model. 2169 interpreter = Interpreter(model_content=tflite_model) 2170 interpreter.allocate_tensors() 2171 2172 input_details = interpreter.get_input_details() 2173 self.assertLen(input_details, 2) 2174 self.assertStartsWith(input_details[0]['name'], 'inputA') 2175 self.assertEqual(np.float32, input_details[0]['dtype']) 2176 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2177 self.assertEqual((0., 0.), input_details[0]['quantization']) 2178 2179 self.assertStartsWith(input_details[1]['name'], 'inputB') 2180 self.assertEqual(np.float32, input_details[1]['dtype']) 2181 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2182 self.assertEqual((0., 0.), input_details[1]['quantization']) 2183 2184 output_details = interpreter.get_output_details() 2185 self.assertLen(output_details, 1) 2186 self.assertStartsWith(output_details[0]['name'], 'add') 2187 self.assertEqual(np.float32, output_details[0]['dtype']) 2188 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2189 self.assertEqual((0., 0.), output_details[0]['quantization']) 2190 2191 def testShapeOverriding(self): 2192 """Test a SavedModel with the input_shapes arugment.""" 2193 saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) 2194 2195 # Convert model and ensure model is not None. 2196 converter = lite.TFLiteConverter.from_saved_model( 2197 saved_model_dir, 2198 input_shapes={ 2199 'inputA': [2, 16, 16, 3], 2200 'inputB': [2, 16, 16, 3] 2201 }) 2202 tflite_model = converter.convert() 2203 self.assertIsNotNone(tflite_model) 2204 2205 interpreter = Interpreter(model_content=tflite_model) 2206 interpreter.allocate_tensors() 2207 2208 input_details = interpreter.get_input_details() 2209 self.assertLen(input_details, 2) 2210 self.assertStartsWith(input_details[0]['name'], 'inputA') 2211 self.assertEqual(np.float32, input_details[0]['dtype']) 2212 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 2213 self.assertEqual((0., 0.), input_details[0]['quantization']) 2214 2215 self.assertStartsWith(input_details[1]['name'], 'inputB') 2216 self.assertEqual(np.float32, input_details[1]['dtype']) 2217 self.assertAllEqual([2, 16, 16, 3], input_details[1]['shape']) 2218 self.assertEqual((0., 0.), input_details[1]['quantization']) 2219 2220 output_details = interpreter.get_output_details() 2221 self.assertLen(output_details, 1) 2222 self.assertStartsWith(output_details[0]['name'], 'add') 2223 self.assertEqual(np.float32, output_details[0]['dtype']) 2224 self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape']) 2225 self.assertEqual((0., 0.), output_details[0]['quantization']) 2226 2227 def testWrongInputShapes(self): 2228 """Test a SavedModel with a wrong name in the input_shapes argument.""" 2229 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2230 2231 # Check case where input shape is given. 2232 with self.assertRaises(ValueError): 2233 lite.TFLiteConverter.from_saved_model( 2234 saved_model_dir, 2235 input_arrays=['inputA'], 2236 input_shapes={'wrong_input': [1, 16, 16, 3]}) 2237 2238 def testSubsetInputShaapes(self): 2239 """Test a SavedModel with a subset of the input array names of the model.""" 2240 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2241 2242 # Check case where input shape is given. 2243 converter = lite.TFLiteConverter.from_saved_model( 2244 saved_model_dir, 2245 input_arrays=['inputA'], 2246 input_shapes={'inputA': [1, 16, 16, 3]}) 2247 2248 # Since we only partially specify the input, this is not allowed. 2249 with self.assertRaises(ConverterError): 2250 _ = converter.convert() 2251 2252 # Check case where input shape is None. 2253 converter = lite.TFLiteConverter.from_saved_model( 2254 saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) 2255 2256 # Since we only partially specify the input, this is not allowed. 2257 with self.assertRaises(ConverterError): 2258 _ = converter.convert() 2259 2260 def testSimpleModelTocoConverter(self): 2261 """Test a SavedModel with deprecated TocoConverter.""" 2262 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2263 2264 # Convert model and ensure model is not None. 2265 converter = lite.TocoConverter.from_saved_model(saved_model_dir) 2266 tflite_model = converter.convert() 2267 self.assertIsNotNone(tflite_model) 2268 2269 # Ensure the model is able to load. 2270 interpreter = Interpreter(model_content=tflite_model) 2271 interpreter.allocate_tensors() 2272 2273 def testGraphDebugInfo(self): 2274 """Test a SavedModel has debug info captured.""" 2275 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2276 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2277 converter.convert() 2278 self.assertValidDebugInfo(converter._debug_info) 2279 2280 2281class MyAddLayer(keras.layers.Layer): 2282 2283 def __init__(self, increment, **kwargs): 2284 super(MyAddLayer, self).__init__(**kwargs) 2285 self._increment = increment 2286 2287 def call(self, inputs): 2288 return inputs + self._increment 2289 2290 def get_config(self): 2291 config = super(MyAddLayer, self).get_config() 2292 config['increment'] = self._increment 2293 return config 2294 2295 2296class FromKerasFile(TestModels, parameterized.TestCase): 2297 2298 def setUp(self): 2299 super(FromKerasFile, self).setUp() 2300 self._keras_file = None 2301 self._custom_objects = None 2302 if not context.executing_eagerly(): 2303 keras.backend.clear_session() 2304 2305 def tearDown(self): 2306 if self._keras_file: 2307 os.remove(self._keras_file) 2308 super(FromKerasFile, self).tearDown() 2309 2310 def _getSequentialModel(self, include_custom_layer=False): 2311 model = keras.models.Sequential() 2312 model.add(keras.layers.Dense(2, input_shape=(3,))) 2313 if include_custom_layer: 2314 model.add(MyAddLayer(1.0)) 2315 model.add(keras.layers.RepeatVector(3)) 2316 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 2317 model.compile( 2318 loss=keras.losses.MSE, 2319 optimizer='sgd', 2320 metrics=[keras.metrics.categorical_accuracy], 2321 sample_weight_mode='temporal') 2322 x = np.random.random((1, 3)) 2323 y = np.random.random((1, 3, 3)) 2324 model.train_on_batch(x, y) 2325 model.predict(x) 2326 2327 try: 2328 fd, self._keras_file = tempfile.mkstemp('.h5') 2329 keras.models.save_model(model, self._keras_file) 2330 finally: 2331 os.close(fd) 2332 2333 if include_custom_layer: 2334 self._custom_objects = {'MyAddLayer': MyAddLayer} 2335 2336 @parameterized.named_parameters(('_graph', context.graph_mode), 2337 ('_eager', context.eager_mode)) 2338 def testSequentialModel(self, test_context): 2339 """Test a Sequential tf.keras model with default inputs.""" 2340 with test_context(): 2341 self._getSequentialModel() 2342 2343 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2344 tflite_model = converter.convert() 2345 self.assertIsNotNone(tflite_model) 2346 2347 # Check tensor details of converted model. 2348 interpreter = Interpreter(model_content=tflite_model) 2349 interpreter.allocate_tensors() 2350 2351 input_details = interpreter.get_input_details() 2352 self.assertLen(input_details, 1) 2353 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2354 self.assertEqual(np.float32, input_details[0]['dtype']) 2355 self.assertAllEqual([1, 3], input_details[0]['shape']) 2356 self.assertEqual((0., 0.), input_details[0]['quantization']) 2357 2358 output_details = interpreter.get_output_details() 2359 self.assertLen(output_details, 1) 2360 self.assertEqual(np.float32, output_details[0]['dtype']) 2361 self.assertAllEqual([1, 3, 3], output_details[0]['shape']) 2362 self.assertEqual((0., 0.), output_details[0]['quantization']) 2363 2364 # Check inference of converted model. 2365 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2366 interpreter.set_tensor(input_details[0]['index'], input_data) 2367 interpreter.invoke() 2368 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2369 2370 keras_model = keras.models.load_model(self._keras_file) 2371 keras_result = keras_model.predict(input_data) 2372 2373 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2374 2375 @parameterized.named_parameters(('_graph', context.graph_mode), 2376 ('_eager', context.eager_mode)) 2377 def testCustomLayer(self, test_context): 2378 """Test a Sequential tf.keras model with default inputs.""" 2379 with test_context(): 2380 self._getSequentialModel(include_custom_layer=True) 2381 2382 converter = lite.TFLiteConverter.from_keras_model_file( 2383 self._keras_file, custom_objects=self._custom_objects) 2384 tflite_model = converter.convert() 2385 self.assertIsNotNone(tflite_model) 2386 2387 # Check tensor details of converted model. 2388 interpreter = Interpreter(model_content=tflite_model) 2389 interpreter.allocate_tensors() 2390 2391 input_details = interpreter.get_input_details() 2392 output_details = interpreter.get_output_details() 2393 2394 # Check inference of converted model. 2395 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2396 interpreter.set_tensor(input_details[0]['index'], input_data) 2397 interpreter.invoke() 2398 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2399 2400 keras_model = keras.models.load_model( 2401 self._keras_file, custom_objects=self._custom_objects) 2402 keras_result = keras_model.predict(input_data) 2403 2404 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2405 2406 def testSequentialModelInputArray(self): 2407 """Test a Sequential tf.keras model testing input arrays argument.""" 2408 ops.disable_eager_execution() 2409 self._getSequentialModel() 2410 2411 # Invalid input array raises error. 2412 with self.assertRaises(ValueError) as error: 2413 lite.TFLiteConverter.from_keras_model_file( 2414 self._keras_file, input_arrays=['invalid-input']) 2415 self.assertEqual("Invalid tensors 'invalid-input' were found.", 2416 str(error.exception)) 2417 2418 # Valid input array. 2419 converter = lite.TFLiteConverter.from_keras_model_file( 2420 self._keras_file, input_arrays=['dense_input']) 2421 tflite_model = converter.convert() 2422 self.assertIsNotNone(tflite_model) 2423 2424 def testSequentialModelInputShape(self): 2425 """Test a Sequential tf.keras model testing input shapes argument.""" 2426 self._getSequentialModel() 2427 2428 # Passing in shape of invalid input array raises error. 2429 with self.assertRaises(ValueError) as error: 2430 converter = lite.TFLiteConverter.from_keras_model_file( 2431 self._keras_file, input_shapes={'invalid-input': [2, 3]}) 2432 self.assertEqual( 2433 "Invalid tensor 'invalid-input' found in tensor shapes map.", 2434 str(error.exception)) 2435 2436 # Passing in shape of valid input array. 2437 converter = lite.TFLiteConverter.from_keras_model_file( 2438 self._keras_file, input_shapes={'dense_input': [2, 3]}) 2439 tflite_model = converter.convert() 2440 self.assertIsNotNone(tflite_model) 2441 2442 # Check input shape from converted model. 2443 interpreter = Interpreter(model_content=tflite_model) 2444 interpreter.allocate_tensors() 2445 2446 input_details = interpreter.get_input_details() 2447 self.assertLen(input_details, 1) 2448 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2449 self.assertAllEqual([2, 3], input_details[0]['shape']) 2450 2451 def testSequentialModelOutputArray(self): 2452 """Test a Sequential tf.keras model testing output arrays argument.""" 2453 ops.disable_eager_execution() 2454 self._getSequentialModel() 2455 2456 # Invalid output array raises error. 2457 with self.assertRaises(ValueError) as error: 2458 lite.TFLiteConverter.from_keras_model_file( 2459 self._keras_file, output_arrays=['invalid-output']) 2460 self.assertEqual("Invalid tensors 'invalid-output' were found.", 2461 str(error.exception)) 2462 2463 # Valid output array. 2464 converter = lite.TFLiteConverter.from_keras_model_file( 2465 self._keras_file, output_arrays=['time_distributed/Reshape_1']) 2466 tflite_model = converter.convert() 2467 self.assertIsNotNone(tflite_model) 2468 2469 @parameterized.named_parameters(('_graph', context.graph_mode), 2470 ('_eager', context.eager_mode)) 2471 def testFunctionalModel(self, test_context): 2472 """Test a Functional tf.keras model with default inputs.""" 2473 with test_context(): 2474 inputs = keras.layers.Input(shape=(3,), name='input') 2475 x = keras.layers.Dense(2)(inputs) 2476 output = keras.layers.Dense(3)(x) 2477 2478 model = keras.models.Model(inputs, output) 2479 model.compile( 2480 loss=keras.losses.MSE, 2481 optimizer='sgd', 2482 metrics=[keras.metrics.categorical_accuracy]) 2483 x = np.random.random((1, 3)) 2484 y = np.random.random((1, 3)) 2485 model.train_on_batch(x, y) 2486 2487 model.predict(x) 2488 fd, self._keras_file = tempfile.mkstemp('.h5') 2489 try: 2490 keras.models.save_model(model, self._keras_file) 2491 finally: 2492 os.close(fd) 2493 2494 # Convert to TFLite model. 2495 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2496 tflite_model = converter.convert() 2497 self.assertIsNotNone(tflite_model) 2498 2499 # Check tensor details of converted model. 2500 interpreter = Interpreter(model_content=tflite_model) 2501 interpreter.allocate_tensors() 2502 2503 input_details = interpreter.get_input_details() 2504 self.assertLen(input_details, 1) 2505 self.assertEqual('input', input_details[0]['name']) 2506 self.assertEqual(np.float32, input_details[0]['dtype']) 2507 self.assertAllEqual([1, 3], input_details[0]['shape']) 2508 self.assertEqual((0., 0.), input_details[0]['quantization']) 2509 2510 output_details = interpreter.get_output_details() 2511 self.assertLen(output_details, 1) 2512 self.assertEqual(np.float32, output_details[0]['dtype']) 2513 self.assertAllEqual([1, 3], output_details[0]['shape']) 2514 self.assertEqual((0., 0.), output_details[0]['quantization']) 2515 2516 # Check inference of converted model. 2517 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2518 interpreter.set_tensor(input_details[0]['index'], input_data) 2519 interpreter.invoke() 2520 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2521 2522 keras_model = keras.models.load_model(self._keras_file) 2523 keras_result = keras_model.predict(input_data) 2524 2525 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2526 2527 def _getFunctionalModelMultipleInputs(self): 2528 a = keras.layers.Input(shape=(3,), name='input_a') 2529 b = keras.layers.Input(shape=(3,), name='input_b') 2530 dense = keras.layers.Dense(4, name='dense') 2531 c = dense(a) 2532 d = dense(b) 2533 e = keras.layers.Dropout(0.5, name='dropout')(c) 2534 2535 model = keras.models.Model([a, b], [d, e]) 2536 model.compile( 2537 loss=keras.losses.MSE, 2538 optimizer='sgd', 2539 metrics=[keras.metrics.mae], 2540 loss_weights=[1., 0.5]) 2541 2542 input_a_np = np.random.random((10, 3)) 2543 input_b_np = np.random.random((10, 3)) 2544 output_d_np = np.random.random((10, 4)) 2545 output_e_np = np.random.random((10, 4)) 2546 model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) 2547 2548 model.predict([input_a_np, input_b_np], batch_size=5) 2549 fd, self._keras_file = tempfile.mkstemp('.h5') 2550 try: 2551 keras.models.save_model(model, self._keras_file) 2552 finally: 2553 os.close(fd) 2554 2555 def testFunctionalModelMultipleInputs(self): 2556 """Test a Functional tf.keras model with multiple inputs and outputs.""" 2557 self._getFunctionalModelMultipleInputs() 2558 2559 # Convert to TFLite model. 2560 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2561 tflite_model = converter.convert() 2562 self.assertIsNotNone(tflite_model) 2563 2564 # Check values from converted model. 2565 interpreter = Interpreter(model_content=tflite_model) 2566 interpreter.allocate_tensors() 2567 2568 input_details = interpreter.get_input_details() 2569 self.assertLen(input_details, 2) 2570 self.assertEndsWith(input_details[0]['name'], 'input_a') 2571 self.assertEqual(np.float32, input_details[0]['dtype']) 2572 self.assertAllEqual([1, 3], input_details[0]['shape']) 2573 self.assertEqual((0., 0.), input_details[0]['quantization']) 2574 2575 self.assertEndsWith(input_details[1]['name'], 'input_b') 2576 self.assertEqual(np.float32, input_details[1]['dtype']) 2577 self.assertAllEqual([1, 3], input_details[1]['shape']) 2578 self.assertEqual((0., 0.), input_details[1]['quantization']) 2579 2580 output_details = interpreter.get_output_details() 2581 self.assertLen(output_details, 2) 2582 self.assertEqual(np.float32, output_details[0]['dtype']) 2583 self.assertAllEqual([1, 4], output_details[0]['shape']) 2584 self.assertEqual((0., 0.), output_details[0]['quantization']) 2585 2586 self.assertEqual(np.float32, output_details[1]['dtype']) 2587 self.assertAllEqual([1, 4], output_details[1]['shape']) 2588 self.assertEqual((0., 0.), output_details[1]['quantization']) 2589 2590 def testShapeOverriding(self): 2591 """Test a Functional tf.keras model with input shape overriding.""" 2592 self._getFunctionalModelMultipleInputs() 2593 2594 # Convert to TFLite model. 2595 converter = lite.TFLiteConverter.from_keras_model_file( 2596 self._keras_file, input_shapes={ 2597 'input_a': {2, 3}, 2598 'input_b': {2, 3} 2599 }) 2600 tflite_model = converter.convert() 2601 self.assertIsNotNone(tflite_model) 2602 2603 # Check values from converted model. 2604 interpreter = Interpreter(model_content=tflite_model) 2605 interpreter.allocate_tensors() 2606 2607 input_details = interpreter.get_input_details() 2608 self.assertLen(input_details, 2) 2609 self.assertEndsWith(input_details[0]['name'], 'input_a') 2610 self.assertEqual(np.float32, input_details[0]['dtype']) 2611 self.assertAllEqual([2, 3], input_details[0]['shape']) 2612 self.assertEqual((0., 0.), input_details[0]['quantization']) 2613 2614 self.assertEndsWith(input_details[1]['name'], 'input_b') 2615 self.assertEqual(np.float32, input_details[1]['dtype']) 2616 self.assertAllEqual([2, 3], input_details[1]['shape']) 2617 self.assertEqual((0., 0.), input_details[1]['quantization']) 2618 2619 output_details = interpreter.get_output_details() 2620 self.assertLen(output_details, 2) 2621 self.assertEqual(np.float32, output_details[0]['dtype']) 2622 self.assertAllEqual([2, 4], output_details[0]['shape']) 2623 self.assertEqual((0., 0.), output_details[0]['quantization']) 2624 2625 self.assertEqual(np.float32, output_details[1]['dtype']) 2626 self.assertAllEqual([2, 4], output_details[1]['shape']) 2627 self.assertEqual((0., 0.), output_details[1]['quantization']) 2628 2629 def testPartialShapeOverriding(self): 2630 """Test a Functional tf.keras model with partial input shape overriding.""" 2631 self._getFunctionalModelMultipleInputs() 2632 2633 # Convert to TFLite model. 2634 converter = lite.TFLiteConverter.from_keras_model_file( 2635 self._keras_file, input_shapes={'input_a': {2, 3}}) 2636 tflite_model = converter.convert() 2637 self.assertIsNotNone(tflite_model) 2638 2639 # Check values from converted model. 2640 interpreter = Interpreter(model_content=tflite_model) 2641 interpreter.allocate_tensors() 2642 2643 input_details = interpreter.get_input_details() 2644 self.assertLen(input_details, 2) 2645 self.assertEndsWith(input_details[0]['name'], 'input_a') 2646 self.assertEqual(np.float32, input_details[0]['dtype']) 2647 self.assertAllEqual([2, 3], input_details[0]['shape']) 2648 self.assertEqual((0., 0.), input_details[0]['quantization']) 2649 2650 self.assertEndsWith(input_details[1]['name'], 'input_b') 2651 self.assertEqual(np.float32, input_details[1]['dtype']) 2652 self.assertAllEqual([1, 3], input_details[1]['shape']) 2653 self.assertEqual((0., 0.), input_details[1]['quantization']) 2654 2655 output_details = interpreter.get_output_details() 2656 self.assertLen(output_details, 2) 2657 self.assertEqual(np.float32, output_details[0]['dtype']) 2658 self.assertAllEqual([1, 4], output_details[0]['shape']) 2659 self.assertEqual((0., 0.), output_details[0]['quantization']) 2660 2661 self.assertEqual(np.float32, output_details[1]['dtype']) 2662 self.assertAllEqual([2, 4], output_details[1]['shape']) 2663 self.assertEqual((0., 0.), output_details[1]['quantization']) 2664 2665 def testWrongShapeOverriding(self): 2666 """Test a Functional tf.keras model with wrong input shape overriding.""" 2667 self._getFunctionalModelMultipleInputs() 2668 2669 # Convert to TFLite model. 2670 with self.assertRaises(ValueError): 2671 lite.TFLiteConverter.from_keras_model_file( 2672 self._keras_file, input_shapes={'wrong_input': {2, 3}}) 2673 2674 def testFunctionalSequentialModel(self): 2675 """Test a Functional tf.keras model containing a Sequential model.""" 2676 model = keras.models.Sequential() 2677 model.add(keras.layers.Dense(2, input_shape=(3,))) 2678 model.add(keras.layers.RepeatVector(3)) 2679 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 2680 model = keras.models.Model(model.input, model.output) 2681 2682 model.compile( 2683 loss=keras.losses.MSE, 2684 optimizer='sgd', 2685 metrics=[keras.metrics.categorical_accuracy], 2686 sample_weight_mode='temporal') 2687 x = np.random.random((1, 3)) 2688 y = np.random.random((1, 3, 3)) 2689 model.train_on_batch(x, y) 2690 model.predict(x) 2691 2692 model.predict(x) 2693 fd, self._keras_file = tempfile.mkstemp('.h5') 2694 try: 2695 keras.models.save_model(model, self._keras_file) 2696 finally: 2697 os.close(fd) 2698 2699 # Convert to TFLite model. 2700 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2701 tflite_model = converter.convert() 2702 self.assertIsNotNone(tflite_model) 2703 2704 # Check tensor details of converted model. 2705 interpreter = Interpreter(model_content=tflite_model) 2706 interpreter.allocate_tensors() 2707 2708 input_details = interpreter.get_input_details() 2709 self.assertLen(input_details, 1) 2710 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2711 self.assertEqual(np.float32, input_details[0]['dtype']) 2712 self.assertAllEqual([1, 3], input_details[0]['shape']) 2713 self.assertEqual((0., 0.), input_details[0]['quantization']) 2714 2715 output_details = interpreter.get_output_details() 2716 self.assertLen(output_details, 1) 2717 self.assertEqual(np.float32, output_details[0]['dtype']) 2718 self.assertAllEqual([1, 3, 3], output_details[0]['shape']) 2719 self.assertEqual((0., 0.), output_details[0]['quantization']) 2720 2721 # Check inference of converted model. 2722 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2723 interpreter.set_tensor(input_details[0]['index'], input_data) 2724 interpreter.invoke() 2725 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2726 2727 keras_model = keras.models.load_model(self._keras_file) 2728 keras_result = keras_model.predict(input_data) 2729 2730 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2731 2732 def testSequentialModelTocoConverter(self): 2733 """Test a Sequential tf.keras model with deprecated TocoConverter.""" 2734 self._getSequentialModel() 2735 2736 converter = lite.TocoConverter.from_keras_model_file(self._keras_file) 2737 tflite_model = converter.convert() 2738 self.assertIsNotNone(tflite_model) 2739 2740 # Ensure the model is able to load. 2741 interpreter = Interpreter(model_content=tflite_model) 2742 interpreter.allocate_tensors() 2743 2744 @parameterized.named_parameters(('_graph', context.graph_mode), 2745 ('_eager', context.eager_mode)) 2746 def testGraphDebugInfo(self, test_context): 2747 """Test a Sequential tf.keras model has debug info captured.""" 2748 with test_context(): 2749 self._getSequentialModel() 2750 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2751 converter.convert() 2752 self.assertValidDebugInfo(converter._debug_info) 2753 2754 def testSparsifyModel(self): 2755 self._getSequentialModel() 2756 2757 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2758 converter.optimizations = {lite.Optimize.EXPERIMENTAL_SPARSITY} 2759 tflite_model = converter.convert() 2760 self.assertTrue(tflite_model) 2761 2762 def testSparsifyQuantizedModel(self): 2763 self._getSequentialModel() 2764 2765 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2766 converter.optimizations = { 2767 lite.Optimize.DEFAULT, lite.Optimize.EXPERIMENTAL_SPARSITY 2768 } 2769 tflite_model = converter.convert() 2770 self.assertIsNotNone(tflite_model) 2771 2772 2773class GrapplerTest(TestModels, parameterized.TestCase): 2774 2775 def testConstantFolding(self): 2776 ops.disable_eager_execution() 2777 # Constant folding handles the tf.broadcast_to operation which was not 2778 # supported by the TFLite at the time this test was added. 2779 with ops.Graph().as_default(): 2780 in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32) 2781 y_const = constant_op.constant([1., 2., 3.]) 2782 y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3]) 2783 out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output') 2784 sess = session.Session() 2785 2786 # Convert model. 2787 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2788 [out_tensor]) 2789 tflite_model = converter.convert() 2790 2791 # Check values from converted model. 2792 interpreter = Interpreter(model_content=tflite_model) 2793 interpreter.allocate_tensors() 2794 2795 input_details = interpreter.get_input_details() 2796 self.assertLen(input_details, 1) 2797 self.assertEqual('Placeholder', input_details[0]['name']) 2798 self.assertEqual(np.float32, input_details[0]['dtype']) 2799 self.assertAllEqual([3, 3], input_details[0]['shape']) 2800 2801 output_details = interpreter.get_output_details() 2802 self.assertLen(output_details, 1) 2803 self.assertEqual('output', output_details[0]['name']) 2804 self.assertEqual(np.float32, output_details[0]['dtype']) 2805 self.assertAllEqual([3, 3], output_details[0]['shape']) 2806 2807 @parameterized.named_parameters( 2808 ('EnableMlirConverter', True), # enable mlir 2809 ('DisableMlirConverter', False)) # disable mlir 2810 def testInputNodeIsNotFolded(self, enable_mlir_converter): 2811 ops.disable_eager_execution() 2812 # Constant folding handles the tf.broadcast_to operation which was not 2813 # supported by the TFLite at the time this test was added. 2814 with ops.Graph().as_default(): 2815 in_tensor = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 2816 y_const = constant_op.constant([1., 2., 3.]) 2817 y_add = y_const + y_const 2818 out_tensor = in_tensor * y_add 2819 sess = session.Session() 2820 2821 # Convert model. 2822 converter = lite.TFLiteConverter.from_session(sess, [in_tensor, y_const], 2823 [out_tensor]) 2824 converter.experimental_new_converter = enable_mlir_converter 2825 tflite_model = converter.convert() 2826 2827 # Check values from converted model. 2828 interpreter = Interpreter(model_content=tflite_model) 2829 interpreter.allocate_tensors() 2830 2831 input_details = interpreter.get_input_details() 2832 self.assertLen(input_details, 2) 2833 self.assertEqual('Placeholder', input_details[0]['name']) 2834 self.assertEqual('Const', input_details[1]['name']) 2835 2836 def testGrapplerConstFolding(self): 2837 # Constant folding converts the following add operation to tf.broadcast_to 2838 # operation which was not supported by the TFLite at the time this test was 2839 # added. 2840 @def_function.function 2841 def plus_placeholder(x, placeholder): 2842 return x + placeholder 2843 2844 with ops.Graph().as_default(): 2845 in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32) 2846 out_tensor = plus_placeholder( 2847 array_ops.zeros([2, 2, 2]), 2848 array_ops.reshape(in_tensor, shape=[2, 2])) 2849 sess = session.Session() 2850 2851 # Convert model. 2852 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2853 [out_tensor]) 2854 tflite_model = converter.convert() 2855 2856 # Check values from converted model. 2857 interpreter = Interpreter(model_content=tflite_model) 2858 interpreter.allocate_tensors() 2859 2860 input_details = interpreter.get_input_details() 2861 self.assertLen(input_details, 1) 2862 self.assertEqual('Placeholder', input_details[0]['name']) 2863 2864 2865class DefaultConverterAttrsTest(LiteTest): 2866 2867 def testAttrs(self): 2868 with ops.Graph().as_default(): 2869 in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32) 2870 out_tensor = in_tensor + in_tensor 2871 sess = session.Session() 2872 2873 # Convert model. 2874 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2875 [out_tensor]) 2876 2877 # Assert output format. 2878 self.assertEqual(converter.output_format, lite_constants.TFLITE) 2879 2880 # Assert the default inference type is float. 2881 self.assertEqual(converter.inference_type, dtypes.float32) 2882 2883 # Assert the default inference type overrides are None. 2884 self.assertIsNone(converter.inference_input_type) 2885 self.assertIsNone(converter.inference_output_type) 2886 2887 # Assert the default quantization options are not set. 2888 self.assertEqual(converter.quantized_input_stats, {}) 2889 self.assertIsNone(converter.default_ranges_stats) 2890 self.assertFalse(converter.reorder_across_fake_quant) 2891 self.assertFalse(converter.change_concat_input_ranges) 2892 2893 # Assert dropping control dependency is enabled by default. 2894 self.assertIsNotNone(converter.drop_control_dependency) 2895 2896 # Assert dumping extra information is disabled by default. 2897 self.assertIsNone(converter.dump_graphviz_dir) 2898 self.assertFalse(converter.dump_graphviz_video) 2899 self.assertIsNone(converter.conversion_summary_dir) 2900 2901 2902class ControlFlowV1OpsTest(LiteTest): 2903 2904 def testConverterErrorOnControlFlowV1Ops(self): 2905 graph_def_file = resource_loader.get_path_to_datafile( 2906 'testdata/control_flow_v1.pbtxt') 2907 input_arrays = ['a', 'b', 'c', 'd'] 2908 output_arrays = ['Merge'] 2909 2910 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 2911 input_arrays, 2912 output_arrays) 2913 with self.assertRaises(ConverterError) as error: 2914 converter.convert() 2915 self.assertIn( 2916 'Failed to functionalize Control Flow V1 ops. Consider using Control ' 2917 'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/' 2918 'tf/compat/v1/enable_control_flow_v2.', str(error.exception)) 2919 2920 2921class QuantizationModeTest(LiteTest, parameterized.TestCase): 2922 2923 @parameterized.named_parameters( 2924 ('size', lite.Optimize.OPTIMIZE_FOR_SIZE), 2925 ('latency', lite.Optimize.OPTIMIZE_FOR_LATENCY)) 2926 def testDeprecatedOptionWarning(self, optimization): 2927 """Test if the warning message when using TOCO is logged.""" 2928 log = io.StringIO() 2929 handler = logging.StreamHandler(log) 2930 logging.root.addHandler(handler) 2931 warning_message = 'please use optimizations=[Optimize.DEFAULT] instead.' 2932 lite.QuantizationMode([optimization], lite.TargetSpec(), None, None) 2933 self.assertIn(warning_message, log.getvalue()) 2934 logging.root.removeHandler(handler) 2935 2936 2937if __name__ == '__main__': 2938 test.main() 2939