1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""XLA utility functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.util import tf_inspect 24 25 26def is_flat(outputs): 27 """Checks if outputs is a flat structure. 28 29 Following structures and values are considered flat: 30 1) None 31 2) A single object 32 3) A list or tuple of Tensors/Operations 33 34 The only structures that this function understands are sequences and 35 dictionaries. E.g. this means that if outputs contains a single 36 user-defined Object, it is considered to be flat. Errors are raised later on 37 if that Object cannot be converted to a Tensor. 38 39 Args: 40 outputs: Output from `computation` inside `xla.compile`. 41 42 Returns: 43 A boolean indicates whether outputs is flat. 44 """ 45 # If outputs is a list or tuple, check if it has any nested structure. If 46 # there is, then outputs is non-flat. 47 if isinstance(outputs, collections.Sequence): 48 for o in outputs: 49 if isinstance(o, collections.Sequence) or isinstance(o, dict): 50 return False 51 52 # If outputs is a dict, it is non-flat. 53 if isinstance(outputs, dict): 54 return False 55 56 # Getting here means either outputs itself is a single non-structured value 57 # or it is a flat list of single non-structured values. 58 return True 59 60 61def check_function_argument_count(func, input_arity, infeed_queue): 62 """Validate the number of input arguments to an XLA function. 63 64 Args: 65 func: the Python function that will be called to generate the body of an XLA 66 computation graph. 67 input_arity: the number of explicit arguments supplied by the caller. 68 infeed_queue: if not None, the infeed queue that will supply 69 additional arguments to the function. 70 71 Returns: 72 None if function can be called with the supplied number of 73 arguments, or an error string if it cannot. 74 """ 75 def format_error(complaint, quantity): 76 return '%s %d argument%s' % (complaint, quantity, '' 77 if quantity == 1 else 's') 78 79 num_args_supplied = input_arity 80 if infeed_queue is not None: 81 num_args_supplied += infeed_queue.number_of_tuple_elements 82 arg_spec = tf_inspect.getargspec(func) 83 num_func_args = len(arg_spec.args) 84 if arg_spec.defaults is None: 85 num_func_defaults = 0 86 else: 87 num_func_defaults = len(arg_spec.defaults) 88 min_func_args = num_func_args - num_func_defaults 89 if num_args_supplied < min_func_args: 90 # The required number of arguments is not enough to call the function. 91 if num_func_defaults == 0 and arg_spec.varargs is None: 92 return format_error('exactly', num_func_args) 93 else: 94 return format_error('at least', min_func_args) 95 if arg_spec.varargs is None and num_args_supplied > num_func_args: 96 # The required number of arguments is too many to call the function. 97 if num_func_defaults == 0: 98 return format_error('exactly', num_func_args) 99 else: 100 return format_error('at most', num_func_args) 101 # Reaching here means either 102 # 1) There are varargs, func can accept any number of arguments greater than 103 # the minimum. 104 # 2) Number of supplied arguments falls in range of acceptable argument count 105 # of func. 106 return None 107