• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Authoring tool package for TFLite compatibility.
16
17WARNING: The package is experimental and subject to change.
18
19This package provides a way to check TFLite compatibility at model authoring
20time.
21
22Example:
23    @tf.lite.experimental.authoring.compatible
24    @tf.function(input_signature=[
25        tf.TensorSpec(shape=[None], dtype=tf.float32)
26    ])
27    def f(x):
28      return tf.cosh(x)
29
30    result = f(tf.constant([0.0]))
31
32    > COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
33    > conversion for TensorFlow Lite.
34    > Op: tf.Cosh
35    >   - tensorflow/python/framework/op_def_library.py:xxx
36    >   - tensorflow/python/ops/gen_math_ops.py:xxx
37    >   - simple_authoring.py:xxx
38"""
39import functools
40
41
42# pylint: disable=g-import-not-at-top
43from tensorflow.lite.python import convert
44from tensorflow.lite.python import lite
45from tensorflow.lite.python.metrics import converter_error_data_pb2
46from tensorflow.python.util.tf_export import tf_export as _tf_export
47
48
49_CUSTOM_OPS_HDR = "Custom ops: "
50_TF_OPS_HDR = "TF Select ops: "
51_AUTHORING_ERROR_HDR = "COMPATIBILITY ERROR"
52_AUTHORING_WARNING_HDR = "COMPATIBILITY WARNING"
53_FUNC_GRAPH_SRC_PATH = "tensorflow/python/framework/func_graph.py"
54
55
56class CompatibilityError(Exception):
57  """Raised when an error occurs with TFLite compatibility."""
58  pass
59
60
61class _Compatible:
62  """A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`."""
63
64  def __init__(self,
65               target,
66               converter_target_spec=None,
67               converter_allow_custom_ops=None,
68               raise_exception=False):
69    """Initialize the decorator object.
70
71    Here is the description of the object variables.
72    - _func     : decorated function.
73    - _obj_func : for class object, we need to use this object to provide `self`
74                  instance as 1 first argument.
75    - _verified : whether the compatibility is checked or not.
76
77    Args:
78      target: decorated function.
79      converter_target_spec : target_spec of TFLite converter parameter.
80      converter_allow_custom_ops : allow_custom_ops of TFLite converter
81          parameter.
82      raise_exception : to raise an exception on compatibility issues.
83          User need to use get_compatibility_log() to check details.
84    """
85    functools.update_wrapper(self, target)
86    self._func = target
87    self._obj_func = None
88    self._verified = False
89    self._log_messages = []
90    self._raise_exception = raise_exception
91    self._converter_target_spec = converter_target_spec
92    self._converter_allow_custom_ops = converter_allow_custom_ops
93
94  def __get__(self, instance, cls):
95    """A Python descriptor interface."""
96    self._obj_func = self._func.__get__(instance, cls)
97    return self
98
99  def _get_func(self):
100    """Returns decorated function object.
101
102    For a class method, use self._obj_func to provide `self` instance.
103    """
104    if self._obj_func is not None:
105      return self._obj_func
106    else:
107      return self._func
108
109  def __call__(self, *args, **kwargs):  # pylint: disable=g-doc-args
110    """Calls decorated function object.
111
112    Also verifies if the function is compatible with TFLite.
113
114    Returns:
115      A execution result of the decorated function.
116    """
117
118    if not self._verified:
119      model = self._get_func()
120      concrete_func = model.get_concrete_function(*args, **kwargs)
121      converter = lite.TFLiteConverterV2.from_concrete_functions(
122          [concrete_func], model)
123      # Set provided converter parameters
124      if self._converter_target_spec is not None:
125        converter.target_spec = self._converter_target_spec
126      if self._converter_allow_custom_ops is not None:
127        converter.allow_custom_ops = self._converter_allow_custom_ops
128      try:
129        converter.convert()
130      except convert.ConverterError as err:
131        self._decode_error(err)
132      finally:
133        self._verified = True
134
135    return self._get_func()(*args, **kwargs)
136
137  def get_concrete_function(self, *args, **kwargs):
138    """Returns a concrete function of the decorated function."""
139    return self._get_func().get_concrete_function(*args, **kwargs)
140
141  def _get_location_string(self, location):
142    """Dump location of ConveterError.errors.location."""
143    callstack = []
144    for single_call in location.call:
145      if (location.type ==
146          converter_error_data_pb2.ConverterErrorData.CALLSITELOC):
147        # Stop showing CallSite after func_graph.py which isn't meaningful.
148        if _FUNC_GRAPH_SRC_PATH in single_call.source.filename:
149          break
150        callstack.append(
151            f"  - {single_call.source.filename}:{single_call.source.line}")
152      else:
153        callstack.append(str(single_call))
154    callstack_dump = "\n".join(callstack)
155    return callstack_dump
156
157  def _dump_error_details(self, ops, locations):
158    """Dump the list of ops and locations."""
159    for i in range(0, len(ops)):
160      callstack_dump = self._get_location_string(locations[i])
161      err_string = f"Op: {ops[i]}\n{callstack_dump}\n"
162      self._log(err_string)
163
164  def _decode_error_legacy(self, err):
165    """Parses the given legacy ConverterError for OSS."""
166    for line in str(err).splitlines():
167      # Check custom op usage error.
168      if line.startswith(_CUSTOM_OPS_HDR):
169        custom_ops = line[len(_CUSTOM_OPS_HDR):]
170        err_string = (
171            f"{_AUTHORING_ERROR_HDR}: op '{custom_ops}' is(are) not natively "
172            "supported by TensorFlow Lite. You need to provide a custom "
173            "operator. https://www.tensorflow.org/lite/guide/ops_custom")
174        self._log(err_string)
175      # Check TensorFlow op usage error.
176      elif line.startswith(_TF_OPS_HDR):
177        tf_ops = line[len(_TF_OPS_HDR):]
178        err_string = (
179            f"{_AUTHORING_WARNING_HDR}: op '{tf_ops}' require(s) \"Select TF "
180            "Ops\" for model conversion for TensorFlow Lite. "
181            "https://www.tensorflow.org/lite/guide/ops_select")
182        self._log(err_string)
183
184  def _decode_converter_error(self, err):
185    """Parses the given ConverterError which has detailed error information."""
186    custom_ops = []
187    custom_ops_location = []
188    tf_ops = []
189    tf_ops_location = []
190    gpu_not_compatible_ops = []
191    for err in err.errors:
192      # Check custom op usage error.
193      if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS:
194        custom_ops.append(err.operator.name)
195        custom_ops_location.append(err.location)
196      # Check TensorFlow op usage error.
197      elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS:
198        tf_ops.append(err.operator.name)
199        tf_ops_location.append(err.location)
200      # Check GPU delegate compatibility error.
201      elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE:
202        gpu_not_compatible_ops.append(err.operator.name)
203        # Log the first line of ConveterError.errors.error_message only
204        # since the seond line is "Error code: xxxx"
205        self._log(err.error_message.splitlines()[0])
206        self._log(self._get_location_string(err.location) + "\n")
207      else:
208        # Log other errors.
209        self._log(f"{_AUTHORING_ERROR_HDR}: {err.error_message}")
210        self._log(self._get_location_string(err.location) + "\n")
211
212    if custom_ops:
213      custom_ops_str = ", ".join(sorted(custom_ops))
214      err_string = (
215          f"{_AUTHORING_ERROR_HDR}: op '{custom_ops_str}' is(are) not natively "
216          "supported by TensorFlow Lite. You need to provide a custom "
217          "operator. https://www.tensorflow.org/lite/guide/ops_custom")
218      self._log(err_string)
219      self._dump_error_details(custom_ops, custom_ops_location)
220
221    if tf_ops:
222      tf_ops_str = ", ".join(sorted(tf_ops))
223      err_string = (
224          f"{_AUTHORING_WARNING_HDR}: op '{tf_ops_str}' require(s) \"Select TF"
225          " Ops\" for model conversion for TensorFlow Lite. "
226          "https://www.tensorflow.org/lite/guide/ops_select")
227      self._log(err_string)
228      self._dump_error_details(tf_ops, tf_ops_location)
229
230    if gpu_not_compatible_ops:
231      not_compatible_ops_str = ", ".join(sorted(gpu_not_compatible_ops))
232      err_string = (
233          f"{_AUTHORING_WARNING_HDR}: op '{not_compatible_ops_str}' aren't "
234          "compatible with TensorFlow Lite GPU delegate. "
235          "https://www.tensorflow.org/lite/performance/gpu")
236      self._log(err_string)
237
238  def _decode_error(self, err):
239    """Parses the given ConverterError and generates compatibility warnings."""
240    if hasattr(err, "errors"):
241      self._decode_converter_error(err)
242    else:
243      self._decode_error_legacy(err)
244
245    if self._raise_exception and self._log_messages:
246      raise CompatibilityError(f"CompatibilityException at {repr(self._func)}")
247
248  def _log(self, message):
249    """Log and print authoring warning / error message."""
250    self._log_messages.append(message)
251    print(message)
252
253  def get_compatibility_log(self):
254    """Returns list of compatibility log messages.
255
256    WARNING: This method should only be used for unit tests.
257
258    Returns:
259      The list of log messages by the recent compatibility check.
260    Raises:
261      RuntimeError: when the compatibility was NOT checked.
262    """
263    if not self._verified:
264      raise RuntimeError("target compatibility isn't verified yet")
265    return self._log_messages
266
267
268@_tf_export("lite.experimental.authoring.compatible")
269def compatible(target=None, converter_target_spec=None, **kwargs):
270  """Wraps `tf.function` into a callable function with TFLite compatibility checking.
271
272  Example:
273
274  ```python
275  @tf.lite.experimental.authoring.compatible
276  @tf.function(input_signature=[
277      tf.TensorSpec(shape=[None], dtype=tf.float32)
278  ])
279  def f(x):
280      return tf.cosh(x)
281
282  result = f(tf.constant([0.0]))
283  # COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
284  # conversion for TensorFlow Lite.
285  # Op: tf.Cosh
286  #   - tensorflow/python/framework/op_def_library.py:748
287  #   - tensorflow/python/ops/gen_math_ops.py:2458
288  #   - <stdin>:6
289  ```
290
291  WARNING: Experimental interface, subject to change.
292
293  Args:
294    target: A `tf.function` to decorate.
295    converter_target_spec : target_spec of TFLite converter parameter.
296    **kwargs: The keyword arguments of the decorator class _Compatible.
297
298  Returns:
299     A callable object of `tf.lite.experimental.authoring._Compatible`.
300  """
301  if target is None:
302    def wrapper(target):
303      return _Compatible(target, converter_target_spec, **kwargs)
304    return wrapper
305  else:
306    return _Compatible(target, converter_target_spec, **kwargs)
307