1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converts a model's graph def into a tflite model with MLIR-based conversion.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import tempfile 22 23import numpy as np 24import tensorflow.compat.v1 as tf 25from tensorflow.lite.python import test_util as tflite_test_util 26from tensorflow.lite.testing import zip_test_utils 27from tensorflow.python.platform import resource_loader 28 29 30def mlir_convert(options, graph_def, input_tensors, output_tensors, **kwargs): 31 """Convert a model's graph def into a tflite model with MLIR-based conversion. 32 33 Args: 34 options: A lite.testing.generate_examples_lib.Options instance. 35 graph_def: A GraphDef object. 36 input_tensors: List of input tensor tuples `(name, shape, type)`. 37 output_tensors: List of output tensors (names). 38 **kwargs: Extra parameters. 39 40 Returns: 41 output tflite model, log_txt from conversion 42 or None, log_txt if it did not convert properly. 43 """ 44 test_params = kwargs.get("test_params", {}) 45 # TODO(b/146025965): Rename ExtraTocoOptions to ExtraConvertOptions or 46 # something else. 47 extra_toco_options = kwargs.get("extra_toco_options", 48 zip_test_utils.ExtraTocoOptions()) 49 input_arrays = [x[0] for x in input_tensors] 50 input_shapes = zip_test_utils.get_input_shapes_map(input_tensors) 51 52 tflite_model = None 53 log = "" 54 55 with tempfile.NamedTemporaryFile() as graphdef_file: 56 graphdef_file.write(graph_def.SerializeToString()) 57 graphdef_file.flush() 58 converter = tf.lite.TFLiteConverter.from_frozen_graph( 59 graphdef_file.name, input_arrays, output_tensors, input_shapes) 60 converter.allow_custom_ops = extra_toco_options.allow_custom_ops 61 converter.experimental_new_quantizer = options.mlir_quantizer 62 63 if options.run_with_flex: 64 converter.supported_ops = set([ 65 tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]) 66 67 if test_params.get("dynamic_range_quantize", False): 68 converter.optimizations = [tf.lite.Optimize.DEFAULT] 69 70 if test_params.get("fully_quantize", False): 71 converter.optimizations = [tf.lite.Optimize.DEFAULT] 72 73 # Read the input range for the representative dataset from parameters. 74 min_value, max_value = test_params.get("input_range", (-1, 1)) 75 76 def representative_dataset(input_tensors): 77 calibration_inputs = [] 78 for _, shape, _ in input_tensors: 79 if shape: 80 dims = [1 if dim.value is None else dim.value for dim in shape.dims] 81 calibration_inputs.append( 82 np.random.uniform(min_value, max_value, 83 tuple(dims)).astype(np.float32)) 84 return calibration_inputs 85 86 def representative_dataset_gen(): 87 for _ in range(100): 88 yield representative_dataset(input_tensors) 89 90 if test_params.get("quant_16x8", False): 91 converter.target_spec.supported_ops = [ 92 tf.lite.OpsSet.\ 93 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 94 ] 95 else: 96 converter.target_spec.supported_ops = [ 97 tf.lite.OpsSet.TFLITE_BUILTINS_INT8 98 ] 99 100 converter.representative_dataset = representative_dataset_gen 101 if extra_toco_options.inference_input_type: 102 converter.inference_input_type = ( 103 extra_toco_options.inference_input_type) 104 105 if extra_toco_options.inference_output_type: 106 converter.inference_output_type = ( 107 extra_toco_options.inference_output_type) 108 109 try: 110 tflite_model = converter.convert() 111 if options.expected_ops_in_converted_model: 112 ops_list = tflite_test_util.get_ops_list(tflite_model) 113 for expected_op in options.expected_ops_in_converted_model: 114 if expected_op not in ops_list: 115 # Force the test to fail. 116 tflite_model = None 117 raise ValueError( 118 "{} op not found in the converted model".format(expected_op)) 119 except Exception as e: # pylint: disable=broad-except 120 log = str(e) 121 122 return tflite_model, log 123 124 125def mlir_convert_file(graph_def_filename, 126 input_tensors, 127 output_tensors, 128 quantization_params=None, 129 additional_flags=""): 130 """Convert a graphdef file into a tflite model with MLIR-based conversion. 131 132 NOTE: this currently shells out to the MLIR binary binary, but we would like 133 convert to Python API tooling in the future. 134 135 Args: 136 graph_def_filename: A GraphDef file. 137 input_tensors: List of input tensor tuples `(name, shape, type)`. name 138 should be a string. shape should be a tuple of integers. type should be a 139 string, for example 'DT_FLOAT' 140 output_tensors: List of output tensors (names). 141 quantization_params: parameters `(inference_type, min_values, max_values)` 142 to quantize the model. 143 additional_flags: A string of additional command line flags to be passed 144 to MLIR converter. 145 146 Returns: 147 output tflite model, log_txt from conversion 148 or None, log_txt if it did not convert properly. 149 """ 150 bin_path = resource_loader.get_path_to_datafile( 151 "../../../../compiler/mlir/lite/tf_tfl_translate") 152 153 with tempfile.NamedTemporaryFile() as output_file, \ 154 tempfile.NamedTemporaryFile("w+") as stdout_file: 155 input_shapes = [] 156 for input_tensor in input_tensors: 157 shape = input_tensor[1] 158 input_shapes.append(",".join([str(dim) for dim in shape])) 159 input_shapes_str = ":".join(input_shapes) 160 161 input_types = ",".join([x[2] for x in input_tensors]) 162 163 quant_flags = "" 164 if quantization_params is not None: 165 min_vals = ",".join([str(val) for val in quantization_params[1]]) 166 max_vals = ",".join([str(val) for val in quantization_params[2]]) 167 quant_flags = ("-tf-inference-type=" + quantization_params[0] + 168 " -tf-input-min-values='" + min_vals + 169 "' -tf-input-max-values='" + max_vals + "' " + 170 "-emit-quant-adaptor-ops ") 171 cmd = ("%s -tf-input-arrays=%s -tf-input-data-types=%s -tf-input-shapes=%s " 172 "-tf-output-arrays=%s " + quant_flags + additional_flags + 173 "%s -o %s") 174 cmd = cmd % ( 175 bin_path, 176 ",".join([x[0] for x in input_tensors]), 177 input_types, 178 input_shapes_str, 179 ",".join(output_tensors), 180 graph_def_filename, 181 output_file.name, 182 ) 183 exit_code = os.system(cmd) 184 log = ( 185 cmd + "exited with code %d" % exit_code + "\n------------------\n" + 186 stdout_file.read()) 187 return (None if exit_code != 0 else output_file.read()), log 188 189 190def mlir_convert_saved_model(saved_model_dir, 191 is_signature_def_saved_model, 192 tags=(), 193 exported_names=(), 194 additional_flags=""): 195 """Convert a saved_model into a tflite model with MLIR-based conversion. 196 197 Args: 198 saved_model_dir: Saved model dir. 199 is_signature_def_saved_model: Whether the SavedModel SignatureDef importer 200 or ObjectGraph importer should be used. 201 tags: Set of tags identifying the MetaGraphDef within the SavedModel to 202 analyze. All tags in the tag set must be present. 203 exported_names: Names to export from SavedModel. 204 additional_flags: A string of additional command line flags to be passed to 205 MLIR converter. 206 207 Returns: 208 output tflite model, log_txt from conversion 209 or None, log_txt if it did not convert properly. 210 """ 211 bin_path = resource_loader.get_path_to_datafile( 212 "../../../../compiler/mlir/lite/tf_tfl_translate") 213 with tempfile.NamedTemporaryFile() as output_file, \ 214 tempfile.NamedTemporaryFile("w+") as stdout_file: 215 tags_str = ",".join(tags) 216 exported_names_str = ",".join(exported_names) 217 218 saved_model_flag = "-savedmodel-objectgraph-to-mlir" 219 if is_signature_def_saved_model: 220 saved_model_flag = "-savedmodel-signaturedefs-to-mlir" 221 222 cmd = ("%s %s --tf-savedmodel-tags=%s --tf-savedmodel-exported-names=%s " + 223 additional_flags + " %s --o=%s") 224 cmd = cmd % ( 225 bin_path, 226 saved_model_flag, 227 tags_str, 228 exported_names_str, 229 saved_model_dir, 230 output_file.name, 231 ) 232 exit_code = os.system(cmd) 233 log = ( 234 cmd + "exited with code %d" % exit_code + "\n------------------\n" + 235 stdout_file.read()) 236 return (None if exit_code != 0 else output_file.read()), log 237