• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators corresponding to Python builtin functions.
16
17List of built-in functions: https://docs.python.org/3/library/functions.html
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import six
25
26from tensorflow.python.autograph.utils import py_func
27from tensorflow.python.autograph.utils import tensors
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import gen_parsing_ops
35from tensorflow.python.ops import gen_string_ops
36from tensorflow.python.ops import list_ops
37from tensorflow.python.ops import math_ops
38
39
40UNSPECIFIED = object()
41
42
43def overload_of(f):
44  if f in SUPPORTED_BUILTINS:
45    return BUILTIN_FUINCTIONS_MAP[f.__name__]
46  return f
47
48
49def abs_(x):
50  if tensor_util.is_tensor(x):
51    return _tf_abs(x)
52  return _py_abs(x)
53
54
55def _tf_abs(x):
56  return math_ops.abs(x)
57
58
59def _py_abs(x):
60  return abs(x)
61
62
63def float_(x=0):
64  if tensor_util.is_tensor(x):
65    return _tf_float(x)
66  return _py_float(x)
67
68
69def _tf_float(x):
70  # TODO(mdan): We shouldn't assume float32.
71  if x.dtype == dtypes.string:
72    return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
73  return math_ops.cast(x, dtype=dtypes.float32)
74
75
76def _py_float(x):
77  return float(x)
78
79
80def int_(x=0, base=UNSPECIFIED):
81  if tensor_util.is_tensor(x):
82    return _tf_int(x, base)
83  return _py_int(x, base)
84
85
86def _tf_int(x, base):
87  if base not in (10, UNSPECIFIED):
88    raise NotImplementedError('base {} not supported for int'.format(base))
89
90  # TODO(mdan): We shouldn't assume int32.
91  if x.dtype == dtypes.string:
92    return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
93  return math_ops.cast(x, dtype=dtypes.int32)
94
95
96def _py_int(x, base):
97  if base is UNSPECIFIED:
98    return int(x)
99  return int(x, base)
100
101
102def len_(s):
103  if tensors.is_tensor_array(s):
104    return _tf_tensor_array_len(s)
105  elif tensors.is_tensor_list(s):
106    return _tf_tensor_list_len(s)
107  elif tensor_util.is_tensor(s):
108    return _tf_tensor_len(s)
109  return _py_len(s)
110
111
112def _tf_tensor_array_len(s):
113  return s.size()
114
115
116def _tf_tensor_list_len(s):
117  return list_ops.tensor_list_length(s)
118
119
120def _tf_tensor_len(s):
121  """Overload of len_ for Tensor arguments."""
122  # Statically shaped tensors: length is known ahead of time.
123  if s.shape.ndims and s.shape.dims[0].value is not None:
124    return s.shape.dims[0].value
125
126  # Static shape of unknown dimensions: use dynamic shape but statically
127  # chech that it's a scalar.
128  shape = array_ops.shape(s)
129
130  assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
131
132  if shape.shape[0] == 0:
133    raise ValueError(
134        'len requires a non-scalar tensor, got one of shape {}'.format(shape))
135
136  if shape.shape.dims[0].value is not None:
137    return array_ops.shape(s)[0]
138
139  # Fully dynamic shape: use ops.
140  rank = array_ops.rank(s)
141
142  def raise_zero_rank_error():
143    msg = gen_string_ops.string_join(
144        ['len requires non-zero rank, got ',
145         gen_string_ops.as_string(rank)])
146    with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
147      return constant_op.constant(0, dtype=dtypes.int32)
148
149  return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
150                               raise_zero_rank_error)
151
152
153def _py_len(s):
154  return len(s)
155
156
157def print_(*objects, **kwargs):
158  """Overload of the print builtin."""
159  # Note: Python 2.6 doesn't support explicit keywords after starargs.
160  unknown_kwargs = tuple(
161      set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
162  if unknown_kwargs:
163    raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
164
165  # TODO(mdan): Use next.flatten(objects) instead?
166  if any(tensor_util.is_tensor(o) for o in objects):
167    # TODO(mdan): use tf.print instead.
168    return _tf_py_func_print(objects, kwargs)
169  else:
170    _py_print(*objects, **kwargs)
171
172
173def _py_print(*objects, **kwargs):
174  print(*objects, **kwargs)
175
176
177def _tf_py_func_print(objects, kwargs):
178  """Overload of print_ as a py_func implementation."""
179  override_kwargs = {k: v for k, v in kwargs.items() if v is not UNSPECIFIED}
180  if 'flush' not in override_kwargs:
181    # Defaulting to flushing the console in graph mode, which helps reduce
182    # garbled output in IPython.
183    override_kwargs['flush'] = True
184
185  def print_wrapper(*vals):
186    vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals)
187    if six.PY3:
188      # TensorFlow doesn't seem to generate Unicode when passing strings to
189      # py_func. This causes the print to add a "b'" wrapper to the output,
190      # which is probably never what you want.
191      vals = tuple(
192          v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
193    six.print_(*vals, **override_kwargs)
194
195  return py_func.wrap_py_func(
196      print_wrapper, None, objects, use_dummy_return=True)
197
198
199def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED):
200  if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
201    return _tf_range(start_or_stop, stop, step)
202  return _py_range(start_or_stop, stop, step)
203
204
205def _tf_range(start_or_stop, stop, step):
206  """Overload of range_ that generates a TF range tensor."""
207  # Note: for static inputs (e.g. constants), tf.range errors out at graph
208  # construction time, instead of returning an empty tensor. Preventing the
209  # graph construction error aligns the semantics with Python.
210
211  # TODO(mdan): We should optimize this when a full tensor is not required.
212  if step is not UNSPECIFIED:
213    # TODO(mdan): Add argument coercion similar to other cases.
214    return math_ops.range(start_or_stop, stop, step)
215  if stop is not UNSPECIFIED:
216    stop = math_ops.maximum(start_or_stop, stop)
217    return math_ops.range(start_or_stop, stop)
218  start_or_stop = math_ops.maximum(start_or_stop, 0)
219  return math_ops.range(start_or_stop)
220
221
222def _py_range(start_or_stop, stop, step):
223  if step is not UNSPECIFIED:
224    return range(start_or_stop, stop, step)
225  if stop is not UNSPECIFIED:
226    return range(start_or_stop, stop)
227  return range(start_or_stop)
228
229
230SUPPORTED_BUILTINS = (abs, float, int, len, print, range)
231
232if six.PY2:
233  SUPPORTED_BUILTINS += (xrange,)
234
235BUILTIN_FUINCTIONS_MAP = {
236    'abs': abs_,
237    'float': float_,
238    'int': int_,
239    'len': len_,
240    'print': print_,
241    'range': range_,
242    # TODO(mdan): This might make more sense as tf.data.range.
243    'xrange': range_,
244}
245