1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for deserializing `Function`s.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23from absl import logging 24 25from tensorflow.core.framework import function_pb2 26from tensorflow.core.protobuf import saved_object_graph_pb2 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import function as function_lib 29from tensorflow.python.framework import func_graph as func_graph_lib 30from tensorflow.python.framework import function_def_to_graph as function_def_lib 31from tensorflow.python.framework import op_def_registry 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import type_spec 35from tensorflow.python.ops import custom_gradient 36from tensorflow.python.ops import default_gradient 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.saved_model import nested_structure_coder 39from tensorflow.python.util import compat 40from tensorflow.python.util import nest 41from tensorflow.python.util import tf_decorator 42from tensorflow.python.util import tf_inspect 43 44 45def _is_tensor(t): 46 return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 47 48 49# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the 50# structured signature. 51def _call_concrete_function(function, inputs): 52 """Calls a restored Function with structured inputs. 53 54 This differs from `function.__call__` in that inputs and outputs are 55 structured and that it casts inputs to tensors if needed. 56 57 Note: this does not checks that non-tensor inputs match. That should be 58 done before via `_concrete_function_callable_with`. 59 60 Args: 61 function: ConcreteFunction to call. 62 inputs: Structured inputs compatible with 63 `function.graph.structured_input_signature`. 64 65 Returns: 66 The structured function output. 67 """ 68 expected_structure = function.graph.structured_input_signature 69 flatten_inputs = nest.flatten_up_to( 70 expected_structure, inputs, expand_composites=True) 71 flatten_expected = nest.flatten(expected_structure, expand_composites=True) 72 tensor_inputs = [] 73 for arg, expected in zip(flatten_inputs, flatten_expected): 74 if isinstance(expected, tensor_spec.TensorSpec): 75 tensor_inputs.append( 76 ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) 77 elif isinstance(expected, resource_variable_ops.VariableSpec): 78 tensor_inputs.append(arg) 79 result = function._call_flat(tensor_inputs, function._captured_inputs) # pylint: disable=protected-access 80 if isinstance(result, ops.Operation): 81 return None 82 return result 83 84 85def _try_convert_to_tensor_spec(arg, dtype_hint): 86 """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" 87 try: 88 # Note: try conversion in a FuncGraph to avoid polluting current context. 89 with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): 90 result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) 91 return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) 92 except (TypeError, ValueError): 93 return None 94 95 96def _concrete_function_callable_with(function, inputs, allow_conversion): 97 """Returns whether concrete `function` can be called with `inputs`.""" 98 expected_structure = function.graph.structured_input_signature 99 try: 100 flatten_inputs = nest.flatten_up_to(expected_structure, inputs) 101 except (TypeError, ValueError): 102 return False 103 104 for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): 105 if isinstance(expected, tensor_spec.TensorSpec): 106 if allow_conversion: 107 arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) 108 if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): 109 return False 110 if arg.dtype != expected.dtype: 111 return False 112 if not expected.shape.is_compatible_with(arg.shape): 113 return False 114 elif isinstance(expected, type_spec.TypeSpec): 115 if not expected.is_compatible_with(arg): 116 return False 117 elif _is_tensor(arg): 118 if id(arg) != id(expected): 119 return False 120 else: 121 if arg != expected: 122 return False 123 return True 124 125 126def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder): 127 """Deserialize a FunctionSpec object from its proto representation.""" 128 typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec) 129 130 # Convert a method function into a non method. 131 if function_spec_proto.is_method: 132 if not typeless_fullargspec.args: 133 raise NotImplementedError( 134 "Missing support to deserialize a method function without a named " 135 "'self' argument.") 136 args = typeless_fullargspec.args[1:] 137 else: 138 args = typeless_fullargspec.args 139 140 fullargspec = tf_inspect.FullArgSpec( 141 args=args, 142 varargs=typeless_fullargspec.varargs, 143 varkw=typeless_fullargspec.varkw, 144 defaults=typeless_fullargspec.defaults, 145 kwonlyargs=typeless_fullargspec.kwonlyargs, 146 kwonlydefaults=typeless_fullargspec.kwonlydefaults, 147 annotations=typeless_fullargspec.annotations) 148 input_signature = coder.decode_proto(function_spec_proto.input_signature) 149 150 # See `tf.function` and the JitCompile proto for details. 151 jit_compile = { 152 saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None, 153 saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True, 154 saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False, 155 }.get(function_spec_proto.jit_compile) 156 157 return function_lib.FunctionSpec(fullargspec=fullargspec, 158 is_method=False, 159 input_signature=input_signature, 160 jit_compile=jit_compile) 161 162 163# TODO(allenl): The fact that we can't derive ConcreteFunction calling 164# conventions from the serialized input spec right now is unfortunate. Merging 165# these would be good, maybe by adding TensorSpec names to cache keys so renamed 166# keyword arguments would yield different ConcreteFunctions. 167def setup_bare_concrete_function(saved_bare_concrete_function, 168 concrete_functions): 169 """Makes a restored bare concrete function callable.""" 170 concrete_function = concrete_functions[ 171 saved_bare_concrete_function.concrete_function_name] 172 # pylint: disable=protected-access 173 concrete_function._arg_keywords = ( 174 saved_bare_concrete_function.argument_keywords) 175 concrete_function._num_positional_args = ( 176 saved_bare_concrete_function.allowed_positional_arguments) 177 if saved_bare_concrete_function.HasField("function_spec"): 178 coder = nested_structure_coder.StructureCoder() 179 function_spec = _deserialize_function_spec_as_nonmethod( 180 saved_bare_concrete_function.function_spec, 181 coder) 182 concrete_function._set_function_spec(function_spec) 183 # pylint: enable=protected-access 184 concrete_function.add_to_graph() 185 return concrete_function 186 187 188class RestoredFunction(def_function.Function): 189 """Wrapper class for a function that has been restored from saved state. 190 191 See `def_function.Function`. 192 """ 193 194 def __init__(self, python_function, name, function_spec, concrete_functions): 195 # TODO(mdan): We may enable autograph once exceptions are supported. 196 super(RestoredFunction, self).__init__( 197 python_function, name, autograph=False, 198 jit_compile=function_spec.jit_compile) 199 self.concrete_functions = concrete_functions 200 self._function_spec = function_spec 201 202 # Prevent RestoredFunction from spamming users with frequent tracing 203 # warnings. 204 self._omit_frequent_tracing_warning = True 205 206 @property 207 def _run_functions_eagerly(self): 208 # We do not have access to the original python function, and thus, we 209 # cannot meaningfully do anything but call our concrete function graphs 210 # under the hood. 211 # 212 # Attempting to call our bespoke python function (i.e. 213 # `restored_function_body`) will work so long as the user passes in all 214 # required and optional arguments. If an optional argument is missing, 215 # however, the call will break. For this reason, we instead skip the 216 # eager call path altogether if a user has enabled eager function execution 217 # via `tf.config.run_functions_eagerly`. 218 return False 219 220 def _list_all_concrete_functions_for_serialization(self): 221 return self.concrete_functions 222 223 def _defun_with_scope(self, scope): 224 func = super(RestoredFunction, self)._defun_with_scope(scope) 225 func._function_spec = self._function_spec # pylint: disable=protected-access 226 return func 227 228 229def recreate_function(saved_function, concrete_functions): 230 """Creates a `Function` from a `SavedFunction`. 231 232 Args: 233 saved_function: `SavedFunction` proto. 234 concrete_functions: map from function name to `ConcreteFunction`. 235 As a side effect of this function, the `FunctionSpec` from 236 `saved_function` is added to each `ConcreteFunction` in this map. 237 238 Returns: 239 A `Function`. 240 """ 241 # TODO(andresp): Construct a `Function` with the cache populated 242 # instead of creating a new `Function` backed by a Python layer to 243 # glue things together. Current approach is nesting functions deeper for each 244 # serialization cycle. 245 coder = nested_structure_coder.StructureCoder() 246 247 # Note: handling method functions is tricky since make_decorator does not 248 # allows control of "ismethod". Additionally since restored functions do 249 # not behave as methods i.e. they always use the same captured tensors 250 # independent of the object they are bound to, there is little value on 251 # propagating that correctly. 252 # 253 # Ideally this conversion should happen at serialization time. But since 254 # there are SavedModels which have "ismethod" populated and have an extra 255 # argument that they expect to be ignored, we do it at deserialization. 256 function_spec = _deserialize_function_spec_as_nonmethod( 257 saved_function.function_spec, 258 coder) 259 260 def restored_function_body(*args, **kwargs): 261 """Calls a restored function or raises an error if no matching function.""" 262 if not saved_function.concrete_functions: 263 raise ValueError("Found zero restored functions for caller function.") 264 # This is the format of function.graph.structured_input_signature. At this 265 # point, the args and kwargs have already been canonicalized. 266 inputs = (args, kwargs) 267 268 # First try to find a concrete function that can be called without input 269 # conversions. This allows one to pick a more specific trace in case there 270 # was also a more expensive one that supported tensors. 271 for allow_conversion in [False, True]: 272 for function_name in saved_function.concrete_functions: 273 function = concrete_functions[function_name] 274 if _concrete_function_callable_with(function, inputs, allow_conversion): 275 return _call_concrete_function(function, inputs) 276 277 signature_descriptions = [] 278 279 def _pretty_format_positional(positional): 280 return "Positional arguments ({} total):\n * {}".format( 281 len(positional), "\n * ".join(str(a) for a in positional)) 282 283 for index, function_name in enumerate(saved_function.concrete_functions): 284 concrete_function = concrete_functions[function_name] 285 positional, keyword = concrete_function.structured_input_signature 286 signature_descriptions.append( 287 "Option {}:\n {}\n Keyword arguments: {}" 288 .format(index + 1, _pretty_format_positional(positional), keyword)) 289 raise ValueError( 290 "Could not find matching function to call loaded from the SavedModel. " 291 "Got:\n {}\n Keyword arguments: {}\n\nExpected " 292 "these arguments to match one of the following {} option(s):\n\n{}" 293 .format(_pretty_format_positional(args), kwargs, 294 len(saved_function.concrete_functions), 295 "\n\n".join(signature_descriptions))) 296 297 concrete_function_objects = [] 298 for concrete_function_name in saved_function.concrete_functions: 299 concrete_function_objects.append(concrete_functions[concrete_function_name]) 300 301 for cf in concrete_function_objects: 302 cf._set_function_spec(function_spec) # pylint: disable=protected-access 303 304 restored_function = RestoredFunction( 305 restored_function_body, 306 restored_function_body.__name__, 307 function_spec, 308 concrete_function_objects) 309 310 return tf_decorator.make_decorator( 311 restored_function_body, 312 restored_function, 313 decorator_argspec=function_spec.fullargspec) 314 315 316def load_function_def_library(library, 317 load_shared_name_suffix=None, 318 wrapper_function=None): 319 """Load a set of functions as concrete functions without captured inputs. 320 321 Functions names are manipulated during load such that they do not overlap 322 with previously created ones. 323 324 Gradients are re-registered under new names. Ops that reference the gradients 325 are updated to reflect the new registered names. 326 327 Args: 328 library: FunctionDefLibrary proto message. 329 load_shared_name_suffix: If specified, used to uniquify shared 330 names. Otherwise, a unique name is generated. 331 wrapper_function: An object that will be wrapped on newly created functions. 332 333 Returns: 334 Map of original function names in the library to instances of 335 `ConcreteFunction` without captured inputs. 336 337 Raises: 338 ValueError: if functions dependencies have a cycle. 339 """ 340 library_function_names = set(fdef.signature.name for fdef in library.function) 341 functions = {} 342 renamed_functions = {} 343 344 # Our graph building code currently requires functions to be registered with 345 # some tf.Graph in order to import functions using the 346 # op-name-is-function-name calling convention. To avoid leaking memory into 347 # the global default graph when executing eagerly, we create a temporary 348 # Graph. 349 # 350 # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by 351 # fixing function_def_to_graph_def. 352 if ops.executing_eagerly_outside_functions(): 353 graph = ops.Graph() 354 else: 355 graph = ops.get_default_graph() 356 357 if load_shared_name_suffix is None: 358 load_shared_name_suffix = "_load_{}".format(ops.uid()) 359 360 # Custom gradient functions must be re-registered under new UIDs. 361 library_gradient_names = {} # Maps old op type to old function name 362 new_gradient_op_types = {} # Maps old gradient op type to new op type. 363 gradients_to_register = {} # Maps old function name to new op type 364 for gdef in library.registered_gradients: 365 if gdef.registered_op_type: 366 new_op_type = custom_gradient.generate_name() 367 old_op_type = compat.as_bytes(gdef.registered_op_type) 368 369 library_gradient_names[old_op_type] = gdef.gradient_func 370 new_gradient_op_types[old_op_type] = new_op_type 371 gradients_to_register[gdef.gradient_func] = new_op_type 372 373 function_deps = {} 374 for fdef in library.function: 375 function_deps[fdef.signature.name] = _list_function_deps( 376 fdef, library_function_names, library_gradient_names) 377 378 loaded_gradients = {} 379 for fdef in _sort_function_defs(library, function_deps): 380 copy = _fix_fdef(fdef, functions, load_shared_name_suffix, 381 new_gradient_op_types) 382 383 # There is no need to copy all functions into the function def graph. It 384 # leads to a O(n^2) increase of memory when importing functions and the 385 # extra function definitions are a no-op since they already imported as a 386 # function before and passed in explicitly (due to the topologic sort 387 # import). 388 with graph.as_default(): 389 func_graph = function_def_lib.function_def_to_graph(copy) 390 # Restores gradients for function-call ops (not the same as ops that use 391 # custom gradients) 392 _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients) 393 394 for dep in function_deps[fdef.signature.name]: 395 functions[dep].add_to_graph(func_graph) 396 397 # We do not initialize the new ConcreteFunction's function_spec and/or 398 # arg_keywords here (which are used to parse the structured and flat 399 # signatures, respectively). ConcreteFunction that are part of a saved 400 # function is set up later by recreate_function(); and bare ConcreteFunction 401 # is set up by by setup_bare_concrete_function(). 402 # However, we copy the FunctionDef attributes to the new ConcreteFunction, 403 # excluding the "_input_shapes", which may cause an error during input shape 404 # initialization at a later stage. 405 if "_input_shapes" in copy.attr: 406 del copy.attr["_input_shapes"] 407 func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr) 408 if wrapper_function: 409 func = wrapper_function(func) 410 func.add_to_graph(graph) 411 412 functions[fdef.signature.name] = func 413 renamed_functions[func.name] = func 414 if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): 415 # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration 416 # is fixed. Currently it's leaking memory to maintain bug compatibility 417 # with previous behavior. 418 func.add_to_graph(ops.get_default_graph()) 419 420 if fdef.signature.name in gradients_to_register: 421 gradient_op_type = gradients_to_register[fdef.signature.name] 422 loaded_gradients[compat.as_bytes(gradient_op_type)] = func 423 ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func)) 424 425 return functions 426 427 428def _gen_gradient_func(func): 429 """Wraps a deserialized function.""" 430 431 def gradient_func(unused_op, *result_grads): 432 # Replace all `None` arguments, because the traced custom gradient function 433 # expects tensors. Replacing with zeros is correct since the `None` values 434 # occur when the gradient is unconnected, and thus the gradient is 435 # "statically proven to be zero." See `tf.UnconnectedGradients` for details. 436 result_grads = [x if x is not None else default_gradient.zeros_like(t) 437 for (x, t) in zip(result_grads, func.graph.inputs)] 438 439 return func(*result_grads) 440 441 return gradient_func 442 443 444def _restore_gradient_functions(func_graph, renamed_functions, 445 loaded_gradients): 446 """Populate function op's _gradient_function with default gradient.""" 447 for op in func_graph.get_operations(): 448 # TODO(andresp): This code assumes that the gradient registered for this 449 # function call is the default gradient for the function and not a custom 450 # one. 451 if op.type in ["StatefulPartitionedCall", "PartitionedCall"]: 452 function = renamed_functions[compat.as_bytes( 453 op.node_def.attr["f"].func.name)] 454 op._gradient_function = function._get_gradient_function() # pylint: disable=protected-access 455 try: 456 gradient_op_type = op.get_attr("_gradient_op_type") 457 except ValueError: 458 pass 459 else: 460 if gradient_op_type in loaded_gradients: 461 grad_fn = loaded_gradients[gradient_op_type] 462 grad_fn._num_positional_args = len(op.inputs) # pylint: disable=protected-access 463 grad_fn._arg_keywords = [inp.name for inp in op.inputs] # pylint: disable=protected-access 464 465 466def _sort_function_defs(library, function_deps): 467 """Return a topologic sort of FunctionDefs in a library.""" 468 edges = collections.defaultdict(list) 469 in_count = collections.defaultdict(lambda: 0) 470 471 for fname, deps in function_deps.items(): 472 for dep in deps: 473 edges[dep].append(fname) 474 in_count[fname] += 1 475 ready = [ 476 fdef.signature.name 477 for fdef in library.function 478 if in_count[fdef.signature.name] == 0 479 ] 480 output = [] 481 while ready: 482 node = ready.pop() 483 output.append(node) 484 for dest in edges[node]: 485 in_count[dest] -= 1 486 if not in_count[dest]: 487 ready.append(dest) 488 489 if len(output) != len(library.function): 490 failed_to_resolve = sorted(set(in_count.keys()) - set(output)) 491 raise ValueError("There is a cyclic-dependency between functions. ", 492 "Could not resolve %r." % (failed_to_resolve,)) 493 494 reverse = {fdef.signature.name: fdef for fdef in library.function} 495 return [reverse[x] for x in output] 496 497 498def _get_gradient_op_type(node_def): 499 """Returns the custom gradient op type.""" 500 if ("_gradient_op_type" in node_def.attr and 501 node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]): 502 return node_def.attr["_gradient_op_type"].s 503 return None 504 505 506def fix_node_def(node_def, functions, shared_name_suffix): 507 """Replace functions calls and shared names in `node_def`.""" 508 if node_def.op in functions: 509 node_def.op = functions[node_def.op].name 510 for _, attr_value in node_def.attr.items(): 511 if attr_value.WhichOneof("value") == "func": 512 attr_value.func.name = functions[attr_value.func.name].name 513 elif attr_value.WhichOneof("value") == "list": 514 for fn in attr_value.list.func: 515 fn.name = functions[fn.name].name 516 517 # Fix old table creation bug. 518 if node_def.op == "HashTableV2": 519 if ("use_node_name_sharing" not in node_def.attr or 520 not node_def.attr["use_node_name_sharing"].b): 521 node_def.attr["use_node_name_sharing"].b = True 522 # We are turning on node mame sharing, so have to make sure we don't 523 # accidentally share a table resource. 524 shared_name_suffix += "_{}".format(ops.uid()) 525 526 # TODO(b/124205571): Avoid accidental sharing and destruction of restored 527 # resources. For now uniquify "shared_name" when loading functions to avoid 528 # sharing. 529 # TODO: Add regression test for b/150826922. 530 op_def = op_def_registry.get(node_def.op) 531 if op_def: 532 attr = next((a for a in op_def.attr if a.name == "shared_name"), None) 533 if attr: 534 shared_name = None 535 if "shared_name" in node_def.attr and node_def.attr["shared_name"].s: 536 shared_name = node_def.attr["shared_name"].s 537 elif attr.default_value.s: 538 shared_name = compat.as_bytes(attr.default_value.s) 539 if not shared_name: 540 shared_name = compat.as_bytes(node_def.name) 541 542 node_def.attr["shared_name"].s = ( 543 shared_name + compat.as_bytes(shared_name_suffix)) 544 545 546def _fix_fdef(orig_fdef, functions, shared_name_suffix, new_gradient_op_types): 547 """Fixes a FunctionDef proto to be loaded in current context. 548 549 In particular, when loading a function library into an eager context, one 550 must rename the functions to avoid conflicts with existent functions. 551 552 Args: 553 orig_fdef: FunctionDef proto to fix. It is not modified. 554 functions: map from function name to a ConcreteFunction instance. 555 shared_name_suffix: A unique string for this load which helps to avoid 556 `shared_name` collisions across loads. Two functions from the same load 557 using the same `shared_name` still need to share, but functions from 558 different loads with the same `shared_name` should not. 559 new_gradient_op_types: map from old gradient op type to newly generated 560 op type. 561 562 Returns: 563 A fixed copy of the original FunctionDef 564 """ 565 fdef = function_pb2.FunctionDef() 566 fdef.CopyFrom(orig_fdef) 567 contains_unsaved_custom_gradients = False 568 569 for node_def in fdef.node_def: 570 fix_node_def(node_def, functions, shared_name_suffix) 571 op_type = _get_gradient_op_type(node_def) 572 if op_type is not None: 573 if op_type in new_gradient_op_types: 574 node_def.attr["_gradient_op_type"].s = compat.as_bytes( 575 new_gradient_op_types[op_type]) 576 else: 577 contains_unsaved_custom_gradients = True 578 if contains_unsaved_custom_gradients: 579 logging.warning( 580 "Importing a function (%s) with ops with unsaved custom gradients. Will" 581 " likely fail if a gradient is requested.", fdef.signature.name) 582 583 fdef.signature.name = _clean_function_name(fdef.signature.name) 584 return fdef 585 586 587def _list_function_deps(fdef, library_function_names, library_gradient_names): 588 """Find functions referenced in `fdef`.""" 589 # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both 590 # when listing deps and when fixing them. `function_def_to_graph` also 591 # requires fixes. 592 deps = set() 593 for node_def in fdef.node_def: 594 grad_op_type = _get_gradient_op_type(node_def) 595 if node_def.op in library_function_names: 596 deps.add(node_def.op) 597 elif grad_op_type and grad_op_type in library_gradient_names: 598 deps.add(library_gradient_names[grad_op_type]) 599 else: 600 for _, attr_value in node_def.attr.items(): 601 if attr_value.WhichOneof("value") == "func": 602 deps.add(attr_value.func.name) 603 elif attr_value.WhichOneof("value") == "list": 604 for fn in attr_value.list.func: 605 deps.add(fn.name) 606 607 return deps 608 609 610_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX 611 ) # pylint:disable=protected-access 612 613 614def _clean_function_name(name): 615 """Vanity function to keep the function names comprehensible.""" 616 # Note: each time a function is wrapped into `function_lib.ConcreteFunction` 617 # its name becomes "__inference_<orig>_xyz". 618 match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name) 619 if match: 620 return match.group(1) 621 else: 622 return name 623