1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Python front-end supports for functions. 16 17NOTE: At this time, functions are experimental and subject to change!. Proceed 18with caution. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import hashlib 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.core.framework import function_pb2 30from tensorflow.python import pywrap_tensorflow as c_api 31from tensorflow.python.eager import context 32from tensorflow.python.framework import c_api_util 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import graph_to_function_def 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import variable_scope as vs 39from tensorflow.python.util import compat 40from tensorflow.python.util import function_utils 41from tensorflow.python.util import tf_contextlib 42from tensorflow.python.util import tf_inspect 43 44 45class Defun(object): 46 """Decorator used to define TensorFlow functions. 47 48 Use this decorator to make a Python function usable directly as a TensorFlow 49 function. 50 51 The decorated function must add ops to the default graph and return zero or 52 more `Tensor` objects. Call the decorator with named arguments, one for each 53 argument of the function to decorate, with the expected type of the argument 54 as value. 55 56 For example if the function to decorate accepts two `tf.float32` arguments 57 named `x` and `y`, call the decorator with: 58 59 @Defun(tf.float32, tf.float32) 60 def foo(x, y): 61 ... 62 63 When you call the decorated function, it adds the `call` ops to the 64 default graph. In addition, it adds the definition of the function into the 65 default graph. Because the addition of the function into the graph 66 is deferred, the decorator can be used anywhere in the program. 67 68 Any variables created inside of the function are hoisted into the outer graph. 69 Note that the variables are created in the variable scope that was active 70 during the first call to the function. Subsequent function calls will refer to 71 the same set of variables. 72 73 Definitions of functions in a graph are frozen as soon as the graph is used to 74 create a session. However, new functions and new calls to existing functions 75 may be added to the graph, with the new functions themselves becoming 76 immediately frozen. 77 78 Example, but also see the [How To on functions](link_needed). 79 80 ```python 81 # Defining the function. 82 @tf.Defun(tf.float32, tf.float32) 83 def MyFunc(x, y): 84 return x + y, x - y 85 86 # Building the graph. 87 a = tf.constant([1.0]) 88 b = tf.constant([2.0]) 89 c, d = MyFunc(a, b, name='mycall') 90 ``` 91 """ 92 93 def __init__(self, *input_types, **kwargs): 94 """Create a `Defun` decorator. 95 96 Args: 97 *input_types: A list of `tf.DType` 98 **kwargs: Optional keyword arguments, including 99 func_name - (optional). A python string, the name to use to 100 declare this `Function` in the graph. 101 102 grad_func - (optional). A function implementing the gradient 103 of the function-to-register. This is must be a 104 `_DefinedFunction` object. The gradient 105 function must satisfy the criterion defined in 106 function.proto:GradientDef. 107 108 python_grad_func - (optional). A function implementing the 109 gradient of the function python-side. This function must 110 take the current op and the gradients w.r.t. its outputs, 111 and return the gradients w.r.t. the inputs. That is it must 112 implement the interface expected by `tf.RegisterGradient`). 113 This will be called by tf.gradients to add the gradient ops 114 to the graph. At most one of grad_func and python_grad_func 115 can be specified. 116 117 out_names = (optional). A list of strings, one per output 118 tensor. 119 120 shape_func - (optional). A function taking the op and returning a list 121 of static shapes to set for the function's outputs. 122 """ 123 self._input_types = input_types 124 self._func_name = kwargs.pop("func_name", None) 125 self._grad_func = kwargs.pop("grad_func", None) 126 self._python_grad_func = kwargs.pop("python_grad_func", None) 127 self._out_names = kwargs.pop("out_names", None) 128 self._extra_kwargs = kwargs 129 130 def __call__(self, func): 131 # Various sanity checks on the callable func. 132 if not callable(func): 133 raise ValueError("function %s must be callable" % func) 134 135 # Func should not use kwargs and defaults. 136 argspec = tf_inspect.getargspec(func) 137 if argspec.keywords or argspec.defaults: 138 raise ValueError( 139 "function with argument defaults or keywords arguments are not" 140 " supported. {} has defaults {} and keywords {}.".format( 141 func, argspec.defaults, argspec.keywords)) 142 143 # Computes how many arguments 'func' has. 144 min_args = len(argspec.args) 145 max_args = min_args 146 if argspec.varargs: 147 max_args = 1000000 148 argnames = argspec.args 149 if tf_inspect.ismethod(func): 150 # 1st argument is the "class" type. 151 min_args -= 1 152 argnames = argnames[1:] 153 154 if self._input_types: 155 # If Defun is given a list of types for the inputs, the number 156 # of input types should be compatible with 'func'. 157 num = len(self._input_types) 158 if num < min_args or num > max_args: 159 raise ValueError( 160 "The function has fewer arguments than the number of specified " 161 "input types.") 162 return _DefinedFunction( 163 func, 164 argnames, 165 self._input_types, 166 self._func_name, 167 self._grad_func, 168 self._python_grad_func, 169 out_names=self._out_names, 170 **self._extra_kwargs) 171 172 # 'func' expects no arguments and input types is an empty list. 173 if min_args == 0 and max_args == 0: 174 return _DefinedFunction( 175 func, [], [], 176 self._func_name, 177 self._grad_func, 178 self._python_grad_func, 179 out_names=self._out_names, 180 **self._extra_kwargs) 181 182 # Input types are unknown. It's an overloaded function and hence 183 # its definition needs to be deferred until it's called. 184 return _OverloadedFunction( 185 func, 186 argnames, 187 self._func_name, 188 self._grad_func, 189 self._python_grad_func, 190 out_names=self._out_names, 191 **self._extra_kwargs) 192 193 194class _DefinedFunction(object): 195 """_DefinedFunction encapsulates a function definition and its properties. 196 197 Attributes: 198 name: The function name. 199 definition: The definition of this function. A FunctionDef proto. 200 grad_func_name: If not None, the name of this function's gradient function. 201 python_grad_func: A python callable implementing the gradient of 202 the function python-side. 203 """ 204 205 def __init__(self, 206 func, 207 argnames, 208 input_types, 209 func_name=None, 210 grad_func=None, 211 python_grad_func=None, 212 out_names=None, 213 shape_func=None, 214 capture_by_value=False, 215 whitelisted_stateful_ops=None, 216 capture_resource_var_by_value=True, 217 **kwargs): 218 """Creates _DefinedFunction. 219 220 Args: 221 func: A python callable which constructs a tf function body. 222 argnames: A list of strings for function argument names. 223 input_types: The function's argument types. Can be a tuple, list of 224 tf data types. 225 func_name: The function name. Defaults to None, in which derives from 226 'func'. 227 grad_func: This function's gradient function, if not None. Defaults 228 to None. 229 python_grad_func: A python callable implementing the gradient of 230 the function python-side. 231 out_names: An optional list of strings for the function return value 232 names. 233 shape_func: An optional function mapping an op to a list of static 234 output shapes. 235 capture_by_value: Boolean (defaults to False). If True, captured values 236 will be copied into the function body. 237 whitelisted_stateful_ops: A set of ops that if stateful we ignore and 238 copy into the function body, when `capture_by_value` is True. 239 capture_resource_var_by_value: Boolean (defaults to True). If False, 240 captured resource variable returns the handle instead of value. 241 **kwargs: The keyword arguments. **kwargs is passed to every call 242 site of this function. 243 244 Raises: 245 ValueError: The function definition is invalid. 246 247 """ 248 self._func = func 249 self._input_types = input_types 250 self._func_name = func_name 251 self._grad_func = grad_func 252 self._python_grad_func = python_grad_func 253 self._out_names = out_names 254 self._shape_func = shape_func 255 self._capture_by_value = capture_by_value 256 self._whitelisted_stateful_ops = whitelisted_stateful_ops 257 if self._whitelisted_stateful_ops is None: 258 self._whitelisted_stateful_ops = set() 259 self._capture_resource_var_by_value = capture_resource_var_by_value 260 self._extra_kwargs = kwargs 261 # Constructed only when C API is disabled, lazily 262 self._definition = None 263 # Constructed only when C API is enabled, lazily 264 self._c_func = None 265 self._sub_functions = dict() # Constructed with _definition or _c_func 266 # pylint: disable=protected-access 267 device_funcs = ops.get_default_graph()._device_functions_outer_to_inner 268 # pylint: enable=protected-access 269 270 # Get the innermost device if possbile. 271 self._caller_device = device_funcs[-1] if device_funcs else None 272 273 # Cached OpDef for this function. When C API is enabled, this is 274 # the only part of FunctionDef that we cache in Python. When C API 275 # is disabled the whole _definition is available and this is simply 276 # another reference to _definition.signature 277 self._op_def = None 278 279 assert isinstance(input_types, (list, tuple)) 280 self._arg_types = input_types 281 self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) 282 for i in range(len(input_types))] 283 284 @property 285 def name(self): 286 """Function name.""" 287 self._create_definition_if_needed() 288 return self._func_name 289 290 @property 291 def definition(self): 292 """Function definition proto.""" 293 self._create_definition_if_needed() 294 if self._c_func: 295 with c_api_util.tf_buffer() as buf: 296 c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) 297 fdef = function_pb2.FunctionDef() 298 proto_data = c_api.TF_GetBuffer(buf) 299 fdef.ParseFromString(compat.as_bytes(proto_data)) 300 return fdef 301 return self._definition 302 303 @property 304 def _signature(self): 305 self._create_definition_if_needed() 306 return self._op_def 307 308 def set_grad_func(self, grad_func): 309 """Specifies the gradient function of this function.""" 310 assert not self._grad_func 311 assert isinstance(grad_func, _DefinedFunction) 312 self._grad_func = grad_func 313 314 @property 315 def grad_func_name(self): 316 """Returns the name of the gradient function.""" 317 return self._grad_func.name if self._grad_func else None 318 319 @property 320 def python_grad_func(self): 321 """Python gradient function callable.""" 322 return self._python_grad_func 323 324 @property 325 def declared_input_types(self): 326 """Returns the list of data types of explicit declared inputs.""" 327 return self._input_types 328 329 @property 330 def captured_inputs(self): 331 """Returns the list of implicitly captured inputs.""" 332 self._create_definition_if_needed() 333 return self._extra_inputs 334 335 @property 336 def stateful_ops(self): 337 """Returns the list of stateful ops in function definition. 338 339 Returns: 340 A list of (op.name, op.type) pairs. 341 """ 342 self._create_definition_if_needed() 343 return self._stateful_ops 344 345 def _create_definition_if_needed(self): 346 """Creates the function definition if it's not created yet.""" 347 with context.graph_mode(): 348 self._create_definition_if_needed_impl() 349 350 def _create_definition_if_needed_impl(self): 351 """This is not what you want, see _create_definition_if_needed.""" 352 if self._definition is not None or self._c_func is not None: 353 return 354 355 temp_graph = func_graph_from_py_func( 356 self._func, 357 self._arg_names, 358 self._arg_types, 359 self._func_name, 360 self._capture_by_value, 361 self._caller_device, 362 whitelisted_stateful_ops=self._whitelisted_stateful_ops, 363 capture_resource_var_by_value=self._capture_resource_var_by_value) 364 365 self._extra_inputs = temp_graph.extra_inputs 366 # pylint: disable=protected-access 367 self._sub_functions = temp_graph._functions 368 # pylint: enable=protected-access 369 370 # Extra kwargs are treated as attrs on the function def. 371 if self._func_name: 372 base_func_name = self._func_name 373 else: 374 base_func_name = function_utils.get_func_name(self._func) 375 if self._grad_func: 376 base_func_name += ("_%s" % self._grad_func.name) 377 kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) 378 379 if not temp_graph._c_graph: # pylint: disable=protected-access 380 # Build the FunctionDef 381 self._definition = graph_to_function_def.graph_to_function_def( 382 temp_graph, 383 temp_graph.get_operations(), 384 temp_graph.inputs, 385 temp_graph.outputs, 386 out_names=self._out_names) 387 388 for k in kwargs_attr: 389 self._definition.attr[k].CopyFrom(kwargs_attr[k]) 390 391 # Hash the definition and its dependencies. 392 self._hash_str = self._create_hash_str( 393 self._definition.signature.input_arg, 394 self._definition.signature.output_arg, self._definition.node_def) 395 396 # Finally, we decide the function name to use. If not specified, 397 # make up something which is almost certainly unique (but deterministic). 398 if not self._func_name: 399 self._func_name = "_".join([base_func_name, self._hash_str]) 400 self._definition.signature.name = self._func_name 401 if self._func.__doc__: 402 self._definition.signature.description = self._func.__doc__ 403 404 self._op_def = self._definition.signature 405 else: # C API is enabled 406 output_names = ([compat.as_bytes(x) for x in self._out_names] 407 if self._out_names else []) 408 description = self._func.__doc__ or None 409 # pylint: disable=protected-access 410 c_func = c_api.TF_GraphToFunction_wrapper( 411 temp_graph._c_graph, 412 base_func_name, 413 self._func_name is None, # append_hash_to_fn_name 414 None, # opers 415 [t._as_tf_output() for t in temp_graph.inputs], 416 [t._as_tf_output() for t in temp_graph.outputs], 417 output_names, 418 [], # control_outputs 419 [], # control_output_names 420 None, # opts 421 description) 422 self._c_func = c_api_util.ScopedTFFunction(c_func) 423 # pylint: enable=protected-access 424 self._set_c_attrs(kwargs_attr) 425 426 # Set cached fields: _op_def and _func_name (if not already set) 427 self._op_def = self.definition.signature 428 if self._func_name: 429 assert self._func_name == self._op_def.name 430 else: 431 self._func_name = compat.as_str(self._op_def.name) 432 433 self._stateful_ops = [(op.name, op.type) 434 for op in temp_graph.get_operations() 435 if op.op_def.is_stateful] 436 437 def _set_c_attrs(self, attrs): 438 """Sets `attrs` as attributes of self._c_func. 439 440 Requires that self._c_func is not None. 441 442 Args: 443 attrs: a dictionary from attribute name to attribute proto value 444 """ 445 for name, attr_value in attrs.items(): 446 serialized = attr_value.SerializeToString() 447 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 448 # It might be worth creating a convenient way to re-use the same status. 449 c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name), 450 serialized) 451 452 def _create_hash_str(self, input_arg, output_arg, node_def): 453 """Creates an 8-character string unique to this input. 454 455 Args: 456 input_arg: the input_arg field of an OpDef 457 (e.g. self._definition.signature.input_arg) 458 output_arg: the output_arg field of an OpDef 459 (e.g. self._definition.signature.output_arg) 460 node_def: the node_def field of a FunctionDef 461 (e.g. self._definition.node_def) 462 463 Returns: 464 The unique string for this input 465 """ 466 hasher = hashlib.sha1() 467 468 def update_num(n): 469 hasher.update(compat.as_bytes("%x" % n)) 470 471 def update_str(s): 472 update_num(len(s)) 473 hasher.update(compat.as_bytes(s)) 474 475 def update_strs(slist): 476 update_num(len(slist)) 477 for s in slist: 478 update_str(s) 479 480 for adef in input_arg: 481 update_str(adef.SerializeToString()) 482 483 for adef in output_arg: 484 update_str(adef.SerializeToString()) 485 486 for n in sorted(node_def, key=lambda n: n.name): 487 update_str(n.name) 488 update_str(n.op) 489 update_strs(n.input) 490 update_num(len(n.attr)) 491 # NOTE: protobuf map serialization does not guarantee ordering. 492 for k in sorted(n.attr): 493 update_str(k) 494 update_str(n.attr[k].SerializeToString()) 495 496 return hasher.hexdigest()[:8] 497 498 def add_to_graph(self, g): 499 """Adds this function into the graph g.""" 500 self._create_definition_if_needed() 501 502 # Adds this function into 'g'. 503 # pylint: disable=protected-access 504 if context.executing_eagerly(): 505 context.context().add_function_def(self.definition) 506 else: 507 g._add_function(self) 508 # pylint: enable=protected-access 509 510 # Ensures related sub-routines are defined in 'g', too. 511 for f in self._sub_functions.values(): 512 f.add_to_graph(g) 513 514 # Adds its gradient function, too. 515 if self._grad_func: 516 self._grad_func.add_to_graph(g) 517 518 def __call__(self, *args, **kwargs): 519 self.add_to_graph(ops.get_default_graph()) 520 args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs 521 ret, op = _call(self._signature, *args, **kwargs) 522 523 # Set a hidden attr in 'op' so that gradients_impl can refer back 524 # to this _DefinedFunction instance to access python_grad_func. 525 assert isinstance(op, ops.Operation) 526 setattr(op, "__defun", self) 527 528 if self._shape_func is not None: 529 shapes = self._shape_func(op) 530 if len(shapes) != len(op.outputs): 531 raise ValueError("shape_func produced %d shapes for %d outputs" % 532 (len(shapes), len(op.outputs))) 533 for (t, shape) in zip(op.outputs, shapes): 534 t.set_shape(shape) 535 return ret 536 537 538class _OverloadedFunction(object): 539 """_OverloadedFunction encapsulates an overloaded function. 540 541 _OverloadedFunction maintains a mapping from input types to 542 instantiated _DefinedFunction in self._overload. 543 544 """ 545 546 def __init__(self, 547 func, 548 argnames, 549 func_name=None, 550 grad_func=None, 551 python_grad_func=None, 552 out_names=None, 553 **kwargs): 554 """Creates _DefinedFunction. 555 556 Args: 557 func: A python callable which constructs a tf function body. 558 argnames: A list of strings for function argument names. 559 func_name: The function name. Defaults to None, in which derives from 560 'func'. 561 grad_func: This function's gradient function, if not None. Defaults 562 to None. 563 python_grad_func: A python callable implementing the gradient of 564 the function python-side. 565 out_names: A list of strings for the function return value names. 566 **kwargs: The keyword arguments. **kwargs is passed to every call 567 site of this function. 568 569 Raises: 570 ValueError: The function definition is invalid. 571 572 """ 573 self._func = func 574 self._argnames = argnames 575 self._func_name = func_name 576 assert grad_func is None or isinstance(grad_func, _OverloadedFunction) 577 self._grad_func = grad_func 578 self._python_grad_func = python_grad_func 579 self._out_names = out_names 580 self._extra_kwargs = kwargs 581 self._overload = {} 582 583 def instantiate(self, input_types): 584 """Instantiate this function given input argument types. 585 586 Args: 587 input_types: A list of data types for the inputs. 588 589 Returns: 590 _DefinedFunction for the given input types. 591 592 """ 593 # Stringify the type list. 594 key = _type_list_to_str(input_types) 595 defined = self._overload.get(key) 596 if not defined: 597 # If not defined yet, define the function given the input types. 598 name = self._func_name 599 if name is not None: 600 name = "_".join([name, key]) 601 defined = _DefinedFunction( 602 self._func, 603 self._argnames, 604 input_types, 605 name, 606 None, 607 self._python_grad_func, 608 out_names=self._out_names, 609 **self._extra_kwargs) 610 _ = defined.name # Fully instantiate the function definition. 611 if self._grad_func: 612 # If _grad_func is given, it is another 613 # _OverloadedFunction. We need to instantiate it with the 614 # right input types. 615 output_types = [ 616 dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access 617 ] 618 # pylint: disable=protected-access 619 defined._grad_func = self._grad_func.instantiate(input_types + 620 output_types) 621 # pylint: enable=protected-access 622 self._overload[key] = defined 623 return defined 624 625 def __call__(self, *args, **kwargs): 626 input_types = [] 627 args = list(args) 628 for (i, x) in enumerate(args): 629 x = ops.convert_to_tensor(x) 630 if not isinstance(x, ops.Tensor): 631 raise ValueError("Expect a Tensor but get ", x) 632 input_types.append(x.dtype) 633 args[i] = x 634 return self.instantiate(input_types)(*args, **kwargs) 635 636 637class _FuncGraph(ops.Graph): 638 """A helper for constructing a function. 639 640 _FuncGraph overrides ops.Graph's create_op() so that we can keep 641 track of all inputs into every op created inside the function. If 642 any input is from other graphs, we keep track of it in self.capture 643 and substitute the input with a place holder. 644 645 Each captured input's corresponding place holder is converted into a 646 function argument and the caller passes in the captured tensor. 647 """ 648 649 def __init__(self, name, capture_by_value, whitelisted_stateful_ops, 650 capture_resource_var_by_value, *args, **kwargs): 651 super(_FuncGraph, self).__init__(*args, **kwargs) 652 self._capture_by_value = capture_by_value 653 self._whitelisted_stateful_ops = whitelisted_stateful_ops 654 self._capture_resource_var_by_value = capture_resource_var_by_value 655 self._building_function = True 656 self._outer_graph = ops.get_default_graph() 657 self._vscope = vs.get_variable_scope() 658 self._old_custom_getter = self._vscope.custom_getter 659 660 # The name of the function. 661 self.name = name 662 # Placeholder tensors representing the inputs to this function. The tensors 663 # are in this _FuncGraph. 664 self.inputs = [] 665 # Tensors that will be returned this function. The tensors are in this 666 # _FuncGraph. 667 self.outputs = [] 668 # Maps external tensor -> internal tensor (e.g. input placeholder). 669 self._captured = {} 670 # The external tensors that have been captured as inputs and must be passed 671 # to this function (empty if capturing by value, otherwise these are the 672 # keys of _captured). 673 self.extra_inputs = [] 674 # Input placeholders that been added for captured values (empty if capturing 675 # by value). 676 self.extra_args = [] 677 # Captured variables. 678 # TODO(skyewm): is this needed? 679 self.extra_vars = [] 680 681 # pylint: disable=g-doc-return-or-yield 682 683 @tf_contextlib.contextmanager 684 def container(self, container_name): 685 """Returns a context manager that specifies the resource container to use. 686 687 Overridden from `tf.Graph` to update both the init_scope container 688 and the present inner container. This is necessary to make sure setting 689 containers applies correctly both to created variables and to stateful 690 ops. 691 692 Args: 693 container_name: container name string. 694 695 Returns: 696 A context manager for defining resource containers for stateful ops, 697 yields the container name. 698 """ 699 original_container = self._container 700 # pylint: disable=protected-access 701 with ops.init_scope(): 702 original_init_container = ops.get_default_graph()._container 703 try: 704 self._container = container_name 705 with ops.init_scope(): 706 ops.get_default_graph()._container = container_name 707 yield self._container 708 finally: 709 self._container = original_container 710 with ops.init_scope(): 711 ops.get_default_graph()._container = original_init_container 712 # pylint: enable=protected-access 713 714 # pylint: enable=g-doc-return-or-yield 715 716 def getvar( 717 self, 718 getter, 719 name, 720 shape=None, 721 dtype=None, 722 initializer=None, 723 reuse=None, 724 trainable=True, 725 collections=None, # pylint: disable=redefined-outer-name 726 use_resource=None, 727 **kwargs): 728 """A custom variable getter.""" 729 # Here, we switch the default graph to the outer graph and ask the 730 # variable scope in which the function is defined to give us the 731 # variable. The variable is stashed in extra_vars and returned to 732 # the caller. 733 # 734 # We capture these variables so that the variable definition is 735 # hoisted upward to the outer most graph. 736 with self._outer_graph.as_default(): 737 # pylint: disable=protected-access 738 var = self._vscope.get_variable( 739 vs._get_default_variable_store(), 740 name, 741 shape=shape, 742 dtype=dtype, 743 initializer=initializer, 744 reuse=reuse, 745 trainable=trainable, 746 collections=collections, 747 use_resource=use_resource) 748 self.extra_vars.append(var) 749 if (isinstance(var, resource_variable_ops.ResourceVariable) and 750 self._capture_resource_var_by_value): 751 # For resource-based variables read the variable outside the function 752 # and pass in the value. This ensures that the function is pure and 753 # differentiable. TODO(apassos) this may have performance problems if 754 # the function will only do embedding lookups on the variable. 755 return var.value() 756 return var 757 758 def create_op(self, op_type, inputs, dtypes=None, **kwargs): # pylint: disable=redefined-outer-name 759 for i, x in enumerate(inputs): 760 if isinstance(x, ops.EagerTensor) or x.graph is not self: 761 inputs[i] = self.capture(x) 762 return super(_FuncGraph, self).create_op(op_type, inputs, 763 dtypes=dtypes, **kwargs) 764 765 def capture(self, tensor, name=None): 766 """Adds the given tensor to this graph and returns the captured tensor.""" 767 if tensor in self._captured: 768 # Captured already. 769 return self._captured[tensor] 770 elif self._capture_by_value: 771 return self._add_tensor_and_parents(tensor) 772 else: 773 return self._capture_tensor_as_extra_input(tensor, name) 774 775 def _capture_tensor_as_extra_input(self, tensor, name=None): 776 # Substitute with a placeholder. 777 self.extra_inputs.append(tensor) 778 # Hoist the new input placeholder out of any control flow context 779 # we're currently in. 780 with ops.control_dependencies(None): 781 ph = array_ops.placeholder( 782 tensor.dtype, shape=tensor.get_shape(), name=name) 783 # pylint: disable=protected-access 784 if isinstance(tensor, ops.EagerTensor): 785 handle_data = tensor._handle_data 786 if handle_data: 787 handle_data = handle_data.SerializeToString() 788 else: 789 handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph, 790 tensor._as_tf_output()) 791 792 if handle_data: 793 c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), 794 compat.as_bytes(handle_data)) 795 # pylint: enable=protected-access 796 self.inputs.append(ph) 797 self._captured[tensor] = ph 798 self.extra_args.append(ph) 799 if _is_guaranteed_const(tensor): 800 with ops.control_dependencies(None): 801 return array_ops.guarantee_const(ph) 802 else: 803 return ph 804 805 def _add_tensor_and_parents(self, tensor): 806 op = self._add_op_and_parents(tensor.op) 807 return op.outputs[tensor.value_index] 808 809 def _add_op_and_parents(self, op): 810 # pylint: disable=protected-access 811 op_def = graph_to_function_def._get_op_def(op) 812 # pylint: enable=protected-access 813 if op_def.is_stateful and op not in self._whitelisted_stateful_ops: 814 raise ValueError("Cannot capture a stateful node (name:%s, type:%s) " 815 "by value." % (op.name, op.type)) 816 elif op.type in ("Placeholder", "PlaceholderV2"): 817 raise ValueError("Cannot capture a placeholder (name:%s, type:%s) " 818 "by value." % (op.name, op.type)) 819 820 captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs] 821 822 captured_op = self.create_op( 823 op.type, 824 captured_inputs, [o.dtype for o in op.outputs], 825 name=op.name, 826 attrs=op.node_def.attr, 827 op_def=op_def) 828 829 for t, captured_t in zip(op.outputs, captured_op.outputs): 830 self._captured[t] = captured_t 831 832 return captured_op 833 834 835def func_graph_from_py_func(func, 836 arg_names, 837 arg_types, 838 name=None, 839 capture_by_value=False, 840 device=None, 841 colocation_stack=None, 842 container=None, 843 collections_ref=None, 844 arg_shapes=None, 845 whitelisted_stateful_ops=None, 846 capture_resource_var_by_value=True): 847 """Returns a _FuncGraph generated from `func`. 848 849 Args: 850 func: A Python callable which constructs a TF function body. The arguments 851 must correspond to `arg_types`. Returns a value or list/tuple of values. 852 No returned value can be None. 853 arg_names: A sequence of strings for the function argument names. 854 arg_types: A sequence of the function's argument types. 855 name: The function name. If None, the name is derived from `func`. 856 capture_by_value: boolean. If True, captured values will be copied into the 857 function body. 858 device: device name or function. 859 colocation_stack: A colocation stack (list) the _FuncGraph should use. 860 container: A container name the _FuncGraph should start with. 861 collections_ref: A reference to a collections dict the _FuncGraph should 862 use internally. 863 arg_shapes: A sequence of the function's argument shapes. 864 whitelisted_stateful_ops: A set of ops that if stateful we ignore and 865 re-create. 866 capture_resource_var_by_value: Boolean (defaults to True). If False, 867 captured resource variable returns the handle instead of value. 868 869 Returns: 870 A _FuncGraph. 871 872 Raises: 873 ValueError: if func returns None. 874 """ 875 if not name: 876 name = function_utils.get_func_name(func) 877 func_graph = _FuncGraph(name, capture_by_value, whitelisted_stateful_ops, 878 capture_resource_var_by_value) 879 880 with func_graph.as_default(), ops.device(device): 881 # pylint: disable=protected-access 882 if collections_ref is not None: 883 func_graph._collections = collections_ref 884 if container is not None: 885 func_graph._container = container 886 if colocation_stack is not None: 887 func_graph._colocation_stack = colocation_stack 888 # pylint: enable=protected-access 889 890 if arg_shapes is None: 891 arg_shapes = [None] * len(arg_types) 892 893 # Create placeholders for the function arguments. 894 for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes): 895 argholder = array_ops.placeholder(argtype, shape=argshape, name=argname) 896 func_graph.inputs.append(argholder) 897 # Call func and gather the output tensors. 898 with vs.variable_scope("", custom_getter=func_graph.getvar): 899 outputs = func(*func_graph.inputs) 900 901 # There is no way of distinguishing between a function not returning 902 # anything and a function returning None in Python. 903 # We need to allow the former and ideally want to forbid the latter as 904 # it is most likely user error. 905 # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to 906 # allow users to explicitly mark the function as not returning anything. 907 # For now, we allow a single None return and interpret it as a function 908 # with no output. 909 if outputs is None: 910 outputs = [] 911 else: 912 # If func only returned one value, make it a tuple. 913 if not isinstance(outputs, (list, tuple)): 914 outputs = (outputs,) 915 if any(_ is None for _ in outputs): 916 raise ValueError("Function %s can not return None." % name) 917 # Ensures each output is a Tensor in the function graph. 918 outputs = [ops.convert_to_tensor(t) for t in outputs] 919 outputs = [func_graph.capture(t) if t.graph is not func_graph else t 920 for t in outputs] 921 func_graph.outputs = outputs 922 return func_graph 923 924 925def _is_guaranteed_const(tensor): 926 """Determines whether `tensor` is guaranteed to be a constant. 927 928 A tensor is guaranteed to be a constant if either it was produced by 929 a `GuaranteeConst` op or if all of its children are guaranteed to be 930 constants. 931 932 Args: 933 tensor: The tensor for which to determine const-ness. 934 935 Returns: 936 True if `tensor` is guaranteed to be a constant, False otherwise. 937 """ 938 939 if isinstance(tensor, ops.EagerTensor): 940 return False 941 942 class Work(object): 943 944 def __init__(self, op, leaving): 945 self.op = op 946 self.leaving = leaving 947 948 is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst" 949 constants = set([]) 950 def all_inputs_const(op): 951 # If all inputs of an op are guaranteed constants, then we can infer that 952 # the op produces a constant as well. 953 return op.inputs and all(inp.op in constants for inp in op.inputs) 954 955 visited = set([]) 956 stack = [Work(tensor.op, leaving=False)] 957 while stack: 958 work = stack.pop() 959 if work.leaving: 960 if all_inputs_const(work.op): 961 constants.add(work.op) 962 continue 963 visited.add(work.op) 964 if is_guaranteed_const(work.op): 965 constants.add(work.op) 966 continue 967 968 # This op will be revisited after all its inputs are checked for const-ness. 969 stack.append(Work(work.op, leaving=True)) 970 for inp in work.op.inputs: 971 if inp.op not in visited: 972 stack.append(Work(inp.op, leaving=False)) 973 return tensor.op in constants 974 975 976def _call(sig, *inputs, **kwargs): 977 """Adds a node calling a function. 978 979 This adds a `call` op to the default graph that calls the function 980 of signature `sig`, passing the tensors in `inputs` as arguments. 981 It returns the outputs of the call, which are one or more tensors. 982 983 `sig` is OpDefArg.a `_DefinedFunction` object. 984 985 You can pass an optional keyword parameter `name=string` to name the 986 added operation. 987 988 You can pass an optional keyword parameter `noinline=True|False` to 989 instruct the runtime not to inline the function body into the call 990 site. 991 992 Args: 993 sig: OpDefArg. The signature of the function. 994 *inputs: arguments to the function. 995 **kwargs: Optional keyword arguments. Can only contain 'name' or 996 'noinline'. 997 998 Returns: 999 A 2-element tuple. First element: a Tensor if the function returns a single 1000 value; a list of Tensors if the function returns multiple value; the 1001 Operation if the function returns no values. Second element: the Operation. 1002 1003 Raises: 1004 ValueError: if the arguments are invalid. 1005 """ 1006 if len(inputs) != len(sig.input_arg): 1007 raise ValueError("Expected number of arguments: %d, received: %d" % (len( 1008 sig.input_arg), len(inputs))) 1009 name = kwargs.pop("name", None) 1010 g = ops.get_default_graph() 1011 func_name = sig.name 1012 if name is None: 1013 name = func_name 1014 attrs = _parse_kwargs_as_attrs(func_name, **kwargs) 1015 output_types = [dtypes.DType(x.type) for x in sig.output_arg] 1016 op = g.create_op( 1017 func_name, 1018 list(inputs), 1019 output_types, 1020 name=name, 1021 attrs=attrs, 1022 op_def=sig, 1023 compute_shapes=False) 1024 if op.outputs: 1025 if len(op.outputs) == 1: 1026 ret = op.outputs[0] 1027 else: 1028 ret = tuple(op.outputs) 1029 else: 1030 ret = op 1031 return ret, op 1032 1033 1034def _from_definition(fdef, grad_func=None): 1035 """Creates a _DefinedFunction initialized from a FunctionDef proto. 1036 1037 Args: 1038 fdef: a FunctionDef 1039 grad_func: a _DefinedFunction or None 1040 1041 Returns: 1042 A _DefinedFunction representing fdef 1043 """ 1044 # TODO(iga): This method does major surgery on _DefinedFunction. 1045 # Make it a named constructor using @classmethod of _DefinedFunction. 1046 1047 # The Python callable is only needed to create a FunctionDef. Since we have 1048 # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we 1049 # have access to such a callable here). 1050 func = None 1051 argnames = [arg.name for arg in fdef.signature.input_arg] 1052 input_types = tuple( 1053 dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) 1054 func_name = fdef.signature.name 1055 # Note: FunctionDefs do not include python gradient functions, so if the 1056 # original _DefinedFunction included one it will not be reflected here. 1057 python_grad_func = None 1058 out_names = [arg.name for arg in fdef.signature.output_arg] 1059 result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, 1060 python_grad_func, out_names) 1061 # pylint: disable=protected-access 1062 serialized = fdef.SerializeToString() 1063 c_func = c_api.TF_FunctionImportFunctionDef(serialized) 1064 result._c_func = c_api_util.ScopedTFFunction(c_func) 1065 result._extra_inputs = [] 1066 result._op_def = fdef.signature 1067 # pylint: enable=protected-access 1068 1069 return result 1070 1071 1072def from_library(lib): 1073 """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. 1074 1075 This method handles assigning the correct gradient functions to each 1076 function. 1077 1078 Args: 1079 lib: a FunctionDefLibrary 1080 1081 Returns: 1082 A list of _DefinedFunctions 1083 1084 Raises: 1085 ValueError: `lib` is invalid 1086 """ 1087 if not lib.function and not lib.gradient: 1088 return [] 1089 1090 # function name -> FunctionDef proto 1091 funcs = {fdef.signature.name: fdef for fdef in lib.function} 1092 1093 # Validate that all references function names have function defs 1094 for g in lib.gradient: 1095 if g.function_name not in funcs: 1096 raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" % 1097 (g.function_name, str(lib))) 1098 if g.gradient_func not in funcs: 1099 raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" % 1100 (g.gradient_func, str(lib))) 1101 1102 # function name -> gradient function name 1103 func_to_grad = collections.defaultdict(lambda: None) 1104 # gradient function name -> names of functions having that grad function 1105 grad_to_funcs = collections.defaultdict(list) 1106 1107 for gdef in lib.gradient: 1108 func_to_grad[gdef.function_name] = gdef.gradient_func 1109 grad_to_funcs[gdef.gradient_func].append(gdef.function_name) 1110 1111 # Start with functions without gradients 1112 ready = [ 1113 fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None 1114 ] 1115 if not ready: 1116 raise ValueError( 1117 "FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib)) 1118 # function name -> _DefinedFunction 1119 initialized = {} 1120 1121 while ready: 1122 fdef = ready.pop() 1123 name = fdef.signature.name 1124 1125 grad = initialized.get(func_to_grad[name]) 1126 if func_to_grad[name]: 1127 assert grad 1128 defined_func = _from_definition(fdef, grad_func=grad) 1129 initialized[name] = defined_func 1130 1131 ready.extend(funcs[f] for f in grad_to_funcs[name]) 1132 1133 return initialized.values() 1134 1135 1136def _get_experimental_kwarg_as_attr(attr_name, value): 1137 """Creates an AttrValue for a python object.""" 1138 if isinstance(value, bool): 1139 return attr_value_pb2.AttrValue(b=value) 1140 elif isinstance(value, int): 1141 return attr_value_pb2.AttrValue(i=value) 1142 elif isinstance(value, float): 1143 return attr_value_pb2.AttrValue(f=value) 1144 elif isinstance(value, str): 1145 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 1146 else: 1147 raise ValueError("Unsupported attribute type for %s with type %s" % 1148 (attr_name, type(value))) 1149 1150 1151def _parse_kwargs_as_attrs(func_name, **kwargs): 1152 """Parses **kwargs into a node's attributes.""" 1153 attrs = {} 1154 1155 noinline = kwargs.pop("noinline", None) 1156 if noinline is not None: 1157 attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) 1158 1159 compiled = kwargs.pop("compiled", None) 1160 separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None) 1161 if compiled is not None: 1162 attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled)) 1163 attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue( 1164 b=bool(separate_compiled_gradients)) 1165 # Forward _XlaScope from enclosing context (if set), otherwise create new. 1166 # pylint: disable=protected-access 1167 if "_XlaScope" in ops.get_default_graph()._attr_scope_map: 1168 attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"] 1169 else: 1170 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 1171 s=("function_%s" % func_name).encode()) 1172 # pylint: enable=protected-access 1173 1174 kwargs_keys = list(kwargs.keys()) 1175 for key in kwargs_keys: 1176 if key.startswith("experimental_"): 1177 attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key]) 1178 del kwargs[key] 1179 1180 if kwargs: 1181 raise ValueError("Unknown keyword arguments: %s" % kwargs.keys()) 1182 return attrs 1183 1184 1185def get_extra_vars(): 1186 """Returns the captured variables by the function. 1187 1188 Returns: 1189 If the default graph is being used to define a function, the 1190 returned list of variables are those created inside the function 1191 body so far. Otherwise, returns an empty list. 1192 """ 1193 g = ops.get_default_graph() 1194 if isinstance(g, _FuncGraph): 1195 return g.extra_vars 1196 else: 1197 return [] 1198 1199 1200def get_extra_inputs(): 1201 """Returns the captured input tensors by the function. 1202 1203 Returns: 1204 If the default graph is being used to define a function, the 1205 returned list of tensors are those accessed inside the function body 1206 but defined outside the function body so far. Otherwise, returns an 1207 empty list. 1208 """ 1209 g = ops.get_default_graph() 1210 if isinstance(g, _FuncGraph): 1211 return g.extra_inputs 1212 else: 1213 return [] 1214 1215 1216def get_extra_args(): 1217 """Returns the corresponding function arguments for the captured inputs. 1218 1219 Returns: 1220 If the default graph is being used to define a function, the 1221 returned list of place holders are those used inside the function 1222 body corresponding those returned by get_extra_inputs(). Otherwise, 1223 returns an empty list. 1224 """ 1225 g = ops.get_default_graph() 1226 if isinstance(g, _FuncGraph): 1227 return g.extra_args 1228 else: 1229 return [] 1230 1231 1232def _type_list_to_str(types): 1233 if any(_ not in _DTYPE_TO_STR for _ in types): 1234 raise ValueError("Unsupported dtypes: %s" % types) 1235 return "".join([_DTYPE_TO_STR[_] for _ in types]) 1236 1237 1238# NOTE: The list needs to be extended when more data types are added. 1239_DTYPE_TO_STR = { 1240 dtypes.float16: "f16", 1241 dtypes.float32: "f32", 1242 dtypes.float64: "f64", 1243 dtypes.int32: "i32", 1244 dtypes.uint8: "i8", 1245 dtypes.uint16: "u16", 1246 dtypes.uint32: "u32", 1247 dtypes.uint64: "u64", 1248 dtypes.int16: "i16", 1249 dtypes.int8: "i8", 1250 dtypes.string: "s", 1251 dtypes.complex64: "c64", 1252 dtypes.complex128: "c128", 1253 dtypes.int64: "i64", 1254 dtypes.bool: "b", 1255 dtypes.qint8: "qi8", 1256 dtypes.quint8: "qu8", 1257 dtypes.qint16: "qi16", 1258 dtypes.quint16: "qu16", 1259 dtypes.qint32: "qi32", 1260 dtypes.bfloat16: "b16" 1261} 1262 1263 1264def function_def_from_tf_function(c_func): 1265 """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto.""" 1266 with c_api_util.tf_buffer() as buf: 1267 c_api.TF_FunctionToFunctionDef(c_func, buf) 1268 data = c_api.TF_GetBuffer(buf) 1269 fdef = function_pb2.FunctionDef() 1270 fdef.ParseFromString(compat.as_bytes(data)) 1271 return fdef 1272