1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Prototype decorator for defining legacy-graph-mode functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import weakref 23 24from tensorflow.core.protobuf import meta_graph_pb2 25from tensorflow.core.protobuf import struct_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.eager import function 28from tensorflow.python.eager import lift_to_graph 29from tensorflow.python.framework import composite_tensor 30from tensorflow.python.framework import func_graph 31from tensorflow.python.framework import importer 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import sparse_tensor 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.saved_model import nested_structure_coder 41from tensorflow.python.training.tracking import data_structures 42from tensorflow.python.util import nest 43from tensorflow.python.util.tf_export import tf_export 44 45 46class VariableHolder(object): 47 """Holds variables for a python function.""" 48 49 def __init__(self, fn=None, share_variables=False): 50 self._fn = fn 51 52 self._share_variables = share_variables 53 self._variables_by_name = data_structures.Mapping() 54 55 @property 56 def variables(self): 57 return self._variables_by_name 58 59 def variable_creator_scope(self, next_creator, **kwargs): 60 """Creates variables & adds them to collections to match legacy code.""" 61 collections = kwargs.pop("collections", None) 62 v = None 63 64 # Get expected variable name. 65 with ops.name_scope( 66 kwargs.get("name", None), "Variable", skip_on_eager=False) as name: 67 variable_name = ops.name_from_scope_name(name) 68 kwargs["name"] = name 69 70 if self._share_variables: 71 v = self._variables_by_name.get(variable_name, None) 72 73 if v is None: 74 v = next_creator(**kwargs) 75 self._variables_by_name[variable_name] = v 76 77 if collections is None: 78 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 79 if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 80 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 81 82 ops.add_to_collections(collections, v) 83 84 return v 85 86 def __call__(self, *args, **kwargs): 87 return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs) 88 89 def call_with_variable_creator_scope(self, fn): 90 91 def wrapped(*args, **kwargs): 92 with variable_scope.variable_creator_scope(self.variable_creator_scope): 93 return fn(*args, **kwargs) 94 95 return wrapped 96 97 98def _get_element_from_tensor_info(tensor_info, graph): 99 """Simplified copy of the deprecated `get_tensor_from_tensor_info`.""" 100 encoding = tensor_info.WhichOneof("encoding") 101 if encoding == "name": 102 # We may get operations here in some cases. TensorInfo is a bit of a 103 # misnomer if so. 104 return graph.as_graph_element(tensor_info.name) 105 elif encoding == "coo_sparse": 106 return sparse_tensor.SparseTensor( 107 graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name), 108 graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name), 109 graph.get_tensor_by_name( 110 tensor_info.coo_sparse.dense_shape_tensor_name)) 111 elif encoding == "composite_tensor": 112 struct_coder = nested_structure_coder.StructureCoder() 113 spec_proto = struct_pb2.StructuredValue( 114 type_spec_value=tensor_info.composite_tensor.type_spec) 115 spec = struct_coder.decode_proto(spec_proto) 116 components = [graph.get_tensor_by_name(component.name) for component in 117 tensor_info.composite_tensor.components] 118 return spec._from_components(components) # pylint: disable=protected-access 119 else: 120 raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid " 121 "encodings are 'name', 'coo_sparse', and " 122 "'composite_tensor'.") 123 124 125def _lift_single_variable(old_variable, graph, variable_holder): 126 """Lifts `old_variable` out of the `FuncGraph` `graph`.""" 127 new_variable = resource_variable_ops.UninitializedVariable( 128 shape=old_variable.shape, 129 dtype=old_variable.dtype, 130 name=old_variable.op.name, 131 trainable=old_variable.trainable, 132 extra_handle_data=old_variable.handle) 133 new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access 134 graph.add_capture(new_variable.handle, old_variable.handle) 135 # Now that we've added the new variable to graph.captures, 136 # graph.capture will use that cached value and do some post-processing 137 # on the capture like recording it on the tape. 138 graph.capture(new_variable.handle) 139 # pylint: disable=protected-access 140 variable_name = new_variable.name.split(":")[0] 141 variable_holder._variables_by_name[variable_name] = new_variable 142 graph._weak_variables.append(weakref.ref(new_variable)) 143 # pylint: enable=protected-access 144 graph.watch_variable(new_variable) 145 return new_variable 146 147 148def _lift_unlifted_variables(graph, variable_holder): 149 """Finds resource variables and lifts them into the outer context. 150 151 When we import a GraphDef inside a wrap_function, no Python graph building 152 code runs. This means we get VarHandleOps which create variable resources, 153 but no corresponding Python objects. Leaving them like this works but gives 154 the user no way to interact with or modify the variables outside the graph. 155 156 This method searches for variables and lifts them out as regular variable 157 objects when possible, indicating to the FuncGraph that they are captures. 158 159 Args: 160 graph: The FuncGraph to lift variables from. 161 variable_holder: A VariableHolder to record the lifted variables in. 162 """ 163 with graph.as_default(): 164 global_collection_variables = ops.get_collection( 165 ops.GraphKeys.GLOBAL_VARIABLES) 166 local_collection_variables = ops.get_collection( 167 ops.GraphKeys.LOCAL_VARIABLES) 168 existing_captures = {id(c) for c in graph.internal_captures} 169 lifted_variables = {} 170 171 def _should_lift_variable(v): 172 return ((v._in_graph_mode # pylint: disable=protected-access 173 and v.graph.building_function) 174 and isinstance(v, resource_variable_ops.BaseResourceVariable) 175 and id(v.handle) not in existing_captures) 176 177 for old_variable in global_collection_variables: 178 if _should_lift_variable(old_variable): 179 new_variable = _lift_single_variable( 180 old_variable, graph, variable_holder) 181 lifted_variables[id(old_variable)] = new_variable 182 existing_captures.add(id(old_variable.handle)) 183 184 for old_variable in local_collection_variables: 185 if _should_lift_variable(old_variable): 186 new_variable = _lift_single_variable( 187 old_variable, graph, variable_holder) 188 lifted_variables[id(old_variable)] = new_variable 189 existing_captures.add(id(old_variable.handle)) 190 if new_variable._in_graph_mode: # pylint: disable=protected-access 191 outer_graph = new_variable.graph 192 # Variables are added to the global collection by default. In this 193 # case we only want the variable in the local collection, so we'll pop 194 # it out. 195 global_collection = outer_graph.get_collection_ref( 196 ops.GraphKeys.GLOBAL_VARIABLES) 197 global_collection.remove(new_variable) 198 outer_graph.add_to_collection( 199 ops.GraphKeys.LOCAL_VARIABLES, new_variable) 200 201 # Update the FuncGraph's collections, partly for the user and partly so this 202 # function is idempotent when it runs again in prune() calls. 203 for collection_name in [ 204 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES 205 ]: 206 mutable_collection = ops.get_collection_ref(collection_name) 207 for index, current in enumerate(mutable_collection): 208 mutable_collection[index] = lifted_variables.get(id(current), current) 209 if not resource_variable_ops.is_resource_variable( 210 mutable_collection[index]): 211 logging.log_first_n( 212 logging.WARN, 213 "Unable to create a python object for variable {} because it is " 214 "a reference variable. It may not be visible to training APIs. " 215 "If this is a problem, consider rebuilding the SavedModel after " 216 "running tf.compat.v1.enable_resource_variables().".format( 217 mutable_collection[index]), 218 5) 219 220 221# TODO(allenl): make this trackable 222class WrappedFunction(function.ConcreteFunction): 223 """Wraps a tf V1 piece of code in a function.""" 224 225 def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): 226 self._variable_holder = variable_holder 227 _lift_unlifted_variables(fn_graph, variable_holder) 228 # We call __init__ after lifting variables so that the function's signature 229 # properly reflects the new captured inputs. 230 for f in fn_graph.as_graph_def().library.function: 231 context.context().add_function_def(f) 232 self._signature = signature 233 super(WrappedFunction, self).__init__(fn_graph, attrs=attrs) 234 235 def _call_impl(self, args, kwargs, cancellation_manager=None): 236 if self._arg_keywords is None: 237 if kwargs: 238 raise NotImplementedError( 239 "Keyword arguments are not supported when calling a " 240 f"wrap_function-decorated function. Got {kwargs}.") 241 if self._signature is not None: 242 args = list(args) 243 for i, arg in enumerate(args): 244 if isinstance(self._signature[i], tensor_spec.DenseSpec): 245 args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype) 246 return self._call_flat(args, self.captured_inputs) 247 else: 248 return super(WrappedFunction, self)._call_impl( 249 args, kwargs, cancellation_manager) 250 251 def prune(self, feeds, fetches, name=None, input_signature=None): 252 """Extract a subgraph of this function's underlying graph. 253 254 Wraps the subgraph in a new `WrappedFunction` object. 255 256 Args: 257 feeds: Input tensors to the subgraph to extract, as `Tensor` objects. 258 fetches: Possibly-nested Python data structure containing information 259 about outputs of the target subgraph. Each entry can either be a 260 `Tensor` object (for data outputs), an `Operation` object (for control 261 outputs), or a `TensorInfo` proto. Any additional shape/dtype 262 information provided in a `TensorInfo` and not present in the original 263 graph will be added to the returned subgraph. 264 name: (optional) Name to give to the underlying `FuncGraph` of the 265 returned object. If no name is provided, the graph's name will be 266 `"pruned"`. 267 input_signature: (optional) possibly-nested Python data structure 268 containing `TensorSpec` objects, with which to populate the returned 269 functions's `FuncGraph`'s `structured_input_signature` field. 270 271 Returns: 272 A new `WrappedFunction` object containing a copy of the portion of this 273 object's graph that goes from `feeds` to `fetches`. 274 """ 275 # TODO(b/129646028): Add support for CompositeTensors. 276 name = name or "pruned" 277 flat_feeds = nest.flatten(feeds, expand_composites=True) 278 flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] 279 for f in flat_feeds: 280 if not isinstance(f, ops.Tensor): 281 raise ValueError("All memebers of argument `feeds` must be tensors. " 282 f"Got {f} with type {type(f)}.") 283 284 # Ignoring all feeds that are captures allows prune to be called 285 # using wrapped_func.inputs even when it uses variables 286 internal_captures = {id(c) for c in self.graph.internal_captures} 287 flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures] 288 289 operation_fetches = [] 290 tensor_fetches = [] 291 tensor_infos = [] 292 293 def _fetch_preprocessing_callback(fetch): 294 """Extract out lists of ops, tensors, and tensor type info. 295 296 Turns TensorInfos into Tensors in the original `fetches` structure. 297 Also extracts ops from `fetches`. 298 299 Args: 300 fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or 301 string identifying a Tensor or Operation. 302 303 Returns: 304 `fetch` converted to a Tensor. 305 """ 306 if isinstance(fetch, ops.Operation): 307 operation_fetches.append(fetch) 308 return fetch 309 elif isinstance(fetch, meta_graph_pb2.TensorInfo): 310 tensor_infos.append(fetch) 311 decoded = _get_element_from_tensor_info(fetch, self._func_graph) 312 if (tensor_util.is_tf_type(decoded) or 313 isinstance(decoded, composite_tensor.CompositeTensor)): 314 tensor_fetches.append(decoded) 315 else: 316 operation_fetches.append(decoded) 317 return decoded 318 elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): 319 tensor_fetches.append(fetch) 320 return fetch 321 else: 322 graph_element = self.graph.as_graph_element(fetch) 323 return _fetch_preprocessing_callback(graph_element) 324 325 fetches = nest.map_structure(_fetch_preprocessing_callback, fetches) 326 327 # Expand composite tensors into their component dense Tensors. 328 tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) 329 330 for f in flat_feeds + tensor_fetches + operation_fetches: 331 if f.graph is not self._func_graph: 332 raise ValueError("Can only prune function whose feeds and fetches " 333 f"from graph {self._func_graph}. Input " 334 f"{f} is from a different graph {f.graph}.") 335 with self._func_graph.as_default(): 336 pruned_graph = func_graph.FuncGraph(name) 337 lift_map = lift_to_graph.lift_to_graph( 338 operation_fetches + tensor_fetches, 339 pruned_graph, 340 sources=flat_feeds + self.graph.internal_captures, 341 base_graph=self._func_graph) 342 343 # Note that we add the component tensors of any composite tensors to the 344 # returned function's outputs list; the list must contain these component 345 # tensors, or the function's sparse outputs won't work properly. 346 pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) 347 pruned_graph.control_outputs.extend( 348 [lift_map[operation] for operation in operation_fetches]) 349 pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) 350 for external_capture, internal_capture in self.graph.captures: 351 pruned_graph.add_capture(external_capture, lift_map[internal_capture]) 352 for ti in tensor_infos: 353 if ti.WhichOneof("encoding") == "name": # Dense tensors only 354 t = pruned_graph.as_graph_element(ti.name) 355 if tensor_util.is_tf_type(t): 356 t.set_shape(tensor_shape.TensorShape(ti.tensor_shape)) 357 # pylint: disable=protected-access 358 for f in self.graph._functions.values(): 359 pruned_graph._add_function(f) 360 # pylint: enable=protected-access 361 362 pruned_graph.variables = self.graph.variables 363 364 def _structured_output_mapping(fetched): 365 """callback for `nest.map_structure()`""" 366 lifted = lift_map[fetched] 367 if isinstance(lifted, ops.Operation): 368 return None 369 return lifted 370 371 # expand_composites=True here causes composite tensors to be expanded 372 # into their component dense Tensors, mapped to the new graph, and then 373 # reconstituted into their original composite form. 374 pruned_graph.structured_outputs = nest.map_structure( 375 _structured_output_mapping, fetches, expand_composites=True) 376 pruned_graph.structured_input_signature = input_signature 377 pruned_fn = WrappedFunction( 378 pruned_graph, variable_holder=self._variable_holder) 379 pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access 380 # TODO(kathywu): Enable keyword arguments if an input signature is specified 381 pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access 382 return pruned_fn 383 384 385def _filter_returned_ops(fn): 386 """Filtering out any ops returned by function. 387 388 Args: 389 fn: a function 390 391 Returns: 392 A tuple of ( 393 Wrapped function that returns `None` in place of any ops, 394 dict that maps the index in the flat output structure to the returned op 395 ) 396 """ 397 returned_ops = {} 398 399 def wrap_and_filter_returned_ops(*args, **kwargs): 400 outputs = fn(*args, **kwargs) 401 flat_outputs = nest.flatten(outputs) 402 for n in range(len(flat_outputs)): 403 output = flat_outputs[n] 404 if isinstance(output, ops.Operation): 405 returned_ops[n] = output 406 flat_outputs[n] = None 407 return nest.pack_sequence_as(outputs, flat_outputs) 408 409 return wrap_and_filter_returned_ops, returned_ops 410 411 412class WrappedGraph(object): 413 """Class for wrapping multiple TF 1.X functions in a single graph. 414 415 Maintains a dictionary mapping names to wrapped functions. See 416 `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions. 417 418 Functions wrapped using this class have access to variables and collections 419 created in other wrapped functions, using the standard TF 1.X API ( 420 `tf.compat.v1.get_variable` or 421 `tf.compat.v1.get_default_graph().get_collection(...)`) 422 423 Outside a function, variables and collections may be accessed using the 424 `variables` and `graph` properties. 425 426 Example: 427 428 ``` 429 def add_v1(x): 430 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 431 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 432 return v + x 433 434 def increment_var_v1(x): 435 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 436 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 437 return v.assign_add(x) 438 439 g = WrappedGraph() 440 add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)]) 441 increment_var = g.wrap_function(increment_var_v1, 442 [tf.TensorSpec([], tf.int32)]) 443 444 assert len(g.variables) == 1 445 assert g.variables[0].numpy() == 0 446 increment_var(tf.constant(5)) 447 assert g.variables[0].numpy() == 5 448 449 ``` 450 """ 451 452 def __init__(self, variable_holder=None, **kwargs): 453 self._variable_holder = ( 454 variable_holder or VariableHolder(share_variables=True)) 455 456 name = kwargs.pop("name", "wrapped_function_graph") 457 # Always start with empty collections, unless otherwise specified. Setting 458 # `collections=None` will copy the collections from the outer graph. 459 collections = kwargs.pop("collections", {}) 460 self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) 461 462 self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) 463 self._functions = {} 464 465 @property 466 def functions(self): 467 return self._functions 468 469 @property 470 def variables(self): 471 return self._variable_holder.variables 472 473 def wrap_function(self, fn, signature, name=None): 474 """Wraps a TF 1.X function and returns an eager-compatible function. 475 476 All functions wrapped in the same `WrappedGraph` will have access to the 477 same graph (`tf.compat.v1.get_default_graph` to get the graph object 478 within a function, or `WrappedGraph.graph` to get the graph outside a 479 function). Variables created within the function will be added to the 480 `variables` list. 481 482 Function inputs: All inputs to the function must be tensors (nested ok), 483 with their shapes and dtypes defined in the `signature` argument. 484 485 Function outputs: 486 487 * The 1.X function may return tensors, variables, and ops. The wrapped 488 eager-compatible function will always return tensors in the same nested 489 structure. 490 * Variables are replaced with a tensor containing the latest read values. 491 * Returned ops are executed, and replaced with None. 492 * The order of op execution and variable reads in the return is 493 nondeterministic. For example: 494 495 ``` 496 def update_var(x): 497 v = tf.Variable(0) 498 op = tf.compat.v1.assign(v, x).op 499 return v, op 500 501 g = WrappedGraph() 502 fn = g.wrap_function(update_var) 503 read_value, _ = fn(tf.constant(3)) 504 print(read_value.numpy()) # could be 0 or 3 505 print(g.variables[0].numpy()) # always 3 506 ``` 507 508 To ensure that ops in the function are executed (e.g. ops added to the 509 `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns. 510 511 Args: 512 fn: a 1.X tensorflow function. 513 signature: a possibly nested sequence of `TensorSpecs` specifying the 514 shapes and dtypes of the arguments. 515 name: an optional string name for the function. The function will be saved 516 with key `name` in the `functions` dictionary. 517 518 Returns: 519 An eager-compatible function. 520 """ 521 return self._wrap_function(fn, signature=signature, name=name) 522 523 def _wrap_function(self, 524 fn, 525 args=None, 526 kwargs=None, 527 signature=None, 528 name=None): 529 """Internal wrap function method with extended func_graph arguments.""" 530 fn_with_filter_and_scope, returned_ops = _filter_returned_ops( 531 self._variable_holder.call_with_variable_creator_scope(fn)) 532 533 func_graph.func_graph_from_py_func( 534 None, # Name is unused. 535 fn_with_filter_and_scope, 536 args=args, 537 kwargs=kwargs, 538 signature=signature, 539 add_control_dependencies=False, 540 func_graph=self.graph) 541 542 # This code relies on questional behavior from `func_graph_from_py_func`. 543 # If an existing FuncGraph is passed into the `func_graph` arg, the inputs 544 # and structured outputs are overwritten. Pretty sure this is a bug, 545 # because structured outputs doesn't match up with the outputs... 546 fn_inputs = self.graph.inputs[:-len(self.graph.captures)] 547 548 # Return filtered ops to the flattened outputs. 549 flat_fn_outputs = nest.flatten(self.graph.structured_outputs) 550 for index, op in returned_ops.items(): 551 flat_fn_outputs[index] = op 552 fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, 553 flat_fn_outputs) 554 555 name = name or fn.__name__ 556 wrapped_function = self._wrapped_function.prune( 557 fn_inputs, fn_outputs, name, self.graph.structured_input_signature) 558 self._functions[name] = wrapped_function 559 return wrapped_function 560 561 562@tf_export(v1=["wrap_function"]) 563def wrap_function(fn, signature, name=None): 564 """Wraps the TF 1.x function fn into a graph function. 565 566 The python function `fn` will be called once with symbolic arguments specified 567 in the `signature`, traced, and turned into a graph function. Any variables 568 created by `fn` will be owned by the object returned by `wrap_function`. The 569 resulting graph function can be called with tensors which match the 570 signature. 571 572 ```python 573 def f(x, do_add): 574 v = tf.Variable(5.0) 575 if do_add: 576 op = v.assign_add(x) 577 else: 578 op = v.assign_sub(x) 579 with tf.control_dependencies([op]): 580 return v.read_value() 581 582 f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) 583 584 assert float(f_add(1.0)) == 6.0 585 assert float(f_add(1.0)) == 7.0 586 587 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set 588 # of variables, and possibly different non-template arguments. 589 f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) 590 591 assert float(f_sub(1.0)) == 4.0 592 assert float(f_sub(1.0)) == 3.0 593 ``` 594 595 Both `tf.compat.v1.wrap_function` and `tf.function` create a callable 596 TensorFlow graph. But while `tf.function` runs all stateful operations 597 (e.g. `tf.print`) and sequences operations to provide the same semantics as 598 eager execution, `wrap_function` is closer to the behavior of `session.run` in 599 TensorFlow 1.x. It will not run any operations unless they are required to 600 compute the function's outputs, either through a data dependency or a control 601 dependency. Nor will it sequence operations. 602 603 Unlike `tf.function`, `wrap_function` will only trace the Python function 604 once. As with placeholders in TF 1.x, shapes and dtypes must be provided to 605 `wrap_function`'s `signature` argument. 606 607 Since it is only traced once, variables and state may be created inside the 608 function and owned by the function wrapper object. 609 610 Args: 611 fn: python function to be wrapped 612 signature: the placeholder and python arguments to be passed to the wrapped 613 function 614 name: Optional. The name of the function. 615 616 Returns: 617 the wrapped graph function. 618 """ 619 holder = VariableHolder(fn) 620 func_graph_name = "wrapped_function" 621 if name is not None: 622 func_graph_name = "wrapped_function_" + name 623 return WrappedFunction( 624 func_graph.func_graph_from_py_func( 625 func_graph_name, 626 holder, 627 args=None, 628 kwargs=None, 629 signature=signature, 630 add_control_dependencies=False, 631 collections={}), 632 variable_holder=holder, 633 signature=signature) 634 635 636def function_from_graph_def(graph_def, inputs, outputs): 637 """Creates a ConcreteFunction from a GraphDef. 638 639 Args: 640 graph_def: A GraphDef to make a function out of. 641 inputs: A Tensor name or nested structure of names in `graph_def` which 642 should be inputs to the function. 643 outputs: A Tensor name or nested structure of names in `graph_def` which 644 should be outputs of the function. 645 646 Returns: 647 A ConcreteFunction. 648 """ 649 650 def _imports_graph_def(): 651 importer.import_graph_def(graph_def, name="") 652 653 wrapped_import = wrap_function(_imports_graph_def, []) 654 import_graph = wrapped_import.graph 655 return wrapped_import.prune( 656 nest.map_structure(import_graph.as_graph_element, inputs), 657 nest.map_structure(import_graph.as_graph_element, outputs)) 658