1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Python front-end supports for functions. 16 17NOTE: At this time, functions are experimental and subject to change!. Proceed 18with caution. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import hashlib 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.core.framework import function_pb2 30from tensorflow.python.client import pywrap_tf_session as c_api 31from tensorflow.python.eager import context 32from tensorflow.python.framework import c_api_util 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import graph_to_function_def 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import variable_scope as vs 39from tensorflow.python.util import compat 40from tensorflow.python.util import function_utils 41from tensorflow.python.util import tf_contextlib 42from tensorflow.python.util import tf_inspect 43 44 45class Defun(object): 46 """Decorator used to define TensorFlow functions. 47 48 Use this decorator to make a Python function usable directly as a TensorFlow 49 function. 50 51 The decorated function must add ops to the default graph and return zero or 52 more `Tensor` objects. Call the decorator with named arguments, one for each 53 argument of the function to decorate, with the expected type of the argument 54 as value. 55 56 For example if the function to decorate accepts two `tf.float32` arguments 57 named `x` and `y`, call the decorator with: 58 59 @Defun(tf.float32, tf.float32) 60 def foo(x, y): 61 ... 62 63 When you call the decorated function, it adds the `call` ops to the 64 default graph. In addition, it adds the definition of the function into the 65 default graph. Because the addition of the function into the graph 66 is deferred, the decorator can be used anywhere in the program. 67 68 Any variables created inside of the function are hoisted into the outer graph. 69 Note that the variables are created in the variable scope that was active 70 during the first call to the function. Subsequent function calls will refer to 71 the same set of variables. 72 73 Definitions of functions in a graph are frozen as soon as the graph is used to 74 create a session. However, new functions and new calls to existing functions 75 may be added to the graph, with the new functions themselves becoming 76 immediately frozen. 77 78 Example, but also see the [How To on functions](link_needed). 79 80 ```python 81 # Defining the function. 82 @tf.Defun(tf.float32, tf.float32) 83 def MyFunc(x, y): 84 return x + y, x - y 85 86 # Building the graph. 87 a = tf.constant([1.0]) 88 b = tf.constant([2.0]) 89 c, d = MyFunc(a, b, name='mycall') 90 ``` 91 """ 92 93 def __init__(self, *input_types, **kwargs): 94 """Create a `Defun` decorator. 95 96 Args: 97 *input_types: A list of `tf.DType` 98 **kwargs: Optional keyword arguments, including 99 func_name - (optional). A python string, the name to use to 100 declare this `Function` in the graph. 101 102 grad_func - (optional). A function implementing the gradient 103 of the function-to-register. This is must be a 104 `_DefinedFunction` object. The gradient 105 function must satisfy the criterion defined in 106 function.proto:GradientDef. 107 108 python_grad_func - (optional). A function implementing the 109 gradient of the function python-side. This function must 110 take the current op and the gradients w.r.t. its outputs, 111 and return the gradients w.r.t. the inputs. That is it must 112 implement the interface expected by `tf.RegisterGradient`). 113 This will be called by tf.gradients to add the gradient ops 114 to the graph. At most one of grad_func and python_grad_func 115 can be specified. 116 117 out_names = (optional). A list of strings, one per output 118 tensor. 119 120 shape_func - (optional). A function taking the op and returning a list 121 of static shapes to set for the function's outputs. 122 """ 123 self._input_types = input_types 124 self._func_name = kwargs.pop("func_name", None) 125 self._grad_func = kwargs.pop("grad_func", None) 126 self._python_grad_func = kwargs.pop("python_grad_func", None) 127 self._out_names = kwargs.pop("out_names", None) 128 self._extra_kwargs = kwargs 129 130 def __call__(self, func): 131 # Various sanity checks on the callable func. 132 if not callable(func): 133 raise ValueError(f"Function {func} must be a callable.") 134 135 # Func should not use kwargs and defaults. 136 argspec = tf_inspect.getargspec(func) 137 if argspec.keywords or argspec.defaults: 138 raise ValueError( 139 "Functions with argument defaults or keywords arguments are not " 140 f"supported. {func} has defaults {argspec.defaults} and keywords " 141 f"{argspec.keywords}.") 142 143 # Computes how many arguments 'func' has. 144 min_args = len(argspec.args) 145 max_args = min_args 146 if argspec.varargs: 147 max_args = 1000000 148 argnames = argspec.args 149 if tf_inspect.ismethod(func): 150 # 1st argument is the "class" type. 151 min_args -= 1 152 argnames = argnames[1:] 153 154 if self._input_types: 155 # If Defun is given a list of types for the inputs, the number 156 # of input types should be compatible with 'func'. 157 num = len(self._input_types) 158 if num < min_args or num > max_args: 159 raise ValueError( 160 "The number of tf.function input types is not compatible with the " 161 f"allowed arguments of {func}. The tf.function have {num} input " 162 f"types, while the python function allows minimum {min_args} and " 163 f"maximum {max_args} arguments.") 164 return _DefinedFunction( 165 func, 166 argnames, 167 self._input_types, 168 self._func_name, 169 self._grad_func, 170 self._python_grad_func, 171 out_names=self._out_names, 172 **self._extra_kwargs) 173 174 # 'func' expects no arguments and input types is an empty list. 175 if min_args == 0 and max_args == 0: 176 return _DefinedFunction( 177 func, [], [], 178 self._func_name, 179 self._grad_func, 180 self._python_grad_func, 181 out_names=self._out_names, 182 **self._extra_kwargs) 183 184 # Input types are unknown. It's an overloaded function and hence 185 # its definition needs to be deferred until it's called. 186 return _OverloadedFunction( 187 func, 188 argnames, 189 self._func_name, 190 self._grad_func, 191 self._python_grad_func, 192 out_names=self._out_names, 193 **self._extra_kwargs) 194 195 196class _DefinedFunctionDeleter(object): 197 """Unregister function from eager context.""" 198 199 __slots__ = ["name"] 200 201 def __init__(self, name): 202 self.name = name 203 204 def __del__(self): 205 try: 206 context.remove_function(self.name) 207 except TypeError: 208 # Suppress some exceptions, mainly for the case when we're running on 209 # module deletion. Things that can go wrong include the context module 210 # already being unloaded, self._handle._handle_data no longer being 211 # valid, and so on. Printing warnings in these cases is silly 212 # (exceptions raised from __del__ are printed as warnings to stderr). 213 pass # 'NoneType' object is not callable when the handle has been 214 # partially unloaded. 215 except AttributeError: 216 pass # 'NoneType' object has no attribute 'eager_mode' when context has 217 # been unloaded. Will catch other module unloads as well. 218 219 220class _DefinedFunction(object): 221 """_DefinedFunction encapsulates a function definition and its properties. 222 223 Attributes: 224 name: The function name. 225 definition: The definition of this function. A FunctionDef proto. 226 grad_func_name: If not None, the name of this function's gradient function. 227 python_grad_func: A python callable implementing the gradient of 228 the function python-side. 229 """ 230 231 def __init__(self, 232 func, 233 argnames, 234 input_types, 235 func_name=None, 236 grad_func=None, 237 python_grad_func=None, 238 out_names=None, 239 shape_func=None, 240 capture_by_value=False, 241 allowlisted_stateful_ops=None, 242 capture_resource_var_by_value=True, 243 **kwargs): 244 """Creates _DefinedFunction. 245 246 Args: 247 func: A python callable which constructs a tf function body. 248 argnames: A list of strings for function argument names. 249 input_types: The function's argument types. Can be a tuple, list of 250 tf data types. 251 func_name: The function name. Defaults to None, in which derives from 252 'func'. 253 grad_func: This function's gradient function, if not None. Defaults 254 to None. 255 python_grad_func: A python callable implementing the gradient of 256 the function python-side. 257 out_names: An optional list of strings for the function return value 258 names. 259 shape_func: An optional function mapping an op to a list of static 260 output shapes. 261 capture_by_value: Boolean (defaults to False). If True, captured values 262 will be copied into the function body. 263 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 264 copy into the function body, when `capture_by_value` is True. 265 capture_resource_var_by_value: Boolean (defaults to True). If False, 266 captured resource variable returns the handle instead of value. 267 **kwargs: The keyword arguments. **kwargs is passed to every call 268 site of this function. 269 270 Raises: 271 ValueError: The function definition is invalid. 272 273 """ 274 self._func = func 275 self._input_types = input_types 276 self._func_name = func_name 277 self._grad_func = grad_func 278 self._python_grad_func = python_grad_func 279 self._out_names = out_names 280 self._shape_func = shape_func 281 self._capture_by_value = capture_by_value 282 self._allowlisted_stateful_ops = allowlisted_stateful_ops 283 if self._allowlisted_stateful_ops is None: 284 self._allowlisted_stateful_ops = set() 285 self._capture_resource_var_by_value = capture_resource_var_by_value 286 self._extra_kwargs = kwargs 287 # Constructed only when C API is disabled, lazily 288 self._definition = None 289 # Constructed only when C API is enabled, lazily 290 self._c_func = None 291 self._function_deleter = None 292 self._sub_functions = {} # Constructed with _definition or _c_func 293 # pylint: disable=protected-access 294 device_funcs = ops.get_default_graph()._device_functions_outer_to_inner 295 # pylint: enable=protected-access 296 297 # Get the innermost device if possible. 298 self._caller_device = device_funcs[-1] if device_funcs else None 299 300 # Cached OpDef for this function. When C API is enabled, this is 301 # the only part of FunctionDef that we cache in Python. When C API 302 # is disabled the whole _definition is available and this is simply 303 # another reference to _definition.signature 304 self._op_def = None 305 306 assert isinstance(input_types, (list, tuple)) 307 self._arg_types = input_types 308 self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) 309 for i in range(len(input_types))] 310 311 @property 312 def name(self): 313 """Function name.""" 314 self._create_definition_if_needed() 315 return self._func_name 316 317 @property 318 def definition(self): 319 """Function definition proto.""" 320 self._create_definition_if_needed() 321 if self._c_func: 322 with c_api_util.tf_buffer() as buf: 323 c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) 324 fdef = function_pb2.FunctionDef() 325 proto_data = c_api.TF_GetBuffer(buf) 326 fdef.ParseFromString(compat.as_bytes(proto_data)) 327 with ops.init_scope(): 328 if context.executing_eagerly(): 329 context.add_function(self._c_func.func) 330 self._function_deleter = _DefinedFunctionDeleter( 331 fdef.signature.name) 332 return fdef 333 return self._definition 334 335 @property 336 def _signature(self): 337 self._create_definition_if_needed() 338 return self._op_def 339 340 def set_grad_func(self, grad_func): 341 """Specifies the gradient function of this function.""" 342 assert not self._grad_func 343 assert isinstance(grad_func, _DefinedFunction) 344 self._grad_func = grad_func 345 346 @property 347 def grad_func_name(self): 348 """Returns the name of the gradient function.""" 349 return self._grad_func.name if self._grad_func else None 350 351 @property 352 def python_grad_func(self): 353 """Python gradient function callable.""" 354 return self._python_grad_func 355 356 @property 357 def declared_input_types(self): 358 """Returns the list of data types of explicit declared inputs.""" 359 return self._input_types 360 361 @property 362 def captured_inputs(self): 363 """Returns the list of implicitly captured inputs.""" 364 self._create_definition_if_needed() 365 return self._extra_inputs 366 367 @property 368 def stateful_ops(self): 369 """Returns the list of stateful ops in function definition. 370 371 Returns: 372 A list of (op.name, op.type) pairs. 373 """ 374 self._create_definition_if_needed() 375 return self._stateful_ops 376 377 def _create_definition_if_needed(self): 378 """Creates the function definition if it's not created yet.""" 379 with context.graph_mode(): 380 self._create_definition_if_needed_impl() 381 382 def _create_definition_if_needed_impl(self): 383 """This is not what you want, see _create_definition_if_needed.""" 384 if self._definition is not None or self._c_func is not None: 385 return 386 387 # Copy variable collections (by reference) from the parent graph such that 388 # name based variable sharing (e.g. via tf.make_template) works between the 389 # func graph and parent graph. 390 variable_keys = [] 391 variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access 392 variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access 393 394 parent_graph = ops.get_default_graph() 395 collections_ref = { 396 key: parent_graph.get_collection_ref(key) for key in variable_keys} 397 398 temp_graph = func_graph_from_py_func( 399 self._func, 400 self._arg_names, 401 self._arg_types, 402 self._func_name, 403 self._capture_by_value, 404 self._caller_device, 405 collections_ref=collections_ref, 406 allowlisted_stateful_ops=self._allowlisted_stateful_ops, 407 capture_resource_var_by_value=self._capture_resource_var_by_value) 408 409 self._extra_inputs = temp_graph.extra_inputs 410 # pylint: disable=protected-access 411 self._sub_functions = temp_graph._functions 412 # pylint: enable=protected-access 413 414 # Extra kwargs are treated as attrs on the function def. 415 if self._func_name: 416 base_func_name = self._func_name 417 else: 418 base_func_name = function_utils.get_func_name(self._func) 419 if self._grad_func: 420 base_func_name += ("_%s" % self._grad_func.name) 421 kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) 422 423 if not temp_graph._c_graph: # pylint: disable=protected-access 424 # Build the FunctionDef 425 self._definition = graph_to_function_def.graph_to_function_def( 426 temp_graph, 427 temp_graph.get_operations(), 428 temp_graph.inputs, 429 temp_graph.outputs, 430 out_names=self._out_names) 431 432 for k in kwargs_attr: 433 self._definition.attr[k].CopyFrom(kwargs_attr[k]) 434 435 # Hash the definition and its dependencies. 436 self._hash_str = self._create_hash_str( 437 self._definition.signature.input_arg, 438 self._definition.signature.output_arg, self._definition.node_def) 439 440 # Finally, we decide the function name to use. If not specified, 441 # make up something which is almost certainly unique (but deterministic). 442 if not self._func_name: 443 self._func_name = "_".join([base_func_name, self._hash_str]) 444 self._definition.signature.name = self._func_name 445 if self._func.__doc__: 446 self._definition.signature.description = self._func.__doc__ 447 448 self._op_def = self._definition.signature 449 else: # C API is enabled 450 output_names = ([compat.as_bytes(x) for x in self._out_names] 451 if self._out_names else []) 452 description = self._func.__doc__ or None 453 # pylint: disable=protected-access 454 c_func = c_api.TF_GraphToFunction_wrapper( 455 temp_graph._c_graph, 456 base_func_name, 457 self._func_name is None, # append_hash_to_fn_name 458 None, # opers 459 [t._as_tf_output() for t in temp_graph.inputs], 460 [t._as_tf_output() for t in temp_graph.outputs], 461 output_names, 462 [], # control_outputs 463 [], # control_output_names 464 None, # opts 465 description) 466 self._c_func = c_api_util.ScopedTFFunction(c_func) 467 # pylint: enable=protected-access 468 self._set_c_attrs(kwargs_attr) 469 470 # Set cached fields: _op_def and _func_name (if not already set) 471 self._op_def = self.definition.signature 472 if self._func_name: 473 assert self._func_name == self._op_def.name 474 else: 475 self._func_name = compat.as_str(self._op_def.name) 476 477 self._stateful_ops = [(op.name, op.type) 478 for op in temp_graph.get_operations() 479 if op._is_stateful] # pylint: disable=protected-access 480 481 def _set_c_attrs(self, attrs): 482 """Sets `attrs` as attributes of self._c_func. 483 484 Requires that self._c_func is not None. 485 486 Args: 487 attrs: a dictionary from attribute name to attribute proto value 488 """ 489 for name, attr_value in attrs.items(): 490 serialized = attr_value.SerializeToString() 491 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 492 # It might be worth creating a convenient way to re-use the same status. 493 c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name), 494 serialized) 495 496 def _create_hash_str(self, input_arg, output_arg, node_def): 497 """Creates an 8-character string unique to this input. 498 499 Args: 500 input_arg: the input_arg field of an OpDef 501 (e.g. self._definition.signature.input_arg) 502 output_arg: the output_arg field of an OpDef 503 (e.g. self._definition.signature.output_arg) 504 node_def: the node_def field of a FunctionDef 505 (e.g. self._definition.node_def) 506 507 Returns: 508 The unique string for this input 509 """ 510 hasher = hashlib.sha1() 511 512 def update_num(n): 513 hasher.update(compat.as_bytes("%x" % n)) 514 515 def update_str(s): 516 update_num(len(s)) 517 hasher.update(compat.as_bytes(s)) 518 519 def update_strs(slist): 520 update_num(len(slist)) 521 for s in slist: 522 update_str(s) 523 524 for adef in input_arg: 525 update_str(adef.SerializeToString()) 526 527 for adef in output_arg: 528 update_str(adef.SerializeToString()) 529 530 for n in sorted(node_def, key=lambda n: n.name): 531 update_str(n.name) 532 update_str(n.op) 533 update_strs(n.input) 534 update_num(len(n.attr)) 535 # NOTE: protobuf map serialization does not guarantee ordering. 536 for k in sorted(n.attr): 537 update_str(k) 538 update_str(n.attr[k].SerializeToString()) 539 540 return hasher.hexdigest()[:8] 541 542 def add_to_graph(self, g): 543 """Adds this function into the graph g.""" 544 self._create_definition_if_needed() 545 546 # Adds this function into 'g'. 547 # pylint: disable=protected-access 548 if context.executing_eagerly(): 549 context.context().add_function_def(self.definition) 550 else: 551 g._add_function(self) 552 # pylint: enable=protected-access 553 554 # Ensures related sub-routines are defined in 'g', too. 555 for f in self._sub_functions.values(): 556 f.add_to_graph(g) 557 558 # Adds its gradient function, too. 559 if self._grad_func: 560 self._grad_func.add_to_graph(g) 561 562 def __call__(self, *args, **kwargs): 563 self.add_to_graph(ops.get_default_graph()) 564 args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs 565 ret, op = _call(self._signature, *args, **kwargs) 566 567 # Set a hidden attr in 'op' so that gradients_impl can refer back 568 # to this _DefinedFunction instance to access python_grad_func. 569 assert isinstance(op, ops.Operation) 570 setattr(op, "__defun", self) 571 572 if self._shape_func is not None: 573 shapes = self._shape_func(op) 574 if len(shapes) != len(op.outputs): 575 raise ValueError(f"shape_func {self._shape_func} produced " 576 f"{len(shapes):d} shapes, which does not match " 577 f"{len(op.outputs)} outputs.") 578 for (t, shape) in zip(op.outputs, shapes): 579 t.set_shape(shape) 580 return ret 581 582 583class _OverloadedFunction(object): 584 """_OverloadedFunction encapsulates an overloaded function. 585 586 _OverloadedFunction maintains a mapping from input types to 587 instantiated _DefinedFunction in self._overload. 588 589 """ 590 591 def __init__(self, 592 func, 593 argnames, 594 func_name=None, 595 grad_func=None, 596 python_grad_func=None, 597 out_names=None, 598 **kwargs): 599 """Creates _DefinedFunction. 600 601 Args: 602 func: A python callable which constructs a tf function body. 603 argnames: A list of strings for function argument names. 604 func_name: The function name. Defaults to None, in which derives from 605 'func'. 606 grad_func: This function's gradient function, if not None. Defaults 607 to None. 608 python_grad_func: A python callable implementing the gradient of 609 the function python-side. 610 out_names: A list of strings for the function return value names. 611 **kwargs: The keyword arguments. **kwargs is passed to every call 612 site of this function. 613 614 Raises: 615 ValueError: The function definition is invalid. 616 617 """ 618 self._func = func 619 self._argnames = argnames 620 self._func_name = func_name 621 assert grad_func is None or isinstance(grad_func, _OverloadedFunction) 622 self._grad_func = grad_func 623 self._python_grad_func = python_grad_func 624 self._out_names = out_names 625 self._extra_kwargs = kwargs 626 self._overload = {} 627 628 def instantiate(self, input_types): 629 """Instantiate this function given input argument types. 630 631 Args: 632 input_types: A list of data types for the inputs. 633 634 Returns: 635 _DefinedFunction for the given input types. 636 637 """ 638 # Stringify the type list. 639 key = _type_list_to_str(input_types) 640 defined = self._overload.get(key) 641 if not defined: 642 # If not defined yet, define the function given the input types. 643 name = self._func_name 644 if name is not None: 645 name = "_".join([name, key]) 646 defined = _DefinedFunction( 647 self._func, 648 self._argnames, 649 input_types, 650 name, 651 None, 652 self._python_grad_func, 653 out_names=self._out_names, 654 **self._extra_kwargs) 655 _ = defined.name # Fully instantiate the function definition. 656 if self._grad_func: 657 # If _grad_func is given, it is another 658 # _OverloadedFunction. We need to instantiate it with the 659 # right input types. 660 output_types = [ 661 dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access 662 ] 663 # pylint: disable=protected-access 664 defined._grad_func = self._grad_func.instantiate(input_types + 665 output_types) 666 # pylint: enable=protected-access 667 self._overload[key] = defined 668 return defined 669 670 def __call__(self, *args, **kwargs): 671 input_types = [] 672 args = list(args) 673 for (i, x) in enumerate(args): 674 x = ops.convert_to_tensor(x) 675 if not isinstance(x, ops.Tensor): 676 raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.") 677 input_types.append(x.dtype) 678 args[i] = x 679 return self.instantiate(input_types)(*args, **kwargs) 680 681 682class _FuncGraph(ops.Graph): 683 """A helper for constructing a function. 684 685 _FuncGraph overrides ops.Graph's create_op() so that we can keep 686 track of all inputs into every op created inside the function. If 687 any input is from other graphs, we keep track of it in self.capture 688 and substitute the input with a place holder. 689 690 Each captured input's corresponding place holder is converted into a 691 function argument and the caller passes in the captured tensor. 692 """ 693 694 def __init__(self, name, capture_by_value, allowlisted_stateful_ops, 695 capture_resource_var_by_value, *args, **kwargs): 696 super(_FuncGraph, self).__init__(*args, **kwargs) 697 self._capture_by_value = capture_by_value 698 self._allowlisted_stateful_ops = allowlisted_stateful_ops 699 self._capture_resource_var_by_value = capture_resource_var_by_value 700 self._building_function = True 701 self._outer_graph = ops.get_default_graph() 702 self._vscope = vs.get_variable_scope() 703 self._old_custom_getter = self._vscope.custom_getter 704 705 # The name of the function. 706 self.name = name 707 # Placeholder tensors representing the inputs to this function. The tensors 708 # are in this _FuncGraph. 709 self.inputs = [] 710 # Tensors that will be returned this function. The tensors are in this 711 # _FuncGraph. 712 self.outputs = [] 713 # Maps external tensor -> internal tensor (e.g. input placeholder). 714 self._captured = {} 715 # The external tensors that have been captured as inputs and must be passed 716 # to this function (empty if capturing by value, otherwise these are the 717 # keys of _captured). 718 self.extra_inputs = [] 719 # Input placeholders that been added for captured values (empty if capturing 720 # by value). 721 self.extra_args = [] 722 # Captured variables. 723 # TODO(skyewm): is this needed? 724 self.extra_vars = [] 725 726 # pylint: disable=g-doc-return-or-yield 727 728 @property 729 def outer_graph(self): 730 """The graph active when this _FuncGraph was created.""" 731 return self._outer_graph 732 733 @tf_contextlib.contextmanager 734 def container(self, container_name): 735 """Returns a context manager that specifies the resource container to use. 736 737 Overridden from `tf.Graph` to update both the init_scope container 738 and the present inner container. This is necessary to make sure setting 739 containers applies correctly both to created variables and to stateful 740 ops. 741 742 Args: 743 container_name: container name string. 744 745 Returns: 746 A context manager for defining resource containers for stateful ops, 747 yields the container name. 748 """ 749 original_container = self._container 750 # pylint: disable=protected-access 751 with ops.init_scope(): 752 original_init_container = ops.get_default_graph()._container 753 try: 754 self._container = container_name 755 with ops.init_scope(): 756 ops.get_default_graph()._container = container_name 757 yield self._container 758 finally: 759 self._container = original_container 760 with ops.init_scope(): 761 ops.get_default_graph()._container = original_init_container 762 # pylint: enable=protected-access 763 764 # pylint: enable=g-doc-return-or-yield 765 766 def getvar( 767 self, 768 getter, 769 name, 770 shape=None, 771 dtype=None, 772 initializer=None, 773 reuse=None, 774 trainable=True, 775 collections=None, # pylint: disable=redefined-outer-name 776 use_resource=None, 777 **kwargs): 778 """A custom variable getter.""" 779 # Here, we switch the default graph to the outer graph and ask the 780 # variable scope in which the function is defined to give us the 781 # variable. The variable is stashed in extra_vars and returned to 782 # the caller. 783 # 784 # We capture these variables so that the variable definition is 785 # hoisted upward to the outer most graph. 786 with self._outer_graph.as_default(): 787 # pylint: disable=protected-access 788 var = self._vscope.get_variable( 789 vs._get_default_variable_store(), 790 name, 791 shape=shape, 792 dtype=dtype, 793 initializer=initializer, 794 reuse=reuse, 795 trainable=trainable, 796 collections=collections, 797 use_resource=use_resource) 798 self.extra_vars.append(var) 799 if (isinstance(var, resource_variable_ops.BaseResourceVariable) and 800 self._capture_resource_var_by_value): 801 # For resource-based variables read the variable outside the function 802 # and pass in the value. This ensures that the function is pure and 803 # differentiable. TODO(apassos) this may have performance problems if 804 # the function will only do embedding lookups on the variable. 805 return var.value() 806 return var 807 808 def _create_op_internal( 809 self, 810 op_type, 811 inputs, 812 dtypes=None, # pylint: disable=redefined-outer-name 813 input_types=None, 814 name=None, 815 attrs=None, 816 op_def=None, 817 compute_device=True): 818 for i, x in enumerate(inputs): 819 if isinstance(x, ops.EagerTensor) or x.graph is not self: 820 inputs[i] = self.capture(x) 821 return super(_FuncGraph, self)._create_op_internal( 822 op_type, 823 inputs, 824 dtypes=dtypes, 825 input_types=input_types, 826 name=name, 827 attrs=attrs, 828 op_def=op_def, 829 compute_device=compute_device) 830 831 def capture(self, tensor, name=None): 832 """Adds the given tensor to this graph and returns the captured tensor.""" 833 if tensor.ref() in self._captured: 834 # Captured already. 835 return self._captured[tensor.ref()] 836 elif self._capture_by_value: 837 return self._add_tensor_and_parents(tensor) 838 else: 839 return self._capture_tensor_as_extra_input(tensor, name) 840 841 @property 842 def captures(self): 843 """Pairs of tensors and captured tensor.""" 844 return [(k.deref(), v) for k, v in self._captured.items()] 845 846 def _capture_tensor_as_extra_input(self, tensor, name=None): 847 # Substitute with a placeholder. 848 self.extra_inputs.append(tensor) 849 # Hoist the new input placeholder out of any control flow context 850 # we're currently in. 851 with ops.control_dependencies(None): 852 ph = array_ops.placeholder( 853 tensor.dtype, shape=tensor.get_shape(), name=name) 854 # pylint: disable=protected-access 855 if isinstance(tensor, ops.EagerTensor): 856 handle_data = tensor._handle_data 857 if handle_data: 858 handle_data = handle_data.SerializeToString() 859 else: 860 handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph, 861 tensor._as_tf_output()) 862 863 if handle_data: 864 c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), 865 compat.as_bytes(handle_data)) 866 # pylint: enable=protected-access 867 self.inputs.append(ph) 868 self._captured[tensor.ref()] = ph 869 self.extra_args.append(ph) 870 if _is_guaranteed_const(tensor): 871 with ops.control_dependencies(None): 872 return array_ops.guarantee_const(ph) 873 else: 874 return ph 875 876 def _add_tensor_and_parents(self, tensor): 877 op = self._add_op_and_parents(tensor.op) 878 return op.outputs[tensor.value_index] 879 880 def _add_op_and_parents(self, op): 881 # pylint: disable=protected-access 882 op_def = graph_to_function_def._get_op_def(op) 883 if op._is_stateful and op not in self._allowlisted_stateful_ops: 884 raise ValueError(f"Cannot capture a stateful node (name:{op.name}, " 885 f"type:{op.type}) by value.") 886 elif op.type in ("Placeholder", "PlaceholderV2"): 887 raise ValueError(f"Cannot capture a placeholder (name:{op.name}, " 888 f"type:{op.type}) by value.") 889 # pylint: enable=protected-access 890 891 captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs] 892 893 captured_op = self._create_op_internal( 894 op.type, 895 captured_inputs, [o.dtype for o in op.outputs], 896 name=op.name, 897 attrs=op.node_def.attr, 898 op_def=op_def) 899 900 for t, captured_t in zip(op.outputs, captured_op.outputs): 901 self._captured[t.ref()] = captured_t 902 903 return captured_op 904 905 906def func_graph_from_py_func(func, 907 arg_names, 908 arg_types, 909 name=None, 910 capture_by_value=False, 911 device=None, 912 colocation_stack=None, 913 container=None, 914 collections_ref=None, 915 arg_shapes=None, 916 allowlisted_stateful_ops=None, 917 capture_resource_var_by_value=True): 918 """Returns a _FuncGraph generated from `func`. 919 920 Args: 921 func: A Python callable which constructs a TF function body. The arguments 922 must correspond to `arg_types`. Returns a value or list/tuple of values. 923 No returned value can be None. 924 arg_names: A sequence of strings for the function argument names. 925 arg_types: A sequence of the function's argument types. 926 name: The function name. If None, the name is derived from `func`. 927 capture_by_value: boolean. If True, captured values will be copied into the 928 function body. 929 device: device name or function. 930 colocation_stack: A colocation stack (list) the _FuncGraph should use. 931 container: A container name the _FuncGraph should start with. 932 collections_ref: A reference to a collections dict the _FuncGraph should 933 use internally. 934 arg_shapes: A sequence of the function's argument shapes. 935 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 936 re-create. 937 capture_resource_var_by_value: Boolean (defaults to True). If False, 938 captured resource variable returns the handle instead of value. 939 940 Returns: 941 A _FuncGraph. 942 943 Raises: 944 ValueError: if func returns None. 945 """ 946 if not name: 947 name = function_utils.get_func_name(func) 948 func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops, 949 capture_resource_var_by_value) 950 951 with func_graph.as_default(), ops.device(device): 952 # pylint: disable=protected-access 953 if collections_ref is not None: 954 func_graph._collections = collections_ref 955 if container is not None: 956 func_graph._container = container 957 if colocation_stack is not None: 958 func_graph._colocation_stack = colocation_stack 959 # pylint: enable=protected-access 960 961 if arg_shapes is None: 962 arg_shapes = [None] * len(arg_types) 963 964 # Create placeholders for the function arguments. 965 for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes): 966 argholder = array_ops.placeholder(argtype, shape=argshape, name=argname) 967 func_graph.inputs.append(argholder) 968 # Call func and gather the output tensors. 969 with vs.variable_scope("", custom_getter=func_graph.getvar): 970 outputs = func(*func_graph.inputs) 971 972 # There is no way of distinguishing between a function not returning 973 # anything and a function returning None in Python. 974 # We need to allow the former and ideally want to forbid the latter as 975 # it is most likely user error. 976 # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to 977 # allow users to explicitly mark the function as not returning anything. 978 # For now, we allow a single None return and interpret it as a function 979 # with no output. 980 if outputs is None: 981 outputs = [] 982 else: 983 # If func only returned one value, make it a tuple. 984 if not isinstance(outputs, (list, tuple)): 985 outputs = (outputs,) 986 if any(_ is None for _ in outputs): 987 raise ValueError(f"Function {name} can not return None.") 988 # Ensures each output is a Tensor in the function graph. 989 outputs = [ops.convert_to_tensor(t) for t in outputs] 990 outputs = [func_graph.capture(t) if t.graph is not func_graph else t 991 for t in outputs] 992 func_graph.outputs = outputs 993 return func_graph 994 995 996def _is_guaranteed_const(tensor): 997 """Determines whether `tensor` is guaranteed to be a constant. 998 999 A tensor is guaranteed to be a constant if either it was produced by 1000 a `GuaranteeConst` op or if all of its children are guaranteed to be 1001 constants. 1002 1003 Args: 1004 tensor: The tensor for which to determine const-ness. 1005 1006 Returns: 1007 True if `tensor` is guaranteed to be a constant, False otherwise. 1008 """ 1009 1010 if isinstance(tensor, ops.EagerTensor): 1011 return False 1012 1013 class Work(object): 1014 1015 def __init__(self, op, leaving): 1016 self.op = op 1017 self.leaving = leaving 1018 1019 is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst" 1020 constants = set([]) 1021 def all_inputs_const(op): 1022 # If all inputs of an op are guaranteed constants, then we can infer that 1023 # the op produces a constant as well. 1024 return op.inputs and all(inp.op in constants for inp in op.inputs) 1025 1026 visited = set([]) 1027 stack = [Work(tensor.op, leaving=False)] 1028 while stack: 1029 work = stack.pop() 1030 if work.leaving: 1031 if all_inputs_const(work.op): 1032 constants.add(work.op) 1033 continue 1034 visited.add(work.op) 1035 if is_guaranteed_const(work.op): 1036 constants.add(work.op) 1037 continue 1038 1039 # This op will be revisited after all its inputs are checked for const-ness. 1040 stack.append(Work(work.op, leaving=True)) 1041 for inp in work.op.inputs: 1042 if inp.op not in visited: 1043 stack.append(Work(inp.op, leaving=False)) 1044 return tensor.op in constants 1045 1046 1047def _call(sig, *inputs, **kwargs): 1048 """Adds a node calling a function. 1049 1050 This adds a `call` op to the default graph that calls the function 1051 of signature `sig`, passing the tensors in `inputs` as arguments. 1052 It returns the outputs of the call, which are one or more tensors. 1053 1054 `sig` is OpDefArg.a `_DefinedFunction` object. 1055 1056 You can pass an optional keyword parameter `name=string` to name the 1057 added operation. 1058 1059 You can pass an optional keyword parameter `noinline=True|False` to 1060 instruct the runtime not to inline the function body into the call 1061 site. 1062 1063 Args: 1064 sig: OpDefArg. The signature of the function. 1065 *inputs: arguments to the function. 1066 **kwargs: Optional keyword arguments. Can only contain 'name' or 1067 'noinline'. 1068 1069 Returns: 1070 A 2-element tuple. First element: a Tensor if the function returns a single 1071 value; a list of Tensors if the function returns multiple value; the 1072 Operation if the function returns no values. Second element: the Operation. 1073 1074 Raises: 1075 ValueError: if the arguments are invalid. 1076 """ 1077 if len(inputs) != len(sig.input_arg): 1078 raise ValueError(f"Expected {len(sig.input_arg):d} arguments, got " 1079 f"{len(inputs):d}.") 1080 name = kwargs.pop("name", None) 1081 g = ops.get_default_graph() 1082 func_name = sig.name 1083 if name is None: 1084 name = func_name 1085 attrs = _parse_kwargs_as_attrs(func_name, **kwargs) 1086 output_types = [dtypes.DType(x.type) for x in sig.output_arg] 1087 op = g._create_op_internal( # pylint: disable=protected-access 1088 func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig) 1089 if op.outputs: 1090 if len(op.outputs) == 1: 1091 ret = op.outputs[0] 1092 else: 1093 ret = tuple(op.outputs) 1094 else: 1095 ret = op 1096 return ret, op 1097 1098 1099def _from_definition(fdef, grad_func=None): 1100 """Creates a _DefinedFunction initialized from a FunctionDef proto. 1101 1102 Args: 1103 fdef: a FunctionDef 1104 grad_func: a _DefinedFunction or None 1105 1106 Returns: 1107 A _DefinedFunction representing fdef 1108 """ 1109 # TODO(iga): This method does major surgery on _DefinedFunction. 1110 # Make it a named constructor using @classmethod of _DefinedFunction. 1111 1112 # The Python callable is only needed to create a FunctionDef. Since we have 1113 # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we 1114 # have access to such a callable here). 1115 func = None 1116 argnames = [arg.name for arg in fdef.signature.input_arg] 1117 input_types = tuple( 1118 dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) 1119 func_name = fdef.signature.name 1120 # Note: FunctionDefs do not include python gradient functions, so if the 1121 # original _DefinedFunction included one it will not be reflected here. 1122 python_grad_func = None 1123 out_names = [arg.name for arg in fdef.signature.output_arg] 1124 result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, 1125 python_grad_func, out_names) 1126 # pylint: disable=protected-access 1127 serialized = fdef.SerializeToString() 1128 c_func = c_api.TF_FunctionImportFunctionDef(serialized) 1129 result._c_func = c_api_util.ScopedTFFunction(c_func) 1130 result._extra_inputs = [] 1131 result._op_def = fdef.signature 1132 # pylint: enable=protected-access 1133 1134 return result 1135 1136 1137def from_library(lib): 1138 """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. 1139 1140 This method handles assigning the correct gradient functions to each 1141 function. 1142 1143 Args: 1144 lib: a FunctionDefLibrary 1145 1146 Returns: 1147 A list of _DefinedFunctions 1148 1149 Raises: 1150 ValueError: `lib` is invalid 1151 """ 1152 if not lib.function and not lib.gradient: 1153 return [] 1154 1155 # function name -> FunctionDef proto 1156 funcs = {fdef.signature.name: fdef for fdef in lib.function} 1157 1158 # Validate that all references function names have function defs 1159 for g in lib.gradient: 1160 if g.function_name not in funcs: 1161 raise ValueError(f"FunctionDefLibrary missing '{g.function_name}' " 1162 f"FunctionDef\n{lib}") 1163 if g.gradient_func not in funcs: 1164 raise ValueError(f"FunctionDefLibrary missing '{g.gradient_func}' " 1165 f"FunctionDef\n{lib}") 1166 1167 # function name -> gradient function name 1168 func_to_grad = collections.defaultdict(lambda: None) 1169 # gradient function name -> names of functions having that grad function 1170 grad_to_funcs = collections.defaultdict(list) 1171 1172 for gdef in lib.gradient: 1173 func_to_grad[gdef.function_name] = gdef.gradient_func 1174 grad_to_funcs[gdef.gradient_func].append(gdef.function_name) 1175 1176 # Start with functions without gradients 1177 ready = [ 1178 fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None 1179 ] 1180 if not ready: 1181 raise ValueError( 1182 f"FunctionDefLibrary contains cyclic gradient functions!\n{lib}") 1183 # function name -> _DefinedFunction 1184 initialized = {} 1185 1186 while ready: 1187 fdef = ready.pop() 1188 name = fdef.signature.name 1189 1190 grad = initialized.get(func_to_grad[name]) 1191 if func_to_grad[name]: 1192 assert grad 1193 defined_func = _from_definition(fdef, grad_func=grad) 1194 initialized[name] = defined_func 1195 1196 ready.extend(funcs[f] for f in grad_to_funcs[name]) 1197 1198 return initialized.values() 1199 1200 1201def _get_experimental_kwarg_as_attr(attr_name, value): 1202 """Creates an AttrValue for a python object.""" 1203 if isinstance(value, bool): 1204 return attr_value_pb2.AttrValue(b=value) 1205 elif isinstance(value, int): 1206 return attr_value_pb2.AttrValue(i=value) 1207 elif isinstance(value, float): 1208 return attr_value_pb2.AttrValue(f=value) 1209 elif isinstance(value, str): 1210 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 1211 else: 1212 raise ValueError(f"Attribute {attr_name} must be bool, int, float, or " 1213 f"str. Got {type(value)}.") 1214 1215 1216def _get_kwarg_as_str_attr(attr_name, value): 1217 """Creates an AttrValue for a python object.""" 1218 if isinstance(value, str): 1219 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 1220 else: 1221 raise ValueError(f"Attribute {attr_name} must be str. Got {type(value)}.") 1222 1223 1224def _parse_kwargs_as_attrs(func_name, **kwargs): 1225 """Parses **kwargs into a node's attributes.""" 1226 attrs = {} 1227 1228 noinline = kwargs.pop("noinline", None) 1229 if noinline is not None: 1230 attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) 1231 1232 # For compatibility with previous behavior, Defun does not perform shape 1233 # inference through its function call operations. 1234 attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True) 1235 1236 compiled = kwargs.pop("compiled", None) 1237 separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None) 1238 if compiled is not None: 1239 attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled)) 1240 attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue( 1241 b=bool(separate_compiled_gradients)) 1242 # Forward _XlaScope from enclosing context (if set), otherwise create new. 1243 # pylint: disable=protected-access 1244 if "_XlaScope" in ops.get_default_graph()._attr_scope_map: 1245 attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"] 1246 else: 1247 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 1248 s=("function_%s" % func_name).encode()) 1249 # pylint: enable=protected-access 1250 1251 kwargs_keys = list(kwargs.keys()) 1252 for key in kwargs_keys: 1253 if key.startswith("experimental_"): 1254 attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key]) 1255 del kwargs[key] 1256 # Support for https://github.com/tensorflow/community/pull/113/files. 1257 elif key == "_implements" or key == "_reference": 1258 attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key]) 1259 del kwargs[key] 1260 if kwargs: 1261 raise ValueError(f"Unknown keyword arguments: {kwargs.keys()}.") 1262 return attrs 1263 1264 1265def get_extra_vars(): 1266 """Returns the captured variables by the function. 1267 1268 Returns: 1269 If the default graph is being used to define a function, the 1270 returned list of variables are those created inside the function 1271 body so far. Otherwise, returns an empty list. 1272 """ 1273 g = ops.get_default_graph() 1274 if isinstance(g, _FuncGraph): 1275 return g.extra_vars 1276 else: 1277 return [] 1278 1279 1280def get_extra_inputs(): 1281 """Returns the captured input tensors by the function. 1282 1283 Returns: 1284 If the default graph is being used to define a function, the 1285 returned list of tensors are those accessed inside the function body 1286 but defined outside the function body so far. Otherwise, returns an 1287 empty list. 1288 """ 1289 g = ops.get_default_graph() 1290 if isinstance(g, _FuncGraph): 1291 return g.extra_inputs 1292 else: 1293 return [] 1294 1295 1296def get_extra_args(): 1297 """Returns the corresponding function arguments for the captured inputs. 1298 1299 Returns: 1300 If the default graph is being used to define a function, the 1301 returned list of place holders are those used inside the function 1302 body corresponding those returned by get_extra_inputs(). Otherwise, 1303 returns an empty list. 1304 """ 1305 g = ops.get_default_graph() 1306 if isinstance(g, _FuncGraph): 1307 return g.extra_args 1308 else: 1309 return [] 1310 1311 1312def _type_list_to_str(types): 1313 if any(_ not in _DTYPE_TO_STR for _ in types): 1314 unsupported_types = [type_ for type_ in types if type_ not in _DTYPE_TO_STR] 1315 raise ValueError(f"Unsupported dtypes {unsupported_types} in " 1316 "`types`. Supported dtypes are " 1317 f"{_DTYPE_TO_STR.keys()}.") 1318 return "".join(_DTYPE_TO_STR[_] for _ in types) 1319 1320 1321# NOTE: The list needs to be extended when more data types are added. 1322_DTYPE_TO_STR = { 1323 dtypes.float16: "f16", 1324 dtypes.float32: "f32", 1325 dtypes.float64: "f64", 1326 dtypes.int32: "i32", 1327 dtypes.uint8: "i8", 1328 dtypes.uint16: "u16", 1329 dtypes.uint32: "u32", 1330 dtypes.uint64: "u64", 1331 dtypes.int16: "i16", 1332 dtypes.int8: "i8", 1333 dtypes.string: "s", 1334 dtypes.complex64: "c64", 1335 dtypes.complex128: "c128", 1336 dtypes.int64: "i64", 1337 dtypes.bool: "b", 1338 dtypes.qint8: "qi8", 1339 dtypes.quint8: "qu8", 1340 dtypes.qint16: "qi16", 1341 dtypes.quint16: "qu16", 1342 dtypes.qint32: "qi32", 1343 dtypes.bfloat16: "b16" 1344} 1345 1346 1347def function_def_from_tf_function(c_func): 1348 """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto.""" 1349 with c_api_util.tf_buffer() as buf: 1350 c_api.TF_FunctionToFunctionDef(c_func, buf) 1351 data = c_api.TF_GetBuffer(buf) 1352 fdef = function_pb2.FunctionDef() 1353 fdef.ParseFromString(compat.as_bytes(data)) 1354 return fdef 1355