1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for collecting TFLite metrics.""" 16 17import collections 18import enum 19import functools 20from typing import Text 21 22from tensorflow.lite.python.metrics import converter_error_data_pb2 23from tensorflow.lite.python.metrics import metrics 24 25 26class Component(enum.Enum): 27 """Enum class defining name of the converter components.""" 28 # Validate the given input and prepare and optimize TensorFlow Model. 29 PREPARE_TF_MODEL = "PREPARE_TF_MODEL" 30 31 # Convert to TFLite model format. 32 CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL" 33 34 # RUN quantization and sparsification. 35 OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL" 36 37 38SubComponentItem = collections.namedtuple("SubComponentItem", 39 ["name", "component"]) 40 41 42class SubComponent(enum.Enum): 43 """Enum class defining name of the converter subcomponents. 44 45 This enum only defines the subcomponents in Python, there might be more 46 subcomponents defined in C++. 47 """ 48 49 def __str__(self): 50 return self.value.name 51 52 @property 53 def name(self): 54 return self.value.name 55 56 @property 57 def component(self): 58 return self.value.component 59 60 # The subcomponent name is unspecified. 61 UNSPECIFIED = SubComponentItem("UNSPECIFIED", None) 62 63 # Valid the given input and parameters. 64 VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS", 65 Component.PREPARE_TF_MODEL) 66 67 # Load GraphDef from SavedModel. 68 LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL", 69 Component.PREPARE_TF_MODEL) 70 71 # Convert a SavedModel to frozen graph. 72 FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL", 73 Component.PREPARE_TF_MODEL) 74 75 # Save a Keras model to SavedModel. 76 CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem( 77 "CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL) 78 79 # Save Concrete functions to SavedModel. 80 CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem( 81 "CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL) 82 83 # Convert a Keras model to a frozen graph. 84 FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL", 85 Component.PREPARE_TF_MODEL) 86 87 # Replace all the variables with constants in a ConcreteFunction. 88 FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION", 89 Component.PREPARE_TF_MODEL) 90 91 # Run grappler optimization. 92 OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL", 93 Component.PREPARE_TF_MODEL) 94 95 # Convert using the old TOCO converter. 96 CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem( 97 "CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER", 98 Component.CONVERT_TF_TO_TFLITE_MODEL) 99 100 # Convert a GraphDef to TFLite model. 101 CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF", 102 Component.CONVERT_TF_TO_TFLITE_MODEL) 103 104 # Convert a SavedModel to TFLite model. 105 CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL", 106 Component.CONVERT_TF_TO_TFLITE_MODEL) 107 108 # Convert a Jax HLO to TFLite model. 109 CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO", 110 Component.CONVERT_TF_TO_TFLITE_MODEL) 111 112 # Do quantization by the deprecated quantizer. 113 QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem( 114 "QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL) 115 116 # Do calibration. 117 CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL) 118 119 # Do quantization by MLIR. 120 QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL) 121 122 # Do sparsification by MLIR. 123 SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL) 124 125 126class ConverterError(Exception): 127 """Raised when an error occurs during model conversion.""" 128 129 def __init__(self, message): 130 super(ConverterError, self).__init__(message) 131 self.errors = [] 132 self._parse_error_message(message) 133 134 def append_error(self, 135 error_data: converter_error_data_pb2.ConverterErrorData): 136 self.errors.append(error_data) 137 138 def _parse_error_message(self, message): 139 """If the message matches a pattern, assigns the associated error code. 140 141 It is difficult to assign an error code to some errrors in MLIR side, Ex: 142 errors thrown by other components than TFLite or not using mlir::emitError. 143 This function try to detect them by the error message and assign the 144 corresponding error code. 145 146 Args: 147 message: The error message of this exception. 148 """ 149 error_code_mapping = { 150 "Failed to functionalize Control Flow V1 ops. Consider using Control " 151 "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/" 152 "tf/compat/v1/enable_control_flow_v2.": 153 converter_error_data_pb2.ConverterErrorData 154 .ERROR_UNSUPPORTED_CONTROL_FLOW_V1, 155 } 156 for pattern, error_code in error_code_mapping.items(): 157 if pattern in message: 158 error_data = converter_error_data_pb2.ConverterErrorData() 159 error_data.error_message = message 160 error_data.error_code = error_code 161 self.append_error(error_data) 162 return 163 164 165def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED): 166 """The decorator to identify converter component and subcomponent. 167 168 Args: 169 component: Converter component name. 170 subcomponent: Converter subcomponent name. 171 172 Returns: 173 Forward the result from the wrapped function. 174 175 Raises: 176 ValueError: if component and subcomponent name is not valid. 177 """ 178 if component not in Component: 179 raise ValueError("Given component name not found") 180 if subcomponent not in SubComponent: 181 raise ValueError("Given subcomponent name not found") 182 if (subcomponent != SubComponent.UNSPECIFIED and 183 subcomponent.component != component): 184 raise ValueError("component and subcomponent name don't match") 185 186 def report_error(error_data: converter_error_data_pb2.ConverterErrorData): 187 # Always overwrites the component information, but only overwrites the 188 # subcomponent if it is not available. 189 error_data.component = component.value 190 if not error_data.subcomponent: 191 error_data.subcomponent = subcomponent.name 192 tflite_metrics = metrics.TFLiteConverterMetrics() 193 tflite_metrics.set_converter_error(error_data) 194 195 def report_error_message(error_message: Text): 196 error_data = converter_error_data_pb2.ConverterErrorData() 197 error_data.error_message = error_message 198 report_error(error_data) 199 200 def actual_decorator(func): 201 202 @functools.wraps(func) 203 def wrapper(*args, **kwargs): 204 try: 205 return func(*args, **kwargs) 206 except ConverterError as converter_error: 207 if converter_error.errors: 208 for error_data in converter_error.errors: 209 report_error(error_data) 210 else: 211 report_error_message(str(converter_error)) 212 raise converter_error from None # Re-throws the exception. 213 except Exception as error: 214 report_error_message(str(error)) 215 raise error from None # Re-throws the exception. 216 217 return wrapper 218 219 return actual_decorator 220