1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions to test TFLite models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24from six import PY2 25from tensorflow import keras 26 27from google.protobuf import text_format as _text_format 28from google.protobuf.message import DecodeError 29from tensorflow.core.framework import graph_pb2 as _graph_pb2 30from tensorflow.lite.python import convert_saved_model as _convert_saved_model 31from tensorflow.lite.python import interpreter as _interpreter 32from tensorflow.lite.python import lite as _lite 33from tensorflow.lite.python import util as _util 34from tensorflow.python.client import session as _session 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework.importer import import_graph_def as _import_graph_def 39from tensorflow.python.lib.io import file_io as _file_io 40from tensorflow.python.platform import resource_loader as _resource_loader 41from tensorflow.python.platform import tf_logging as logging 42from tensorflow.python.saved_model import load as _load 43from tensorflow.python.saved_model import loader as _loader 44from tensorflow.python.saved_model import signature_constants as _signature_constants 45from tensorflow.python.saved_model import tag_constants as _tag_constants 46 47 48_GOLDENS_UPDATE_WARNING = """ 49 Golden file update requested! 50 This test is now going to write new golden files. 51 52 Make sure to package the updates together with your CL. 53""" 54 55 56def get_filepath(filename, base_dir=None): 57 """Returns the full path of the filename. 58 59 Args: 60 filename: Subdirectory and name of the model file. 61 base_dir: Base directory containing model file. 62 63 Returns: 64 str. 65 """ 66 if base_dir is None: 67 base_dir = "learning/brain/mobile/tflite_compat_models" 68 return os.path.join(_resource_loader.get_root_dir_with_all_resources(), 69 base_dir, filename) 70 71 72def get_golden_filepath(name): 73 """Returns the full path to a golden values file. 74 75 Args: 76 name: the name of the golden data, usually same as the test name. 77 """ 78 goldens_directory = os.path.join(_resource_loader.get_data_files_path(), 79 "testdata", "golden") 80 return os.path.join(goldens_directory, "%s.npy.golden" % name) 81 82 83def get_image(size): 84 """Returns an image loaded into an np.ndarray with dims [1, size, size, 3]. 85 86 Args: 87 size: Size of image. 88 89 Returns: 90 np.ndarray. 91 """ 92 img_filename = _resource_loader.get_path_to_datafile( 93 "testdata/grace_hopper.jpg") 94 img = keras.preprocessing.image.load_img( 95 img_filename, target_size=(size, size)) 96 img_array = keras.preprocessing.image.img_to_array(img) 97 img_array = np.expand_dims(img_array, axis=0) 98 return img_array 99 100 101def _get_calib_data_func(input_size): 102 """Returns a function to generate a representative data set. 103 104 Args: 105 input_size: 3D shape of the representative data. 106 """ 107 def representative_data_gen(): 108 num_calibration = 20 109 for _ in range(num_calibration): 110 yield [ 111 np.random.rand( 112 1, 113 input_size[0], 114 input_size[1], 115 input_size[2], 116 ).astype(np.float32) 117 ] 118 119 return representative_data_gen 120 121 122def _convert(converter, **kwargs): 123 """Converts the model. 124 125 Args: 126 converter: TFLiteConverter object. 127 **kwargs: Additional arguments to be passed into the converter. Supported 128 flags are {"target_ops", "post_training_quantize", "quantize_to_float16", 129 "post_training_quantize_int8", "post_training_quantize_16x8", 130 "model_input_size"}. 131 132 Returns: 133 The converted TFLite model in serialized format. 134 135 Raises: 136 ValueError: Invalid version number. 137 """ 138 139 if "target_ops" in kwargs: 140 converter.target_spec.supported_ops = kwargs["target_ops"] 141 if "post_training_quantize" in kwargs: 142 converter.optimizations = [_lite.Optimize.DEFAULT] 143 if kwargs.get("quantize_to_float16", False): 144 converter.target_spec.supported_types = [dtypes.float16] 145 if kwargs.get("post_training_quantize_int8", False): 146 input_size = kwargs.get("model_input_size") 147 converter.optimizations = [_lite.Optimize.DEFAULT] 148 converter.target_spec.supported_ops = [_lite.OpsSet.TFLITE_BUILTINS_INT8] 149 converter.representative_dataset = _get_calib_data_func(input_size) 150 # Note that the full integer quantization is by the mlir quantizer 151 converter.experimental_new_quantizer = True 152 if kwargs.get("post_training_quantize_16x8", False): 153 input_size = kwargs.get("model_input_size") 154 converter.optimizations = [_lite.Optimize.DEFAULT] 155 converter.target_spec.supported_ops = \ 156 [_lite.OpsSet.\ 157 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8] 158 converter.representative_dataset = _get_calib_data_func(input_size) 159 return converter.convert() 160 161 162def _check_model_quantized_to_16x8(tflite_model): 163 """Checks that the activations are quantized into int16. 164 165 Args: 166 tflite_model: Serialized TensorFlow Lite model. 167 168 Raises: 169 ValueError: Activations with int16 type are not found. 170 """ 171 interpreter = _get_tflite_interpreter(tflite_model) 172 interpreter.allocate_tensors() 173 all_tensor_details = interpreter.get_tensor_details() 174 175 found_input = False 176 for tensor in all_tensor_details: 177 if "_int16" in tensor["name"]: 178 found_input = True 179 if tensor["dtype"] is not np.int16: 180 raise ValueError("Activations should be int16.") 181 182 # Check that we found activations in the correct type: int16 183 if not found_input: 184 raise ValueError("Could not find int16 activations.") 185 186 187def _get_tflite_interpreter(tflite_model, 188 input_shapes_resize=None, 189 custom_op_registerers=None): 190 """Creates a TFLite interpreter with resized input tensors. 191 192 Args: 193 tflite_model: Serialized TensorFlow Lite model. 194 input_shapes_resize: A map where the key is the input tensor name and the 195 value is the shape of the input tensor. This resize happens after model 196 conversion, prior to calling allocate tensors. (default None) 197 custom_op_registerers: Op registerers for custom ops. 198 199 Returns: 200 lite.Interpreter 201 """ 202 if custom_op_registerers is None: 203 custom_op_registerers = [] 204 interpreter = _interpreter.InterpreterWithCustomOps( 205 model_content=tflite_model, custom_op_registerers=custom_op_registerers) 206 if input_shapes_resize: 207 input_details = interpreter.get_input_details() 208 input_details_map = { 209 detail["name"]: detail["index"] for detail in input_details 210 } 211 for name, shape in input_shapes_resize.items(): 212 idx = input_details_map[name] 213 interpreter.resize_tensor_input(idx, shape) 214 return interpreter 215 216 217def _get_input_data_map(tflite_model, input_data, custom_op_registerers=None): 218 """Generates a map of input data based on the TFLite model. 219 220 Args: 221 tflite_model: Serialized TensorFlow Lite model. 222 input_data: List of np.ndarray. 223 custom_op_registerers: Op registerers for custom ops. 224 225 Returns: 226 {str: [np.ndarray]}. 227 """ 228 interpreter = _get_tflite_interpreter( 229 tflite_model, custom_op_registerers=custom_op_registerers) 230 interpreter.allocate_tensors() 231 input_details = interpreter.get_input_details() 232 return { 233 input_tensor["name"]: data 234 for input_tensor, data in zip(input_details, input_data) 235 } 236 237 238def _generate_random_input_data(tflite_model, 239 seed=None, 240 input_data_range=None, 241 input_shapes_resize=None, 242 custom_op_registerers=None): 243 """Generates input data based on the input tensors in the TFLite model. 244 245 Args: 246 tflite_model: Serialized TensorFlow Lite model. 247 seed: Integer seed for the random generator. (default None) 248 input_data_range: A map where the key is the input tensor name and 249 the value is a tuple (min_val, max_val) which specifies the value range of 250 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 251 generate a random value for tensor `input1` within range [1.0, 5.0) 252 (half-inclusive). (default None) 253 input_shapes_resize: A map where the key is the input tensor name and the 254 value is the shape of the input tensor. This resize happens after model 255 conversion, prior to calling allocate tensors. (default None) 256 custom_op_registerers: Op registerers for custom ops. 257 258 Returns: 259 ([np.ndarray], {str : [np.ndarray]}). 260 """ 261 interpreter = _get_tflite_interpreter( 262 tflite_model, 263 input_shapes_resize, 264 custom_op_registerers=custom_op_registerers) 265 interpreter.allocate_tensors() 266 input_details = interpreter.get_input_details() 267 268 if seed: 269 np.random.seed(seed=seed) 270 271 # Generate random input data. If a tensor's value range is specified, say 272 # [a, b), then the generated value will be (b - a) * Unif[0.0, 1.0) + a, 273 # otherwise it's Unif[0.0, 1.0). 274 input_data = [] 275 for input_tensor in input_details: 276 val = np.random.random_sample(input_tensor["shape"]) 277 if (input_data_range is not None and 278 input_tensor["name"] in input_data_range): 279 val = (input_data_range[input_tensor["name"]][1] - 280 input_data_range[input_tensor["name"]][0] 281 ) * val + input_data_range[input_tensor["name"]][0] 282 input_data.append(np.array(val, dtype=input_tensor["dtype"])) 283 284 input_data_map = _get_input_data_map( 285 tflite_model, input_data, custom_op_registerers=custom_op_registerers) 286 return input_data, input_data_map 287 288 289def _evaluate_tflite_model(tflite_model, 290 input_data, 291 input_shapes_resize=None, 292 custom_op_registerers=None): 293 """Returns evaluation of input data on TFLite model. 294 295 Args: 296 tflite_model: Serialized TensorFlow Lite model. 297 input_data: List of np.ndarray. 298 input_shapes_resize: A map where the key is the input tensor name and the 299 value is the shape of the input tensor. This resize happens after model 300 conversion, prior to calling allocate tensors. (default None) 301 custom_op_registerers: Op registerers for custom ops. 302 303 Returns: 304 List of np.ndarray. 305 """ 306 interpreter = _get_tflite_interpreter( 307 tflite_model, 308 input_shapes_resize, 309 custom_op_registerers=custom_op_registerers) 310 interpreter.allocate_tensors() 311 312 input_details = interpreter.get_input_details() 313 output_details = interpreter.get_output_details() 314 315 for input_tensor, tensor_data in zip(input_details, input_data): 316 interpreter.set_tensor(input_tensor["index"], tensor_data) 317 318 interpreter.invoke() 319 output_data = [ 320 interpreter.get_tensor(output_tensor["index"]) 321 for output_tensor in output_details 322 ] 323 output_labels = [output_tensor["name"] for output_tensor in output_details] 324 return output_data, output_labels 325 326 327def evaluate_frozen_graph(filename, input_arrays, output_arrays): 328 """Returns a function that evaluates the frozen graph on input data. 329 330 Args: 331 filename: Full filepath of file containing frozen GraphDef. 332 input_arrays: List of input tensors to freeze graph with. 333 output_arrays: List of output tensors to freeze graph with. 334 335 Returns: 336 Lambda function ([np.ndarray data] : [np.ndarray result]). 337 """ 338 with _file_io.FileIO(filename, "rb") as f: 339 file_content = f.read() 340 341 graph_def = _graph_pb2.GraphDef() 342 try: 343 graph_def.ParseFromString(file_content) 344 except (_text_format.ParseError, DecodeError): 345 if not isinstance(file_content, str): 346 if PY2: 347 file_content = file_content.encode("utf-8") 348 else: 349 file_content = file_content.decode("utf-8") 350 _text_format.Merge(file_content, graph_def) 351 352 graph = ops.Graph() 353 with graph.as_default(): 354 _import_graph_def(graph_def, name="") 355 inputs = _util.get_tensors_from_tensor_names(graph, input_arrays) 356 outputs = _util.get_tensors_from_tensor_names(graph, output_arrays) 357 358 def run_session(input_data): 359 with _session.Session(graph=graph) as sess: 360 return sess.run(outputs, dict(zip(inputs, input_data))) 361 362 return run_session 363 364 365def evaluate_saved_model(directory, tag_set, signature_key): 366 """Returns a function that evaluates the SavedModel on input data. 367 368 Args: 369 directory: SavedModel directory to convert. 370 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 371 analyze. All tags in the tag set must be present. 372 signature_key: Key identifying SignatureDef containing inputs and outputs. 373 374 Returns: 375 Lambda function ([np.ndarray data] : [np.ndarray result]). 376 """ 377 with _session.Session().as_default() as sess: 378 if tag_set is None: 379 tag_set = set([_tag_constants.SERVING]) 380 if signature_key is None: 381 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 382 383 meta_graph = _loader.load(sess, tag_set, directory) 384 signature_def = _convert_saved_model.get_signature_def( 385 meta_graph, signature_key) 386 inputs, outputs = _convert_saved_model.get_inputs_outputs(signature_def) 387 388 return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data))) 389 390 391def evaluate_keras_model(filename): 392 """Returns a function that evaluates the tf.keras model on input data. 393 394 Args: 395 filename: Full filepath of HDF5 file containing the tf.keras model. 396 397 Returns: 398 Lambda function ([np.ndarray data] : [np.ndarray result]). 399 """ 400 keras_model = keras.models.load_model(filename) 401 return lambda input_data: [keras_model.predict(input_data)] 402 403 404def compare_models(tflite_model, 405 tf_eval_func, 406 input_shapes_resize=None, 407 input_data=None, 408 input_data_range=None, 409 tolerance=5): 410 """Compares TensorFlow and TFLite models. 411 412 Unless the input data is provided, the models are compared with random data. 413 414 Args: 415 tflite_model: Serialized TensorFlow Lite model. 416 tf_eval_func: Lambda function that takes in input data and outputs the 417 results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]). 418 input_shapes_resize: A map where the key is the input tensor name and the 419 value is the shape of the input tensor. This resize happens after model 420 conversion, prior to calling allocate tensors. (default None) 421 input_data: np.ndarray to pass into models during inference. (default None) 422 input_data_range: A map where the key is the input tensor name and 423 the value is a tuple (min_val, max_val) which specifies the value range of 424 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 425 generate a random value for tensor `input1` within range [1.0, 5.0) 426 (half-inclusive). (default None) 427 tolerance: Decimal place to check accuracy to. (default 5). 428 """ 429 if input_data is None: 430 input_data, _ = _generate_random_input_data( 431 tflite_model=tflite_model, 432 input_data_range=input_data_range, 433 input_shapes_resize=input_shapes_resize) 434 tf_results = tf_eval_func(input_data) 435 tflite_results, _ = _evaluate_tflite_model( 436 tflite_model, input_data, input_shapes_resize=input_shapes_resize) 437 for tf_result, tflite_result in zip(tf_results, tflite_results): 438 np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) 439 440 441def _compare_tf_tflite_results(tf_results, 442 tflite_results, 443 tflite_labels, 444 tolerance=5): 445 """Compare the result of TF and TFLite model. 446 447 Args: 448 tf_results: results returned by the TF model. 449 tflite_results: results returned by the TFLite model. 450 tflite_labels: names of the output tensors in the TFlite model. 451 tolerance: Decimal place to check accuracy to. (default 5). 452 """ 453 # Convert the output TensorFlow results into an ordered list. 454 if isinstance(tf_results, dict): 455 if len(tf_results) == 1: 456 tf_results = [tf_results[list(tf_results.keys())[0]]] 457 else: 458 tf_results = [tf_results[tflite_label] for tflite_label in tflite_labels] 459 else: 460 tf_results = [tf_results] 461 462 for tf_result, tflite_result in zip(tf_results, tflite_results): 463 np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) 464 465 466def compare_models_v2(tflite_model, 467 tf_eval_func, 468 input_data=None, 469 input_data_range=None, 470 tolerance=5): 471 """Compares TensorFlow and TFLite models for TensorFlow 2.0. 472 473 Unless the input data is provided, the models are compared with random data. 474 Currently only 1 input and 1 output are supported by this function. 475 476 Args: 477 tflite_model: Serialized TensorFlow Lite model. 478 tf_eval_func: Function to evaluate TensorFlow model. Either a lambda 479 function that takes in input data and outputs the results or a TensorFlow 480 ConcreteFunction. 481 input_data: np.ndarray to pass into models during inference. (default None). 482 input_data_range: A map where the key is the input tensor name and 483 the value is a tuple (min_val, max_val) which specifies the value range of 484 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 485 generate a random value for tensor `input1` within range [1.0, 5.0) 486 (half-inclusive). (default None) 487 tolerance: Decimal place to check accuracy to. (default 5) 488 """ 489 # Convert the input data into a map. 490 if input_data is None: 491 input_data, input_data_map = _generate_random_input_data( 492 tflite_model=tflite_model, input_data_range=input_data_range) 493 else: 494 input_data_map = _get_input_data_map(tflite_model, input_data) 495 input_data_func_map = { 496 input_name: constant_op.constant(input_data) 497 for input_name, input_data in input_data_map.items() 498 } 499 500 if len(input_data) > 1: 501 tf_results = tf_eval_func(**input_data_func_map) 502 else: 503 tf_results = tf_eval_func(constant_op.constant(input_data[0])) 504 tflite_results, tflite_labels = _evaluate_tflite_model( 505 tflite_model, input_data) 506 507 _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels, 508 tolerance) 509 510 511def compare_tflite_keras_models_v2(tflite_model, 512 keras_model, 513 input_data=None, 514 input_data_range=None, 515 tolerance=5, 516 custom_op_registerers=None): 517 """Similar to compare_models_v2 but accept Keras model. 518 519 Unless the input data is provided, the models are compared with random data. 520 Currently only 1 input and 1 output are supported by this function. 521 522 Args: 523 tflite_model: Serialized TensorFlow Lite model. 524 keras_model: Keras model to evaluate. 525 input_data: np.ndarray to pass into models during inference. (default None). 526 input_data_range: A map where the key is the input tensor name and the value 527 is a tuple (min_val, max_val) which specifies the value range of 528 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 529 generate a random value for tensor `input1` within range [1.0, 5.0) 530 (half-inclusive). (default None) 531 tolerance: Decimal place to check accuracy to. (default 5) 532 custom_op_registerers: Op registerers for custom ops. 533 """ 534 # Generate random input data if not provided. 535 if input_data is None: 536 input_data, _ = _generate_random_input_data( 537 tflite_model=tflite_model, 538 input_data_range=input_data_range, 539 custom_op_registerers=custom_op_registerers) 540 541 if len(input_data) > 1: 542 tf_results = keras_model.predict(input_data) 543 else: 544 tf_results = keras_model.predict(input_data[0]) 545 tflite_results, tflite_labels = _evaluate_tflite_model( 546 tflite_model, input_data, custom_op_registerers=custom_op_registerers) 547 548 _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels, 549 tolerance) 550 551 552def compare_model_golden(tflite_model, 553 input_data, 554 golden_name, 555 update_golden=False, 556 tolerance=5): 557 """Compares the output of a TFLite model against pre-existing golden values. 558 559 Args: 560 tflite_model: Serialized TensorFlow Lite model. 561 input_data: np.ndarray to pass into models during inference. 562 golden_name: Name of the file containing the (expected) golden values. 563 update_golden: Whether to update the golden values with the model output 564 instead of comparing against them. This should only be done when a change 565 in TFLite warrants it. 566 tolerance: Decimal place to check accuracy to. (default 5). 567 """ 568 tflite_results, _ = _evaluate_tflite_model(tflite_model, input_data) 569 golden_file = get_golden_filepath(golden_name) 570 if update_golden: 571 logging.warning(_GOLDENS_UPDATE_WARNING) 572 logging.warning("Updating golden values in file %s.", golden_file) 573 if not os.path.exists(golden_file): 574 golden_relative_path = os.path.relpath( 575 golden_file, _resource_loader.get_root_dir_with_all_resources()) 576 logging.warning( 577 "Golden file not found. Manually create it first:\ntouch %r", 578 golden_relative_path) 579 580 with open(golden_file, "wb") as f: 581 np.save(f, tflite_results, allow_pickle=False) 582 else: 583 golden_data = np.load(golden_file, allow_pickle=False) 584 np.testing.assert_almost_equal(golden_data, tflite_results, tolerance) 585 586 587def test_frozen_graph_quant(filename, 588 input_arrays, 589 output_arrays, 590 input_shapes=None, 591 **kwargs): 592 """Sanity check to validate post quantize flag alters the graph. 593 594 This test does not check correctness of the converted model. It converts the 595 TensorFlow frozen graph to TFLite with and without the post_training_quantized 596 flag. It ensures some tensors have different types between the float and 597 quantized models in the case of an all TFLite model or mix-and-match model. 598 It ensures tensor types do not change in the case of an all Flex model. 599 600 Args: 601 filename: Full filepath of file containing frozen GraphDef. 602 input_arrays: List of input tensors to freeze graph with. 603 output_arrays: List of output tensors to freeze graph with. 604 input_shapes: Dict of strings representing input tensor names to list of 605 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 606 Automatically determined when input shapes is None (e.g., {"foo" : None}). 607 (default None) 608 **kwargs: Additional arguments to be passed into the converter. 609 610 Raises: 611 ValueError: post_training_quantize flag doesn't act as intended. 612 """ 613 # Convert and load the float model. 614 converter = _lite.TFLiteConverter.from_frozen_graph( 615 filename, input_arrays, output_arrays, input_shapes) 616 tflite_model_float = _convert(converter, **kwargs) 617 618 interpreter_float = _get_tflite_interpreter(tflite_model_float) 619 interpreter_float.allocate_tensors() 620 float_tensors = interpreter_float.get_tensor_details() 621 622 # Convert and load the quantized model. 623 converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays, 624 output_arrays, 625 input_shapes) 626 tflite_model_quant = _convert( 627 converter, post_training_quantize=True, **kwargs) 628 629 interpreter_quant = _get_tflite_interpreter(tflite_model_quant) 630 interpreter_quant.allocate_tensors() 631 quant_tensors = interpreter_quant.get_tensor_details() 632 quant_tensors_map = { 633 tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors 634 } 635 quantized_tensors = { 636 tensor_detail["name"]: tensor_detail 637 for tensor_detail in quant_tensors 638 if tensor_detail["quantization_parameters"] 639 } 640 641 # Check if weights are of different types in the float and quantized models. 642 num_tensors_float = len(float_tensors) 643 num_tensors_same_dtypes = sum( 644 float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"] 645 for float_tensor in float_tensors) 646 has_quant_tensor = num_tensors_float != num_tensors_same_dtypes 647 648 # For the "flex" case, post_training_quantize should not alter the graph, 649 # unless we are quantizing to float16. 650 if ("target_ops" in kwargs and 651 not kwargs.get("quantize_to_float16", False) and 652 not kwargs.get("post_training_quantize_int8", False) and 653 not kwargs.get("post_training_quantize_16x8", False) and 654 set(kwargs["target_ops"]) == set([_lite.OpsSet.SELECT_TF_OPS])): 655 if has_quant_tensor: 656 raise ValueError("--post_training_quantize flag unexpectedly altered the " 657 "full Flex mode graph.") 658 elif kwargs.get("post_training_quantize_int8", False): 659 # Instead of using tensor names, we use the number of tensors which have 660 # quantization parameters to verify the model is quantized. 661 if not quantized_tensors: 662 raise ValueError("--post_training_quantize flag was unable to quantize " 663 "the graph as expected in TFLite.") 664 elif not has_quant_tensor: 665 raise ValueError("--post_training_quantize flag was unable to quantize the " 666 "graph as expected in TFLite and mix-and-match mode.") 667 668 669def test_frozen_graph(filename, 670 input_arrays, 671 output_arrays, 672 input_shapes=None, 673 input_shapes_resize=None, 674 input_data=None, 675 input_data_range=None, 676 **kwargs): 677 """Validates the TensorFlow frozen graph converts to a TFLite model. 678 679 Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the 680 model on random data. 681 682 Args: 683 filename: Full filepath of file containing frozen GraphDef. 684 input_arrays: List of input tensors to freeze graph with. 685 output_arrays: List of output tensors to freeze graph with. 686 input_shapes: Dict of strings representing input tensor names to list of 687 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 688 Automatically determined when input shapes is None (e.g., {"foo" : None}). 689 (default None) 690 input_shapes_resize: A map where the key is the input tensor name and the 691 value is the shape of the input tensor. This resize happens after model 692 conversion, prior to calling allocate tensors. (default None) 693 input_data: np.ndarray to pass into models during inference. (default None). 694 input_data_range: A map where the key is the input tensor name and 695 the value is a tuple (min_val, max_val) which specifies the value range of 696 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 697 generate a random value for tensor `input1` within range [1.0, 5.0) 698 (half-inclusive). (default None) 699 **kwargs: Additional arguments to be passed into the converter. 700 """ 701 converter = _lite.TFLiteConverter.from_frozen_graph( 702 filename, input_arrays, output_arrays, input_shapes) 703 tflite_model = _convert(converter, **kwargs) 704 705 tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays) 706 compare_models( 707 tflite_model, 708 tf_eval_func, 709 input_shapes_resize=input_shapes_resize, 710 input_data=input_data, 711 input_data_range=input_data_range) 712 713 714def test_saved_model(directory, 715 input_shapes=None, 716 tag_set=None, 717 signature_key=None, 718 input_data=None, 719 input_data_range=None, 720 **kwargs): 721 """Validates the TensorFlow SavedModel converts to a TFLite model. 722 723 Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the 724 model on random data. 725 726 Args: 727 directory: SavedModel directory to convert. 728 input_shapes: Dict of strings representing input tensor names to list of 729 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 730 Automatically determined when input shapes is None (e.g., {"foo" : None}). 731 (default None) 732 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 733 analyze. All tags in the tag set must be present. 734 signature_key: Key identifying SignatureDef containing inputs and outputs. 735 input_data: np.ndarray to pass into models during inference. (default None). 736 input_data_range: A map where the key is the input tensor name and 737 the value is a tuple (min_val, max_val) which specifies the value range of 738 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 739 generate a random value for tensor `input1` within range [1.0, 5.0) 740 (half-inclusive). (default None) 741 **kwargs: Additional arguments to be passed into the converter. 742 """ 743 converter = _lite.TFLiteConverter.from_saved_model( 744 directory, 745 input_shapes=input_shapes, 746 tag_set=tag_set, 747 signature_key=signature_key) 748 tflite_model = _convert(converter, **kwargs) 749 750 # 5 decimal places by default 751 tolerance = 5 752 if kwargs.get("post_training_quantize_16x8", False): 753 _check_model_quantized_to_16x8(tflite_model) 754 # only 2 decimal places for full quantization 755 tolerance = 2 756 757 tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key) 758 compare_models( 759 tflite_model, 760 tf_eval_func, 761 input_data=input_data, 762 input_data_range=input_data_range, 763 tolerance=tolerance) 764 765 766def test_saved_model_v2(directory, 767 tag_set=None, 768 signature_key=None, 769 input_data=None, 770 input_data_range=None, 771 **kwargs): 772 """Validates the TensorFlow SavedModel converts to a TFLite model. 773 774 Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the 775 model on random data. 776 777 Args: 778 directory: SavedModel directory to convert. 779 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 780 analyze. All tags in the tag set must be present. 781 signature_key: Key identifying SignatureDef containing inputs and outputs. 782 input_data: np.ndarray to pass into models during inference. (default None). 783 input_data_range: A map where the key is the input tensor name and 784 the value is a tuple (min_val, max_val) which specifies the value range of 785 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 786 generate a random value for tensor `input1` within range [1.0, 5.0) 787 (half-inclusive). (default None) 788 **kwargs: Additional arguments to be passed into the converter. 789 """ 790 model = _load.load(directory, tags=tag_set) 791 if not signature_key: 792 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 793 concrete_func = model.signatures[signature_key] 794 795 converter = _lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) 796 tflite_model = _convert(converter, **kwargs) 797 798 compare_models_v2( 799 tflite_model, 800 concrete_func, 801 input_data=input_data, 802 input_data_range=input_data_range) 803 804 805def _test_conversion_quant_float16(converter, 806 input_data, 807 golden_name=None, 808 update_golden=False, 809 **kwargs): 810 """Validates conversion with float16 quantization. 811 812 Args: 813 converter: TFLite converter instance for the model to convert. 814 input_data: np.ndarray to pass into models during inference. 815 golden_name: Optional golden values to compare the output of the model 816 against. 817 update_golden: Whether to update the golden values with the model output 818 instead of comparing against them. 819 **kwargs: Additional arguments to be passed into the converter. 820 """ 821 tflite_model_float = _convert(converter, version=2, **kwargs) 822 823 interpreter_float = _get_tflite_interpreter(tflite_model_float) 824 interpreter_float.allocate_tensors() 825 float_tensors = interpreter_float.get_tensor_details() 826 827 tflite_model_quant = _convert( 828 converter, 829 version=2, 830 post_training_quantize=True, 831 quantize_to_float16=True, 832 **kwargs) 833 834 interpreter_quant = _get_tflite_interpreter(tflite_model_quant) 835 interpreter_quant.allocate_tensors() 836 quant_tensors = interpreter_quant.get_tensor_details() 837 quant_tensors_map = { 838 tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors 839 } 840 841 # Check if weights are of different types in the float and quantized models. 842 num_tensors_float = len(float_tensors) 843 num_tensors_same_dtypes = sum( 844 float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"] 845 for float_tensor in float_tensors) 846 has_quant_tensor = num_tensors_float != num_tensors_same_dtypes 847 848 if not has_quant_tensor: 849 raise ValueError("--post_training_quantize flag was unable to quantize the " 850 "graph as expected.") 851 852 if golden_name: 853 compare_model_golden(tflite_model_quant, input_data, golden_name, 854 update_golden) 855 856 857def test_saved_model_v2_quant_float16(directory, 858 input_data, 859 golden_name=None, 860 update_golden=False, 861 **kwargs): 862 """Validates conversion of a saved model to TFLite with float16 quantization. 863 864 Args: 865 directory: SavedModel directory to convert. 866 input_data: np.ndarray to pass into models during inference. 867 golden_name: Optional golden values to compare the output of the model 868 against. 869 update_golden: Whether to update the golden values with the model output 870 instead of comparing against them. 871 **kwargs: Additional arguments to be passed into the converter. 872 """ 873 converter = _lite.TFLiteConverterV2.from_saved_model(directory) 874 _test_conversion_quant_float16(converter, input_data, golden_name, 875 update_golden, **kwargs) 876 877 878def test_frozen_graph_quant_float16(filename, 879 input_arrays, 880 output_arrays, 881 input_data, 882 input_shapes=None, 883 golden_name=None, 884 update_golden=False, 885 **kwargs): 886 """Validates conversion of a frozen graph to TFLite with float16 quantization. 887 888 Args: 889 filename: Full filepath of file containing frozen GraphDef. 890 input_arrays: List of input tensors to freeze graph with. 891 output_arrays: List of output tensors to freeze graph with. 892 input_data: np.ndarray to pass into models during inference. 893 input_shapes: Dict of strings representing input tensor names to list of 894 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 895 Automatically determined when input shapes is None (e.g., {"foo" : None}). 896 (default None) 897 golden_name: Optional golden values to compare the output of the model 898 against. 899 update_golden: Whether to update the golden values with the model output 900 instead of comparing against them. 901 **kwargs: Additional arguments to be passed into the converter. 902 """ 903 converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays, 904 output_arrays, 905 input_shapes) 906 _test_conversion_quant_float16(converter, input_data, 907 golden_name, update_golden, **kwargs) 908 909 910def test_keras_model(filename, 911 input_arrays=None, 912 input_shapes=None, 913 input_data=None, 914 input_data_range=None, 915 **kwargs): 916 """Validates the tf.keras model converts to a TFLite model. 917 918 Converts the tf.keras model to TFLite and checks the accuracy of the model on 919 random data. 920 921 Args: 922 filename: Full filepath of HDF5 file containing the tf.keras model. 923 input_arrays: List of input tensors to freeze graph with. 924 input_shapes: Dict of strings representing input tensor names to list of 925 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 926 Automatically determined when input shapes is None (e.g., {"foo" : None}). 927 (default None) 928 input_data: np.ndarray to pass into models during inference. (default None). 929 input_data_range: A map where the key is the input tensor name and 930 the value is a tuple (min_val, max_val) which specifies the value range of 931 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 932 generate a random value for tensor `input1` within range [1.0, 5.0) 933 (half-inclusive). (default None) 934 **kwargs: Additional arguments to be passed into the converter. 935 """ 936 converter = _lite.TFLiteConverter.from_keras_model_file( 937 filename, input_arrays=input_arrays, input_shapes=input_shapes) 938 tflite_model = _convert(converter, **kwargs) 939 940 tf_eval_func = evaluate_keras_model(filename) 941 compare_models( 942 tflite_model, 943 tf_eval_func, 944 input_data=input_data, 945 input_data_range=input_data_range) 946 947 948def test_keras_model_v2(filename, 949 input_shapes=None, 950 input_data=None, 951 input_data_range=None, 952 **kwargs): 953 """Validates the tf.keras model converts to a TFLite model. 954 955 Converts the tf.keras model to TFLite and checks the accuracy of the model on 956 random data. 957 958 Args: 959 filename: Full filepath of HDF5 file containing the tf.keras model. 960 input_shapes: List of list of integers representing input shapes in the 961 order of the tf.keras model's .input attribute (e.g., [[1, 16, 16, 3]]). 962 (default None) 963 input_data: np.ndarray to pass into models during inference. (default None). 964 input_data_range: A map where the key is the input tensor name and 965 the value is a tuple (min_val, max_val) which specifies the value range of 966 the corresponding input tensor. For example, '{'input1': (1, 5)}' means to 967 generate a random value for tensor `input1` within range [1.0, 5.0) 968 (half-inclusive). (default None) 969 **kwargs: Additional arguments to be passed into the converter. 970 """ 971 keras_model = keras.models.load_model(filename) 972 if input_shapes: 973 for tensor, shape in zip(keras_model.inputs, input_shapes): 974 tensor.set_shape(shape) 975 976 converter = _lite.TFLiteConverterV2.from_keras_model(keras_model) 977 tflite_model = _convert(converter, **kwargs) 978 979 tf_eval_func = evaluate_keras_model(filename) 980 compare_models_v2( 981 tflite_model, 982 tf_eval_func, 983 input_data=input_data, 984 input_data_range=input_data_range) 985