1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to retrieve function args.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23import six 24 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.python.util import tf_decorator 27from tensorflow.python.util import tf_inspect 28 29 30def _is_bound_method(fn): 31 _, fn = tf_decorator.unwrap(fn) 32 return tf_inspect.ismethod(fn) and (fn.__self__ is not None) 33 34 35def _is_callable_object(obj): 36 return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__) 37 38 39def fn_args(fn): 40 """Get argument names for function-like object. 41 42 Args: 43 fn: Function, or function-like object (e.g., result of `functools.partial`). 44 45 Returns: 46 `tuple` of string argument names. 47 48 Raises: 49 ValueError: if partial function has positionally bound arguments 50 """ 51 if isinstance(fn, functools.partial): 52 args = fn_args(fn.func) 53 args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] 54 else: 55 if _is_callable_object(fn): 56 fn = fn.__call__ 57 args = tf_inspect.getfullargspec(fn).args 58 if _is_bound_method(fn) and args: 59 # If it's a bound method, it may or may not have a self/cls first 60 # argument; for example, self could be captured in *args. 61 # If it does have a positional argument, it is self/cls. 62 args.pop(0) 63 return tuple(args) 64 65 66def has_kwargs(fn): 67 """Returns whether the passed callable has **kwargs in its signature. 68 69 Args: 70 fn: Function, or function-like object (e.g., result of `functools.partial`). 71 72 Returns: 73 `bool`: if `fn` has **kwargs in its signature. 74 75 Raises: 76 `TypeError`: If fn is not a Function, or function-like object. 77 """ 78 if isinstance(fn, functools.partial): 79 fn = fn.func 80 elif _is_callable_object(fn): 81 fn = fn.__call__ 82 elif not callable(fn): 83 raise TypeError( 84 'fn should be a function-like object, but is of type {}.'.format( 85 type(fn))) 86 return tf_inspect.getfullargspec(fn).varkw is not None 87 88 89def get_func_name(func): 90 """Returns name of passed callable.""" 91 _, func = tf_decorator.unwrap(func) 92 if callable(func): 93 if tf_inspect.isfunction(func): 94 return func.__name__ 95 elif tf_inspect.ismethod(func): 96 return '%s.%s' % (six.get_method_self(func).__class__.__name__, 97 six.get_method_function(func).__name__) 98 else: # Probably a class instance with __call__ 99 return str(type(func)) 100 else: 101 raise ValueError('Argument must be callable') 102 103 104def get_func_code(func): 105 """Returns func_code of passed callable, or None if not available.""" 106 _, func = tf_decorator.unwrap(func) 107 if callable(func): 108 if tf_inspect.isfunction(func) or tf_inspect.ismethod(func): 109 return six.get_function_code(func) 110 # Since the object is not a function or method, but is a callable, we will 111 # try to access the __call__method as a function. This works with callable 112 # classes but fails with functool.partial objects despite their __call__ 113 # attribute. 114 try: 115 return six.get_function_code(func.__call__) 116 except AttributeError: 117 return None 118 else: 119 raise ValueError('Argument must be callable') 120 121 122_rewriter_config_optimizer_disabled = None 123 124 125def get_disabled_rewriter_config(): 126 global _rewriter_config_optimizer_disabled 127 if _rewriter_config_optimizer_disabled is None: 128 config = config_pb2.ConfigProto() 129 rewriter_config = config.graph_options.rewrite_options 130 rewriter_config.disable_meta_optimizer = True 131 _rewriter_config_optimizer_disabled = config.SerializeToString() 132 return _rewriter_config_optimizer_disabled 133