• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators corresponding to Python builtin functions.
16
17List of built-in functions: https://docs.python.org/3/library/functions.html
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import functools
25import inspect
26
27import numpy as np
28import six
29
30from tensorflow.python.autograph.utils import py_func
31from tensorflow.python.autograph.utils import tensors
32from tensorflow.python.data.experimental.ops import cardinality
33from tensorflow.python.data.ops import dataset_ops
34from tensorflow.python.data.ops import iterator_ops
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import check_ops
42from tensorflow.python.ops import control_flow_ops
43from tensorflow.python.ops import gen_parsing_ops
44from tensorflow.python.ops import gen_string_ops
45from tensorflow.python.ops import list_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import sort_ops
48from tensorflow.python.util import lazy_loader
49from tensorflow.python.util import nest
50
51
52# TODO(b/145618471): Remove this dependency.
53# Lazy import to work around circular dependencies
54input_lib = lazy_loader.LazyLoader(
55    'input_lib', globals(),
56    'tensorflow.python.distribute.input_lib')
57parallel_ops = lazy_loader.LazyLoader(
58    'parallel_ops', globals(),
59    'tensorflow.python.ops.parallel_for.control_flow_ops')
60
61
62UNSPECIFIED = object()
63
64
65def overload_of(f):
66  if f in SUPPORTED_BUILTINS:
67    return BUILTIN_FUNCTIONS_MAP[f.__name__]
68  return f
69
70
71def _find_originating_frame(caller_fn_scope, innermost=True):
72  """Locates the frame in which `caller_fn_scope` was defined."""
73  ctx_frame = inspect.currentframe()
74  result = None
75  while ctx_frame is not None:
76    # Note it should not be normally possible to get false positives this way
77    # because the function scope object is not accessible to user code (barring
78    # call stack introspection).
79    if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope:
80      result = ctx_frame
81      if innermost:
82        break
83    ctx_frame = ctx_frame.f_back
84
85  assert result is not None, (
86      'the conversion process should ensure the caller_fn_scope is always'
87      ' found somewhere on the call stack')
88
89  return result
90
91
92def locals_in_original_context(caller_fn_scope):
93  """Executes the locals function in the context of a specified function."""
94  return _find_originating_frame(caller_fn_scope, innermost=True).f_locals
95
96
97def globals_in_original_context(caller_fn_scope):
98  """Executes the locals function in the context of a specified function."""
99  return _find_originating_frame(caller_fn_scope, innermost=True).f_globals
100
101
102def eval_in_original_context(f, args, caller_fn_scope):
103  """Executes the eval function in the context of a specified function."""
104  # When control flow is rewritten using functions, eval should use the
105  # variables found in the same block where it was called. That is equivalent
106  # to the innermost function call.
107  ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True)
108
109  args = (
110      args[0],
111      ctx_frame.f_globals if len(args) < 2 else args[1],
112      ctx_frame.f_locals if len(args) < 3 else args[2],
113  )
114  return f(*args)
115
116
117def super_in_original_context(f, args, caller_fn_scope):
118  """Executes the super function in the context of a specified function.
119
120  See https://docs.python.org/3/library/functions.html#super for the exact
121  details
122
123  Args:
124    f: Callable, typically the super builtin
125    args: List[Any], the original call arguments
126    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
127      scope of the converted function in which this call was originally made
128
129  Returns:
130    The result of calling `f` as if it was called in the frame indicated by
131      `caller_fn_scope`.
132  """
133
134  # Python 2 doesn't support implicit argument super variants.
135  if six.PY2:
136    return f(*args)
137
138  # Only the no-arg call is desugared.
139  if args:
140    return f(*args)
141
142  # Inner functions seem to include their closure in f_locals, so we need
143  # to find the outermost frame.
144  ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False)
145
146  # When super(..) is called without arguments, it looks for __class__ cell
147  # variable and the first argument passed in the enclosing function according
148  # to the spec https://www.python.org/dev/peps/pep-3135/ .
149  #
150  # We couldn't verify if `inspect.currentframe().f_code.co_varnames[0]` is
151  # guaranteed to be the first argument from an official doc or PEP, however,
152  # it's fairly stable and well established:
153  # - An unofficial community doc mentions it.
154  #   https://python-reference.readthedocs.io/en/latest/docs/code/varnames.html
155  # - CPython has tests checking that order, which was merged in 2008, and
156  #   unchanged since then.
157  #   https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157
158  #   https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192
159  #
160  # Note: the name can be more reliably obtained by inspecting the calling
161  # function's argspec.
162  #
163  # Even though methods can be declared using *args (def method(*args)),
164  # that pattern is disallowed by super() -- it raises super() no arguments.
165  # Method definitions using **kwargs are not allowed at all.
166  # In other words, we can always assume that self is on the first positional
167  # argument (for correct code).
168  #
169  # TODO(mdan): Consider additional checks in case the input code is incorrect.
170  # For example, the error might be cryptic compared to what super() regularly
171  # raises.
172
173  type_arg = ctx_frame.f_locals['__class__']
174  self_arg_name = ctx_frame.f_code.co_varnames[0]
175  self_arg = ctx_frame.f_locals[self_arg_name]
176  return f(type_arg, self_arg)
177
178
179def abs_(x):
180  if tensor_util.is_tf_type(x):
181    return _tf_abs(x)
182  if isinstance(x, dataset_ops.DatasetV2):
183    return _tf_dataset_abs(x)
184  return _py_abs(x)
185
186
187def _tf_abs(x):
188  return math_ops.abs(x)
189
190
191def _tf_dataset_abs(x):
192  specs = nest.flatten(x.element_spec)
193  if len(specs) == 1:
194    return x.map(math_ops.abs, num_parallel_calls=dataset_ops.AUTOTUNE)
195  return x.map(
196      lambda *e: nest.map_structure(math_ops.abs, e),
197      num_parallel_calls=dataset_ops.AUTOTUNE)
198
199
200def _py_abs(x):
201  return abs(x)
202
203
204def float_(x=0):
205  if tensor_util.is_tf_type(x):
206    return _tf_float(x)
207  return _py_float(x)
208
209
210def _tf_float(x):
211  # TODO(mdan): We shouldn't assume float32.
212  if x.dtype == dtypes.string:
213    return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
214  return math_ops.cast(x, dtype=dtypes.float32)
215
216
217def _py_float(x):
218  return float(x)
219
220
221def int_(x=0, base=UNSPECIFIED):
222  if tensor_util.is_tf_type(x):
223    return _tf_int(x, base)
224  return _py_int(x, base)
225
226
227def _tf_int(x, base):
228  if base not in (10, UNSPECIFIED):
229    raise NotImplementedError('base {} not supported for int'.format(base))
230
231  # TODO(mdan): We shouldn't assume int32.
232  if x.dtype == dtypes.string:
233    return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
234  return math_ops.cast(x, dtype=dtypes.int32)
235
236
237def _py_int(x, base):
238  if base is UNSPECIFIED:
239    return int(x)
240  return int(x, base)
241
242
243def len_(s):
244  if tensors.is_tensor_array(s):
245    return _tf_tensor_array_len(s)
246  elif tensors.is_tensor_list(s):
247    return _tf_tensor_list_len(s)
248  elif tensor_util.is_tf_type(s):
249    return _tf_tensor_len(s)
250  if isinstance(s, dataset_ops.DatasetV2):
251    return _tf_dataset_len(s)
252  return _py_len(s)
253
254
255def _tf_tensor_array_len(s):
256  return s.size()
257
258
259def _tf_tensor_list_len(s):
260  return list_ops.tensor_list_length(s)
261
262
263def _tf_tensor_len(s):
264  """Overload of len_ for Tensor arguments."""
265  # Statically shaped tensors: length is known ahead of time.
266  if s.shape.ndims and s.shape.dims[0].value is not None:
267    return s.shape.dims[0].value
268
269  # Static shape of unknown dimensions: use dynamic shape but statically
270  # check that it's a scalar.
271  shape = array_ops.shape(s)
272
273  assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
274
275  if shape.shape[0] == 0:
276    raise ValueError(
277        'len requires a non-scalar tensor, got one of shape {}'.format(shape))
278
279  if shape.shape.dims[0].value is not None:
280    return array_ops.shape(s)[0]
281
282  # Fully dynamic shape: use ops.
283  rank = array_ops.rank(s)
284
285  def raise_zero_rank_error():
286    msg = gen_string_ops.string_join(
287        ['len requires non-zero rank, got ',
288         gen_string_ops.as_string(rank)])
289    with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
290      return constant_op.constant(0, dtype=dtypes.int32)
291
292  return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
293                               raise_zero_rank_error)
294
295
296def _tf_dataset_len(s):
297  l = cardinality.cardinality(s)
298  msg = gen_string_ops.string_join([
299      'len requires dataset with definitive cardinality, got ',
300      gen_string_ops.as_string(l)
301  ])
302  # TODO (yongtang): UNKNOWN is treated as an error.
303  # In case there are more UNKNOWN cases for dataset, we could
304  # use dataset.reduce() to find out the length (in an expensive way).
305  with ops.control_dependencies([
306      control_flow_ops.Assert(
307          math_ops.logical_and(
308              math_ops.not_equal(l, cardinality.INFINITE),
309              math_ops.not_equal(l, cardinality.UNKNOWN)), [msg])
310  ]):
311    l = array_ops.identity(l)
312
313  return l
314
315
316def _py_len(s):
317  return len(s)
318
319
320def print_(*objects, **kwargs):
321  """Overload of the print builtin."""
322  # Note: Python 2.6 doesn't support explicit keywords after starargs.
323  unknown_kwargs = tuple(
324      set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
325  if unknown_kwargs:
326    raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
327
328  # TODO(mdan): Use next.flatten(objects) instead?
329  if any(tensor_util.is_tf_type(o) for o in objects):
330    # TODO(mdan): use tf.print instead.
331    return _tf_py_func_print(objects, kwargs)
332  else:
333    _py_print(*objects, **kwargs)
334
335
336def _py_print(*objects, **kwargs):
337  print(*objects, **kwargs)
338
339
340def _tf_py_func_print(objects, kwargs):
341  """Overload of print_ as a py_func implementation."""
342  override_kwargs = {k: v for k, v in kwargs.items() if v is not UNSPECIFIED}
343  if 'flush' not in override_kwargs:
344    # Defaulting to flushing the console in graph mode, which helps reduce
345    # garbled output in IPython.
346    override_kwargs['flush'] = True
347
348  def print_wrapper(*vals):
349    vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals)
350    if not six.PY2:
351      # TensorFlow doesn't seem to generate Unicode when passing strings to
352      # py_func. This causes the print to add a "b'" wrapper to the output,
353      # which is probably never what you want.
354      vals = tuple(
355          v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
356    six.print_(*vals, **override_kwargs)
357
358  return py_func.wrap_py_func(
359      print_wrapper, None, objects, use_dummy_return=True)
360
361
362def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED):
363  if any(tensor_util.is_tf_type(s) for s in (start_or_stop, stop, step)):
364    return _tf_range(start_or_stop, stop, step)
365  return _py_range(start_or_stop, stop, step)
366
367
368def _tf_range(start_or_stop, stop, step):
369  """Overload of range_ that generates a TF range tensor."""
370  # Note: for static inputs (e.g. constants), tf.range errors out at graph
371  # construction time, instead of returning an empty tensor. Preventing the
372  # graph construction error aligns the semantics with Python.
373
374  # TODO(mdan): We should optimize this when a full tensor is not required.
375  if step is not UNSPECIFIED:
376    # TODO(mdan): Add argument coercion similar to other cases.
377    return math_ops.range(start_or_stop, stop, step)
378  if stop is not UNSPECIFIED:
379    stop = math_ops.maximum(start_or_stop, stop)
380    return math_ops.range(start_or_stop, stop)
381  start_or_stop = math_ops.maximum(start_or_stop, 0)
382  return math_ops.range(start_or_stop)
383
384
385def _py_range(start_or_stop, stop, step):
386  if step is not UNSPECIFIED:
387    return range(start_or_stop, stop, step)
388  if stop is not UNSPECIFIED:
389    return range(start_or_stop, stop)
390  return range(start_or_stop)
391
392
393def enumerate_(s, start=0):
394  if isinstance(s, dataset_ops.DatasetV2):
395    return _tf_dataset_enumerate(s, start)
396  if isinstance(
397      s, (input_lib.DistributedIterator, input_lib.DistributedDataset)):
398    raise NotImplementedError(
399        'use a for loop over the dataset and keep a separate counter')
400  return _py_enumerate(s, start)
401
402
403def _tf_dataset_enumerate(s, start=0):
404  return s.enumerate(start)
405
406
407def _py_enumerate(s, start=0):
408  return enumerate(s, start)
409
410
411def zip_(*iterables):
412  if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables):
413    return _tf_dataset_zip(*iterables)
414  return _py_zip(*iterables)
415
416
417def _tf_dataset_zip(*iterables):
418  return dataset_ops.DatasetV2.zip(iterables)
419
420
421def _py_zip(*iterables):
422  return zip(*iterables)
423
424
425def map_(fn, *iterables):
426  if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables):
427    return _tf_dataset_map(fn, *iterables)
428  return _py_map(fn, *iterables)
429
430
431def _tf_dataset_map(fn, *iterables):
432  return dataset_ops.DatasetV2.zip(iterables).map(fn)
433
434
435def _py_map(fn, *iterables):
436  return map(fn, *iterables)
437
438
439def next_(iterator, default=UNSPECIFIED):
440  if isinstance(iterator, iterator_ops.OwnedIterator):
441    return next_tf_iterator(iterator, default)
442  return next_py(iterator, default)
443
444
445# TODO(mdan): These checks should be easier. Fix the nest API.
446def _verify_spec_compatible(input_name, spec_name, input_, spec):
447  """Verifies that a symbol has a type compatible vith a given spec.
448
449  Here, compatibility is viewed in the general TensorFlow sense: that the dtypes
450  are the same after implicit conversion, if both are tensors.
451
452  This verifier ensures consistent treatment of types across AutoGraph.
453
454  Args:
455    input_name: A name to use for `input_` in error messages.
456    spec_name: A name to use for `spec` in error messages.
457    input_: Any, value to verify.
458    spec: TypeSpec that `input_` must be compatible with.
459
460  Raises:
461    ValueError if the two types have been determined not to be compatible.
462  """
463  assert isinstance(spec, tensor_spec.TensorSpec)
464  if input is None:
465    # TODO(mdan): raise from None when switching to Py3.
466    raise ValueError('{} cannot be None'.format(input_name))
467
468  # TODO(mdan): Use TensorCompatible when ready.
469  if isinstance(input_, (bool, int, float, str, np.ndarray)):
470    input_ = ops.convert_to_tensor_v2(input_)
471
472  input_dtype = getattr(input_, 'dtype', None)
473
474  if input_dtype != spec.dtype:
475    input_dtype_str = 'no dtype' if input_dtype is None else str(input_dtype)
476
477    raise TypeError(
478        '{} must have the same dtype as {}. Expected {}, got {}'.format(
479            input_name, spec_name, spec.dtype, input_dtype_str))
480
481
482def _verify_structure_compatible(input_name, spec_name, input_, spec):
483  """Verifies that possibly-structured symbol has types compatible vith another.
484
485  See _verify_spec_compatible for a more concrete meaning of "compatible".
486  Unspec _verify_spec_compatible, which handles singular Tensor-spec objects,
487  verify_structures_compatible can process structures recognized by tf.nest.
488
489  Args:
490    input_name: A name to use for `input_` in error messages.
491    spec_name: A name to use for `spec` in error messages.
492    input_: Any, value to verify. May, but doesn't need to, be a structure.
493    spec: Any, value that `input_` must be compatible with. May, but doesn't
494        need to, be a structure.
495
496  Raises:
497    ValueError if the two types have been determined not to be compatible.
498  """
499  try:
500    nest.assert_same_structure(input_, spec, expand_composites=True)
501  except (ValueError, TypeError) as e:
502    raise TypeError(
503        '{} must have the same element structure as {}.\n\n{}'.format(
504            input_name, spec_name, str(e)))
505
506  nest.map_structure(
507      functools.partial(_verify_spec_compatible, input_name, spec_name), input_,
508      spec)
509
510
511def next_tf_iterator(iterator, default=UNSPECIFIED):
512  if default is UNSPECIFIED:
513    # Without a default, fall back to the "normal" behavior which raises
514    # a runtime exception.
515    return next(iterator)
516  opt_iterate = iterator.get_next_as_optional()
517  _verify_structure_compatible(
518      'the default argument', 'the iterate', default, iterator.element_spec)
519  return control_flow_ops.cond(
520      opt_iterate.has_value(), opt_iterate.get_value, lambda: default)
521
522
523def next_py(iterator, default=UNSPECIFIED):
524  if default is UNSPECIFIED:
525    return next(iterator)
526  return next(iterator, default)
527
528
529def filter_(function, iterable):
530  if isinstance(iterable, dataset_ops.DatasetV2):
531    return _tf_dataset_filter(function, iterable)
532  return _py_filter(function, iterable)
533
534
535def _tf_dataset_filter(function, iterable):
536  return iterable.filter(function)
537
538
539def _py_filter(function, iterable):
540  return filter(function, iterable)
541
542
543def any_(iterable):
544  if isinstance(iterable, dataset_ops.DatasetV2):
545    return _tf_dataset_any(iterable)
546  return _py_any(iterable)
547
548
549# any() operation is essentially a "if first True element exist".
550# For that it could be translated to `filter(True)` to filter out
551# only `True` element, and then `take(1)`. This works in tf.data
552# as tf.data's filter+take is done in pipeline so it will stop
553# as soon as `take(1)` returns.
554def _tf_dataset_any(iterable):
555  # check and make sure iterable.element_spec only consists of one
556  # element of tf.bool.
557  specs = nest.flatten(iterable.element_spec)
558  if len(specs) != 1 or specs[0].dtype != dtypes.bool:
559    raise ValueError('in graph mode, the "any" builtin only supports datasets '
560                     'that return bool scalars; got: {}'.format(
561                         iterable.element_spec))
562  ds = iterable.filter(lambda x: x)
563  ds = ds.take(1)
564  ds = ds.reduce(constant_op.constant(False, dtype=dtypes.bool), lambda _, y: y)
565  return ds
566
567
568def _py_any(iterable):
569  return any(iterable)
570
571
572def all_(iterable):
573  if isinstance(iterable, dataset_ops.DatasetV2):
574    return _tf_dataset_all(iterable)
575  return _py_all(iterable)
576
577
578# all() operation is similar to any() and could be translated
579# to `filter(False)` then `take(1)`, and check if `False` exists.
580def _tf_dataset_all(iterable):
581  # check and make sure iterable.element_spec only consists of one
582  # element of tf.bool.
583  specs = nest.flatten(iterable.element_spec)
584  if len(specs) != 1 or specs[0].dtype != dtypes.bool:
585    raise ValueError('in graph mode, the "all" builtin only supports datasets '
586                     'that return bool scalars; got: {}'.format(
587                         iterable.element_spec))
588  ds = iterable.filter(lambda x: math_ops.logical_not(x))
589  ds = ds.take(1)
590  ds = ds.reduce(constant_op.constant(True, dtype=dtypes.bool), lambda _, y: y)
591  return ds
592
593
594def _py_all(iterable):
595  return all(iterable)
596
597
598def sorted_(iterable, key=UNSPECIFIED, reverse=UNSPECIFIED):
599  if tensor_util.is_tf_type(iterable):
600    return _tf_sorted(iterable, key, reverse)
601  return _py_sorted(iterable, key, reverse)
602
603
604def _tf_sorted(iterable, key, reverse):
605  """Overload of sorted_ for Tensor iterable."""
606  if reverse is UNSPECIFIED:
607    direction = 'ASCENDING'
608  else:
609    direction = 'DESCENDING'
610  if key is not UNSPECIFIED:
611    mapped = parallel_ops.vectorized_map(key, iterable)
612    if mapped.shape.rank is not None and mapped.shape.rank != 1:
613      raise ValueError('sort only supports only 1D tensors')
614    with ops.control_dependencies([
615        check_ops.assert_rank_v2(mapped, 1,
616                                 'sort only supports only 1D tensors')
617    ]):
618      order = sort_ops.argsort(mapped, direction=direction)
619      return array_ops.gather_v2(iterable, order)
620  if iterable.shape.rank is not None and iterable.shape.rank != 1:
621    raise ValueError('sort only supports only 1D tensors')
622  with ops.control_dependencies([
623      check_ops.assert_rank_v2(iterable, 1,
624                               'sort only supports only 1D tensors')
625  ]):
626    return sort_ops.sort(iterable, direction=direction)
627
628
629def _py_sorted(iterable, key, reverse):
630  if key is not UNSPECIFIED and reverse is UNSPECIFIED:
631    return sorted(iterable, key=key)
632  if key is UNSPECIFIED and reverse is not UNSPECIFIED:
633    return sorted(iterable, reverse=reverse)
634  if key is not UNSPECIFIED and reverse is not UNSPECIFIED:
635    return sorted(iterable, key=key, reverse=reverse)
636  return sorted(iterable)
637
638
639SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map,
640                      filter, any, all, sorted)
641
642if six.PY2:
643  SUPPORTED_BUILTINS += (xrange,)
644
645BUILTIN_FUNCTIONS_MAP = {
646    'abs': abs_,
647    'any': any_,
648    'all': all_,
649    'enumerate': enumerate_,
650    'filter': filter_,
651    'float': float_,
652    'int': int_,
653    'len': len_,
654    'map': map_,
655    'next': next_,
656    'print': print_,
657    'range': range_,
658    'sorted': sorted_,
659    'xrange': range_,
660    'zip': zip_,
661}
662