1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to retrieve function args.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23import six 24 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.python.util import tf_decorator 27from tensorflow.python.util import tf_inspect 28 29 30def _is_bounded_method(fn): 31 _, fn = tf_decorator.unwrap(fn) 32 return tf_inspect.ismethod(fn) and (fn.__self__ is not None) 33 34 35def _is_callable_object(obj): 36 return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__) 37 38 39def fn_args(fn): 40 """Get argument names for function-like object. 41 42 Args: 43 fn: Function, or function-like object (e.g., result of `functools.partial`). 44 45 Returns: 46 `tuple` of string argument names. 47 48 Raises: 49 ValueError: if partial function has positionally bound arguments 50 """ 51 if isinstance(fn, functools.partial): 52 args = fn_args(fn.func) 53 args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] 54 else: 55 if _is_callable_object(fn): 56 fn = fn.__call__ 57 args = tf_inspect.getfullargspec(fn).args 58 if _is_bounded_method(fn): 59 args.remove('self') 60 return tuple(args) 61 62 63def has_kwargs(fn): 64 """Returns whether the passed callable has **kwargs in its signature. 65 66 Args: 67 fn: Function, or function-like object (e.g., result of `functools.partial`). 68 69 Returns: 70 `bool`: if `fn` has **kwargs in its signature. 71 72 Raises: 73 `TypeError`: If fn is not a Function, or function-like object. 74 """ 75 if isinstance(fn, functools.partial): 76 fn = fn.func 77 elif _is_callable_object(fn): 78 fn = fn.__call__ 79 elif not callable(fn): 80 raise TypeError( 81 'fn should be a function-like object, but is of type {}.'.format( 82 type(fn))) 83 return tf_inspect.getfullargspec(fn).varkw is not None 84 85 86def get_func_name(func): 87 """Returns name of passed callable.""" 88 _, func = tf_decorator.unwrap(func) 89 if callable(func): 90 if tf_inspect.isfunction(func): 91 return func.__name__ 92 elif tf_inspect.ismethod(func): 93 return '%s.%s' % (six.get_method_self(func).__class__.__name__, 94 six.get_method_function(func).__name__) 95 else: # Probably a class instance with __call__ 96 return str(type(func)) 97 else: 98 raise ValueError('Argument must be callable') 99 100 101def get_func_code(func): 102 """Returns func_code of passed callable, or None if not available.""" 103 _, func = tf_decorator.unwrap(func) 104 if callable(func): 105 if tf_inspect.isfunction(func) or tf_inspect.ismethod(func): 106 return six.get_function_code(func) 107 # Since the object is not a function or method, but is a callable, we will 108 # try to access the __call__method as a function. This works with callable 109 # classes but fails with functool.partial objects despite their __call__ 110 # attribute. 111 try: 112 return six.get_function_code(func.__call__) 113 except AttributeError: 114 return None 115 else: 116 raise ValueError('Argument must be callable') 117 118 119_rewriter_config_optimizer_disabled = None 120 121 122def get_disabled_rewriter_config(): 123 global _rewriter_config_optimizer_disabled 124 if _rewriter_config_optimizer_disabled is None: 125 config = config_pb2.ConfigProto() 126 rewriter_config = config.graph_options.rewrite_options 127 rewriter_config.disable_meta_optimizer = True 128 _rewriter_config_optimizer_disabled = config.SerializeToString() 129 return _rewriter_config_optimizer_disabled 130