1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operators corresponding to Python builtin functions. 16 17List of built-in functions: https://docs.python.org/3/library/functions.html 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import six 25 26from tensorflow.python.autograph.utils import py_func 27from tensorflow.python.autograph.utils import tensors 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import gen_parsing_ops 35from tensorflow.python.ops import gen_string_ops 36from tensorflow.python.ops import list_ops 37from tensorflow.python.ops import math_ops 38 39 40UNSPECIFIED = object() 41 42 43def overload_of(f): 44 if f in SUPPORTED_BUILTINS: 45 return BUILTIN_FUINCTIONS_MAP[f.__name__] 46 return f 47 48 49def abs_(x): 50 if tensor_util.is_tensor(x): 51 return _tf_abs(x) 52 return _py_abs(x) 53 54 55def _tf_abs(x): 56 return math_ops.abs(x) 57 58 59def _py_abs(x): 60 return abs(x) 61 62 63def float_(x=0): 64 if tensor_util.is_tensor(x): 65 return _tf_float(x) 66 return _py_float(x) 67 68 69def _tf_float(x): 70 # TODO(mdan): We shouldn't assume float32. 71 if x.dtype == dtypes.string: 72 return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32) 73 return math_ops.cast(x, dtype=dtypes.float32) 74 75 76def _py_float(x): 77 return float(x) 78 79 80def int_(x=0, base=UNSPECIFIED): 81 if tensor_util.is_tensor(x): 82 return _tf_int(x, base) 83 return _py_int(x, base) 84 85 86def _tf_int(x, base): 87 if base not in (10, UNSPECIFIED): 88 raise NotImplementedError('base {} not supported for int'.format(base)) 89 90 # TODO(mdan): We shouldn't assume int32. 91 if x.dtype == dtypes.string: 92 return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32) 93 return math_ops.cast(x, dtype=dtypes.int32) 94 95 96def _py_int(x, base): 97 if base is UNSPECIFIED: 98 return int(x) 99 return int(x, base) 100 101 102def len_(s): 103 if tensors.is_tensor_array(s): 104 return _tf_tensor_array_len(s) 105 elif tensors.is_tensor_list(s): 106 return _tf_tensor_list_len(s) 107 elif tensor_util.is_tensor(s): 108 return _tf_tensor_len(s) 109 return _py_len(s) 110 111 112def _tf_tensor_array_len(s): 113 return s.size() 114 115 116def _tf_tensor_list_len(s): 117 return list_ops.tensor_list_length(s) 118 119 120def _tf_tensor_len(s): 121 """Overload of len_ for Tensor arguments.""" 122 # Statically shaped tensors: length is known ahead of time. 123 if s.shape.ndims and s.shape.dims[0].value is not None: 124 return s.shape.dims[0].value 125 126 # Static shape of unknown dimensions: use dynamic shape but statically 127 # chech that it's a scalar. 128 shape = array_ops.shape(s) 129 130 assert shape.shape, 'shape tensor of zero size? {}'.format(shape) 131 132 if shape.shape[0] == 0: 133 raise ValueError( 134 'len requires a non-scalar tensor, got one of shape {}'.format(shape)) 135 136 if shape.shape.dims[0].value is not None: 137 return array_ops.shape(s)[0] 138 139 # Fully dynamic shape: use ops. 140 rank = array_ops.rank(s) 141 142 def raise_zero_rank_error(): 143 msg = gen_string_ops.string_join( 144 ['len requires non-zero rank, got ', 145 gen_string_ops.as_string(rank)]) 146 with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]): 147 return constant_op.constant(0, dtype=dtypes.int32) 148 149 return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0], 150 raise_zero_rank_error) 151 152 153def _py_len(s): 154 return len(s) 155 156 157def print_(*objects, **kwargs): 158 """Overload of the print builtin.""" 159 # Note: Python 2.6 doesn't support explicit keywords after starargs. 160 unknown_kwargs = tuple( 161 set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush'))) 162 if unknown_kwargs: 163 raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs)) 164 165 # TODO(mdan): Use next.flatten(objects) instead? 166 if any(tensor_util.is_tensor(o) for o in objects): 167 # TODO(mdan): use tf.print instead. 168 return _tf_py_func_print(objects, kwargs) 169 else: 170 _py_print(*objects, **kwargs) 171 172 173def _py_print(*objects, **kwargs): 174 print(*objects, **kwargs) 175 176 177def _tf_py_func_print(objects, kwargs): 178 """Overload of print_ as a py_func implementation.""" 179 override_kwargs = {k: v for k, v in kwargs.items() if v is not UNSPECIFIED} 180 if 'flush' not in override_kwargs: 181 # Defaulting to flushing the console in graph mode, which helps reduce 182 # garbled output in IPython. 183 override_kwargs['flush'] = True 184 185 def print_wrapper(*vals): 186 vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals) 187 if six.PY3: 188 # TensorFlow doesn't seem to generate Unicode when passing strings to 189 # py_func. This causes the print to add a "b'" wrapper to the output, 190 # which is probably never what you want. 191 vals = tuple( 192 v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) 193 six.print_(*vals, **override_kwargs) 194 195 return py_func.wrap_py_func( 196 print_wrapper, None, objects, use_dummy_return=True) 197 198 199def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED): 200 if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)): 201 return _tf_range(start_or_stop, stop, step) 202 return _py_range(start_or_stop, stop, step) 203 204 205def _tf_range(start_or_stop, stop, step): 206 """Overload of range_ that generates a TF range tensor.""" 207 # Note: for static inputs (e.g. constants), tf.range errors out at graph 208 # construction time, instead of returning an empty tensor. Preventing the 209 # graph construction error aligns the semantics with Python. 210 211 # TODO(mdan): We should optimize this when a full tensor is not required. 212 if step is not UNSPECIFIED: 213 # TODO(mdan): Add argument coercion similar to other cases. 214 return math_ops.range(start_or_stop, stop, step) 215 if stop is not UNSPECIFIED: 216 stop = math_ops.maximum(start_or_stop, stop) 217 return math_ops.range(start_or_stop, stop) 218 start_or_stop = math_ops.maximum(start_or_stop, 0) 219 return math_ops.range(start_or_stop) 220 221 222def _py_range(start_or_stop, stop, step): 223 if step is not UNSPECIFIED: 224 return range(start_or_stop, stop, step) 225 if stop is not UNSPECIFIED: 226 return range(start_or_stop, stop) 227 return range(start_or_stop) 228 229 230SUPPORTED_BUILTINS = (abs, float, int, len, print, range) 231 232if six.PY2: 233 SUPPORTED_BUILTINS += (xrange,) 234 235BUILTIN_FUINCTIONS_MAP = { 236 'abs': abs_, 237 'float': float_, 238 'int': int_, 239 'len': len_, 240 'print': print_, 241 'range': range_, 242 # TODO(mdan): This might make more sense as tf.data.range. 243 'xrange': range_, 244} 245