1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions to test TFLite models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import numpy as np 23 24from tensorflow.core.framework import graph_pb2 as _graph_pb2 25from tensorflow.lite.python import convert_saved_model as _convert_saved_model 26from tensorflow.lite.python import lite as _lite 27from tensorflow.python import keras as _keras 28from tensorflow.python.client import session as _session 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework.importer import import_graph_def as _import_graph_def 31from tensorflow.python.keras.preprocessing import image 32from tensorflow.python.lib.io import file_io as _file_io 33from tensorflow.python.platform import resource_loader as _resource_loader 34from tensorflow.python.saved_model import load as _load 35from tensorflow.python.saved_model import loader as _loader 36from tensorflow.python.saved_model import signature_constants as _signature_constants 37from tensorflow.python.saved_model import tag_constants as _tag_constants 38 39 40def get_filepath(filename, base_dir=None): 41 """Returns the full path of the filename. 42 43 Args: 44 filename: Subdirectory and name of the model file. 45 base_dir: Base directory containing model file. 46 47 Returns: 48 str. 49 """ 50 if base_dir is None: 51 base_dir = "learning/brain/mobile/tflite_compat_models" 52 return os.path.join(_resource_loader.get_root_dir_with_all_resources(), 53 base_dir, filename) 54 55 56def get_image(size): 57 """Returns an image loaded into an np.ndarray with dims [1, size, size, 3]. 58 59 Args: 60 size: Size of image. 61 62 Returns: 63 np.ndarray. 64 """ 65 img_filename = _resource_loader.get_path_to_datafile( 66 "testdata/grace_hopper.jpg") 67 img = image.load_img(img_filename, target_size=(size, size)) 68 img_array = image.img_to_array(img) 69 img_array = np.expand_dims(img_array, axis=0) 70 return img_array 71 72 73def _convert(converter, **kwargs): 74 """Converts the model. 75 76 Args: 77 converter: TFLiteConverter object. 78 **kwargs: Additional arguments to be passed into the converter. Supported 79 flags are {"target_ops", "post_training_quantize"}. 80 81 Returns: 82 The converted TFLite model in serialized format. 83 """ 84 if "target_ops" in kwargs: 85 converter.target_ops = kwargs["target_ops"] 86 if "post_training_quantize" in kwargs: 87 converter.post_training_quantize = kwargs["post_training_quantize"] 88 return converter.convert() 89 90 91def _generate_random_input_data(tflite_model, seed=None): 92 """Generates input data based on the input tensors in the TFLite model. 93 94 Args: 95 tflite_model: Serialized TensorFlow Lite model. 96 seed: Integer seed for the random generator. (default None) 97 98 Returns: 99 List of np.ndarray. 100 """ 101 interpreter = _lite.Interpreter(model_content=tflite_model) 102 interpreter.allocate_tensors() 103 input_details = interpreter.get_input_details() 104 105 if seed: 106 np.random.seed(seed=seed) 107 return [ 108 np.array( 109 np.random.random_sample(input_tensor["shape"]), 110 dtype=input_tensor["dtype"]) for input_tensor in input_details 111 ] 112 113 114def _evaluate_tflite_model(tflite_model, input_data): 115 """Returns evaluation of input data on TFLite model. 116 117 Args: 118 tflite_model: Serialized TensorFlow Lite model. 119 input_data: List of np.ndarray. 120 121 Returns: 122 List of np.ndarray. 123 """ 124 interpreter = _lite.Interpreter(model_content=tflite_model) 125 interpreter.allocate_tensors() 126 127 input_details = interpreter.get_input_details() 128 output_details = interpreter.get_output_details() 129 130 for input_tensor, tensor_data in zip(input_details, input_data): 131 interpreter.set_tensor(input_tensor["index"], tensor_data) 132 133 interpreter.invoke() 134 output_data = [ 135 interpreter.get_tensor(output_tensor["index"]) 136 for output_tensor in output_details 137 ] 138 return output_data 139 140 141def evaluate_frozen_graph(filename, input_arrays, output_arrays): 142 """Returns a function that evaluates the frozen graph on input data. 143 144 Args: 145 filename: Full filepath of file containing frozen GraphDef. 146 input_arrays: List of input tensors to freeze graph with. 147 output_arrays: List of output tensors to freeze graph with. 148 149 Returns: 150 Lambda function ([np.ndarray data] : [np.ndarray result]). 151 """ 152 with _session.Session().as_default() as sess: 153 with _file_io.FileIO(filename, "rb") as f: 154 file_content = f.read() 155 156 graph_def = _graph_pb2.GraphDef() 157 graph_def.ParseFromString(file_content) 158 _import_graph_def(graph_def, name="") 159 160 inputs = _convert_saved_model.get_tensors_from_tensor_names( 161 sess.graph, input_arrays) 162 outputs = _convert_saved_model.get_tensors_from_tensor_names( 163 sess.graph, output_arrays) 164 165 return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data))) 166 167 168def evaluate_saved_model(directory, tag_set, signature_key): 169 """Returns a function that evaluates the SavedModel on input data. 170 171 Args: 172 directory: SavedModel directory to convert. 173 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 174 analyze. All tags in the tag set must be present. 175 signature_key: Key identifying SignatureDef containing inputs and outputs. 176 177 Returns: 178 Lambda function ([np.ndarray data] : [np.ndarray result]). 179 """ 180 with _session.Session().as_default() as sess: 181 if tag_set is None: 182 tag_set = set([_tag_constants.SERVING]) 183 if signature_key is None: 184 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 185 186 meta_graph = _loader.load(sess, tag_set, directory) 187 signature_def = _convert_saved_model.get_signature_def( 188 meta_graph, signature_key) 189 inputs, outputs = _convert_saved_model.get_inputs_outputs(signature_def) 190 191 return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data))) 192 193 194def evaluate_keras_model(filename): 195 """Returns a function that evaluates the tf.keras model on input data. 196 197 Args: 198 filename: Full filepath of HDF5 file containing the tf.keras model. 199 200 Returns: 201 Lambda function ([np.ndarray data] : [np.ndarray result]). 202 """ 203 keras_model = _keras.models.load_model(filename) 204 return lambda input_data: [keras_model.predict(input_data)] 205 206 207def compare_models(tflite_model, tf_eval_func, input_data=None, tolerance=5): 208 """Compares TensorFlow and TFLite models. 209 210 Unless the input data is provided, the models are compared with random data. 211 212 Args: 213 tflite_model: Serialized TensorFlow Lite model. 214 tf_eval_func: Lambda function that takes in input data and outputs the 215 results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]). 216 input_data: np.ndarray to pass into models during inference. (default None) 217 tolerance: Decimal place to check accuracy to. (default 5) 218 """ 219 if input_data is None: 220 input_data = _generate_random_input_data(tflite_model) 221 tf_results = tf_eval_func(input_data) 222 tflite_results = _evaluate_tflite_model(tflite_model, input_data) 223 for tf_result, tflite_result in zip(tf_results, tflite_results): 224 np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) 225 226 227def compare_models_v2(tflite_model, concrete_func, input_data=None, 228 tolerance=5): 229 """Compares TensorFlow and TFLite models for TensorFlow 2.0. 230 231 Unless the input data is provided, the models are compared with random data. 232 Currently only 1 input and 1 output are supported by this function. 233 234 Args: 235 tflite_model: Serialized TensorFlow Lite model. 236 concrete_func: TensorFlow ConcreteFunction. 237 input_data: np.ndarray to pass into models during inference. (default None) 238 tolerance: Decimal place to check accuracy to. (default 5) 239 """ 240 if input_data is None: 241 input_data = _generate_random_input_data(tflite_model) 242 input_data_func = constant_op.constant(input_data[0]) 243 244 # Gets the TensorFlow results as a map from the output names to outputs. 245 # Converts the map into a list that is equivalent to the TFLite list. 246 tf_results_map = concrete_func(input_data_func) 247 tf_results = [tf_results_map[tf_results_map.keys()[0]]] 248 tflite_results = _evaluate_tflite_model(tflite_model, input_data) 249 for tf_result, tflite_result in zip(tf_results, tflite_results): 250 np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) 251 252 253def test_frozen_graph_quant(filename, 254 input_arrays, 255 output_arrays, 256 input_shapes=None, 257 **kwargs): 258 """Sanity check to validate post quantize flag alters the graph. 259 260 This test does not check correctness of the converted model. It converts the 261 TensorFlow frozen graph to TFLite with and without the post_training_quantized 262 flag. It ensures some tensors have different types between the float and 263 quantized models in the case of an all TFLite model or mix-and-match model. 264 It ensures tensor types do not change in the case of an all Flex model. 265 266 Args: 267 filename: Full filepath of file containing frozen GraphDef. 268 input_arrays: List of input tensors to freeze graph with. 269 output_arrays: List of output tensors to freeze graph with. 270 input_shapes: Dict of strings representing input tensor names to list of 271 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 272 Automatically determined when input shapes is None (e.g., {"foo" : None}). 273 (default None) 274 **kwargs: Additional arguments to be passed into the converter. 275 276 Raises: 277 ValueError: post_training_quantize flag doesn't act as intended. 278 """ 279 # Convert and load the float model. 280 converter = _lite.TFLiteConverter.from_frozen_graph( 281 filename, input_arrays, output_arrays, input_shapes) 282 tflite_model_float = _convert(converter, **kwargs) 283 284 interpreter_float = _lite.Interpreter(model_content=tflite_model_float) 285 interpreter_float.allocate_tensors() 286 float_tensors = interpreter_float.get_tensor_details() 287 288 # Convert and load the quantized model. 289 converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays, 290 output_arrays) 291 tflite_model_quant = _convert( 292 converter, post_training_quantize=True, **kwargs) 293 294 interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant) 295 interpreter_quant.allocate_tensors() 296 quant_tensors = interpreter_quant.get_tensor_details() 297 quant_tensors_map = { 298 tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors 299 } 300 301 # Check if weights are of different types in the float and quantized models. 302 num_tensors_float = len(float_tensors) 303 num_tensors_same_dtypes = sum( 304 float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"] 305 for float_tensor in float_tensors) 306 has_quant_tensor = num_tensors_float != num_tensors_same_dtypes 307 308 if ("target_ops" in kwargs and 309 set(kwargs["target_ops"]) == set([_lite.OpsSet.SELECT_TF_OPS])): 310 if has_quant_tensor: 311 raise ValueError("--post_training_quantize flag unexpectedly altered the " 312 "full Flex mode graph.") 313 elif not has_quant_tensor: 314 raise ValueError("--post_training_quantize flag was unable to quantize the " 315 "graph as expected in TFLite and mix-and-match mode.") 316 317 318def test_frozen_graph(filename, 319 input_arrays, 320 output_arrays, 321 input_shapes=None, 322 input_data=None, 323 **kwargs): 324 """Validates the TensorFlow frozen graph converts to a TFLite model. 325 326 Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the 327 model on random data. 328 329 Args: 330 filename: Full filepath of file containing frozen GraphDef. 331 input_arrays: List of input tensors to freeze graph with. 332 output_arrays: List of output tensors to freeze graph with. 333 input_shapes: Dict of strings representing input tensor names to list of 334 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 335 Automatically determined when input shapes is None (e.g., {"foo" : None}). 336 (default None) 337 input_data: np.ndarray to pass into models during inference. (default None) 338 **kwargs: Additional arguments to be passed into the converter. 339 """ 340 converter = _lite.TFLiteConverter.from_frozen_graph( 341 filename, input_arrays, output_arrays, input_shapes) 342 tflite_model = _convert(converter, **kwargs) 343 344 tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays) 345 compare_models(tflite_model, tf_eval_func, input_data=input_data) 346 347 348def test_saved_model(directory, 349 input_shapes=None, 350 tag_set=None, 351 signature_key=None, 352 input_data=None, 353 **kwargs): 354 """Validates the TensorFlow SavedModel converts to a TFLite model. 355 356 Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the 357 model on random data. 358 359 Args: 360 directory: SavedModel directory to convert. 361 input_shapes: Dict of strings representing input tensor names to list of 362 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 363 Automatically determined when input shapes is None (e.g., {"foo" : None}). 364 (default None) 365 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 366 analyze. All tags in the tag set must be present. 367 signature_key: Key identifying SignatureDef containing inputs and outputs. 368 input_data: np.ndarray to pass into models during inference. (default None) 369 **kwargs: Additional arguments to be passed into the converter. 370 """ 371 converter = _lite.TFLiteConverter.from_saved_model( 372 directory, 373 input_shapes=input_shapes, 374 tag_set=tag_set, 375 signature_key=signature_key) 376 tflite_model = _convert(converter, **kwargs) 377 378 tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key) 379 compare_models(tflite_model, tf_eval_func, input_data=input_data) 380 381 382# TODO(nupurgarg): Remove input_shape parameter after bug with shapes is fixed. 383def test_saved_model_v2(directory, 384 input_shape=None, 385 tag_set=None, 386 signature_key=None, 387 input_data=None, 388 **kwargs): 389 """Validates the TensorFlow SavedModel converts to a TFLite model. 390 391 Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the 392 model on random data. 393 394 Args: 395 directory: SavedModel directory to convert. 396 input_shape: Input shape for the single input array as a list of integers. 397 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 398 analyze. All tags in the tag set must be present. 399 signature_key: Key identifying SignatureDef containing inputs and outputs. 400 input_data: np.ndarray to pass into models during inference. (default None) 401 **kwargs: Additional arguments to be passed into the converter. 402 """ 403 model = _load.load(directory, tags=tag_set) 404 if not signature_key: 405 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 406 concrete_func = model.signatures[signature_key] 407 concrete_func.inputs[0].set_shape(input_shape) 408 409 converter = _lite.TFLiteConverterV2.from_concrete_function(concrete_func) 410 tflite_model = _convert(converter, **kwargs) 411 412 compare_models_v2(tflite_model, concrete_func, input_data=input_data) 413 414 415def test_keras_model(filename, 416 input_arrays=None, 417 input_shapes=None, 418 input_data=None, 419 **kwargs): 420 """Validates the tf.keras model converts to a TFLite model. 421 422 Converts the tf.keras model to TFLite and checks the accuracy of the model on 423 random data. 424 425 Args: 426 filename: Full filepath of HDF5 file containing the tf.keras model. 427 input_arrays: List of input tensors to freeze graph with. 428 input_shapes: Dict of strings representing input tensor names to list of 429 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 430 Automatically determined when input shapes is None (e.g., {"foo" : None}). 431 (default None) 432 input_data: np.ndarray to pass into models during inference. (default None) 433 **kwargs: Additional arguments to be passed into the converter. 434 """ 435 converter = _lite.TFLiteConverter.from_keras_model_file( 436 filename, input_arrays=input_arrays, input_shapes=input_shapes) 437 tflite_model = _convert(converter, **kwargs) 438 439 tf_eval_func = evaluate_keras_model(filename) 440 compare_models(tflite_model, tf_eval_func, input_data=input_data) 441