• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Utilities for collecting TFLite metrics."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import enum
24import functools
25from typing import Text
26
27from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
28
29# pylint: disable=g-import-not-at-top
30try:
31  from tensorflow.lite.python import metrics_portable as metrics
32except ImportError:
33  from tensorflow.lite.python import metrics_nonportable as metrics
34# pylint: enable=g-import-not-at-top
35
36
37class Component(enum.Enum):
38  """Enum class defining name of the converter components."""
39  # Validate the given input and prepare and optimize TensorFlow Model.
40  PREPARE_TF_MODEL = "PREPARE_TF_MODEL"
41
42  # Convert to TFLite model format.
43  CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL"
44
45  # RUN quantization and sparsification.
46  OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL"
47
48
49SubComponentItem = collections.namedtuple("SubComponentItem",
50                                          ["name", "component"])
51
52
53class SubComponent(enum.Enum):
54  """Enum class defining name of the converter subcomponents.
55
56  This enum only defines the subcomponents in Python, there might be more
57  subcomponents defined in C++.
58  """
59
60  def __str__(self):
61    return self.value.name
62
63  @property
64  def name(self):
65    return self.value.name
66
67  @property
68  def component(self):
69    return self.value.component
70
71  # The subcomponent name is unspecified.
72  UNSPECIFIED = SubComponentItem("UNSPECIFIED", None)
73
74  # Valid the given input and parameters.
75  VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS",
76                                     Component.PREPARE_TF_MODEL)
77
78  # Load GraphDef from SavedModel.
79  LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL",
80                                      Component.PREPARE_TF_MODEL)
81
82  # Convert a SavedModel to frozen graph.
83  FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL",
84                                        Component.PREPARE_TF_MODEL)
85
86  # Save a Keras model to SavedModel.
87  CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem(
88      "CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
89
90  # Save Concrete functions to SavedModel.
91  CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem(
92      "CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
93
94  # Convert a Keras model to a frozen graph.
95  FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL",
96                                        Component.PREPARE_TF_MODEL)
97
98  # Replace all the variables with constants in a ConcreteFunction.
99  FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION",
100                                              Component.PREPARE_TF_MODEL)
101
102  # Run grappler optimization.
103  OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL",
104                                       Component.PREPARE_TF_MODEL)
105
106  # Convert using the old TOCO converter.
107  CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem(
108      "CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER",
109      Component.CONVERT_TF_TO_TFLITE_MODEL)
110
111  # Convert a GraphDef to TFLite model.
112  CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF",
113                                      Component.CONVERT_TF_TO_TFLITE_MODEL)
114
115  # Convert a SavedModel to TFLite model.
116  CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL",
117                                         Component.CONVERT_TF_TO_TFLITE_MODEL)
118
119  # Do quantization by the deprecated quantizer.
120  QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem(
121      "QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL)
122
123  # Do calibration.
124  CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL)
125
126  # Do quantization by MLIR.
127  QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL)
128
129  # Do sparsification by MLIR.
130  SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL)
131
132
133class ConverterError(Exception):
134  """Raised when an error occurs during model conversion."""
135
136  def __init__(self, message):
137    super(ConverterError, self).__init__(message)
138    self.errors = []
139    self._parse_error_message(message)
140
141  def append_error(self,
142                   error_data: converter_error_data_pb2.ConverterErrorData):
143    self.errors.append(error_data)
144
145  def _parse_error_message(self, message):
146    """If the message matches a pattern, assigns the associated error code.
147
148    It is difficult to assign an error code to some errrors in MLIR side, Ex:
149    errors thrown by other components than TFLite or not using mlir::emitError.
150    This function try to detect them by the error message and assign the
151    corresponding error code.
152
153    Args:
154      message: The error message of this exception.
155    """
156    error_code_mapping = {
157        "Failed to functionalize Control Flow V1 ops. Consider using Control "
158        "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/"
159        "tf/compat/v1/enable_control_flow_v2.":
160            converter_error_data_pb2.ConverterErrorData
161            .ERROR_UNSUPPORTED_CONTROL_FLOW_V1,
162    }
163    for pattern, error_code in error_code_mapping.items():
164      if pattern in message:
165        error_data = converter_error_data_pb2.ConverterErrorData()
166        error_data.error_message = message
167        error_data.error_code = error_code
168        self.append_error(error_data)
169        return
170
171
172def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED):
173  """The decorator to identify converter component and subcomponent.
174
175  Args:
176    component: Converter component name.
177    subcomponent: Converter subcomponent name.
178
179  Returns:
180    Forward the result from the wrapped function.
181
182  Raises:
183    ValueError: if component and subcomponent name is not valid.
184  """
185  if component not in Component:
186    raise ValueError("Given component name not found")
187  if subcomponent not in SubComponent:
188    raise ValueError("Given subcomponent name not found")
189  if (subcomponent != SubComponent.UNSPECIFIED and
190      subcomponent.component != component):
191    raise ValueError("component and subcomponent name don't match")
192
193  def report_error(error_data: converter_error_data_pb2.ConverterErrorData):
194    # Always overwrites the component information, but only overwrites the
195    # subcomponent if it is not available.
196    error_data.component = component.value
197    if not error_data.subcomponent:
198      error_data.subcomponent = subcomponent.name
199    tflite_metrics = metrics.TFLiteConverterMetrics()
200    tflite_metrics.set_converter_error(error_data)
201
202  def report_error_message(error_message: Text):
203    error_data = converter_error_data_pb2.ConverterErrorData()
204    error_data.error_message = error_message
205    report_error(error_data)
206
207  def actual_decorator(func):
208
209    @functools.wraps(func)
210    def wrapper(*args, **kwargs):
211      try:
212        return func(*args, **kwargs)
213      except ConverterError as converter_error:
214        if converter_error.errors:
215          for error_data in converter_error.errors:
216            report_error(error_data)
217        else:
218          report_error_message(str(converter_error))
219        raise converter_error from None  # Re-throws the exception.
220      except Exception as error:
221        report_error_message(str(error))
222        raise error from None  # Re-throws the exception.
223
224    return wrapper
225
226  return actual_decorator
227