1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Authoring tool package for TFLite compatibility. 16 17WARNING: The package is experimental and subject to change. 18 19This package provides a way to check TFLite compatibility at model authoring 20time. 21 22Example: 23 @tf.lite.experimental.authoring.compatible 24 @tf.function(input_signature=[ 25 tf.TensorSpec(shape=[None], dtype=tf.float32) 26 ]) 27 def f(x): 28 return tf.cosh(x) 29 30 result = f(tf.constant([0.0])) 31 32 > COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model 33 > conversion for TensorFlow Lite. 34 > Op: tf.Cosh 35 > - tensorflow/python/framework/op_def_library.py:xxx 36 > - tensorflow/python/ops/gen_math_ops.py:xxx 37 > - simple_authoring.py:xxx 38""" 39import functools 40 41 42# pylint: disable=g-import-not-at-top 43from tensorflow.lite.python import convert 44from tensorflow.lite.python import lite 45from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2 46 47 48_CUSTOM_OPS_HDR = "Custom ops: " 49_TF_OPS_HDR = "TF Select ops: " 50_AUTHORING_ERROR_HDR = "COMPATIBILITY ERROR" 51_AUTHORING_WARNING_HDR = "COMPATIBILITY WARNING" 52_FUNC_GRAPH_SRC_PATH = "tensorflow/python/framework/func_graph.py" 53 54 55class CompatibilityError(Exception): 56 """Raised when an error occurs with TFLite compatibility.""" 57 pass 58 59 60class _Compatible: 61 """A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`.""" 62 63 def __init__(self, 64 target, 65 converter_target_spec=None, 66 converter_allow_custom_ops=None, 67 raise_exception=False): 68 """Initialize the decorator object. 69 70 Here is the description of the object variables. 71 - _func : decorated function. 72 - _obj_func : for class object, we need to use this object to provide `self` 73 instance as 1 first argument. 74 - _verified : whether the compatibility is checked or not. 75 76 Args: 77 target: decorated function. 78 converter_target_spec : target_spec of TFLite converter parameter. 79 converter_allow_custom_ops : allow_custom_ops of TFLite converter 80 parameter. 81 raise_exception : to raise an exception on compatibility issues. 82 User need to use get_compatibility_log() to check details. 83 """ 84 functools.update_wrapper(self, target) 85 self._func = target 86 self._obj_func = None 87 self._verified = False 88 self._log_messages = [] 89 self._raise_exception = raise_exception 90 self._converter_target_spec = converter_target_spec 91 self._converter_allow_custom_ops = converter_allow_custom_ops 92 93 def __get__(self, instance, cls): 94 """A Python descriptor interface.""" 95 self._obj_func = self._func.__get__(instance, cls) 96 return self 97 98 def _get_func(self): 99 """Returns decorated function object. 100 101 For a class method, use self._obj_func to provide `self` instance. 102 """ 103 if self._obj_func is not None: 104 return self._obj_func 105 else: 106 return self._func 107 108 def __call__(self, *args, **kwargs): # pylint: disable=g-doc-args 109 """Calls decorated function object. 110 111 Also verifies if the function is compatible with TFLite. 112 113 Returns: 114 A execution result of the decorated function. 115 """ 116 117 if not self._verified: 118 model = self._get_func() 119 concrete_func = model.get_concrete_function(*args, **kwargs) 120 converter = lite.TFLiteConverterV2.from_concrete_functions( 121 [concrete_func], model) 122 # Set provided converter parameters 123 if self._converter_target_spec is not None: 124 converter.target_spec = self._converter_target_spec 125 if self._converter_allow_custom_ops is not None: 126 converter.allow_custom_ops = self._converter_allow_custom_ops 127 try: 128 converter.convert() 129 except convert.ConverterError as err: 130 self._decode_error(err) 131 finally: 132 self._verified = True 133 134 return self._get_func()(*args, **kwargs) 135 136 def get_concrete_function(self, *args, **kwargs): 137 """Returns a concrete function of the decorated function.""" 138 return self._get_func().get_concrete_function(*args, **kwargs) 139 140 def _dump_error_details(self, ops, locations): 141 """Dump the list of ops and locations.""" 142 for i in range(0, len(ops)): 143 callstack = [] 144 for single_call in locations[i].call: 145 if (locations[i].type == 146 converter_error_data_pb2.ConverterErrorData.CALLSITELOC): 147 # Stop showing CallSite after func_graph.py which isn't meaningful. 148 if _FUNC_GRAPH_SRC_PATH in single_call.source.filename: 149 break 150 callstack.append( 151 f" - {single_call.source.filename}:{single_call.source.line}") 152 else: 153 callstack.append(str(single_call)) 154 callstack_dump = "\n".join(callstack) 155 err_string = f"Op: {ops[i]}\n{callstack_dump}\n" 156 self._log(err_string) 157 158 def _decode_error_legacy(self, err): 159 """Parses the given legacy ConverterError for OSS.""" 160 for line in str(err).splitlines(): 161 # Check custom op usage error. 162 if line.startswith(_CUSTOM_OPS_HDR): 163 custom_ops = line[len(_CUSTOM_OPS_HDR):] 164 err_string = ( 165 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops}' is(are) not natively " 166 "supported by TensorFlow Lite. You need to provide a custom " 167 "operator. https://www.tensorflow.org/lite/guide/ops_custom") 168 self._log(err_string) 169 # Check TensorFlow op usage error. 170 elif line.startswith(_TF_OPS_HDR): 171 tf_ops = line[len(_TF_OPS_HDR):] 172 err_string = ( 173 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops}' require(s) \"Select TF " 174 "Ops\" for model conversion for TensorFlow Lite. " 175 "https://www.tensorflow.org/lite/guide/ops_select") 176 self._log(err_string) 177 178 def _decode_converter_error(self, err): 179 """Parses the given ConverterError which has detailed error information.""" 180 custom_ops = [] 181 custom_ops_location = [] 182 tf_ops = [] 183 tf_ops_location = [] 184 for err in err.errors: 185 # Check custom op usage error. 186 if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS: 187 custom_ops.append(err.operator.name) 188 custom_ops_location.append(err.location) 189 # Check TensorFlow op usage error. 190 elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS: 191 tf_ops.append(err.operator.name) 192 tf_ops_location.append(err.location) 193 194 if custom_ops: 195 custom_ops_str = ", ".join(sorted(custom_ops)) 196 err_string = ( 197 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops_str}' is(are) not natively " 198 "supported by TensorFlow Lite. You need to provide a custom " 199 "operator. https://www.tensorflow.org/lite/guide/ops_custom") 200 self._log(err_string) 201 self._dump_error_details(custom_ops, custom_ops_location) 202 203 if tf_ops: 204 tf_ops_str = ", ".join(sorted(tf_ops)) 205 err_string = ( 206 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops_str}' require(s) \"Select TF" 207 " Ops\" for model conversion for TensorFlow Lite. " 208 "https://www.tensorflow.org/lite/guide/ops_select") 209 self._log(err_string) 210 self._dump_error_details(tf_ops, tf_ops_location) 211 212 def _decode_error(self, err): 213 """Parses the given ConverterError and generates compatibility warnings.""" 214 if hasattr(err, "errors"): 215 self._decode_converter_error(err) 216 else: 217 self._decode_error_legacy(err) 218 219 if self._raise_exception and self._log_messages: 220 raise CompatibilityError(f"CompatibilityException at {repr(self._func)}") 221 222 def _log(self, message): 223 """Log and print authoring warning / error message.""" 224 self._log_messages.append(message) 225 print(message) 226 227 def get_compatibility_log(self): 228 """Returns list of compatibility log messages. 229 230 WARNING: This method should only be used for unit tests. 231 232 Returns: 233 The list of log messages by the recent compatibility check. 234 Raises: 235 RuntimeError: when the compatibility was NOT checked. 236 """ 237 if not self._verified: 238 raise RuntimeError("target compatibility isn't verified yet") 239 return self._log_messages 240 241 242def compatible(target=None, converter_target_spec=None, **kwargs): 243 """Wraps `tf.function` into a callable function with TFLite compatibility checking. 244 245 Args: 246 target: A `tf.function` to decorate. 247 converter_target_spec : target_spec of TFLite converter parameter. 248 **kwargs: The keyword arguments of the decorator class _Compatible. 249 250 Returns: 251 A callable object of `tf.lite.experimental.authoring._Compatible`. 252 """ 253 if target is None: 254 def wrapper(target): 255 return _Compatible(target, converter_target_spec, **kwargs) 256 257 return wrapper 258 else: 259 return _Compatible(target, converter_target_spec, **kwargs) 260