1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Python front-end supports for functions. 16 17NOTE: functions are currently experimental and subject to change! 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import collections 25import hashlib 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.framework import function_pb2 29from tensorflow.python import pywrap_tensorflow as c_api 30from tensorflow.python.eager import context 31from tensorflow.python.framework import c_api_util 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import graph_to_function_def 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import variable_scope as vs 39from tensorflow.python.util import compat 40from tensorflow.python.util import tf_decorator 41from tensorflow.python.util import tf_inspect 42 43 44class Defun(object): 45 """Decorator used to define TensorFlow functions. 46 47 Use this decorator to make a Python function usable directly as a TensorFlow 48 function. 49 50 The decorated function must add ops to the default graph and return zero or 51 more `Tensor` objects. Call the decorator with named arguments, one for each 52 argument of the function to decorate, with the expected type of the argument 53 as value. 54 55 For example if the function to decorate accepts two `tf.float32` arguments 56 named `x` and `y`, call the decorator with: 57 58 @Defun(tf.float32, tf.float32) 59 def foo(x, y): 60 ... 61 62 When you call the decorated function it will add `call` ops to the 63 default graph and adds the definition of the function into the 64 default graph. Because the addition of the function into the graph 65 is deferred, the decorator can be used anywhere in the program. 66 67 Any variables created inside of the function are hoisted into the outer graph. 68 Note that the variables are created in the variable scope that was active 69 during the first call to the function. Subsequent function calls will refer to 70 the same set of variables. 71 72 Definitions of functions are frozen in a graph as soon as the graph is used to 73 create a session. Therefore, nodes using the function must be created in the 74 graph before the corresponding session is created. 75 76 Example, but also see the [How To on functions](link_needed). 77 78 ```python 79 # Defining the function. 80 @tf.Defun(tf.float32, tf.float32) 81 def MyFunc(x, y): 82 return x + y, x - y 83 84 # Building the graph. 85 a = tf.constant([1.0]) 86 b = tf.constant([2.0]) 87 c, d = MyFunc(a, b, name='mycall') 88 ``` 89 """ 90 91 def __init__(self, *input_types, **kwargs): 92 """Create a `Defun` decorator. 93 94 Args: 95 *input_types: A list of `tf.DType` 96 **kwargs: Optional keyword arguments, including 97 func_name - (optional). A python string, the name to use to 98 declare this `Function` in the graph. 99 100 grad_func - (optional). A function implementing the gradient 101 of the function-to-register. This is must be a 102 `_DefinedFunction` object. The gradient 103 function must satisfy the criterion defined in 104 function.proto:GradientDef. 105 106 python_grad_func - (optional). A function implementing the 107 gradient of the function python-side. This function must 108 take the current op and the gradients w.r.t. its outputs, 109 and return the gradients w.r.t. the inputs. That is it must 110 implement the interface expected by `tf.RegisterGradient`). 111 This will be called by tf.gradients to add the gradient ops 112 to the graph. At most one of grad_func and python_grad_func 113 can be specified. 114 115 out_names = (optional). A list of strings, one per output 116 tensor. 117 118 shape_func - (optional). A function taking the op and returning a list 119 of static shapes to set for the function's outputs. 120 """ 121 self._input_types = input_types 122 self._func_name = kwargs.pop("func_name", None) 123 self._grad_func = kwargs.pop("grad_func", None) 124 self._python_grad_func = kwargs.pop("python_grad_func", None) 125 self._out_names = kwargs.pop("out_names", None) 126 self._extra_kwargs = kwargs 127 128 def __call__(self, func): 129 # Various sanity checks on the callable func. 130 if not callable(func): 131 raise ValueError("func %s must be callable" % func) 132 133 # Func should not use kwargs and defaults. 134 argspec = tf_inspect.getargspec(func) 135 if argspec.keywords or argspec.defaults: 136 raise ValueError("Functions with argument defaults or keyword " 137 "arguments are not supported.") 138 139 # Computes how many arguments 'func' has. 140 min_args = len(argspec.args) 141 max_args = min_args 142 if argspec.varargs: 143 max_args = 1000000 144 argnames = argspec.args 145 if tf_inspect.ismethod(func): 146 # 1st argument is the "class" type. 147 min_args -= 1 148 argnames = argnames[1:] 149 150 if self._input_types: 151 # If Defun is given a list of types for the inputs, the number 152 # of input types should be compatible with 'func'. 153 num = len(self._input_types) 154 if num < min_args or num > max_args: 155 raise ValueError( 156 "The function has fewer arguments than the number of specified " 157 "input types.") 158 return _DefinedFunction( 159 func, 160 argnames, 161 self._input_types, 162 self._func_name, 163 self._grad_func, 164 self._python_grad_func, 165 out_names=self._out_names, 166 **self._extra_kwargs) 167 168 # 'func' expects no arguments and input types is an empty list. 169 if min_args == 0 and max_args == 0: 170 return _DefinedFunction( 171 func, [], [], 172 self._func_name, 173 self._grad_func, 174 self._python_grad_func, 175 out_names=self._out_names, 176 **self._extra_kwargs) 177 178 # Input types are unknown. It's an overloaded function and hence 179 # its definition needs to be deferred until it's called. 180 return _OverloadedFunction( 181 func, 182 argnames, 183 self._func_name, 184 self._grad_func, 185 self._python_grad_func, 186 out_names=self._out_names, 187 **self._extra_kwargs) 188 189 190class _DefinedFunction(object): 191 """_DefinedFunction encapsulates a function definition and its properties. 192 193 Attributes: 194 name: The function name. 195 definition: The definition of this function. A FunctionDef proto. 196 grad_func_name: If not None, the name of this function's gradient function. 197 python_grad_func: A python callable implementing the gradient of 198 the function python-side. 199 """ 200 201 def __init__(self, 202 func, 203 argnames, 204 input_types, 205 func_name=None, 206 grad_func=None, 207 python_grad_func=None, 208 out_names=None, 209 shape_func=None, 210 capture_by_value=False, 211 **kwargs): 212 """Creates _DefinedFunction. 213 214 Args: 215 func: A python callable which constructs a tf function body. 216 argnames: A list of strings for function argument names. 217 input_types: The function's argument types. Can be a tuple, list of 218 tf data types. 219 func_name: The function name. Defaults to None, in which derives from 220 'func'. 221 grad_func: This function's gradient function, if not None. Defaults 222 to None. 223 python_grad_func: A python callable implementing the gradient of 224 the function python-side. 225 out_names: An optional list of strings for the function return value 226 names. 227 shape_func: An optional function mapping an op to a list of static 228 output shapes. 229 capture_by_value: Boolean (defaults to False). If True, captured values 230 will be copied into the function body. 231 **kwargs: The keyword arguments. **kwargs is passed to every call 232 site of this function. 233 234 Raises: 235 ValueError: The function definition is invalid. 236 237 """ 238 self._func = func 239 self._input_types = input_types 240 self._func_name = func_name 241 self._grad_func = grad_func 242 self._python_grad_func = python_grad_func 243 self._out_names = out_names 244 self._shape_func = shape_func 245 self._capture_by_value = capture_by_value 246 self._extra_kwargs = kwargs 247 # Constructed only when C API is disabled, lazily 248 self._definition = None 249 # Constructed only when C API is enabled, lazily 250 self._c_func = None 251 self._sub_functions = dict() # Constructed with _definition or _c_func 252 253 # Cached OpDef for this function. When C API is enabled, this is 254 # the only part of FunctionDef that we cache in Python. When C API 255 # is disabled the whole _definition is available and this is simply 256 # another reference to _definition.signature 257 self._op_def = None 258 259 self._args = [] 260 assert isinstance(input_types, (list, tuple)) 261 for i in range(len(input_types)): 262 argname = argnames[i] if i < len(argnames) else ("arg%d" % i) 263 argtype = input_types[i] 264 self._args.append((argname, argtype)) 265 266 @property 267 def name(self): 268 """Function name.""" 269 self._create_definition_if_needed() 270 return self._func_name 271 272 @property 273 def definition(self): 274 """Function definition proto.""" 275 self._create_definition_if_needed() 276 if self._c_func: 277 with c_api_util.tf_buffer() as buf: 278 with errors.raise_exception_on_not_ok_status() as status: 279 c_api.TF_FunctionToFunctionDef(self._c_func, buf, status) 280 fdef = function_pb2.FunctionDef() 281 proto_data = c_api.TF_GetBuffer(buf) 282 fdef.ParseFromString(compat.as_bytes(proto_data)) 283 return fdef 284 return self._definition 285 286 @property 287 def _signature(self): 288 self._create_definition_if_needed() 289 return self._op_def 290 291 def set_grad_func(self, grad_func): 292 """Specifies the gradient function of this function.""" 293 assert not self._grad_func 294 assert isinstance(grad_func, _DefinedFunction) 295 self._grad_func = grad_func 296 297 @property 298 def grad_func_name(self): 299 """Its gradient function's name.""" 300 return self._grad_func.name if self._grad_func else None 301 302 @property 303 def python_grad_func(self): 304 """Python gradient function callable.""" 305 return self._python_grad_func 306 307 @property 308 def declared_input_types(self): 309 """Returns the list of data types of explicit declared inputs.""" 310 return self._input_types 311 312 @property 313 def captured_inputs(self): 314 """Returns the list of implicitly captured inputs.""" 315 self._create_definition_if_needed() 316 return self._extra_inputs 317 318 def _create_definition_if_needed(self): 319 """Creates the function definition if it's not created yet.""" 320 with context.graph_mode(): 321 self._create_definition_if_needed_impl() 322 323 def _create_definition_if_needed_impl(self): 324 """This is not what you want, see _create_definition_if_needed.""" 325 if self._definition is not None or self._c_func is not None: 326 return 327 328 # Create the func_def object. 329 temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) 330 with temp_graph.as_default(): 331 # List of placeholders for the function_def. 332 inputs = [] 333 for (argname, argtype) in self._args: 334 argholder = array_ops.placeholder(argtype, name=argname) 335 inputs.append(argholder) 336 # Call func and gather the output tensors. 337 with vs.variable_scope("", custom_getter=temp_graph.getvar): 338 outputs = self._func(*inputs) 339 340 # There is no way of distinguishing between a function not returning 341 # anything and a function returning None in Python. 342 # We need to allow the former and ideally want to forbid the latter as 343 # it is most likely user error. 344 # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to 345 # allow users to explicitly mark the function as not returning anything. 346 # For now, we allow a single None return and interpret it as a function 347 # with no output. 348 if outputs is None: 349 outputs = [] 350 else: 351 # If func only returned one value, make it a tuple. 352 if not isinstance(outputs, (list, tuple)): 353 outputs = (outputs,) 354 if any([_ is None for _ in outputs]): 355 raise ValueError("Function can not return None.") 356 # Ensures each output is a Tensor. 357 outputs = [ops.convert_to_tensor(_) for _ in outputs] 358 self._extra_inputs = temp_graph.extra_inputs 359 inputs.extend(temp_graph.extra_args) 360 # pylint: disable=protected-access 361 self._sub_functions = temp_graph._functions 362 # pylint: enable=protected-access 363 364 # Extra kwargs are treated as attrs on the function def. 365 base_func_name = self._func_name or _get_func_name(self._func) 366 kwargs_attr = _parse_kwargs_as_attrs(base_func_name, 367 **self._extra_kwargs) 368 369 if not temp_graph._c_graph: # pylint: disable=protected-access 370 # Build the FunctionDef 371 self._definition = graph_to_function_def.graph_to_function_def( 372 temp_graph, 373 temp_graph.get_operations(), 374 inputs, 375 outputs, 376 out_names=self._out_names) 377 378 for k in kwargs_attr: 379 self._definition.attr[k].CopyFrom(kwargs_attr[k]) 380 381 # Hash the definition and its dependencies. 382 self._hash_str = self._create_hash_str( 383 self._definition.signature.input_arg, 384 self._definition.signature.output_arg, self._definition.node_def) 385 386 # Finally, we decide the function name to use. If not specified, 387 # make up something which is almost certainly unique (but deterministic). 388 if not self._func_name: 389 self._func_name = "_".join([base_func_name, self._hash_str]) 390 self._definition.signature.name = self._func_name 391 if self._func.__doc__: 392 self._definition.signature.description = self._func.__doc__ 393 394 self._op_def = self._definition.signature 395 else: # C API is enabled 396 output_names = ([compat.as_bytes(x) for x in self._out_names] 397 if self._out_names else []) 398 description = self._func.__doc__ or None 399 # pylint: disable=protected-access 400 with errors.raise_exception_on_not_ok_status() as status: 401 self._c_func = c_api.TF_GraphToFunction_wrapper( 402 temp_graph._c_graph, 403 base_func_name, 404 self._func_name is None, # append_hash_to_fn_name 405 None, # opers 406 [t._as_tf_output() for t in inputs], 407 [t._as_tf_output() for t in outputs], 408 output_names, 409 None, # opts 410 description, 411 status) 412 # pylint: enable=protected-access 413 self._set_c_attrs(kwargs_attr) 414 415 # Set cached fields: _op_def and _func_name (if not already set) 416 self._op_def = self.definition.signature 417 if self._func_name: 418 assert self._func_name == self._op_def.name 419 else: 420 self._func_name = compat.as_str(self._op_def.name) 421 422 def _set_c_attrs(self, attrs): 423 """Sets `attrs` as attributes of self._c_func. 424 425 Requires that self._c_func is not None. 426 427 Args: 428 attrs: a dictionary from attribute name to attribute proto value 429 """ 430 for name, attr_value in attrs.items(): 431 serialized = attr_value.SerializeToString() 432 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 433 # It might be worth creating a convenient way to re-use the same status. 434 with errors.raise_exception_on_not_ok_status() as status: 435 c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name), 436 serialized, status) 437 438 def _create_hash_str(self, input_arg, output_arg, node_def): 439 """Creates an 8-character string unique to this input. 440 441 Args: 442 input_arg: the input_arg field of an OpDef 443 (e.g. self._definition.signature.input_arg) 444 output_arg: the output_arg field of an OpDef 445 (e.g. self._definition.signature.output_arg) 446 node_def: the node_def field of a FunctionDef 447 (e.g. self._definition.node_def) 448 449 Returns: 450 The unique string for this input 451 """ 452 hasher = hashlib.sha1() 453 454 def update_num(n): 455 hasher.update(compat.as_bytes("%x" % n)) 456 457 def update_str(s): 458 update_num(len(s)) 459 hasher.update(compat.as_bytes(s)) 460 461 def update_strs(slist): 462 update_num(len(slist)) 463 for s in slist: 464 update_str(s) 465 466 for adef in input_arg: 467 update_str(adef.SerializeToString()) 468 469 for adef in output_arg: 470 update_str(adef.SerializeToString()) 471 472 for n in sorted(node_def, key=lambda n: n.name): 473 update_str(n.name) 474 update_str(n.op) 475 update_strs(n.input) 476 update_num(len(n.attr)) 477 # NOTE: protobuf map serialization does not guarantee ordering. 478 for k in sorted(n.attr): 479 update_str(k) 480 update_str(n.attr[k].SerializeToString()) 481 482 return hasher.hexdigest()[:8] 483 484 def add_to_graph(self, g): 485 """Adds this function into the graph g.""" 486 self._create_definition_if_needed() 487 488 # Adds this function into 'g'. 489 # pylint: disable=protected-access 490 if context.in_graph_mode(): 491 g._add_function(self) 492 else: 493 context.context().add_function_def(self.definition) 494 # pylint: enable=protected-access 495 496 # Ensures related sub-routines are defined in 'g', too. 497 for f in self._sub_functions.values(): 498 f.add_to_graph(g) 499 500 # Adds its gradient function, too. 501 if self._grad_func: 502 self._grad_func.add_to_graph(g) 503 504 def __call__(self, *args, **kwargs): 505 self.add_to_graph(ops.get_default_graph()) 506 args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs 507 ret, op = _call(self._signature, *args, **kwargs) 508 if self._shape_func is not None: 509 shapes = self._shape_func(op) 510 if len(shapes) != len(op.outputs): 511 raise ValueError("shape_func produced %d shapes for %d outputs" % 512 (len(shapes), len(op.outputs))) 513 for (t, shape) in zip(op.outputs, shapes): 514 t.set_shape(shape) 515 return ret 516 517 518class _OverloadedFunction(object): 519 """_OverloadedFunction encapsulates an overloaded function. 520 521 _OverloadedFunction maintains a mapping from input types to 522 instantiated _DefinedFunction in self._overload. 523 524 """ 525 526 def __init__(self, 527 func, 528 argnames, 529 func_name=None, 530 grad_func=None, 531 python_grad_func=None, 532 out_names=None, 533 **kwargs): 534 """Creates _DefinedFunction. 535 536 Args: 537 func: A python callable which constructs a tf function body. 538 argnames: A list of strings for function argument names. 539 func_name: The function name. Defaults to None, in which derives from 540 'func'. 541 grad_func: This function's gradient function, if not None. Defaults 542 to None. 543 python_grad_func: A python callable implementing the gradient of 544 the function python-side. 545 out_names: A list of strings for the function return value names. 546 **kwargs: The keyword arguments. **kwargs is passed to every call 547 site of this function. 548 549 Raises: 550 ValueError: The function definition is invalid. 551 552 """ 553 self._func = func 554 self._argnames = argnames 555 self._func_name = func_name 556 assert grad_func is None or isinstance(grad_func, _OverloadedFunction) 557 self._grad_func = grad_func 558 self._python_grad_func = python_grad_func 559 self._out_names = out_names 560 self._extra_kwargs = kwargs 561 self._overload = {} 562 563 def instantiate(self, input_types): 564 """Instantiate this function given input argument types. 565 566 Args: 567 input_types: A list of data types for the inputs. 568 569 Returns: 570 _DefinedFunction for the given input types. 571 572 """ 573 # Stringify the type list. 574 key = _type_list_to_str(input_types) 575 defined = self._overload.get(key) 576 if not defined: 577 # If not defined yet, define the function given the input types. 578 name = self._func_name 579 if name is not None: 580 name = "_".join([name, key]) 581 defined = _DefinedFunction( 582 self._func, 583 self._argnames, 584 input_types, 585 name, 586 None, 587 self._python_grad_func, 588 out_names=self._out_names, 589 **self._extra_kwargs) 590 _ = defined.name # Fully instantiate the function definition. 591 if self._grad_func: 592 # If _grad_func is given, it is another 593 # _OverloadedFunction. We need to instantiate it with the 594 # right input types. 595 output_types = [ 596 dtypes.DType(_.type) 597 for _ in defined._signature.output_arg # pylint: disable=protected-access 598 ] 599 # pylint: disable=protected-access 600 defined._grad_func = self._grad_func.instantiate( 601 input_types + output_types) 602 # pylint: enable=protected-access 603 self._overload[key] = defined 604 return defined 605 606 def __call__(self, *args, **kwargs): 607 input_types = [] 608 args = list(args) 609 for (i, x) in enumerate(args): 610 x = ops.convert_to_tensor(x) 611 if not isinstance(x, ops.Tensor): 612 raise ValueError("Expect a Tensor but get ", x) 613 input_types.append(x.dtype) 614 args[i] = x 615 return self.instantiate(input_types)(*args, **kwargs) 616 617 618class _FuncGraph(ops.Graph): 619 """A helper for constructing a function. 620 621 _FuncGraph overrides ops.Graph's create_op() so that we can keep 622 track of all inputs into every op created inside the function. If 623 any input is from other graphs, we keep track of it in self.capture 624 and substitute the input with a place holder. 625 626 Each captured input's corresponding place holder is converted into a 627 function argument and the caller passes in the captured tensor. 628 """ 629 630 def __init__(self, capture_by_value, *args, **kwargs): 631 super(_FuncGraph, self).__init__(*args, **kwargs) 632 self._capture_by_value = capture_by_value 633 self._building_function = True 634 self._outer_graph = ops.get_default_graph() 635 self._vscope = vs.get_variable_scope() 636 self._old_custom_getter = self._vscope.custom_getter 637 self._captured = {} 638 self.extra_inputs = [] 639 self.extra_args = [] 640 self.extra_vars = [] 641 642 def getvar( 643 self, 644 getter, 645 name, 646 shape=None, 647 dtype=None, 648 initializer=None, 649 reuse=None, 650 trainable=True, 651 collections=None, # pylint: disable=redefined-outer-name 652 use_resource=None, 653 **kwargs): 654 """A custom variable getter.""" 655 # Here, we switch the default graph to the outer graph and ask the 656 # variable scope in which the function is defined to give us the 657 # variable. The variable is stashed in extra_vars and returned to 658 # the caller. 659 # 660 # We capture these variables so that the variable definition is 661 # hoisted upward to the outer most graph. 662 with self._outer_graph.as_default(): 663 # pylint: disable=protected-access 664 var = self._vscope.get_variable( 665 vs._get_default_variable_store(), 666 name, 667 shape=shape, 668 dtype=dtype, 669 initializer=initializer, 670 reuse=reuse, 671 trainable=trainable, 672 collections=collections, 673 use_resource=use_resource) 674 self.extra_vars.append(var) 675 if isinstance(var, resource_variable_ops.ResourceVariable): 676 # For resource-based variables read the variable outside the function 677 # and pass in the value. This ensures that the function is pure and 678 # differentiable. TODO(apassos) this may have performance problems if 679 # the function will only do embedding lookups on the variable. 680 return var.value() 681 return var 682 683 def create_op(self, op_type, inputs, data_types, **kwargs): 684 for i, x in enumerate(inputs): 685 if isinstance(x, ops.EagerTensor) or x.graph is not self: 686 # Referring to a tensor from other graph. 687 if x in self._captured: 688 # Captured already. 689 inputs[i] = self._captured[x] 690 elif self._capture_by_value: 691 inputs[i] = self._add_tensor_and_parents(x) 692 else: 693 # Substitute with a placeholder. 694 self.extra_inputs.append(x) 695 # Hoist the new input placeholder out of any control flow context 696 # we're currently in. 697 with ops.control_dependencies(None): 698 ph = array_ops.placeholder(x.dtype, shape=x.get_shape()) 699 # pylint: disable=protected-access 700 ph._handle_data = x._handle_data 701 # pylint: enable=protected-access 702 inputs[i] = ph 703 self._captured[x] = ph 704 self.extra_args.append(ph) 705 return super(_FuncGraph, self).create_op(op_type, inputs, data_types, 706 **kwargs) 707 708 def _add_tensor_and_parents(self, tensor): 709 op = self._add_op_and_parents(tensor.op) 710 return op.outputs[tensor.value_index] 711 712 def _add_op_and_parents(self, op): 713 # pylint: disable=protected-access 714 op_def = graph_to_function_def._get_op_def(op) 715 # pylint: enable=protected-access 716 if op_def.is_stateful: 717 raise ValueError("Cannot capture a stateful node (name:%s, type:%s) " 718 "by value." % (op.name, op.type)) 719 elif op.type in ("Placeholder", "PlaceholderV2"): 720 raise ValueError("Cannot capture a placeholder (name:%s, type:%s) " 721 "by value." % (op.name, op.type)) 722 723 captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs] 724 725 captured_op = self.create_op( 726 op.type, 727 captured_inputs, [o.dtype for o in op.outputs], 728 name=op.name, 729 attrs=op.node_def.attr, 730 op_def=op_def) 731 732 for t, captured_t in zip(op.outputs, captured_op.outputs): 733 self._captured[t] = captured_t 734 735 return captured_op 736 737 738def _call(sig, *inputs, **kwargs): 739 """Adds a node calling a function. 740 741 This adds a `call` op to the default graph that calls the function 742 of signature `sig`, passing the tensors in `inputs` as arguments. 743 It returns the outputs of the call, which are one or more tensors. 744 745 `sig` is OpDefArg.a `_DefinedFunction` object. 746 747 You can pass an optional keyword parameter `name=string` to name the 748 added operation. 749 750 You can pass an optional keyword parameter `noinline=True|False` to 751 instruct the runtime not to inline the function body into the call 752 site. 753 754 Args: 755 sig: OpDefArg. The signature of the function. 756 *inputs: arguments to the function. 757 **kwargs: Optional keyword arguments. Can only contain 'name' or 758 'noinline'. 759 760 Returns: 761 A 2-element tuple. First element: a Tensor if the function returns a single 762 value; a list of Tensors if the function returns multiple value; the 763 Operation if the function returns no values. Second element: the Operation. 764 765 Raises: 766 ValueError: if the arguments are invalid. 767 """ 768 if len(inputs) != len(sig.input_arg): 769 raise ValueError("Expected number of arguments: %d, received: %d" % 770 (len(sig.input_arg), len(inputs))) 771 name = kwargs.pop("name", None) 772 g = ops.get_default_graph() 773 func_name = sig.name 774 attrs = _parse_kwargs_as_attrs(func_name, **kwargs) 775 output_types = [dtypes.DType(x.type) for x in sig.output_arg] 776 with ops.name_scope(name, func_name, inputs) as name: 777 op = g.create_op( 778 func_name, 779 list(inputs), 780 output_types, 781 name=name, 782 attrs=attrs, 783 op_def=sig, 784 compute_shapes=False) 785 if op.outputs: 786 if len(op.outputs) == 1: 787 ret = op.outputs[0] 788 else: 789 ret = tuple(op.outputs) 790 else: 791 ret = op 792 return ret, op 793 794 795def _from_definition(fdef, grad_func=None): 796 """Creates a _DefinedFunction initialized from a FunctionDef proto. 797 798 Args: 799 fdef: a FunctionDef 800 grad_func: a _DefinedFunction or None 801 802 Returns: 803 A _DefinedFunction representing fdef 804 """ 805 # TODO(iga): This method does major surgery on _DefinedFunction. 806 # Make it a named constructor using @classmethod of _DefinedFunction. 807 808 # The Python callable is only needed to create a FunctionDef. Since we have 809 # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we 810 # have access to such a callable here). 811 func = None 812 argnames = [arg.name for arg in fdef.signature.input_arg] 813 input_types = tuple( 814 dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) 815 func_name = fdef.signature.name 816 # Note: FunctionDefs do not include python gradient functions, so if the 817 # original _DefinedFunction included one it will not be reflected here. 818 python_grad_func = None 819 out_names = [arg.name for arg in fdef.signature.output_arg] 820 result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, 821 python_grad_func, out_names) 822 # pylint: disable=protected-access 823 if ops._USE_C_API: 824 serialized = fdef.SerializeToString() 825 with errors.raise_exception_on_not_ok_status() as status: 826 result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status) 827 result._extra_inputs = [] 828 else: 829 result._definition = fdef 830 # Captured inputs are added as regular inputs to a function when it's 831 # serialized, i.e. any extra inputs from the original function are now 832 # included in `result`._args 833 result._extra_inputs = [] 834 result._hash_str = result._create_hash_str( 835 result._definition.signature.input_arg, 836 result._definition.signature.output_arg, result._definition.node_def) 837 # pylint: enable=protected-access 838 839 return result 840 841 842def _from_library(lib): 843 """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. 844 845 This method handles assigning the correct gradient functions to each 846 function. 847 848 Args: 849 lib: a FunctionDefLibrary 850 851 Returns: 852 A list of _DefinedFunctions 853 854 Raises: 855 ValueError: `lib` is invalid 856 """ 857 if not lib.function and not lib.gradient: 858 return [] 859 860 # function name -> FunctionDef proto 861 funcs = {fdef.signature.name: fdef for fdef in lib.function} 862 863 # Validate that all references function names have function defs 864 for g in lib.gradient: 865 if g.function_name not in funcs: 866 raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" % 867 (g.function_name, str(lib))) 868 if g.gradient_func not in funcs: 869 raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" % 870 (g.gradient_func, str(lib))) 871 872 # function name -> gradient function name 873 func_to_grad = collections.defaultdict(lambda: None) 874 # gradient function name -> names of functions having that grad function 875 grad_to_funcs = collections.defaultdict(list) 876 877 for gdef in lib.gradient: 878 func_to_grad[gdef.function_name] = gdef.gradient_func 879 grad_to_funcs[gdef.gradient_func].append(gdef.function_name) 880 881 # Start with functions without gradients 882 ready = [ 883 fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None 884 ] 885 if not ready: 886 raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n" 887 + str(lib)) 888 # function name -> _DefinedFunction 889 initialized = {} 890 891 while ready: 892 fdef = ready.pop() 893 name = fdef.signature.name 894 895 grad = initialized.get(func_to_grad[name]) 896 if func_to_grad[name]: 897 assert grad 898 defined_func = _from_definition(fdef, grad_func=grad) 899 initialized[name] = defined_func 900 901 ready.extend(funcs[f] for f in grad_to_funcs[name]) 902 903 return initialized.values() 904 905 906def _parse_kwargs_as_attrs(func_name, **kwargs): 907 """Parses **kwargs into a node's attributes.""" 908 attrs = {} 909 910 noinline = kwargs.pop("noinline", None) 911 if noinline is not None: 912 attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) 913 914 compiled = kwargs.pop("compiled", None) 915 separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None) 916 if compiled is not None: 917 attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled)) 918 attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue( 919 b=bool(separate_compiled_gradients)) 920 # Forward _XlaScope from enclosing context (if set), otherwise create new. 921 # pylint: disable=protected-access 922 if "_XlaScope" in ops.get_default_graph()._attr_scope_map: 923 attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"] 924 else: 925 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 926 s=("function_%s" % func_name).encode()) 927 # pylint: enable=protected-access 928 929 if kwargs: 930 raise ValueError("Unknown keyword arguments: %s" % kwargs.keys()) 931 return attrs 932 933 934def _get_func_name(func): 935 _, func = tf_decorator.unwrap(func) 936 if callable(func): 937 if tf_inspect.isfunction(func): 938 return func.__name__ 939 elif tf_inspect.ismethod(func): 940 return "%s.%s" % (func.__self__.__name__, func.__name__) 941 else: # Probably a class instance with __call__ 942 return type(func) 943 else: 944 raise ValueError("Argument must be callable") 945 946 947def get_extra_vars(): 948 """Returns the captured variables by the function. 949 950 Returns: 951 If the default graph is being used to define a function, the 952 returned list of variables are those created inside the function 953 body so far. Otherwise, returns an empty list. 954 """ 955 g = ops.get_default_graph() 956 if isinstance(g, _FuncGraph): 957 return g.extra_vars 958 else: 959 return [] 960 961 962def get_extra_inputs(): 963 """Returns the captured input tensors by the function. 964 965 Returns: 966 If the default graph is being used to define a function, the 967 returned list of tensors are those accessed inside the function body 968 but defined outside the function body so far. Otherwise, returns an 969 empty list. 970 """ 971 g = ops.get_default_graph() 972 if isinstance(g, _FuncGraph): 973 return g.extra_inputs 974 else: 975 return [] 976 977 978def get_extra_args(): 979 """Returns the corresponding function arguments for the captured inputs. 980 981 Returns: 982 If the default graph is being used to define a function, the 983 returned list of place holders are those used inside the function 984 body corresponding those returned by get_extra_inputs(). Otherwise, 985 returns an empty list. 986 """ 987 g = ops.get_default_graph() 988 if isinstance(g, _FuncGraph): 989 return g.extra_args 990 else: 991 return [] 992 993 994def _type_list_to_str(types): 995 if any([_ not in _DTYPE_TO_STR for _ in types]): 996 raise ValueError("Unsupported dtypes: %s" % types) 997 return "".join([_DTYPE_TO_STR[_] for _ in types]) 998 999 1000# NOTE: The list needs to be extended when more data types are added. 1001_DTYPE_TO_STR = { 1002 dtypes.float16: "f16", 1003 dtypes.float32: "f32", 1004 dtypes.float64: "f64", 1005 dtypes.int32: "i32", 1006 dtypes.uint8: "i8", 1007 dtypes.uint16: "u16", 1008 dtypes.uint32: "u32", 1009 dtypes.uint64: "u64", 1010 dtypes.int16: "i16", 1011 dtypes.int8: "i8", 1012 dtypes.string: "s", 1013 dtypes.complex64: "c64", 1014 dtypes.complex128: "c128", 1015 dtypes.int64: "i64", 1016 dtypes.bool: "b", 1017 dtypes.qint8: "qi8", 1018 dtypes.quint8: "qu8", 1019 dtypes.qint16: "qi16", 1020 dtypes.quint16: "qu16", 1021 dtypes.qint32: "qi32", 1022 dtypes.bfloat16: "b16" 1023} 1024