1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Python front-end supports for functions. 16 17NOTE: At this time, functions are experimental and subject to change!. Proceed 18with caution. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import hashlib 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.core.framework import function_pb2 30from tensorflow.python.client import pywrap_tf_session as c_api 31from tensorflow.python.eager import context 32from tensorflow.python.framework import c_api_util 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import graph_to_function_def 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import variable_scope as vs 39from tensorflow.python.util import compat 40from tensorflow.python.util import function_utils 41from tensorflow.python.util import tf_contextlib 42from tensorflow.python.util import tf_inspect 43 44 45class Defun(object): 46 """Decorator used to define TensorFlow functions. 47 48 Use this decorator to make a Python function usable directly as a TensorFlow 49 function. 50 51 The decorated function must add ops to the default graph and return zero or 52 more `Tensor` objects. Call the decorator with named arguments, one for each 53 argument of the function to decorate, with the expected type of the argument 54 as value. 55 56 For example if the function to decorate accepts two `tf.float32` arguments 57 named `x` and `y`, call the decorator with: 58 59 @Defun(tf.float32, tf.float32) 60 def foo(x, y): 61 ... 62 63 When you call the decorated function, it adds the `call` ops to the 64 default graph. In addition, it adds the definition of the function into the 65 default graph. Because the addition of the function into the graph 66 is deferred, the decorator can be used anywhere in the program. 67 68 Any variables created inside of the function are hoisted into the outer graph. 69 Note that the variables are created in the variable scope that was active 70 during the first call to the function. Subsequent function calls will refer to 71 the same set of variables. 72 73 Definitions of functions in a graph are frozen as soon as the graph is used to 74 create a session. However, new functions and new calls to existing functions 75 may be added to the graph, with the new functions themselves becoming 76 immediately frozen. 77 78 Example, but also see the [How To on functions](link_needed). 79 80 ```python 81 # Defining the function. 82 @tf.Defun(tf.float32, tf.float32) 83 def MyFunc(x, y): 84 return x + y, x - y 85 86 # Building the graph. 87 a = tf.constant([1.0]) 88 b = tf.constant([2.0]) 89 c, d = MyFunc(a, b, name='mycall') 90 ``` 91 """ 92 93 def __init__(self, *input_types, **kwargs): 94 """Create a `Defun` decorator. 95 96 Args: 97 *input_types: A list of `tf.DType` 98 **kwargs: Optional keyword arguments, including 99 func_name - (optional). A python string, the name to use to 100 declare this `Function` in the graph. 101 102 grad_func - (optional). A function implementing the gradient 103 of the function-to-register. This is must be a 104 `_DefinedFunction` object. The gradient 105 function must satisfy the criterion defined in 106 function.proto:GradientDef. 107 108 python_grad_func - (optional). A function implementing the 109 gradient of the function python-side. This function must 110 take the current op and the gradients w.r.t. its outputs, 111 and return the gradients w.r.t. the inputs. That is it must 112 implement the interface expected by `tf.RegisterGradient`). 113 This will be called by tf.gradients to add the gradient ops 114 to the graph. At most one of grad_func and python_grad_func 115 can be specified. 116 117 out_names = (optional). A list of strings, one per output 118 tensor. 119 120 shape_func - (optional). A function taking the op and returning a list 121 of static shapes to set for the function's outputs. 122 """ 123 self._input_types = input_types 124 self._func_name = kwargs.pop("func_name", None) 125 self._grad_func = kwargs.pop("grad_func", None) 126 self._python_grad_func = kwargs.pop("python_grad_func", None) 127 self._out_names = kwargs.pop("out_names", None) 128 self._extra_kwargs = kwargs 129 130 def __call__(self, func): 131 # Various sanity checks on the callable func. 132 if not callable(func): 133 raise ValueError("function %s must be callable" % func) 134 135 # Func should not use kwargs and defaults. 136 argspec = tf_inspect.getargspec(func) 137 if argspec.keywords or argspec.defaults: 138 raise ValueError( 139 "function with argument defaults or keywords arguments are not" 140 " supported. {} has defaults {} and keywords {}.".format( 141 func, argspec.defaults, argspec.keywords)) 142 143 # Computes how many arguments 'func' has. 144 min_args = len(argspec.args) 145 max_args = min_args 146 if argspec.varargs: 147 max_args = 1000000 148 argnames = argspec.args 149 if tf_inspect.ismethod(func): 150 # 1st argument is the "class" type. 151 min_args -= 1 152 argnames = argnames[1:] 153 154 if self._input_types: 155 # If Defun is given a list of types for the inputs, the number 156 # of input types should be compatible with 'func'. 157 num = len(self._input_types) 158 if num < min_args or num > max_args: 159 raise ValueError( 160 "The function has fewer arguments than the number of specified " 161 "input types.") 162 return _DefinedFunction( 163 func, 164 argnames, 165 self._input_types, 166 self._func_name, 167 self._grad_func, 168 self._python_grad_func, 169 out_names=self._out_names, 170 **self._extra_kwargs) 171 172 # 'func' expects no arguments and input types is an empty list. 173 if min_args == 0 and max_args == 0: 174 return _DefinedFunction( 175 func, [], [], 176 self._func_name, 177 self._grad_func, 178 self._python_grad_func, 179 out_names=self._out_names, 180 **self._extra_kwargs) 181 182 # Input types are unknown. It's an overloaded function and hence 183 # its definition needs to be deferred until it's called. 184 return _OverloadedFunction( 185 func, 186 argnames, 187 self._func_name, 188 self._grad_func, 189 self._python_grad_func, 190 out_names=self._out_names, 191 **self._extra_kwargs) 192 193 194class _DefinedFunctionDeleter(object): 195 """Unregister function from eager context.""" 196 197 __slots__ = ["name"] 198 199 def __init__(self, name): 200 self.name = name 201 202 def __del__(self): 203 try: 204 context.remove_function(self.name) 205 except TypeError: 206 # Suppress some exceptions, mainly for the case when we're running on 207 # module deletion. Things that can go wrong include the context module 208 # already being unloaded, self._handle._handle_data no longer being 209 # valid, and so on. Printing warnings in these cases is silly 210 # (exceptions raised from __del__ are printed as warnings to stderr). 211 pass # 'NoneType' object is not callable when the handle has been 212 # partially unloaded. 213 except AttributeError: 214 pass # 'NoneType' object has no attribute 'eager_mode' when context has 215 # been unloaded. Will catch other module unloads as well. 216 217 218class _DefinedFunction(object): 219 """_DefinedFunction encapsulates a function definition and its properties. 220 221 Attributes: 222 name: The function name. 223 definition: The definition of this function. A FunctionDef proto. 224 grad_func_name: If not None, the name of this function's gradient function. 225 python_grad_func: A python callable implementing the gradient of 226 the function python-side. 227 """ 228 229 def __init__(self, 230 func, 231 argnames, 232 input_types, 233 func_name=None, 234 grad_func=None, 235 python_grad_func=None, 236 out_names=None, 237 shape_func=None, 238 capture_by_value=False, 239 allowlisted_stateful_ops=None, 240 capture_resource_var_by_value=True, 241 **kwargs): 242 """Creates _DefinedFunction. 243 244 Args: 245 func: A python callable which constructs a tf function body. 246 argnames: A list of strings for function argument names. 247 input_types: The function's argument types. Can be a tuple, list of 248 tf data types. 249 func_name: The function name. Defaults to None, in which derives from 250 'func'. 251 grad_func: This function's gradient function, if not None. Defaults 252 to None. 253 python_grad_func: A python callable implementing the gradient of 254 the function python-side. 255 out_names: An optional list of strings for the function return value 256 names. 257 shape_func: An optional function mapping an op to a list of static 258 output shapes. 259 capture_by_value: Boolean (defaults to False). If True, captured values 260 will be copied into the function body. 261 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 262 copy into the function body, when `capture_by_value` is True. 263 capture_resource_var_by_value: Boolean (defaults to True). If False, 264 captured resource variable returns the handle instead of value. 265 **kwargs: The keyword arguments. **kwargs is passed to every call 266 site of this function. 267 268 Raises: 269 ValueError: The function definition is invalid. 270 271 """ 272 self._func = func 273 self._input_types = input_types 274 self._func_name = func_name 275 self._grad_func = grad_func 276 self._python_grad_func = python_grad_func 277 self._out_names = out_names 278 self._shape_func = shape_func 279 self._capture_by_value = capture_by_value 280 self._allowlisted_stateful_ops = allowlisted_stateful_ops 281 if self._allowlisted_stateful_ops is None: 282 self._allowlisted_stateful_ops = set() 283 self._capture_resource_var_by_value = capture_resource_var_by_value 284 self._extra_kwargs = kwargs 285 # Constructed only when C API is disabled, lazily 286 self._definition = None 287 # Constructed only when C API is enabled, lazily 288 self._c_func = None 289 self._function_deleter = None 290 self._sub_functions = {} # Constructed with _definition or _c_func 291 # pylint: disable=protected-access 292 device_funcs = ops.get_default_graph()._device_functions_outer_to_inner 293 # pylint: enable=protected-access 294 295 # Get the innermost device if possible. 296 self._caller_device = device_funcs[-1] if device_funcs else None 297 298 # Cached OpDef for this function. When C API is enabled, this is 299 # the only part of FunctionDef that we cache in Python. When C API 300 # is disabled the whole _definition is available and this is simply 301 # another reference to _definition.signature 302 self._op_def = None 303 304 assert isinstance(input_types, (list, tuple)) 305 self._arg_types = input_types 306 self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) 307 for i in range(len(input_types))] 308 309 @property 310 def name(self): 311 """Function name.""" 312 self._create_definition_if_needed() 313 return self._func_name 314 315 @property 316 def definition(self): 317 """Function definition proto.""" 318 self._create_definition_if_needed() 319 if self._c_func: 320 with c_api_util.tf_buffer() as buf: 321 c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) 322 fdef = function_pb2.FunctionDef() 323 proto_data = c_api.TF_GetBuffer(buf) 324 fdef.ParseFromString(compat.as_bytes(proto_data)) 325 with ops.init_scope(): 326 if context.executing_eagerly(): 327 context.add_function(self._c_func.func) 328 self._function_deleter = _DefinedFunctionDeleter( 329 fdef.signature.name) 330 return fdef 331 return self._definition 332 333 @property 334 def _signature(self): 335 self._create_definition_if_needed() 336 return self._op_def 337 338 def set_grad_func(self, grad_func): 339 """Specifies the gradient function of this function.""" 340 assert not self._grad_func 341 assert isinstance(grad_func, _DefinedFunction) 342 self._grad_func = grad_func 343 344 @property 345 def grad_func_name(self): 346 """Returns the name of the gradient function.""" 347 return self._grad_func.name if self._grad_func else None 348 349 @property 350 def python_grad_func(self): 351 """Python gradient function callable.""" 352 return self._python_grad_func 353 354 @property 355 def declared_input_types(self): 356 """Returns the list of data types of explicit declared inputs.""" 357 return self._input_types 358 359 @property 360 def captured_inputs(self): 361 """Returns the list of implicitly captured inputs.""" 362 self._create_definition_if_needed() 363 return self._extra_inputs 364 365 @property 366 def stateful_ops(self): 367 """Returns the list of stateful ops in function definition. 368 369 Returns: 370 A list of (op.name, op.type) pairs. 371 """ 372 self._create_definition_if_needed() 373 return self._stateful_ops 374 375 def _create_definition_if_needed(self): 376 """Creates the function definition if it's not created yet.""" 377 with context.graph_mode(): 378 self._create_definition_if_needed_impl() 379 380 def _create_definition_if_needed_impl(self): 381 """This is not what you want, see _create_definition_if_needed.""" 382 if self._definition is not None or self._c_func is not None: 383 return 384 385 # Copy variable collections (by reference) from the parent graph such that 386 # name based variable sharing (e.g. via tf.make_template) works between the 387 # func graph and parent graph. 388 variable_keys = [] 389 variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access 390 variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access 391 392 parent_graph = ops.get_default_graph() 393 collections_ref = { 394 key: parent_graph.get_collection_ref(key) for key in variable_keys} 395 396 temp_graph = func_graph_from_py_func( 397 self._func, 398 self._arg_names, 399 self._arg_types, 400 self._func_name, 401 self._capture_by_value, 402 self._caller_device, 403 collections_ref=collections_ref, 404 allowlisted_stateful_ops=self._allowlisted_stateful_ops, 405 capture_resource_var_by_value=self._capture_resource_var_by_value) 406 407 self._extra_inputs = temp_graph.extra_inputs 408 # pylint: disable=protected-access 409 self._sub_functions = temp_graph._functions 410 # pylint: enable=protected-access 411 412 # Extra kwargs are treated as attrs on the function def. 413 if self._func_name: 414 base_func_name = self._func_name 415 else: 416 base_func_name = function_utils.get_func_name(self._func) 417 if self._grad_func: 418 base_func_name += ("_%s" % self._grad_func.name) 419 kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) 420 421 if not temp_graph._c_graph: # pylint: disable=protected-access 422 # Build the FunctionDef 423 self._definition = graph_to_function_def.graph_to_function_def( 424 temp_graph, 425 temp_graph.get_operations(), 426 temp_graph.inputs, 427 temp_graph.outputs, 428 out_names=self._out_names) 429 430 for k in kwargs_attr: 431 self._definition.attr[k].CopyFrom(kwargs_attr[k]) 432 433 # Hash the definition and its dependencies. 434 self._hash_str = self._create_hash_str( 435 self._definition.signature.input_arg, 436 self._definition.signature.output_arg, self._definition.node_def) 437 438 # Finally, we decide the function name to use. If not specified, 439 # make up something which is almost certainly unique (but deterministic). 440 if not self._func_name: 441 self._func_name = "_".join([base_func_name, self._hash_str]) 442 self._definition.signature.name = self._func_name 443 if self._func.__doc__: 444 self._definition.signature.description = self._func.__doc__ 445 446 self._op_def = self._definition.signature 447 else: # C API is enabled 448 output_names = ([compat.as_bytes(x) for x in self._out_names] 449 if self._out_names else []) 450 description = self._func.__doc__ or None 451 # pylint: disable=protected-access 452 c_func = c_api.TF_GraphToFunction_wrapper( 453 temp_graph._c_graph, 454 base_func_name, 455 self._func_name is None, # append_hash_to_fn_name 456 None, # opers 457 [t._as_tf_output() for t in temp_graph.inputs], 458 [t._as_tf_output() for t in temp_graph.outputs], 459 output_names, 460 [], # control_outputs 461 [], # control_output_names 462 None, # opts 463 description) 464 self._c_func = c_api_util.ScopedTFFunction(c_func) 465 # pylint: enable=protected-access 466 self._set_c_attrs(kwargs_attr) 467 468 # Set cached fields: _op_def and _func_name (if not already set) 469 self._op_def = self.definition.signature 470 if self._func_name: 471 assert self._func_name == self._op_def.name 472 else: 473 self._func_name = compat.as_str(self._op_def.name) 474 475 self._stateful_ops = [(op.name, op.type) 476 for op in temp_graph.get_operations() 477 if op._is_stateful] # pylint: disable=protected-access 478 479 def _set_c_attrs(self, attrs): 480 """Sets `attrs` as attributes of self._c_func. 481 482 Requires that self._c_func is not None. 483 484 Args: 485 attrs: a dictionary from attribute name to attribute proto value 486 """ 487 for name, attr_value in attrs.items(): 488 serialized = attr_value.SerializeToString() 489 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 490 # It might be worth creating a convenient way to re-use the same status. 491 c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name), 492 serialized) 493 494 def _create_hash_str(self, input_arg, output_arg, node_def): 495 """Creates an 8-character string unique to this input. 496 497 Args: 498 input_arg: the input_arg field of an OpDef 499 (e.g. self._definition.signature.input_arg) 500 output_arg: the output_arg field of an OpDef 501 (e.g. self._definition.signature.output_arg) 502 node_def: the node_def field of a FunctionDef 503 (e.g. self._definition.node_def) 504 505 Returns: 506 The unique string for this input 507 """ 508 hasher = hashlib.sha1() 509 510 def update_num(n): 511 hasher.update(compat.as_bytes("%x" % n)) 512 513 def update_str(s): 514 update_num(len(s)) 515 hasher.update(compat.as_bytes(s)) 516 517 def update_strs(slist): 518 update_num(len(slist)) 519 for s in slist: 520 update_str(s) 521 522 for adef in input_arg: 523 update_str(adef.SerializeToString()) 524 525 for adef in output_arg: 526 update_str(adef.SerializeToString()) 527 528 for n in sorted(node_def, key=lambda n: n.name): 529 update_str(n.name) 530 update_str(n.op) 531 update_strs(n.input) 532 update_num(len(n.attr)) 533 # NOTE: protobuf map serialization does not guarantee ordering. 534 for k in sorted(n.attr): 535 update_str(k) 536 update_str(n.attr[k].SerializeToString()) 537 538 return hasher.hexdigest()[:8] 539 540 def add_to_graph(self, g): 541 """Adds this function into the graph g.""" 542 self._create_definition_if_needed() 543 544 # Adds this function into 'g'. 545 # pylint: disable=protected-access 546 if context.executing_eagerly(): 547 context.context().add_function_def(self.definition) 548 else: 549 g._add_function(self) 550 # pylint: enable=protected-access 551 552 # Ensures related sub-routines are defined in 'g', too. 553 for f in self._sub_functions.values(): 554 f.add_to_graph(g) 555 556 # Adds its gradient function, too. 557 if self._grad_func: 558 self._grad_func.add_to_graph(g) 559 560 def __call__(self, *args, **kwargs): 561 self.add_to_graph(ops.get_default_graph()) 562 args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs 563 ret, op = _call(self._signature, *args, **kwargs) 564 565 # Set a hidden attr in 'op' so that gradients_impl can refer back 566 # to this _DefinedFunction instance to access python_grad_func. 567 assert isinstance(op, ops.Operation) 568 setattr(op, "__defun", self) 569 570 if self._shape_func is not None: 571 shapes = self._shape_func(op) 572 if len(shapes) != len(op.outputs): 573 raise ValueError("shape_func produced %d shapes for %d outputs" % 574 (len(shapes), len(op.outputs))) 575 for (t, shape) in zip(op.outputs, shapes): 576 t.set_shape(shape) 577 return ret 578 579 580class _OverloadedFunction(object): 581 """_OverloadedFunction encapsulates an overloaded function. 582 583 _OverloadedFunction maintains a mapping from input types to 584 instantiated _DefinedFunction in self._overload. 585 586 """ 587 588 def __init__(self, 589 func, 590 argnames, 591 func_name=None, 592 grad_func=None, 593 python_grad_func=None, 594 out_names=None, 595 **kwargs): 596 """Creates _DefinedFunction. 597 598 Args: 599 func: A python callable which constructs a tf function body. 600 argnames: A list of strings for function argument names. 601 func_name: The function name. Defaults to None, in which derives from 602 'func'. 603 grad_func: This function's gradient function, if not None. Defaults 604 to None. 605 python_grad_func: A python callable implementing the gradient of 606 the function python-side. 607 out_names: A list of strings for the function return value names. 608 **kwargs: The keyword arguments. **kwargs is passed to every call 609 site of this function. 610 611 Raises: 612 ValueError: The function definition is invalid. 613 614 """ 615 self._func = func 616 self._argnames = argnames 617 self._func_name = func_name 618 assert grad_func is None or isinstance(grad_func, _OverloadedFunction) 619 self._grad_func = grad_func 620 self._python_grad_func = python_grad_func 621 self._out_names = out_names 622 self._extra_kwargs = kwargs 623 self._overload = {} 624 625 def instantiate(self, input_types): 626 """Instantiate this function given input argument types. 627 628 Args: 629 input_types: A list of data types for the inputs. 630 631 Returns: 632 _DefinedFunction for the given input types. 633 634 """ 635 # Stringify the type list. 636 key = _type_list_to_str(input_types) 637 defined = self._overload.get(key) 638 if not defined: 639 # If not defined yet, define the function given the input types. 640 name = self._func_name 641 if name is not None: 642 name = "_".join([name, key]) 643 defined = _DefinedFunction( 644 self._func, 645 self._argnames, 646 input_types, 647 name, 648 None, 649 self._python_grad_func, 650 out_names=self._out_names, 651 **self._extra_kwargs) 652 _ = defined.name # Fully instantiate the function definition. 653 if self._grad_func: 654 # If _grad_func is given, it is another 655 # _OverloadedFunction. We need to instantiate it with the 656 # right input types. 657 output_types = [ 658 dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access 659 ] 660 # pylint: disable=protected-access 661 defined._grad_func = self._grad_func.instantiate(input_types + 662 output_types) 663 # pylint: enable=protected-access 664 self._overload[key] = defined 665 return defined 666 667 def __call__(self, *args, **kwargs): 668 input_types = [] 669 args = list(args) 670 for (i, x) in enumerate(args): 671 x = ops.convert_to_tensor(x) 672 if not isinstance(x, ops.Tensor): 673 raise ValueError("Expect a Tensor but get ", x) 674 input_types.append(x.dtype) 675 args[i] = x 676 return self.instantiate(input_types)(*args, **kwargs) 677 678 679class _FuncGraph(ops.Graph): 680 """A helper for constructing a function. 681 682 _FuncGraph overrides ops.Graph's create_op() so that we can keep 683 track of all inputs into every op created inside the function. If 684 any input is from other graphs, we keep track of it in self.capture 685 and substitute the input with a place holder. 686 687 Each captured input's corresponding place holder is converted into a 688 function argument and the caller passes in the captured tensor. 689 """ 690 691 def __init__(self, name, capture_by_value, allowlisted_stateful_ops, 692 capture_resource_var_by_value, *args, **kwargs): 693 super(_FuncGraph, self).__init__(*args, **kwargs) 694 self._capture_by_value = capture_by_value 695 self._allowlisted_stateful_ops = allowlisted_stateful_ops 696 self._capture_resource_var_by_value = capture_resource_var_by_value 697 self._building_function = True 698 self._outer_graph = ops.get_default_graph() 699 self._vscope = vs.get_variable_scope() 700 self._old_custom_getter = self._vscope.custom_getter 701 702 # The name of the function. 703 self.name = name 704 # Placeholder tensors representing the inputs to this function. The tensors 705 # are in this _FuncGraph. 706 self.inputs = [] 707 # Tensors that will be returned this function. The tensors are in this 708 # _FuncGraph. 709 self.outputs = [] 710 # Maps external tensor -> internal tensor (e.g. input placeholder). 711 self._captured = {} 712 # The external tensors that have been captured as inputs and must be passed 713 # to this function (empty if capturing by value, otherwise these are the 714 # keys of _captured). 715 self.extra_inputs = [] 716 # Input placeholders that been added for captured values (empty if capturing 717 # by value). 718 self.extra_args = [] 719 # Captured variables. 720 # TODO(skyewm): is this needed? 721 self.extra_vars = [] 722 723 # pylint: disable=g-doc-return-or-yield 724 725 @property 726 def outer_graph(self): 727 """The graph active when this _FuncGraph was created.""" 728 return self._outer_graph 729 730 @tf_contextlib.contextmanager 731 def container(self, container_name): 732 """Returns a context manager that specifies the resource container to use. 733 734 Overridden from `tf.Graph` to update both the init_scope container 735 and the present inner container. This is necessary to make sure setting 736 containers applies correctly both to created variables and to stateful 737 ops. 738 739 Args: 740 container_name: container name string. 741 742 Returns: 743 A context manager for defining resource containers for stateful ops, 744 yields the container name. 745 """ 746 original_container = self._container 747 # pylint: disable=protected-access 748 with ops.init_scope(): 749 original_init_container = ops.get_default_graph()._container 750 try: 751 self._container = container_name 752 with ops.init_scope(): 753 ops.get_default_graph()._container = container_name 754 yield self._container 755 finally: 756 self._container = original_container 757 with ops.init_scope(): 758 ops.get_default_graph()._container = original_init_container 759 # pylint: enable=protected-access 760 761 # pylint: enable=g-doc-return-or-yield 762 763 def getvar( 764 self, 765 getter, 766 name, 767 shape=None, 768 dtype=None, 769 initializer=None, 770 reuse=None, 771 trainable=True, 772 collections=None, # pylint: disable=redefined-outer-name 773 use_resource=None, 774 **kwargs): 775 """A custom variable getter.""" 776 # Here, we switch the default graph to the outer graph and ask the 777 # variable scope in which the function is defined to give us the 778 # variable. The variable is stashed in extra_vars and returned to 779 # the caller. 780 # 781 # We capture these variables so that the variable definition is 782 # hoisted upward to the outer most graph. 783 with self._outer_graph.as_default(): 784 # pylint: disable=protected-access 785 var = self._vscope.get_variable( 786 vs._get_default_variable_store(), 787 name, 788 shape=shape, 789 dtype=dtype, 790 initializer=initializer, 791 reuse=reuse, 792 trainable=trainable, 793 collections=collections, 794 use_resource=use_resource) 795 self.extra_vars.append(var) 796 if (isinstance(var, resource_variable_ops.BaseResourceVariable) and 797 self._capture_resource_var_by_value): 798 # For resource-based variables read the variable outside the function 799 # and pass in the value. This ensures that the function is pure and 800 # differentiable. TODO(apassos) this may have performance problems if 801 # the function will only do embedding lookups on the variable. 802 return var.value() 803 return var 804 805 def _create_op_internal( 806 self, 807 op_type, 808 inputs, 809 dtypes=None, # pylint: disable=redefined-outer-name 810 input_types=None, 811 name=None, 812 attrs=None, 813 op_def=None, 814 compute_device=True): 815 for i, x in enumerate(inputs): 816 if isinstance(x, ops.EagerTensor) or x.graph is not self: 817 inputs[i] = self.capture(x) 818 return super(_FuncGraph, self)._create_op_internal( 819 op_type, 820 inputs, 821 dtypes=dtypes, 822 input_types=input_types, 823 name=name, 824 attrs=attrs, 825 op_def=op_def, 826 compute_device=compute_device) 827 828 def capture(self, tensor, name=None): 829 """Adds the given tensor to this graph and returns the captured tensor.""" 830 if tensor.ref() in self._captured: 831 # Captured already. 832 return self._captured[tensor.ref()] 833 elif self._capture_by_value: 834 return self._add_tensor_and_parents(tensor) 835 else: 836 return self._capture_tensor_as_extra_input(tensor, name) 837 838 @property 839 def captures(self): 840 """Pairs of tensors and captured tensor.""" 841 return [(k.deref(), v) for k, v in self._captured.items()] 842 843 def _capture_tensor_as_extra_input(self, tensor, name=None): 844 # Substitute with a placeholder. 845 self.extra_inputs.append(tensor) 846 # Hoist the new input placeholder out of any control flow context 847 # we're currently in. 848 with ops.control_dependencies(None): 849 ph = array_ops.placeholder( 850 tensor.dtype, shape=tensor.get_shape(), name=name) 851 # pylint: disable=protected-access 852 if isinstance(tensor, ops.EagerTensor): 853 handle_data = tensor._handle_data 854 if handle_data: 855 handle_data = handle_data.SerializeToString() 856 else: 857 handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph, 858 tensor._as_tf_output()) 859 860 if handle_data: 861 c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), 862 compat.as_bytes(handle_data)) 863 # pylint: enable=protected-access 864 self.inputs.append(ph) 865 self._captured[tensor.ref()] = ph 866 self.extra_args.append(ph) 867 if _is_guaranteed_const(tensor): 868 with ops.control_dependencies(None): 869 return array_ops.guarantee_const(ph) 870 else: 871 return ph 872 873 def _add_tensor_and_parents(self, tensor): 874 op = self._add_op_and_parents(tensor.op) 875 return op.outputs[tensor.value_index] 876 877 def _add_op_and_parents(self, op): 878 # pylint: disable=protected-access 879 op_def = graph_to_function_def._get_op_def(op) 880 if op._is_stateful and op not in self._allowlisted_stateful_ops: 881 raise ValueError("Cannot capture a stateful node (name:%s, type:%s) " 882 "by value." % (op.name, op.type)) 883 elif op.type in ("Placeholder", "PlaceholderV2"): 884 raise ValueError("Cannot capture a placeholder (name:%s, type:%s) " 885 "by value." % (op.name, op.type)) 886 # pylint: enable=protected-access 887 888 captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs] 889 890 captured_op = self._create_op_internal( 891 op.type, 892 captured_inputs, [o.dtype for o in op.outputs], 893 name=op.name, 894 attrs=op.node_def.attr, 895 op_def=op_def) 896 897 for t, captured_t in zip(op.outputs, captured_op.outputs): 898 self._captured[t.ref()] = captured_t 899 900 return captured_op 901 902 903def func_graph_from_py_func(func, 904 arg_names, 905 arg_types, 906 name=None, 907 capture_by_value=False, 908 device=None, 909 colocation_stack=None, 910 container=None, 911 collections_ref=None, 912 arg_shapes=None, 913 allowlisted_stateful_ops=None, 914 capture_resource_var_by_value=True): 915 """Returns a _FuncGraph generated from `func`. 916 917 Args: 918 func: A Python callable which constructs a TF function body. The arguments 919 must correspond to `arg_types`. Returns a value or list/tuple of values. 920 No returned value can be None. 921 arg_names: A sequence of strings for the function argument names. 922 arg_types: A sequence of the function's argument types. 923 name: The function name. If None, the name is derived from `func`. 924 capture_by_value: boolean. If True, captured values will be copied into the 925 function body. 926 device: device name or function. 927 colocation_stack: A colocation stack (list) the _FuncGraph should use. 928 container: A container name the _FuncGraph should start with. 929 collections_ref: A reference to a collections dict the _FuncGraph should 930 use internally. 931 arg_shapes: A sequence of the function's argument shapes. 932 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 933 re-create. 934 capture_resource_var_by_value: Boolean (defaults to True). If False, 935 captured resource variable returns the handle instead of value. 936 937 Returns: 938 A _FuncGraph. 939 940 Raises: 941 ValueError: if func returns None. 942 """ 943 if not name: 944 name = function_utils.get_func_name(func) 945 func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops, 946 capture_resource_var_by_value) 947 948 with func_graph.as_default(), ops.device(device): 949 # pylint: disable=protected-access 950 if collections_ref is not None: 951 func_graph._collections = collections_ref 952 if container is not None: 953 func_graph._container = container 954 if colocation_stack is not None: 955 func_graph._colocation_stack = colocation_stack 956 # pylint: enable=protected-access 957 958 if arg_shapes is None: 959 arg_shapes = [None] * len(arg_types) 960 961 # Create placeholders for the function arguments. 962 for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes): 963 argholder = array_ops.placeholder(argtype, shape=argshape, name=argname) 964 func_graph.inputs.append(argholder) 965 # Call func and gather the output tensors. 966 with vs.variable_scope("", custom_getter=func_graph.getvar): 967 outputs = func(*func_graph.inputs) 968 969 # There is no way of distinguishing between a function not returning 970 # anything and a function returning None in Python. 971 # We need to allow the former and ideally want to forbid the latter as 972 # it is most likely user error. 973 # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to 974 # allow users to explicitly mark the function as not returning anything. 975 # For now, we allow a single None return and interpret it as a function 976 # with no output. 977 if outputs is None: 978 outputs = [] 979 else: 980 # If func only returned one value, make it a tuple. 981 if not isinstance(outputs, (list, tuple)): 982 outputs = (outputs,) 983 if any(_ is None for _ in outputs): 984 raise ValueError("Function %s can not return None." % name) 985 # Ensures each output is a Tensor in the function graph. 986 outputs = [ops.convert_to_tensor(t) for t in outputs] 987 outputs = [func_graph.capture(t) if t.graph is not func_graph else t 988 for t in outputs] 989 func_graph.outputs = outputs 990 return func_graph 991 992 993def _is_guaranteed_const(tensor): 994 """Determines whether `tensor` is guaranteed to be a constant. 995 996 A tensor is guaranteed to be a constant if either it was produced by 997 a `GuaranteeConst` op or if all of its children are guaranteed to be 998 constants. 999 1000 Args: 1001 tensor: The tensor for which to determine const-ness. 1002 1003 Returns: 1004 True if `tensor` is guaranteed to be a constant, False otherwise. 1005 """ 1006 1007 if isinstance(tensor, ops.EagerTensor): 1008 return False 1009 1010 class Work(object): 1011 1012 def __init__(self, op, leaving): 1013 self.op = op 1014 self.leaving = leaving 1015 1016 is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst" 1017 constants = set([]) 1018 def all_inputs_const(op): 1019 # If all inputs of an op are guaranteed constants, then we can infer that 1020 # the op produces a constant as well. 1021 return op.inputs and all(inp.op in constants for inp in op.inputs) 1022 1023 visited = set([]) 1024 stack = [Work(tensor.op, leaving=False)] 1025 while stack: 1026 work = stack.pop() 1027 if work.leaving: 1028 if all_inputs_const(work.op): 1029 constants.add(work.op) 1030 continue 1031 visited.add(work.op) 1032 if is_guaranteed_const(work.op): 1033 constants.add(work.op) 1034 continue 1035 1036 # This op will be revisited after all its inputs are checked for const-ness. 1037 stack.append(Work(work.op, leaving=True)) 1038 for inp in work.op.inputs: 1039 if inp.op not in visited: 1040 stack.append(Work(inp.op, leaving=False)) 1041 return tensor.op in constants 1042 1043 1044def _call(sig, *inputs, **kwargs): 1045 """Adds a node calling a function. 1046 1047 This adds a `call` op to the default graph that calls the function 1048 of signature `sig`, passing the tensors in `inputs` as arguments. 1049 It returns the outputs of the call, which are one or more tensors. 1050 1051 `sig` is OpDefArg.a `_DefinedFunction` object. 1052 1053 You can pass an optional keyword parameter `name=string` to name the 1054 added operation. 1055 1056 You can pass an optional keyword parameter `noinline=True|False` to 1057 instruct the runtime not to inline the function body into the call 1058 site. 1059 1060 Args: 1061 sig: OpDefArg. The signature of the function. 1062 *inputs: arguments to the function. 1063 **kwargs: Optional keyword arguments. Can only contain 'name' or 1064 'noinline'. 1065 1066 Returns: 1067 A 2-element tuple. First element: a Tensor if the function returns a single 1068 value; a list of Tensors if the function returns multiple value; the 1069 Operation if the function returns no values. Second element: the Operation. 1070 1071 Raises: 1072 ValueError: if the arguments are invalid. 1073 """ 1074 if len(inputs) != len(sig.input_arg): 1075 raise ValueError("Expected number of arguments: %d, received: %d" % (len( 1076 sig.input_arg), len(inputs))) 1077 name = kwargs.pop("name", None) 1078 g = ops.get_default_graph() 1079 func_name = sig.name 1080 if name is None: 1081 name = func_name 1082 attrs = _parse_kwargs_as_attrs(func_name, **kwargs) 1083 output_types = [dtypes.DType(x.type) for x in sig.output_arg] 1084 op = g._create_op_internal( # pylint: disable=protected-access 1085 func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig) 1086 if op.outputs: 1087 if len(op.outputs) == 1: 1088 ret = op.outputs[0] 1089 else: 1090 ret = tuple(op.outputs) 1091 else: 1092 ret = op 1093 return ret, op 1094 1095 1096def _from_definition(fdef, grad_func=None): 1097 """Creates a _DefinedFunction initialized from a FunctionDef proto. 1098 1099 Args: 1100 fdef: a FunctionDef 1101 grad_func: a _DefinedFunction or None 1102 1103 Returns: 1104 A _DefinedFunction representing fdef 1105 """ 1106 # TODO(iga): This method does major surgery on _DefinedFunction. 1107 # Make it a named constructor using @classmethod of _DefinedFunction. 1108 1109 # The Python callable is only needed to create a FunctionDef. Since we have 1110 # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we 1111 # have access to such a callable here). 1112 func = None 1113 argnames = [arg.name for arg in fdef.signature.input_arg] 1114 input_types = tuple( 1115 dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) 1116 func_name = fdef.signature.name 1117 # Note: FunctionDefs do not include python gradient functions, so if the 1118 # original _DefinedFunction included one it will not be reflected here. 1119 python_grad_func = None 1120 out_names = [arg.name for arg in fdef.signature.output_arg] 1121 result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, 1122 python_grad_func, out_names) 1123 # pylint: disable=protected-access 1124 serialized = fdef.SerializeToString() 1125 c_func = c_api.TF_FunctionImportFunctionDef(serialized) 1126 result._c_func = c_api_util.ScopedTFFunction(c_func) 1127 result._extra_inputs = [] 1128 result._op_def = fdef.signature 1129 # pylint: enable=protected-access 1130 1131 return result 1132 1133 1134def from_library(lib): 1135 """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. 1136 1137 This method handles assigning the correct gradient functions to each 1138 function. 1139 1140 Args: 1141 lib: a FunctionDefLibrary 1142 1143 Returns: 1144 A list of _DefinedFunctions 1145 1146 Raises: 1147 ValueError: `lib` is invalid 1148 """ 1149 if not lib.function and not lib.gradient: 1150 return [] 1151 1152 # function name -> FunctionDef proto 1153 funcs = {fdef.signature.name: fdef for fdef in lib.function} 1154 1155 # Validate that all references function names have function defs 1156 for g in lib.gradient: 1157 if g.function_name not in funcs: 1158 raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" % 1159 (g.function_name, str(lib))) 1160 if g.gradient_func not in funcs: 1161 raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" % 1162 (g.gradient_func, str(lib))) 1163 1164 # function name -> gradient function name 1165 func_to_grad = collections.defaultdict(lambda: None) 1166 # gradient function name -> names of functions having that grad function 1167 grad_to_funcs = collections.defaultdict(list) 1168 1169 for gdef in lib.gradient: 1170 func_to_grad[gdef.function_name] = gdef.gradient_func 1171 grad_to_funcs[gdef.gradient_func].append(gdef.function_name) 1172 1173 # Start with functions without gradients 1174 ready = [ 1175 fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None 1176 ] 1177 if not ready: 1178 raise ValueError( 1179 "FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib)) 1180 # function name -> _DefinedFunction 1181 initialized = {} 1182 1183 while ready: 1184 fdef = ready.pop() 1185 name = fdef.signature.name 1186 1187 grad = initialized.get(func_to_grad[name]) 1188 if func_to_grad[name]: 1189 assert grad 1190 defined_func = _from_definition(fdef, grad_func=grad) 1191 initialized[name] = defined_func 1192 1193 ready.extend(funcs[f] for f in grad_to_funcs[name]) 1194 1195 return initialized.values() 1196 1197 1198def _get_experimental_kwarg_as_attr(attr_name, value): 1199 """Creates an AttrValue for a python object.""" 1200 if isinstance(value, bool): 1201 return attr_value_pb2.AttrValue(b=value) 1202 elif isinstance(value, int): 1203 return attr_value_pb2.AttrValue(i=value) 1204 elif isinstance(value, float): 1205 return attr_value_pb2.AttrValue(f=value) 1206 elif isinstance(value, str): 1207 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 1208 else: 1209 raise ValueError("Unsupported attribute type for %s with type %s" % 1210 (attr_name, type(value))) 1211 1212 1213def _get_kwarg_as_str_attr(attr_name, value): 1214 """Creates an AttrValue for a python object.""" 1215 if isinstance(value, str): 1216 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 1217 else: 1218 raise ValueError("Unsupported attribute type for %s with type %s" % 1219 (attr_name, type(value))) 1220 1221 1222def _parse_kwargs_as_attrs(func_name, **kwargs): 1223 """Parses **kwargs into a node's attributes.""" 1224 attrs = {} 1225 1226 noinline = kwargs.pop("noinline", None) 1227 if noinline is not None: 1228 attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) 1229 1230 # For compatibility with previous behavior, Defun does not perform shape 1231 # inference through its function call operations. 1232 attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True) 1233 1234 compiled = kwargs.pop("compiled", None) 1235 separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None) 1236 if compiled is not None: 1237 attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled)) 1238 attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue( 1239 b=bool(separate_compiled_gradients)) 1240 # Forward _XlaScope from enclosing context (if set), otherwise create new. 1241 # pylint: disable=protected-access 1242 if "_XlaScope" in ops.get_default_graph()._attr_scope_map: 1243 attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"] 1244 else: 1245 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 1246 s=("function_%s" % func_name).encode()) 1247 # pylint: enable=protected-access 1248 1249 kwargs_keys = list(kwargs.keys()) 1250 for key in kwargs_keys: 1251 if key.startswith("experimental_"): 1252 attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key]) 1253 del kwargs[key] 1254 # Support for https://github.com/tensorflow/community/pull/113/files. 1255 elif key == "_implements" or key == "_reference": 1256 attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key]) 1257 del kwargs[key] 1258 if kwargs: 1259 raise ValueError("Unknown keyword arguments: %s" % kwargs.keys()) 1260 return attrs 1261 1262 1263def get_extra_vars(): 1264 """Returns the captured variables by the function. 1265 1266 Returns: 1267 If the default graph is being used to define a function, the 1268 returned list of variables are those created inside the function 1269 body so far. Otherwise, returns an empty list. 1270 """ 1271 g = ops.get_default_graph() 1272 if isinstance(g, _FuncGraph): 1273 return g.extra_vars 1274 else: 1275 return [] 1276 1277 1278def get_extra_inputs(): 1279 """Returns the captured input tensors by the function. 1280 1281 Returns: 1282 If the default graph is being used to define a function, the 1283 returned list of tensors are those accessed inside the function body 1284 but defined outside the function body so far. Otherwise, returns an 1285 empty list. 1286 """ 1287 g = ops.get_default_graph() 1288 if isinstance(g, _FuncGraph): 1289 return g.extra_inputs 1290 else: 1291 return [] 1292 1293 1294def get_extra_args(): 1295 """Returns the corresponding function arguments for the captured inputs. 1296 1297 Returns: 1298 If the default graph is being used to define a function, the 1299 returned list of place holders are those used inside the function 1300 body corresponding those returned by get_extra_inputs(). Otherwise, 1301 returns an empty list. 1302 """ 1303 g = ops.get_default_graph() 1304 if isinstance(g, _FuncGraph): 1305 return g.extra_args 1306 else: 1307 return [] 1308 1309 1310def _type_list_to_str(types): 1311 if any(_ not in _DTYPE_TO_STR for _ in types): 1312 raise ValueError("Unsupported dtypes: %s" % types) 1313 return "".join(_DTYPE_TO_STR[_] for _ in types) 1314 1315 1316# NOTE: The list needs to be extended when more data types are added. 1317_DTYPE_TO_STR = { 1318 dtypes.float16: "f16", 1319 dtypes.float32: "f32", 1320 dtypes.float64: "f64", 1321 dtypes.int32: "i32", 1322 dtypes.uint8: "i8", 1323 dtypes.uint16: "u16", 1324 dtypes.uint32: "u32", 1325 dtypes.uint64: "u64", 1326 dtypes.int16: "i16", 1327 dtypes.int8: "i8", 1328 dtypes.string: "s", 1329 dtypes.complex64: "c64", 1330 dtypes.complex128: "c128", 1331 dtypes.int64: "i64", 1332 dtypes.bool: "b", 1333 dtypes.qint8: "qi8", 1334 dtypes.quint8: "qu8", 1335 dtypes.qint16: "qi16", 1336 dtypes.quint16: "qu16", 1337 dtypes.qint32: "qi32", 1338 dtypes.bfloat16: "b16" 1339} 1340 1341 1342def function_def_from_tf_function(c_func): 1343 """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto.""" 1344 with c_api_util.tf_buffer() as buf: 1345 c_api.TF_FunctionToFunctionDef(c_func, buf) 1346 data = c_api.TF_GetBuffer(buf) 1347 fdef = function_pb2.FunctionDef() 1348 fdef.ParseFromString(compat.as_bytes(data)) 1349 return fdef 1350