1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Authoring tool package for TFLite compatibility. 16 17WARNING: The package is experimental and subject to change. 18 19This package provides a way to check TFLite compatibility at model authoring 20time. 21 22Example: 23 @tf.lite.experimental.authoring.compatible 24 @tf.function(input_signature=[ 25 tf.TensorSpec(shape=[None], dtype=tf.float32) 26 ]) 27 def f(x): 28 return tf.cosh(x) 29 30 result = f(tf.constant([0.0])) 31 32 > COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model 33 > conversion for TensorFlow Lite. 34 > Op: tf.Cosh 35 > - tensorflow/python/framework/op_def_library.py:xxx 36 > - tensorflow/python/ops/gen_math_ops.py:xxx 37 > - simple_authoring.py:xxx 38""" 39import functools 40 41 42# pylint: disable=g-import-not-at-top 43from tensorflow.lite.python import convert 44from tensorflow.lite.python import lite 45from tensorflow.lite.python.metrics import converter_error_data_pb2 46from tensorflow.python.util.tf_export import tf_export as _tf_export 47 48 49_CUSTOM_OPS_HDR = "Custom ops: " 50_TF_OPS_HDR = "TF Select ops: " 51_AUTHORING_ERROR_HDR = "COMPATIBILITY ERROR" 52_AUTHORING_WARNING_HDR = "COMPATIBILITY WARNING" 53_FUNC_GRAPH_SRC_PATH = "tensorflow/python/framework/func_graph.py" 54 55 56class CompatibilityError(Exception): 57 """Raised when an error occurs with TFLite compatibility.""" 58 pass 59 60 61class _Compatible: 62 """A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`.""" 63 64 def __init__(self, 65 target, 66 converter_target_spec=None, 67 converter_allow_custom_ops=None, 68 raise_exception=False): 69 """Initialize the decorator object. 70 71 Here is the description of the object variables. 72 - _func : decorated function. 73 - _obj_func : for class object, we need to use this object to provide `self` 74 instance as 1 first argument. 75 - _verified : whether the compatibility is checked or not. 76 77 Args: 78 target: decorated function. 79 converter_target_spec : target_spec of TFLite converter parameter. 80 converter_allow_custom_ops : allow_custom_ops of TFLite converter 81 parameter. 82 raise_exception : to raise an exception on compatibility issues. 83 User need to use get_compatibility_log() to check details. 84 """ 85 functools.update_wrapper(self, target) 86 self._func = target 87 self._obj_func = None 88 self._verified = False 89 self._log_messages = [] 90 self._raise_exception = raise_exception 91 self._converter_target_spec = converter_target_spec 92 self._converter_allow_custom_ops = converter_allow_custom_ops 93 94 def __get__(self, instance, cls): 95 """A Python descriptor interface.""" 96 self._obj_func = self._func.__get__(instance, cls) 97 return self 98 99 def _get_func(self): 100 """Returns decorated function object. 101 102 For a class method, use self._obj_func to provide `self` instance. 103 """ 104 if self._obj_func is not None: 105 return self._obj_func 106 else: 107 return self._func 108 109 def __call__(self, *args, **kwargs): # pylint: disable=g-doc-args 110 """Calls decorated function object. 111 112 Also verifies if the function is compatible with TFLite. 113 114 Returns: 115 A execution result of the decorated function. 116 """ 117 118 if not self._verified: 119 model = self._get_func() 120 concrete_func = model.get_concrete_function(*args, **kwargs) 121 converter = lite.TFLiteConverterV2.from_concrete_functions( 122 [concrete_func], model) 123 # Set provided converter parameters 124 if self._converter_target_spec is not None: 125 converter.target_spec = self._converter_target_spec 126 if self._converter_allow_custom_ops is not None: 127 converter.allow_custom_ops = self._converter_allow_custom_ops 128 try: 129 converter.convert() 130 except convert.ConverterError as err: 131 self._decode_error(err) 132 finally: 133 self._verified = True 134 135 return self._get_func()(*args, **kwargs) 136 137 def get_concrete_function(self, *args, **kwargs): 138 """Returns a concrete function of the decorated function.""" 139 return self._get_func().get_concrete_function(*args, **kwargs) 140 141 def _get_location_string(self, location): 142 """Dump location of ConveterError.errors.location.""" 143 callstack = [] 144 for single_call in location.call: 145 if (location.type == 146 converter_error_data_pb2.ConverterErrorData.CALLSITELOC): 147 # Stop showing CallSite after func_graph.py which isn't meaningful. 148 if _FUNC_GRAPH_SRC_PATH in single_call.source.filename: 149 break 150 callstack.append( 151 f" - {single_call.source.filename}:{single_call.source.line}") 152 else: 153 callstack.append(str(single_call)) 154 callstack_dump = "\n".join(callstack) 155 return callstack_dump 156 157 def _dump_error_details(self, ops, locations): 158 """Dump the list of ops and locations.""" 159 for i in range(0, len(ops)): 160 callstack_dump = self._get_location_string(locations[i]) 161 err_string = f"Op: {ops[i]}\n{callstack_dump}\n" 162 self._log(err_string) 163 164 def _decode_error_legacy(self, err): 165 """Parses the given legacy ConverterError for OSS.""" 166 for line in str(err).splitlines(): 167 # Check custom op usage error. 168 if line.startswith(_CUSTOM_OPS_HDR): 169 custom_ops = line[len(_CUSTOM_OPS_HDR):] 170 err_string = ( 171 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops}' is(are) not natively " 172 "supported by TensorFlow Lite. You need to provide a custom " 173 "operator. https://www.tensorflow.org/lite/guide/ops_custom") 174 self._log(err_string) 175 # Check TensorFlow op usage error. 176 elif line.startswith(_TF_OPS_HDR): 177 tf_ops = line[len(_TF_OPS_HDR):] 178 err_string = ( 179 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops}' require(s) \"Select TF " 180 "Ops\" for model conversion for TensorFlow Lite. " 181 "https://www.tensorflow.org/lite/guide/ops_select") 182 self._log(err_string) 183 184 def _decode_converter_error(self, err): 185 """Parses the given ConverterError which has detailed error information.""" 186 custom_ops = [] 187 custom_ops_location = [] 188 tf_ops = [] 189 tf_ops_location = [] 190 gpu_not_compatible_ops = [] 191 for err in err.errors: 192 # Check custom op usage error. 193 if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS: 194 custom_ops.append(err.operator.name) 195 custom_ops_location.append(err.location) 196 # Check TensorFlow op usage error. 197 elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS: 198 tf_ops.append(err.operator.name) 199 tf_ops_location.append(err.location) 200 # Check GPU delegate compatibility error. 201 elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE: 202 gpu_not_compatible_ops.append(err.operator.name) 203 # Log the first line of ConveterError.errors.error_message only 204 # since the seond line is "Error code: xxxx" 205 self._log(err.error_message.splitlines()[0]) 206 self._log(self._get_location_string(err.location) + "\n") 207 else: 208 # Log other errors. 209 self._log(f"{_AUTHORING_ERROR_HDR}: {err.error_message}") 210 self._log(self._get_location_string(err.location) + "\n") 211 212 if custom_ops: 213 custom_ops_str = ", ".join(sorted(custom_ops)) 214 err_string = ( 215 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops_str}' is(are) not natively " 216 "supported by TensorFlow Lite. You need to provide a custom " 217 "operator. https://www.tensorflow.org/lite/guide/ops_custom") 218 self._log(err_string) 219 self._dump_error_details(custom_ops, custom_ops_location) 220 221 if tf_ops: 222 tf_ops_str = ", ".join(sorted(tf_ops)) 223 err_string = ( 224 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops_str}' require(s) \"Select TF" 225 " Ops\" for model conversion for TensorFlow Lite. " 226 "https://www.tensorflow.org/lite/guide/ops_select") 227 self._log(err_string) 228 self._dump_error_details(tf_ops, tf_ops_location) 229 230 if gpu_not_compatible_ops: 231 not_compatible_ops_str = ", ".join(sorted(gpu_not_compatible_ops)) 232 err_string = ( 233 f"{_AUTHORING_WARNING_HDR}: op '{not_compatible_ops_str}' aren't " 234 "compatible with TensorFlow Lite GPU delegate. " 235 "https://www.tensorflow.org/lite/performance/gpu") 236 self._log(err_string) 237 238 def _decode_error(self, err): 239 """Parses the given ConverterError and generates compatibility warnings.""" 240 if hasattr(err, "errors"): 241 self._decode_converter_error(err) 242 else: 243 self._decode_error_legacy(err) 244 245 if self._raise_exception and self._log_messages: 246 raise CompatibilityError(f"CompatibilityException at {repr(self._func)}") 247 248 def _log(self, message): 249 """Log and print authoring warning / error message.""" 250 self._log_messages.append(message) 251 print(message) 252 253 def get_compatibility_log(self): 254 """Returns list of compatibility log messages. 255 256 WARNING: This method should only be used for unit tests. 257 258 Returns: 259 The list of log messages by the recent compatibility check. 260 Raises: 261 RuntimeError: when the compatibility was NOT checked. 262 """ 263 if not self._verified: 264 raise RuntimeError("target compatibility isn't verified yet") 265 return self._log_messages 266 267 268@_tf_export("lite.experimental.authoring.compatible") 269def compatible(target=None, converter_target_spec=None, **kwargs): 270 """Wraps `tf.function` into a callable function with TFLite compatibility checking. 271 272 Example: 273 274 ```python 275 @tf.lite.experimental.authoring.compatible 276 @tf.function(input_signature=[ 277 tf.TensorSpec(shape=[None], dtype=tf.float32) 278 ]) 279 def f(x): 280 return tf.cosh(x) 281 282 result = f(tf.constant([0.0])) 283 # COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model 284 # conversion for TensorFlow Lite. 285 # Op: tf.Cosh 286 # - tensorflow/python/framework/op_def_library.py:748 287 # - tensorflow/python/ops/gen_math_ops.py:2458 288 # - <stdin>:6 289 ``` 290 291 WARNING: Experimental interface, subject to change. 292 293 Args: 294 target: A `tf.function` to decorate. 295 converter_target_spec : target_spec of TFLite converter parameter. 296 **kwargs: The keyword arguments of the decorator class _Compatible. 297 298 Returns: 299 A callable object of `tf.lite.experimental.authoring._Compatible`. 300 """ 301 if target is None: 302 def wrapper(target): 303 return _Compatible(target, converter_target_spec, **kwargs) 304 return wrapper 305 else: 306 return _Compatible(target, converter_target_spec, **kwargs) 307