• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for deserializing `Function`s."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23from absl import logging
24
25from tensorflow.core.framework import function_pb2
26from tensorflow.core.protobuf import saved_object_graph_pb2
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function as function_lib
29from tensorflow.python.framework import func_graph as func_graph_lib
30from tensorflow.python.framework import function_def_to_graph as function_def_lib
31from tensorflow.python.framework import op_def_registry
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import custom_gradient
36from tensorflow.python.ops import default_gradient
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.saved_model import nested_structure_coder
39from tensorflow.python.util import compat
40from tensorflow.python.util import nest
41from tensorflow.python.util import tf_decorator
42from tensorflow.python.util import tf_inspect
43
44
45def _is_tensor(t):
46  return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
47
48
49# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the
50# structured signature.
51def _call_concrete_function(function, inputs):
52  """Calls a restored Function with structured inputs.
53
54  This differs from `function.__call__` in that inputs and outputs are
55  structured and that it casts inputs to tensors if needed.
56
57  Note: this does not checks that non-tensor inputs match. That should be
58  done before via `_concrete_function_callable_with`.
59
60  Args:
61    function: ConcreteFunction to call.
62    inputs: Structured inputs compatible with
63        `function.graph.structured_input_signature`.
64
65  Returns:
66    The structured function output.
67  """
68  expected_structure = function.graph.structured_input_signature
69  flatten_inputs = nest.flatten_up_to(
70      expected_structure, inputs, expand_composites=True)
71  flatten_expected = nest.flatten(expected_structure, expand_composites=True)
72  tensor_inputs = []
73  for arg, expected in zip(flatten_inputs, flatten_expected):
74    if isinstance(expected, tensor_spec.TensorSpec):
75      tensor_inputs.append(
76          ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
77    elif isinstance(expected, resource_variable_ops.VariableSpec):
78      tensor_inputs.append(arg)
79  result = function._call_flat(tensor_inputs, function._captured_inputs)  # pylint: disable=protected-access
80  if isinstance(result, ops.Operation):
81    return None
82  return result
83
84
85def _try_convert_to_tensor_spec(arg, dtype_hint):
86  """Returns None or TensorSpec obtained if `arg` is converted to tensor."""
87  try:
88    # Note: try conversion in a FuncGraph to avoid polluting current context.
89    with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
90      result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
91      return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
92  except (TypeError, ValueError):
93    return None
94
95
96def _concrete_function_callable_with(function, inputs, allow_conversion):
97  """Returns whether concrete `function` can be called with `inputs`."""
98  expected_structure = function.graph.structured_input_signature
99  try:
100    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
101  except (TypeError, ValueError):
102    return False
103
104  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
105    if isinstance(expected, tensor_spec.TensorSpec):
106      if allow_conversion:
107        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
108      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
109        return False
110      if arg.dtype != expected.dtype:
111        return False
112      if not expected.shape.is_compatible_with(arg.shape):
113        return False
114    elif isinstance(expected, type_spec.TypeSpec):
115      if not expected.is_compatible_with(arg):
116        return False
117    elif _is_tensor(arg):
118      if id(arg) != id(expected):
119        return False
120    else:
121      if arg != expected:
122        return False
123  return True
124
125
126def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
127  """Deserialize a FunctionSpec object from its proto representation."""
128  typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
129
130  # Convert a method function into a non method.
131  if function_spec_proto.is_method:
132    if not typeless_fullargspec.args:
133      raise NotImplementedError(
134          "Missing support to deserialize a method function without a named "
135          "'self' argument.")
136    args = typeless_fullargspec.args[1:]
137  else:
138    args = typeless_fullargspec.args
139
140  fullargspec = tf_inspect.FullArgSpec(
141      args=args,
142      varargs=typeless_fullargspec.varargs,
143      varkw=typeless_fullargspec.varkw,
144      defaults=typeless_fullargspec.defaults,
145      kwonlyargs=typeless_fullargspec.kwonlyargs,
146      kwonlydefaults=typeless_fullargspec.kwonlydefaults,
147      annotations=typeless_fullargspec.annotations)
148  input_signature = coder.decode_proto(function_spec_proto.input_signature)
149
150  # See `tf.function` and the JitCompile proto for details.
151  jit_compile = {
152      saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None,
153      saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True,
154      saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False,
155  }.get(function_spec_proto.jit_compile)
156
157  return function_lib.FunctionSpec(fullargspec=fullargspec,
158                                   is_method=False,
159                                   input_signature=input_signature,
160                                   jit_compile=jit_compile)
161
162
163# TODO(allenl): The fact that we can't derive ConcreteFunction calling
164# conventions from the serialized input spec right now is unfortunate. Merging
165# these would be good, maybe by adding TensorSpec names to cache keys so renamed
166# keyword arguments would yield different ConcreteFunctions.
167def setup_bare_concrete_function(saved_bare_concrete_function,
168                                 concrete_functions):
169  """Makes a restored bare concrete function callable."""
170  concrete_function = concrete_functions[
171      saved_bare_concrete_function.concrete_function_name]
172  # pylint: disable=protected-access
173  concrete_function._arg_keywords = (
174      saved_bare_concrete_function.argument_keywords)
175  concrete_function._num_positional_args = (
176      saved_bare_concrete_function.allowed_positional_arguments)
177  if saved_bare_concrete_function.HasField("function_spec"):
178    coder = nested_structure_coder.StructureCoder()
179    function_spec = _deserialize_function_spec_as_nonmethod(
180        saved_bare_concrete_function.function_spec,
181        coder)
182    concrete_function._set_function_spec(function_spec)
183  # pylint: enable=protected-access
184  concrete_function.add_to_graph()
185  return concrete_function
186
187
188class RestoredFunction(def_function.Function):
189  """Wrapper class for a function that has been restored from saved state.
190
191  See `def_function.Function`.
192  """
193
194  def __init__(self, python_function, name, function_spec, concrete_functions):
195    # TODO(mdan): We may enable autograph once exceptions are supported.
196    super(RestoredFunction, self).__init__(
197        python_function, name, autograph=False,
198        jit_compile=function_spec.jit_compile)
199    self.concrete_functions = concrete_functions
200    self._function_spec = function_spec
201
202    # Prevent RestoredFunction from spamming users with frequent tracing
203    # warnings.
204    self._omit_frequent_tracing_warning = True
205
206  @property
207  def _run_functions_eagerly(self):
208    # We do not have access to the original python function, and thus, we
209    # cannot meaningfully do anything but call our concrete function graphs
210    # under the hood.
211    #
212    # Attempting to call our bespoke python function (i.e.
213    # `restored_function_body`) will work so long as the user passes in all
214    # required and optional arguments. If an optional argument is missing,
215    # however, the call will break. For this reason, we instead skip the
216    # eager call path altogether if a user has enabled eager function execution
217    # via `tf.config.run_functions_eagerly`.
218    return False
219
220  def _list_all_concrete_functions_for_serialization(self):
221    return self.concrete_functions
222
223  def _defun_with_scope(self, scope):
224    func = super(RestoredFunction, self)._defun_with_scope(scope)
225    func._function_spec = self._function_spec  # pylint: disable=protected-access
226    return func
227
228
229def recreate_function(saved_function, concrete_functions):
230  """Creates a `Function` from a `SavedFunction`.
231
232  Args:
233    saved_function: `SavedFunction` proto.
234    concrete_functions: map from function name to `ConcreteFunction`.
235      As a side effect of this function, the `FunctionSpec` from
236      `saved_function` is added to each `ConcreteFunction` in this map.
237
238  Returns:
239    A `Function`.
240  """
241  # TODO(andresp): Construct a `Function` with the cache populated
242  # instead of creating a new `Function` backed by a Python layer to
243  # glue things together. Current approach is nesting functions deeper for each
244  # serialization cycle.
245  coder = nested_structure_coder.StructureCoder()
246
247  # Note: handling method functions is tricky since make_decorator does not
248  # allows control of "ismethod". Additionally since restored functions do
249  # not behave as methods i.e. they always use the same captured tensors
250  # independent of the object they are bound to, there is little value on
251  # propagating that correctly.
252  #
253  # Ideally this conversion should happen at serialization time. But since
254  # there are SavedModels which have "ismethod" populated and have an extra
255  # argument that they expect to be ignored, we do it at deserialization.
256  function_spec = _deserialize_function_spec_as_nonmethod(
257      saved_function.function_spec,
258      coder)
259
260  def restored_function_body(*args, **kwargs):
261    """Calls a restored function or raises an error if no matching function."""
262    if not saved_function.concrete_functions:
263      raise ValueError("Found zero restored functions for caller function.")
264    # This is the format of function.graph.structured_input_signature. At this
265    # point, the args and kwargs have already been canonicalized.
266    inputs = (args, kwargs)
267
268    # First try to find a concrete function that can be called without input
269    # conversions. This allows one to pick a more specific trace in case there
270    # was also a more expensive one that supported tensors.
271    for allow_conversion in [False, True]:
272      for function_name in saved_function.concrete_functions:
273        function = concrete_functions[function_name]
274        if _concrete_function_callable_with(function, inputs, allow_conversion):
275          return _call_concrete_function(function, inputs)
276
277    signature_descriptions = []
278
279    def _pretty_format_positional(positional):
280      return "Positional arguments ({} total):\n    * {}".format(
281          len(positional), "\n    * ".join(str(a) for a in positional))
282
283    for index, function_name in enumerate(saved_function.concrete_functions):
284      concrete_function = concrete_functions[function_name]
285      positional, keyword = concrete_function.structured_input_signature
286      signature_descriptions.append(
287          "Option {}:\n  {}\n  Keyword arguments: {}"
288          .format(index + 1, _pretty_format_positional(positional), keyword))
289    raise ValueError(
290        "Could not find matching function to call loaded from the SavedModel. "
291        "Got:\n  {}\n  Keyword arguments: {}\n\nExpected "
292        "these arguments to match one of the following {} option(s):\n\n{}"
293        .format(_pretty_format_positional(args), kwargs,
294                len(saved_function.concrete_functions),
295                "\n\n".join(signature_descriptions)))
296
297  concrete_function_objects = []
298  for concrete_function_name in saved_function.concrete_functions:
299    concrete_function_objects.append(concrete_functions[concrete_function_name])
300
301  for cf in concrete_function_objects:
302    cf._set_function_spec(function_spec)  # pylint: disable=protected-access
303
304  restored_function = RestoredFunction(
305      restored_function_body,
306      restored_function_body.__name__,
307      function_spec,
308      concrete_function_objects)
309
310  return tf_decorator.make_decorator(
311      restored_function_body,
312      restored_function,
313      decorator_argspec=function_spec.fullargspec)
314
315
316def load_function_def_library(library,
317                              load_shared_name_suffix=None,
318                              wrapper_function=None):
319  """Load a set of functions as concrete functions without captured inputs.
320
321  Functions names are manipulated during load such that they do not overlap
322  with previously created ones.
323
324  Gradients are re-registered under new names. Ops that reference the gradients
325  are updated to reflect the new registered names.
326
327  Args:
328    library: FunctionDefLibrary proto message.
329    load_shared_name_suffix: If specified, used to uniquify shared
330      names. Otherwise, a unique name is generated.
331    wrapper_function: An object that will be wrapped on newly created functions.
332
333  Returns:
334    Map of original function names in the library to instances of
335    `ConcreteFunction` without captured inputs.
336
337  Raises:
338    ValueError: if functions dependencies have a cycle.
339  """
340  library_function_names = set(fdef.signature.name for fdef in library.function)
341  functions = {}
342  renamed_functions = {}
343
344  # Our graph building code currently requires functions to be registered with
345  # some tf.Graph in order to import functions using the
346  # op-name-is-function-name calling convention. To avoid leaking memory into
347  # the global default graph when executing eagerly, we create a temporary
348  # Graph.
349  #
350  # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
351  # fixing function_def_to_graph_def.
352  if ops.executing_eagerly_outside_functions():
353    graph = ops.Graph()
354  else:
355    graph = ops.get_default_graph()
356
357  if load_shared_name_suffix is None:
358    load_shared_name_suffix = "_load_{}".format(ops.uid())
359
360  # Custom gradient functions must be re-registered under new UIDs.
361  library_gradient_names = {}  # Maps old op type to old function name
362  new_gradient_op_types = {}  # Maps old gradient op type to new op type.
363  gradients_to_register = {}  # Maps old function name to new op type
364  for gdef in library.registered_gradients:
365    if gdef.registered_op_type:
366      new_op_type = custom_gradient.generate_name()
367      old_op_type = compat.as_bytes(gdef.registered_op_type)
368
369      library_gradient_names[old_op_type] = gdef.gradient_func
370      new_gradient_op_types[old_op_type] = new_op_type
371      gradients_to_register[gdef.gradient_func] = new_op_type
372
373  function_deps = {}
374  for fdef in library.function:
375    function_deps[fdef.signature.name] = _list_function_deps(
376        fdef, library_function_names, library_gradient_names)
377
378  loaded_gradients = {}
379  for fdef in _sort_function_defs(library, function_deps):
380    copy = _fix_fdef(fdef, functions, load_shared_name_suffix,
381                     new_gradient_op_types)
382
383    # There is no need to copy all functions into the function def graph. It
384    # leads to a O(n^2) increase of memory when importing functions and the
385    # extra function definitions are a no-op since they already imported as a
386    # function before and passed in explicitly (due to the topologic sort
387    # import).
388    with graph.as_default():
389      func_graph = function_def_lib.function_def_to_graph(copy)
390    # Restores gradients for function-call ops (not the same as ops that use
391    # custom gradients)
392    _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients)
393
394    for dep in function_deps[fdef.signature.name]:
395      functions[dep].add_to_graph(func_graph)
396
397    # We do not initialize the new ConcreteFunction's function_spec and/or
398    # arg_keywords here (which are used to parse the structured and flat
399    # signatures, respectively). ConcreteFunction that are part of a saved
400    # function is set up later by recreate_function(); and bare ConcreteFunction
401    # is set up by by setup_bare_concrete_function().
402    # However, we copy the FunctionDef attributes to the new ConcreteFunction,
403    # excluding the "_input_shapes", which may cause an error during input shape
404    # initialization at a later stage.
405    if "_input_shapes" in copy.attr:
406      del copy.attr["_input_shapes"]
407    func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr)
408    if wrapper_function:
409      func = wrapper_function(func)
410    func.add_to_graph(graph)
411
412    functions[fdef.signature.name] = func
413    renamed_functions[func.name] = func
414    if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
415      # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
416      # is fixed. Currently it's leaking memory to maintain bug compatibility
417      # with previous behavior.
418      func.add_to_graph(ops.get_default_graph())
419
420    if fdef.signature.name in gradients_to_register:
421      gradient_op_type = gradients_to_register[fdef.signature.name]
422      loaded_gradients[compat.as_bytes(gradient_op_type)] = func
423      ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func))
424
425  return functions
426
427
428def _gen_gradient_func(func):
429  """Wraps a deserialized function."""
430
431  def gradient_func(unused_op, *result_grads):
432    # Replace all `None` arguments, because the traced custom gradient function
433    # expects tensors. Replacing with zeros is correct since the `None` values
434    # occur when the gradient is unconnected, and thus the gradient is
435    # "statically proven to be zero." See `tf.UnconnectedGradients` for details.
436    result_grads = [x if x is not None else default_gradient.zeros_like(t)
437                    for (x, t) in zip(result_grads, func.graph.inputs)]
438
439    return func(*result_grads)
440
441  return gradient_func
442
443
444def _restore_gradient_functions(func_graph, renamed_functions,
445                                loaded_gradients):
446  """Populate function op's _gradient_function with default gradient."""
447  for op in func_graph.get_operations():
448    # TODO(andresp): This code assumes that the gradient registered for this
449    # function call is the default gradient for the function and not a custom
450    # one.
451    if op.type in ["StatefulPartitionedCall", "PartitionedCall"]:
452      function = renamed_functions[compat.as_bytes(
453          op.node_def.attr["f"].func.name)]
454      op._gradient_function = function._get_gradient_function()  # pylint: disable=protected-access
455    try:
456      gradient_op_type = op.get_attr("_gradient_op_type")
457    except ValueError:
458      pass
459    else:
460      if gradient_op_type in loaded_gradients:
461        grad_fn = loaded_gradients[gradient_op_type]
462        grad_fn._num_positional_args = len(op.inputs)  # pylint: disable=protected-access
463        grad_fn._arg_keywords = [inp.name for inp in op.inputs]  # pylint: disable=protected-access
464
465
466def _sort_function_defs(library, function_deps):
467  """Return a topologic sort of FunctionDefs in a library."""
468  edges = collections.defaultdict(list)
469  in_count = collections.defaultdict(lambda: 0)
470
471  for fname, deps in function_deps.items():
472    for dep in deps:
473      edges[dep].append(fname)
474      in_count[fname] += 1
475  ready = [
476      fdef.signature.name
477      for fdef in library.function
478      if in_count[fdef.signature.name] == 0
479  ]
480  output = []
481  while ready:
482    node = ready.pop()
483    output.append(node)
484    for dest in edges[node]:
485      in_count[dest] -= 1
486      if not in_count[dest]:
487        ready.append(dest)
488
489  if len(output) != len(library.function):
490    failed_to_resolve = sorted(set(in_count.keys()) - set(output))
491    raise ValueError("There is a cyclic-dependency between functions. ",
492                     "Could not resolve %r." % (failed_to_resolve,))
493
494  reverse = {fdef.signature.name: fdef for fdef in library.function}
495  return [reverse[x] for x in output]
496
497
498def _get_gradient_op_type(node_def):
499  """Returns the custom gradient op type."""
500  if ("_gradient_op_type" in node_def.attr and
501      node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]):
502    return node_def.attr["_gradient_op_type"].s
503  return None
504
505
506def fix_node_def(node_def, functions, shared_name_suffix):
507  """Replace functions calls and shared names in `node_def`."""
508  if node_def.op in functions:
509    node_def.op = functions[node_def.op].name
510  for _, attr_value in node_def.attr.items():
511    if attr_value.WhichOneof("value") == "func":
512      attr_value.func.name = functions[attr_value.func.name].name
513    elif attr_value.WhichOneof("value") == "list":
514      for fn in attr_value.list.func:
515        fn.name = functions[fn.name].name
516
517  # Fix old table creation bug.
518  if node_def.op == "HashTableV2":
519    if ("use_node_name_sharing" not in node_def.attr or
520        not node_def.attr["use_node_name_sharing"].b):
521      node_def.attr["use_node_name_sharing"].b = True
522      # We are turning on node mame sharing, so have to make sure we don't
523      # accidentally share a table resource.
524      shared_name_suffix += "_{}".format(ops.uid())
525
526  # TODO(b/124205571): Avoid accidental sharing and destruction of restored
527  # resources. For now uniquify "shared_name" when loading functions to avoid
528  # sharing.
529  # TODO: Add regression test for b/150826922.
530  op_def = op_def_registry.get(node_def.op)
531  if op_def:
532    attr = next((a for a in op_def.attr if a.name == "shared_name"), None)
533    if attr:
534      shared_name = None
535      if "shared_name" in node_def.attr and node_def.attr["shared_name"].s:
536        shared_name = node_def.attr["shared_name"].s
537      elif attr.default_value.s:
538        shared_name = compat.as_bytes(attr.default_value.s)
539      if not shared_name:
540        shared_name = compat.as_bytes(node_def.name)
541
542      node_def.attr["shared_name"].s = (
543          shared_name + compat.as_bytes(shared_name_suffix))
544
545
546def _fix_fdef(orig_fdef, functions, shared_name_suffix, new_gradient_op_types):
547  """Fixes a FunctionDef proto to be loaded in current context.
548
549  In particular, when loading a function library into an eager context, one
550  must rename the functions to avoid conflicts with existent functions.
551
552  Args:
553    orig_fdef: FunctionDef proto to fix. It is not modified.
554    functions: map from function name to a ConcreteFunction instance.
555    shared_name_suffix: A unique string for this load which helps to avoid
556      `shared_name` collisions across loads. Two functions from the same load
557      using the same `shared_name` still need to share, but functions from
558      different loads with the same `shared_name` should not.
559    new_gradient_op_types: map from old gradient op type to newly generated
560      op type.
561
562  Returns:
563    A fixed copy of the original FunctionDef
564  """
565  fdef = function_pb2.FunctionDef()
566  fdef.CopyFrom(orig_fdef)
567  contains_unsaved_custom_gradients = False
568
569  for node_def in fdef.node_def:
570    fix_node_def(node_def, functions, shared_name_suffix)
571    op_type = _get_gradient_op_type(node_def)
572    if op_type is not None:
573      if op_type in new_gradient_op_types:
574        node_def.attr["_gradient_op_type"].s = compat.as_bytes(
575            new_gradient_op_types[op_type])
576      else:
577        contains_unsaved_custom_gradients = True
578  if contains_unsaved_custom_gradients:
579    logging.warning(
580        "Importing a function (%s) with ops with unsaved custom gradients. Will"
581        " likely fail if a gradient is requested.", fdef.signature.name)
582
583  fdef.signature.name = _clean_function_name(fdef.signature.name)
584  return fdef
585
586
587def _list_function_deps(fdef, library_function_names, library_gradient_names):
588  """Find functions referenced in `fdef`."""
589  # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both
590  # when listing deps and when fixing them. `function_def_to_graph` also
591  # requires fixes.
592  deps = set()
593  for node_def in fdef.node_def:
594    grad_op_type = _get_gradient_op_type(node_def)
595    if node_def.op in library_function_names:
596      deps.add(node_def.op)
597    elif grad_op_type and grad_op_type in library_gradient_names:
598      deps.add(library_gradient_names[grad_op_type])
599    else:
600      for _, attr_value in node_def.attr.items():
601        if attr_value.WhichOneof("value") == "func":
602          deps.add(attr_value.func.name)
603        elif attr_value.WhichOneof("value") == "list":
604          for fn in attr_value.list.func:
605            deps.add(fn.name)
606
607  return deps
608
609
610_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX
611                                                 )  # pylint:disable=protected-access
612
613
614def _clean_function_name(name):
615  """Vanity function to keep the function names comprehensible."""
616  # Note: each time a function is wrapped into `function_lib.ConcreteFunction`
617  # its name becomes "__inference_<orig>_xyz".
618  match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name)
619  if match:
620    return match.group(1)
621  else:
622    return name
623