1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for deserializing `Function`s.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23from absl import logging 24 25from tensorflow.core.framework import function_pb2 26from tensorflow.core.protobuf import saved_object_graph_pb2 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import function as function_lib 29from tensorflow.python.framework import func_graph as func_graph_lib 30from tensorflow.python.framework import function_def_to_graph as function_def_lib 31from tensorflow.python.framework import op_def_registry 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import type_spec 35from tensorflow.python.ops import resource_variable_ops 36from tensorflow.python.saved_model import nested_structure_coder 37from tensorflow.python.util import compat 38from tensorflow.python.util import nest 39from tensorflow.python.util import tf_decorator 40from tensorflow.python.util import tf_inspect 41 42 43def _is_tensor(t): 44 return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 45 46 47# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the 48# structured signature. 49def _call_concrete_function(function, inputs): 50 """Calls a restored Function with structured inputs. 51 52 This differs from `function.__call__` in that inputs and outputs are 53 structured and that it casts inputs to tensors if needed. 54 55 Note: this does not checks that non-tensor inputs match. That should be 56 done before via `_concrete_function_callable_with`. 57 58 Args: 59 function: ConcreteFunction to call. 60 inputs: Structured inputs compatible with 61 `function.graph.structured_input_signature`. 62 63 Returns: 64 The structured function output. 65 """ 66 expected_structure = function.graph.structured_input_signature 67 flatten_inputs = nest.flatten_up_to( 68 expected_structure, inputs, expand_composites=True) 69 flatten_expected = nest.flatten(expected_structure, expand_composites=True) 70 tensor_inputs = [] 71 for arg, expected in zip(flatten_inputs, flatten_expected): 72 if isinstance(expected, tensor_spec.TensorSpec): 73 tensor_inputs.append( 74 ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) 75 result = function._call_flat(tensor_inputs, function._captured_inputs) # pylint: disable=protected-access 76 if isinstance(result, ops.Operation): 77 return None 78 return result 79 80 81def _try_convert_to_tensor_spec(arg, dtype_hint): 82 """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" 83 try: 84 # Note: try conversion in a FuncGraph to avoid polluting current context. 85 with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): 86 result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) 87 return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) 88 except (TypeError, ValueError): 89 return None 90 91 92def _concrete_function_callable_with(function, inputs, allow_conversion): 93 """Returns whether concrete `function` can be called with `inputs`.""" 94 expected_structure = function.graph.structured_input_signature 95 try: 96 flatten_inputs = nest.flatten_up_to(expected_structure, inputs) 97 except (TypeError, ValueError): 98 return False 99 100 for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): 101 if isinstance(expected, tensor_spec.TensorSpec): 102 if allow_conversion: 103 arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) 104 if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): 105 return False 106 if arg.dtype != expected.dtype: 107 return False 108 if not expected.shape.is_compatible_with(arg.shape): 109 return False 110 elif isinstance(expected, type_spec.TypeSpec): 111 if not expected.is_compatible_with(arg): 112 return False 113 elif _is_tensor(arg): 114 if id(arg) != id(expected): 115 return False 116 else: 117 if arg != expected: 118 return False 119 return True 120 121 122def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder): 123 """Deserialize a FunctionSpec object from its proto representation.""" 124 typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec) 125 126 # Convert a method function into a non method. 127 if function_spec_proto.is_method: 128 if not typeless_fullargspec.args: 129 raise NotImplementedError( 130 "Missing support to deserialize a method function without a named " 131 "'self' argument.") 132 args = typeless_fullargspec.args[1:] 133 else: 134 args = typeless_fullargspec.args 135 136 fullargspec = tf_inspect.FullArgSpec( 137 args=args, 138 varargs=typeless_fullargspec.varargs, 139 varkw=typeless_fullargspec.varkw, 140 defaults=typeless_fullargspec.defaults, 141 kwonlyargs=typeless_fullargspec.kwonlyargs, 142 kwonlydefaults=typeless_fullargspec.kwonlydefaults, 143 annotations=typeless_fullargspec.annotations) 144 input_signature = coder.decode_proto(function_spec_proto.input_signature) 145 146 # See `tf.function` and the JitCompile proto for details. 147 jit_compile = { 148 saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None, 149 saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True, 150 saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False, 151 }.get(function_spec_proto.jit_compile) 152 153 return function_lib.FunctionSpec(fullargspec=fullargspec, 154 is_method=False, 155 input_signature=input_signature, 156 jit_compile=jit_compile) 157 158 159# TODO(allenl): The fact that we can't derive ConcreteFunction calling 160# conventions from the serialized input spec right now is unfortunate. Merging 161# these would be good, maybe by adding TensorSpec names to cache keys so renamed 162# keyword arguments would yield different ConcreteFunctions. 163def setup_bare_concrete_function(saved_bare_concrete_function, 164 concrete_functions): 165 """Makes a restored bare concrete function callable.""" 166 concrete_function = concrete_functions[ 167 saved_bare_concrete_function.concrete_function_name] 168 # pylint: disable=protected-access 169 concrete_function._arg_keywords = ( 170 saved_bare_concrete_function.argument_keywords) 171 concrete_function._num_positional_args = ( 172 saved_bare_concrete_function.allowed_positional_arguments) 173 if saved_bare_concrete_function.HasField("function_spec"): 174 coder = nested_structure_coder.StructureCoder() 175 function_spec = _deserialize_function_spec_as_nonmethod( 176 saved_bare_concrete_function.function_spec, 177 coder) 178 concrete_function._set_function_spec(function_spec) 179 # pylint: enable=protected-access 180 concrete_function.add_to_graph() 181 return concrete_function 182 183 184class RestoredFunction(def_function.Function): 185 """Wrapper class for a function that has been restored from saved state. 186 187 See `def_function.Function`. 188 """ 189 190 def __init__(self, python_function, name, function_spec, concrete_functions): 191 # TODO(mdan): We may enable autograph once exceptions are supported. 192 super(RestoredFunction, self).__init__( 193 python_function, name, autograph=False, 194 jit_compile=function_spec.jit_compile) 195 self.concrete_functions = concrete_functions 196 self._function_spec = function_spec 197 198 # Prevent RestoredFunction from spamming users with frequent tracing 199 # warnings. 200 self._omit_frequent_tracing_warning = True 201 202 @property 203 def _run_functions_eagerly(self): 204 # We do not have access to the original python function, and thus, we 205 # cannot meaningfully do anything but call our concrete function graphs 206 # under the hood. 207 # 208 # Attempting to call our bespoke python function (i.e. 209 # `restored_function_body`) will work so long as the user passes in all 210 # required and optional arguments. If an optional argument is missing, 211 # however, the call will break. For this reason, we instead skip the 212 # eager call path altogether if a user has enabled eager function execution 213 # via `tf.config.run_functions_eagerly`. 214 return False 215 216 def _list_all_concrete_functions_for_serialization(self): 217 return self.concrete_functions 218 219 def _defun_with_scope(self, scope): 220 func = super(RestoredFunction, self)._defun_with_scope(scope) 221 func._function_spec = self._function_spec # pylint: disable=protected-access 222 return func 223 224 225def recreate_function(saved_function, concrete_functions): 226 """Creates a `Function` from a `SavedFunction`. 227 228 Args: 229 saved_function: `SavedFunction` proto. 230 concrete_functions: map from function name to `ConcreteFunction`. 231 As a side effect of this function, the `FunctionSpec` from 232 `saved_function` is added to each `ConcreteFunction` in this map. 233 234 Returns: 235 A `Function`. 236 """ 237 # TODO(andresp): Construct a `Function` with the cache populated 238 # instead of creating a new `Function` backed by a Python layer to 239 # glue things together. Current approach is nesting functions deeper for each 240 # serialization cycle. 241 coder = nested_structure_coder.StructureCoder() 242 243 # Note: handling method functions is tricky since make_decorator does not 244 # allows control of "ismethod". Additionally since restored functions do 245 # not behave as methods i.e. they always use the same captured tensors 246 # independent of the object they are bound to, there is little value on 247 # propagating that correctly. 248 # 249 # Ideally this conversion should happen at serialization time. But since 250 # there are SavedModels which have "ismethod" populated and have an extra 251 # argument that they expect to be ignored, we do it at deserialization. 252 function_spec = _deserialize_function_spec_as_nonmethod( 253 saved_function.function_spec, 254 coder) 255 256 def restored_function_body(*args, **kwargs): 257 """Calls a restored function or raises an error if no matching function.""" 258 if not saved_function.concrete_functions: 259 raise ValueError("Found zero restored functions for caller function.") 260 # This is the format of function.graph.structured_input_signature. At this 261 # point, the args and kwargs have already been canonicalized. 262 inputs = (args, kwargs) 263 264 # First try to find a concrete function that can be called without input 265 # conversions. This allows one to pick a more specific trace in case there 266 # was also a more expensive one that supported tensors. 267 for allow_conversion in [False, True]: 268 for function_name in saved_function.concrete_functions: 269 function = concrete_functions[function_name] 270 if _concrete_function_callable_with(function, inputs, allow_conversion): 271 return _call_concrete_function(function, inputs) 272 273 signature_descriptions = [] 274 275 def _pretty_format_positional(positional): 276 return "Positional arguments ({} total):\n * {}".format( 277 len(positional), "\n * ".join(str(a) for a in positional)) 278 279 for index, function_name in enumerate(saved_function.concrete_functions): 280 concrete_function = concrete_functions[function_name] 281 positional, keyword = concrete_function.structured_input_signature 282 signature_descriptions.append( 283 "Option {}:\n {}\n Keyword arguments: {}" 284 .format(index + 1, _pretty_format_positional(positional), keyword)) 285 raise ValueError( 286 "Could not find matching function to call loaded from the SavedModel. " 287 "Got:\n {}\n Keyword arguments: {}\n\nExpected " 288 "these arguments to match one of the following {} option(s):\n\n{}" 289 .format(_pretty_format_positional(args), kwargs, 290 len(saved_function.concrete_functions), 291 "\n\n".join(signature_descriptions))) 292 293 concrete_function_objects = [] 294 for concrete_function_name in saved_function.concrete_functions: 295 concrete_function_objects.append(concrete_functions[concrete_function_name]) 296 297 for cf in concrete_function_objects: 298 cf._set_function_spec(function_spec) # pylint: disable=protected-access 299 300 restored_function = RestoredFunction( 301 restored_function_body, 302 restored_function_body.__name__, 303 function_spec, 304 concrete_function_objects) 305 306 return tf_decorator.make_decorator( 307 restored_function_body, 308 restored_function, 309 decorator_argspec=function_spec.fullargspec) 310 311 312def load_function_def_library(library, load_shared_name_suffix=None): 313 """Load a set of functions as concrete functions without captured inputs. 314 315 Functions names are manipulated during load such that they do not overlap 316 with previously created ones. 317 318 Args: 319 library: FunctionDefLibrary proto message. 320 load_shared_name_suffix: If specified, used to uniquify shared 321 names. Otherwise, a unique name is generated. 322 323 Returns: 324 Map of original function names in the library to instances of 325 `ConcreteFunction` without captured inputs. 326 327 Raises: 328 ValueError: if functions dependencies have a cycle. 329 """ 330 library_function_names = set(fdef.signature.name for fdef in library.function) 331 functions = {} 332 renamed_functions = {} 333 334 # Our graph building code currently requires functions to be registered with 335 # some tf.Graph in order to import functions using the 336 # op-name-is-function-name calling convention. To avoid leaking memory into 337 # the global default graph when executing eagerly, we create a temporary 338 # Graph. 339 # 340 # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by 341 # fixing function_def_to_graph_def. 342 if ops.executing_eagerly_outside_functions(): 343 graph = ops.Graph() 344 else: 345 graph = ops.get_default_graph() 346 347 if load_shared_name_suffix is None: 348 load_shared_name_suffix = "_load_{}".format(ops.uid()) 349 for fdef in _sort_function_defs(library, library_function_names): 350 copy = _fix_fdef(fdef, functions, load_shared_name_suffix) 351 352 # There is no need to copy all functions into the function def graph. It 353 # leads to a O(n^2) increase of memory when importing functions and the 354 # extra function definitions are a no-op since they already imported as a 355 # function before and passed in explicitly (due to the topologic sort 356 # import). 357 with graph.as_default(): 358 func_graph = function_def_lib.function_def_to_graph(copy) 359 _restore_gradient_functions(func_graph, renamed_functions) 360 361 for dep in _list_function_deps(fdef, library_function_names): 362 functions[dep].add_to_graph(func_graph) 363 364 # We do not initialize the new ConcreteFunction's function_spec and/or 365 # arg_keywords here (which are used to parse the structured and flat 366 # signatures, respectively). ConcreteFunction that are part of a saved 367 # function is set up later by recreate_function(); and bare ConcreteFunction 368 # is set up by by setup_bare_concrete_function(). 369 func = function_lib.ConcreteFunction(func_graph) 370 func.add_to_graph(graph) 371 372 functions[fdef.signature.name] = func 373 renamed_functions[func.name] = func 374 if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): 375 # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration 376 # is fixed. Currently it's leaking memory to maintain bug compatibility 377 # with previous behavior. 378 func.add_to_graph(ops.get_default_graph()) 379 380 return functions 381 382 383def _restore_gradient_functions(func_graph, renamed_functions): 384 """Populate function op's _gradient_function with default gradient.""" 385 for op in func_graph.get_operations(): 386 # TODO(andresp): This code assumes that the gradient registered for this 387 # function call is the default gradient for the function and not a custom 388 # one. 389 if op.type in ["StatefulPartitionedCall", "PartitionedCall"]: 390 function = renamed_functions[compat.as_bytes( 391 op.node_def.attr["f"].func.name)] 392 op._gradient_function = function._get_gradient_function() # pylint: disable=protected-access 393 394 395def _sort_function_defs(library, library_function_names): 396 """Return a topologic sort of FunctionDefs in a library.""" 397 edges = collections.defaultdict(list) 398 in_count = collections.defaultdict(lambda: 0) 399 400 for fdef in library.function: 401 for dep in _list_function_deps(fdef, library_function_names): 402 edges[dep].append(fdef.signature.name) 403 in_count[fdef.signature.name] += 1 404 405 ready = [ 406 fdef.signature.name 407 for fdef in library.function 408 if in_count[fdef.signature.name] == 0 409 ] 410 output = [] 411 while ready: 412 node = ready.pop() 413 output.append(node) 414 for dest in edges[node]: 415 in_count[dest] -= 1 416 if not in_count[dest]: 417 ready.append(dest) 418 419 if len(output) != len(library.function): 420 failed_to_resolve = sorted(set(in_count.keys()) - set(output)) 421 raise ValueError("There is a cyclic-dependency between functions. ", 422 "Could not resolve %r." % (failed_to_resolve,)) 423 424 reverse = {fdef.signature.name: fdef for fdef in library.function} 425 return [reverse[x] for x in output] 426 427 428def _check_op_has_custom_gradients(node_def): 429 """Returns True if op has custom gradients.""" 430 return ("_gradient_op_type" in node_def.attr and 431 node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]) 432 433 434def fix_node_def(node_def, functions, shared_name_suffix): 435 """Replace functions calls and shared names in `node_def`.""" 436 if node_def.op in functions: 437 node_def.op = functions[node_def.op].name 438 for _, attr_value in node_def.attr.items(): 439 if attr_value.WhichOneof("value") == "func": 440 attr_value.func.name = functions[attr_value.func.name].name 441 elif attr_value.WhichOneof("value") == "list": 442 for fn in attr_value.list.func: 443 fn.name = functions[fn.name].name 444 445 # Fix old table creation bug. 446 if node_def.op == "HashTableV2": 447 if ("use_node_name_sharing" not in node_def.attr or 448 not node_def.attr["use_node_name_sharing"].b): 449 node_def.attr["use_node_name_sharing"].b = True 450 # We are turning on node mame sharing, so have to make sure we don't 451 # accidentally share a table resource. 452 shared_name_suffix += "_{}".format(ops.uid()) 453 454 # TODO(b/124205571): Avoid accidental sharing and destruction of restored 455 # resources. For now uniquify "shared_name" when loading functions to avoid 456 # sharing. 457 # TODO: Add regression test for b/150826922. 458 op_def = op_def_registry.get(node_def.op) 459 if op_def: 460 attr = next((a for a in op_def.attr if a.name == "shared_name"), None) 461 if attr: 462 shared_name = None 463 if "shared_name" in node_def.attr and node_def.attr["shared_name"].s: 464 shared_name = node_def.attr["shared_name"].s 465 elif attr.default_value.s: 466 shared_name = compat.as_bytes(attr.default_value.s) 467 if not shared_name: 468 shared_name = compat.as_bytes(node_def.name) 469 470 node_def.attr["shared_name"].s = ( 471 shared_name + compat.as_bytes(shared_name_suffix)) 472 473 474def _fix_fdef(orig_fdef, functions, shared_name_suffix): 475 """Fixes a FunctionDef proto to be loaded in current context. 476 477 In particular, when loading a function library into an eager context, one 478 must rename the functions to avoid conflicts with existent functions. 479 480 Args: 481 orig_fdef: FunctionDef proto to fix. It is not modified. 482 functions: map from function name to a ConcreteFunction instance. 483 shared_name_suffix: A unique string for this load which helps to avoid 484 `shared_name` collisions across loads. Two functions from the same load 485 using the same `shared_name` still need to share, but functions from 486 different loads with the same `shared_name` should not. 487 488 Returns: 489 A fixed copy of the original FunctionDef. 490 """ 491 fdef = function_pb2.FunctionDef() 492 fdef.CopyFrom(orig_fdef) 493 contains_custom_gradients = False 494 495 for node_def in fdef.node_def: 496 fix_node_def(node_def, functions, shared_name_suffix) 497 if not contains_custom_gradients: 498 contains_custom_gradients = _check_op_has_custom_gradients(node_def) 499 if contains_custom_gradients: 500 logging.warning( 501 "Importing a function (%s) with ops with custom gradients. Will likely " 502 "fail if a gradient is requested.", fdef.signature.name) 503 504 fdef.signature.name = _clean_function_name(fdef.signature.name) 505 return fdef 506 507 508def _list_function_deps(fdef, library_function_names): 509 """Find functions referenced in `fdef`.""" 510 # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both 511 # when listing deps and when fixing them. `function_def_to_graph` also 512 # requires fixes. 513 deps = set() 514 for node_def in fdef.node_def: 515 if node_def.op in library_function_names: 516 deps.add(node_def.op) 517 else: 518 for _, attr_value in node_def.attr.items(): 519 if attr_value.WhichOneof("value") == "func": 520 deps.add(attr_value.func.name) 521 elif attr_value.WhichOneof("value") == "list": 522 for fn in attr_value.list.func: 523 deps.add(fn.name) 524 525 return deps 526 527 528_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX 529 ) # pylint:disable=protected-access 530 531 532def _clean_function_name(name): 533 """Vanity function to keep the function names comprehensible.""" 534 # Note: each time a function is wrapped into `function_lib.ConcreteFunction` 535 # its name becomes "__inference_<orig>_xyz". 536 match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name) 537 if match: 538 return match.group(1) 539 else: 540 return name 541