• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Authoring tool package for TFLite compatibility.
16
17WARNING: The package is experimental and subject to change.
18
19This package provides a way to check TFLite compatibility at model authoring
20time.
21
22Example:
23    @tf.lite.experimental.authoring.compatible
24    @tf.function(input_signature=[
25        tf.TensorSpec(shape=[None], dtype=tf.float32)
26    ])
27    def f(x):
28      return tf.cosh(x)
29
30    result = f(tf.constant([0.0]))
31
32    > COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
33    > conversion for TensorFlow Lite.
34    > Op: tf.Cosh
35    >   - tensorflow/python/framework/op_def_library.py:xxx
36    >   - tensorflow/python/ops/gen_math_ops.py:xxx
37    >   - simple_authoring.py:xxx
38"""
39import functools
40
41
42# pylint: disable=g-import-not-at-top
43from tensorflow.lite.python import convert
44from tensorflow.lite.python import lite
45from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
46
47
48_CUSTOM_OPS_HDR = "Custom ops: "
49_TF_OPS_HDR = "TF Select ops: "
50_AUTHORING_ERROR_HDR = "COMPATIBILITY ERROR"
51_AUTHORING_WARNING_HDR = "COMPATIBILITY WARNING"
52_FUNC_GRAPH_SRC_PATH = "tensorflow/python/framework/func_graph.py"
53
54
55class CompatibilityError(Exception):
56  """Raised when an error occurs with TFLite compatibility."""
57  pass
58
59
60class _Compatible:
61  """A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`."""
62
63  def __init__(self,
64               target,
65               converter_target_spec=None,
66               converter_allow_custom_ops=None,
67               raise_exception=False):
68    """Initialize the decorator object.
69
70    Here is the description of the object variables.
71    - _func     : decorated function.
72    - _obj_func : for class object, we need to use this object to provide `self`
73                  instance as 1 first argument.
74    - _verified : whether the compatibility is checked or not.
75
76    Args:
77      target: decorated function.
78      converter_target_spec : target_spec of TFLite converter parameter.
79      converter_allow_custom_ops : allow_custom_ops of TFLite converter
80          parameter.
81      raise_exception : to raise an exception on compatibility issues.
82          User need to use get_compatibility_log() to check details.
83    """
84    functools.update_wrapper(self, target)
85    self._func = target
86    self._obj_func = None
87    self._verified = False
88    self._log_messages = []
89    self._raise_exception = raise_exception
90    self._converter_target_spec = converter_target_spec
91    self._converter_allow_custom_ops = converter_allow_custom_ops
92
93  def __get__(self, instance, cls):
94    """A Python descriptor interface."""
95    self._obj_func = self._func.__get__(instance, cls)
96    return self
97
98  def _get_func(self):
99    """Returns decorated function object.
100
101    For a class method, use self._obj_func to provide `self` instance.
102    """
103    if self._obj_func is not None:
104      return self._obj_func
105    else:
106      return self._func
107
108  def __call__(self, *args, **kwargs):  # pylint: disable=g-doc-args
109    """Calls decorated function object.
110
111    Also verifies if the function is compatible with TFLite.
112
113    Returns:
114      A execution result of the decorated function.
115    """
116
117    if not self._verified:
118      model = self._get_func()
119      concrete_func = model.get_concrete_function(*args, **kwargs)
120      converter = lite.TFLiteConverterV2.from_concrete_functions(
121          [concrete_func], model)
122      # Set provided converter parameters
123      if self._converter_target_spec is not None:
124        converter.target_spec = self._converter_target_spec
125      if self._converter_allow_custom_ops is not None:
126        converter.allow_custom_ops = self._converter_allow_custom_ops
127      try:
128        converter.convert()
129      except convert.ConverterError as err:
130        self._decode_error(err)
131      finally:
132        self._verified = True
133
134    return self._get_func()(*args, **kwargs)
135
136  def get_concrete_function(self, *args, **kwargs):
137    """Returns a concrete function of the decorated function."""
138    return self._get_func().get_concrete_function(*args, **kwargs)
139
140  def _dump_error_details(self, ops, locations):
141    """Dump the list of ops and locations."""
142    for i in range(0, len(ops)):
143      callstack = []
144      for single_call in locations[i].call:
145        if (locations[i].type ==
146            converter_error_data_pb2.ConverterErrorData.CALLSITELOC):
147          # Stop showing CallSite after func_graph.py which isn't meaningful.
148          if _FUNC_GRAPH_SRC_PATH in single_call.source.filename:
149            break
150          callstack.append(
151              f"  - {single_call.source.filename}:{single_call.source.line}")
152        else:
153          callstack.append(str(single_call))
154      callstack_dump = "\n".join(callstack)
155      err_string = f"Op: {ops[i]}\n{callstack_dump}\n"
156      self._log(err_string)
157
158  def _decode_error_legacy(self, err):
159    """Parses the given legacy ConverterError for OSS."""
160    for line in str(err).splitlines():
161      # Check custom op usage error.
162      if line.startswith(_CUSTOM_OPS_HDR):
163        custom_ops = line[len(_CUSTOM_OPS_HDR):]
164        err_string = (
165            f"{_AUTHORING_ERROR_HDR}: op '{custom_ops}' is(are) not natively "
166            "supported by TensorFlow Lite. You need to provide a custom "
167            "operator. https://www.tensorflow.org/lite/guide/ops_custom")
168        self._log(err_string)
169      # Check TensorFlow op usage error.
170      elif line.startswith(_TF_OPS_HDR):
171        tf_ops = line[len(_TF_OPS_HDR):]
172        err_string = (
173            f"{_AUTHORING_WARNING_HDR}: op '{tf_ops}' require(s) \"Select TF "
174            "Ops\" for model conversion for TensorFlow Lite. "
175            "https://www.tensorflow.org/lite/guide/ops_select")
176        self._log(err_string)
177
178  def _decode_converter_error(self, err):
179    """Parses the given ConverterError which has detailed error information."""
180    custom_ops = []
181    custom_ops_location = []
182    tf_ops = []
183    tf_ops_location = []
184    for err in err.errors:
185      # Check custom op usage error.
186      if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS:
187        custom_ops.append(err.operator.name)
188        custom_ops_location.append(err.location)
189      # Check TensorFlow op usage error.
190      elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS:
191        tf_ops.append(err.operator.name)
192        tf_ops_location.append(err.location)
193
194    if custom_ops:
195      custom_ops_str = ", ".join(sorted(custom_ops))
196      err_string = (
197          f"{_AUTHORING_ERROR_HDR}: op '{custom_ops_str}' is(are) not natively "
198          "supported by TensorFlow Lite. You need to provide a custom "
199          "operator. https://www.tensorflow.org/lite/guide/ops_custom")
200      self._log(err_string)
201      self._dump_error_details(custom_ops, custom_ops_location)
202
203    if tf_ops:
204      tf_ops_str = ", ".join(sorted(tf_ops))
205      err_string = (
206          f"{_AUTHORING_WARNING_HDR}: op '{tf_ops_str}' require(s) \"Select TF"
207          " Ops\" for model conversion for TensorFlow Lite. "
208          "https://www.tensorflow.org/lite/guide/ops_select")
209      self._log(err_string)
210      self._dump_error_details(tf_ops, tf_ops_location)
211
212  def _decode_error(self, err):
213    """Parses the given ConverterError and generates compatibility warnings."""
214    if hasattr(err, "errors"):
215      self._decode_converter_error(err)
216    else:
217      self._decode_error_legacy(err)
218
219    if self._raise_exception and self._log_messages:
220      raise CompatibilityError(f"CompatibilityException at {repr(self._func)}")
221
222  def _log(self, message):
223    """Log and print authoring warning / error message."""
224    self._log_messages.append(message)
225    print(message)
226
227  def get_compatibility_log(self):
228    """Returns list of compatibility log messages.
229
230    WARNING: This method should only be used for unit tests.
231
232    Returns:
233      The list of log messages by the recent compatibility check.
234    Raises:
235      RuntimeError: when the compatibility was NOT checked.
236    """
237    if not self._verified:
238      raise RuntimeError("target compatibility isn't verified yet")
239    return self._log_messages
240
241
242def compatible(target=None, converter_target_spec=None, **kwargs):
243  """Wraps `tf.function` into a callable function with TFLite compatibility checking.
244
245  Args:
246    target: A `tf.function` to decorate.
247    converter_target_spec : target_spec of TFLite converter parameter.
248    **kwargs: The keyword arguments of the decorator class _Compatible.
249
250  Returns:
251     A callable object of `tf.lite.experimental.authoring._Compatible`.
252  """
253  if target is None:
254    def wrapper(target):
255      return _Compatible(target, converter_target_spec, **kwargs)
256
257    return wrapper
258  else:
259    return _Compatible(target, converter_target_spec, **kwargs)
260