1# Lint as: python2, python3 2# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Utilities for collecting TFLite metrics.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import enum 24import functools 25from typing import Text 26 27from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2 28 29# pylint: disable=g-import-not-at-top 30try: 31 from tensorflow.lite.python import metrics_portable as metrics 32except ImportError: 33 from tensorflow.lite.python import metrics_nonportable as metrics 34# pylint: enable=g-import-not-at-top 35 36 37class Component(enum.Enum): 38 """Enum class defining name of the converter components.""" 39 # Validate the given input and prepare and optimize TensorFlow Model. 40 PREPARE_TF_MODEL = "PREPARE_TF_MODEL" 41 42 # Convert to TFLite model format. 43 CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL" 44 45 # RUN quantization and sparsification. 46 OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL" 47 48 49SubComponentItem = collections.namedtuple("SubComponentItem", 50 ["name", "component"]) 51 52 53class SubComponent(enum.Enum): 54 """Enum class defining name of the converter subcomponents. 55 56 This enum only defines the subcomponents in Python, there might be more 57 subcomponents defined in C++. 58 """ 59 60 def __str__(self): 61 return self.value.name 62 63 @property 64 def name(self): 65 return self.value.name 66 67 @property 68 def component(self): 69 return self.value.component 70 71 # The subcomponent name is unspecified. 72 UNSPECIFIED = SubComponentItem("UNSPECIFIED", None) 73 74 # Valid the given input and parameters. 75 VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS", 76 Component.PREPARE_TF_MODEL) 77 78 # Load GraphDef from SavedModel. 79 LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL", 80 Component.PREPARE_TF_MODEL) 81 82 # Convert a SavedModel to frozen graph. 83 FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL", 84 Component.PREPARE_TF_MODEL) 85 86 # Save a Keras model to SavedModel. 87 CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem( 88 "CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL) 89 90 # Save Concrete functions to SavedModel. 91 CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem( 92 "CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL) 93 94 # Convert a Keras model to a frozen graph. 95 FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL", 96 Component.PREPARE_TF_MODEL) 97 98 # Replace all the variables with constants in a ConcreteFunction. 99 FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION", 100 Component.PREPARE_TF_MODEL) 101 102 # Run grappler optimization. 103 OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL", 104 Component.PREPARE_TF_MODEL) 105 106 # Convert using the old TOCO converter. 107 CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem( 108 "CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER", 109 Component.CONVERT_TF_TO_TFLITE_MODEL) 110 111 # Convert a GraphDef to TFLite model. 112 CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF", 113 Component.CONVERT_TF_TO_TFLITE_MODEL) 114 115 # Convert a SavedModel to TFLite model. 116 CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL", 117 Component.CONVERT_TF_TO_TFLITE_MODEL) 118 119 # Do quantization by the deprecated quantizer. 120 QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem( 121 "QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL) 122 123 # Do calibration. 124 CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL) 125 126 # Do quantization by MLIR. 127 QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL) 128 129 # Do sparsification by MLIR. 130 SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL) 131 132 133class ConverterError(Exception): 134 """Raised when an error occurs during model conversion.""" 135 136 def __init__(self, message): 137 super(ConverterError, self).__init__(message) 138 self.errors = [] 139 self._parse_error_message(message) 140 141 def append_error(self, 142 error_data: converter_error_data_pb2.ConverterErrorData): 143 self.errors.append(error_data) 144 145 def _parse_error_message(self, message): 146 """If the message matches a pattern, assigns the associated error code. 147 148 It is difficult to assign an error code to some errrors in MLIR side, Ex: 149 errors thrown by other components than TFLite or not using mlir::emitError. 150 This function try to detect them by the error message and assign the 151 corresponding error code. 152 153 Args: 154 message: The error message of this exception. 155 """ 156 error_code_mapping = { 157 "Failed to functionalize Control Flow V1 ops. Consider using Control " 158 "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/" 159 "tf/compat/v1/enable_control_flow_v2.": 160 converter_error_data_pb2.ConverterErrorData 161 .ERROR_UNSUPPORTED_CONTROL_FLOW_V1, 162 } 163 for pattern, error_code in error_code_mapping.items(): 164 if pattern in message: 165 error_data = converter_error_data_pb2.ConverterErrorData() 166 error_data.error_message = message 167 error_data.error_code = error_code 168 self.append_error(error_data) 169 return 170 171 172def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED): 173 """The decorator to identify converter component and subcomponent. 174 175 Args: 176 component: Converter component name. 177 subcomponent: Converter subcomponent name. 178 179 Returns: 180 Forward the result from the wrapped function. 181 182 Raises: 183 ValueError: if component and subcomponent name is not valid. 184 """ 185 if component not in Component: 186 raise ValueError("Given component name not found") 187 if subcomponent not in SubComponent: 188 raise ValueError("Given subcomponent name not found") 189 if (subcomponent != SubComponent.UNSPECIFIED and 190 subcomponent.component != component): 191 raise ValueError("component and subcomponent name don't match") 192 193 def report_error(error_data: converter_error_data_pb2.ConverterErrorData): 194 # Always overwrites the component information, but only overwrites the 195 # subcomponent if it is not available. 196 error_data.component = component.value 197 if not error_data.subcomponent: 198 error_data.subcomponent = subcomponent.name 199 tflite_metrics = metrics.TFLiteConverterMetrics() 200 tflite_metrics.set_converter_error(error_data) 201 202 def report_error_message(error_message: Text): 203 error_data = converter_error_data_pb2.ConverterErrorData() 204 error_data.error_message = error_message 205 report_error(error_data) 206 207 def actual_decorator(func): 208 209 @functools.wraps(func) 210 def wrapper(*args, **kwargs): 211 try: 212 return func(*args, **kwargs) 213 except ConverterError as converter_error: 214 if converter_error.errors: 215 for error_data in converter_error.errors: 216 report_error(error_data) 217 else: 218 report_error_message(str(converter_error)) 219 raise converter_error from None # Re-throws the exception. 220 except Exception as error: 221 report_error_message(str(error)) 222 raise error from None # Re-throws the exception. 223 224 return wrapper 225 226 return actual_decorator 227