• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for collecting TFLite metrics."""
16
17import collections
18import enum
19import functools
20from typing import Text
21
22from tensorflow.lite.python.metrics import converter_error_data_pb2
23from tensorflow.lite.python.metrics import metrics
24
25
26class Component(enum.Enum):
27  """Enum class defining name of the converter components."""
28  # Validate the given input and prepare and optimize TensorFlow Model.
29  PREPARE_TF_MODEL = "PREPARE_TF_MODEL"
30
31  # Convert to TFLite model format.
32  CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL"
33
34  # RUN quantization and sparsification.
35  OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL"
36
37
38SubComponentItem = collections.namedtuple("SubComponentItem",
39                                          ["name", "component"])
40
41
42class SubComponent(enum.Enum):
43  """Enum class defining name of the converter subcomponents.
44
45  This enum only defines the subcomponents in Python, there might be more
46  subcomponents defined in C++.
47  """
48
49  def __str__(self):
50    return self.value.name
51
52  @property
53  def name(self):
54    return self.value.name
55
56  @property
57  def component(self):
58    return self.value.component
59
60  # The subcomponent name is unspecified.
61  UNSPECIFIED = SubComponentItem("UNSPECIFIED", None)
62
63  # Valid the given input and parameters.
64  VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS",
65                                     Component.PREPARE_TF_MODEL)
66
67  # Load GraphDef from SavedModel.
68  LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL",
69                                      Component.PREPARE_TF_MODEL)
70
71  # Convert a SavedModel to frozen graph.
72  FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL",
73                                        Component.PREPARE_TF_MODEL)
74
75  # Save a Keras model to SavedModel.
76  CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem(
77      "CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
78
79  # Save Concrete functions to SavedModel.
80  CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem(
81      "CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
82
83  # Convert a Keras model to a frozen graph.
84  FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL",
85                                        Component.PREPARE_TF_MODEL)
86
87  # Replace all the variables with constants in a ConcreteFunction.
88  FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION",
89                                              Component.PREPARE_TF_MODEL)
90
91  # Run grappler optimization.
92  OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL",
93                                       Component.PREPARE_TF_MODEL)
94
95  # Convert using the old TOCO converter.
96  CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem(
97      "CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER",
98      Component.CONVERT_TF_TO_TFLITE_MODEL)
99
100  # Convert a GraphDef to TFLite model.
101  CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF",
102                                      Component.CONVERT_TF_TO_TFLITE_MODEL)
103
104  # Convert a SavedModel to TFLite model.
105  CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL",
106                                         Component.CONVERT_TF_TO_TFLITE_MODEL)
107
108  # Convert a Jax HLO to TFLite model.
109  CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO",
110                                     Component.CONVERT_TF_TO_TFLITE_MODEL)
111
112  # Do quantization by the deprecated quantizer.
113  QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem(
114      "QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL)
115
116  # Do calibration.
117  CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL)
118
119  # Do quantization by MLIR.
120  QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL)
121
122  # Do sparsification by MLIR.
123  SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL)
124
125
126class ConverterError(Exception):
127  """Raised when an error occurs during model conversion."""
128
129  def __init__(self, message):
130    super(ConverterError, self).__init__(message)
131    self.errors = []
132    self._parse_error_message(message)
133
134  def append_error(self,
135                   error_data: converter_error_data_pb2.ConverterErrorData):
136    self.errors.append(error_data)
137
138  def _parse_error_message(self, message):
139    """If the message matches a pattern, assigns the associated error code.
140
141    It is difficult to assign an error code to some errrors in MLIR side, Ex:
142    errors thrown by other components than TFLite or not using mlir::emitError.
143    This function try to detect them by the error message and assign the
144    corresponding error code.
145
146    Args:
147      message: The error message of this exception.
148    """
149    error_code_mapping = {
150        "Failed to functionalize Control Flow V1 ops. Consider using Control "
151        "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/"
152        "tf/compat/v1/enable_control_flow_v2.":
153            converter_error_data_pb2.ConverterErrorData
154            .ERROR_UNSUPPORTED_CONTROL_FLOW_V1,
155    }
156    for pattern, error_code in error_code_mapping.items():
157      if pattern in message:
158        error_data = converter_error_data_pb2.ConverterErrorData()
159        error_data.error_message = message
160        error_data.error_code = error_code
161        self.append_error(error_data)
162        return
163
164
165def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED):
166  """The decorator to identify converter component and subcomponent.
167
168  Args:
169    component: Converter component name.
170    subcomponent: Converter subcomponent name.
171
172  Returns:
173    Forward the result from the wrapped function.
174
175  Raises:
176    ValueError: if component and subcomponent name is not valid.
177  """
178  if component not in Component:
179    raise ValueError("Given component name not found")
180  if subcomponent not in SubComponent:
181    raise ValueError("Given subcomponent name not found")
182  if (subcomponent != SubComponent.UNSPECIFIED and
183      subcomponent.component != component):
184    raise ValueError("component and subcomponent name don't match")
185
186  def report_error(error_data: converter_error_data_pb2.ConverterErrorData):
187    # Always overwrites the component information, but only overwrites the
188    # subcomponent if it is not available.
189    error_data.component = component.value
190    if not error_data.subcomponent:
191      error_data.subcomponent = subcomponent.name
192    tflite_metrics = metrics.TFLiteConverterMetrics()
193    tflite_metrics.set_converter_error(error_data)
194
195  def report_error_message(error_message: Text):
196    error_data = converter_error_data_pb2.ConverterErrorData()
197    error_data.error_message = error_message
198    report_error(error_data)
199
200  def actual_decorator(func):
201
202    @functools.wraps(func)
203    def wrapper(*args, **kwargs):
204      try:
205        return func(*args, **kwargs)
206      except ConverterError as converter_error:
207        if converter_error.errors:
208          for error_data in converter_error.errors:
209            report_error(error_data)
210        else:
211          report_error_message(str(converter_error))
212        raise converter_error from None  # Re-throws the exception.
213      except Exception as error:
214        report_error_message(str(error))
215        raise error from None  # Re-throws the exception.
216
217    return wrapper
218
219  return actual_decorator
220