1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operators corresponding to Python builtin functions. 16 17List of built-in functions: https://docs.python.org/3/library/functions.html 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import functools 25import inspect 26 27import numpy as np 28import six 29 30from tensorflow.python.autograph.utils import py_func 31from tensorflow.python.autograph.utils import tensors 32from tensorflow.python.data.experimental.ops import cardinality 33from tensorflow.python.data.ops import dataset_ops 34from tensorflow.python.data.ops import iterator_ops 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import check_ops 42from tensorflow.python.ops import control_flow_ops 43from tensorflow.python.ops import gen_parsing_ops 44from tensorflow.python.ops import gen_string_ops 45from tensorflow.python.ops import list_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import sort_ops 48from tensorflow.python.util import lazy_loader 49from tensorflow.python.util import nest 50 51 52# TODO(b/145618471): Remove this dependency. 53# Lazy import to work around circular dependencies 54input_lib = lazy_loader.LazyLoader( 55 'input_lib', globals(), 56 'tensorflow.python.distribute.input_lib') 57parallel_ops = lazy_loader.LazyLoader( 58 'parallel_ops', globals(), 59 'tensorflow.python.ops.parallel_for.control_flow_ops') 60 61 62UNSPECIFIED = object() 63 64 65def overload_of(f): 66 if f in SUPPORTED_BUILTINS: 67 return BUILTIN_FUNCTIONS_MAP[f.__name__] 68 return f 69 70 71def _find_originating_frame(caller_fn_scope, innermost=True): 72 """Locates the frame in which `caller_fn_scope` was defined.""" 73 ctx_frame = inspect.currentframe() 74 result = None 75 while ctx_frame is not None: 76 # Note it should not be normally possible to get false positives this way 77 # because the function scope object is not accessible to user code (barring 78 # call stack introspection). 79 if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope: 80 result = ctx_frame 81 if innermost: 82 break 83 ctx_frame = ctx_frame.f_back 84 85 assert result is not None, ( 86 'the conversion process should ensure the caller_fn_scope is always' 87 ' found somewhere on the call stack') 88 89 return result 90 91 92def locals_in_original_context(caller_fn_scope): 93 """Executes the locals function in the context of a specified function.""" 94 return _find_originating_frame(caller_fn_scope, innermost=True).f_locals 95 96 97def globals_in_original_context(caller_fn_scope): 98 """Executes the locals function in the context of a specified function.""" 99 return _find_originating_frame(caller_fn_scope, innermost=True).f_globals 100 101 102def eval_in_original_context(f, args, caller_fn_scope): 103 """Executes the eval function in the context of a specified function.""" 104 # When control flow is rewritten using functions, eval should use the 105 # variables found in the same block where it was called. That is equivalent 106 # to the innermost function call. 107 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True) 108 109 args = ( 110 args[0], 111 ctx_frame.f_globals if len(args) < 2 else args[1], 112 ctx_frame.f_locals if len(args) < 3 else args[2], 113 ) 114 return f(*args) 115 116 117def super_in_original_context(f, args, caller_fn_scope): 118 """Executes the super function in the context of a specified function. 119 120 See https://docs.python.org/3/library/functions.html#super for the exact 121 details 122 123 Args: 124 f: Callable, typically the super builtin 125 args: List[Any], the original call arguments 126 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function 127 scope of the converted function in which this call was originally made 128 129 Returns: 130 The result of calling `f` as if it was called in the frame indicated by 131 `caller_fn_scope`. 132 """ 133 134 # Python 2 doesn't support implicit argument super variants. 135 if six.PY2: 136 return f(*args) 137 138 # Only the no-arg call is desugared. 139 if args: 140 return f(*args) 141 142 # Inner functions seem to include their closure in f_locals, so we need 143 # to find the outermost frame. 144 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False) 145 146 # When super(..) is called without arguments, it looks for __class__ cell 147 # variable and the first argument passed in the enclosing function according 148 # to the spec https://www.python.org/dev/peps/pep-3135/ . 149 # 150 # We couldn't verify if `inspect.currentframe().f_code.co_varnames[0]` is 151 # guaranteed to be the first argument from an official doc or PEP, however, 152 # it's fairly stable and well established: 153 # - An unofficial community doc mentions it. 154 # https://python-reference.readthedocs.io/en/latest/docs/code/varnames.html 155 # - CPython has tests checking that order, which was merged in 2008, and 156 # unchanged since then. 157 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157 158 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192 159 # 160 # Note: the name can be more reliably obtained by inspecting the calling 161 # function's argspec. 162 # 163 # Even though methods can be declared using *args (def method(*args)), 164 # that pattern is disallowed by super() -- it raises super() no arguments. 165 # Method definitions using **kwargs are not allowed at all. 166 # In other words, we can always assume that self is on the first positional 167 # argument (for correct code). 168 # 169 # TODO(mdan): Consider additional checks in case the input code is incorrect. 170 # For example, the error might be cryptic compared to what super() regularly 171 # raises. 172 173 type_arg = ctx_frame.f_locals['__class__'] 174 self_arg_name = ctx_frame.f_code.co_varnames[0] 175 self_arg = ctx_frame.f_locals[self_arg_name] 176 return f(type_arg, self_arg) 177 178 179def abs_(x): 180 if tensor_util.is_tf_type(x): 181 return _tf_abs(x) 182 if isinstance(x, dataset_ops.DatasetV2): 183 return _tf_dataset_abs(x) 184 return _py_abs(x) 185 186 187def _tf_abs(x): 188 return math_ops.abs(x) 189 190 191def _tf_dataset_abs(x): 192 specs = nest.flatten(x.element_spec) 193 if len(specs) == 1: 194 return x.map(math_ops.abs, num_parallel_calls=dataset_ops.AUTOTUNE) 195 return x.map( 196 lambda *e: nest.map_structure(math_ops.abs, e), 197 num_parallel_calls=dataset_ops.AUTOTUNE) 198 199 200def _py_abs(x): 201 return abs(x) 202 203 204def float_(x=0): 205 if tensor_util.is_tf_type(x): 206 return _tf_float(x) 207 return _py_float(x) 208 209 210def _tf_float(x): 211 # TODO(mdan): We shouldn't assume float32. 212 if x.dtype == dtypes.string: 213 return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32) 214 return math_ops.cast(x, dtype=dtypes.float32) 215 216 217def _py_float(x): 218 return float(x) 219 220 221def int_(x=0, base=UNSPECIFIED): 222 if tensor_util.is_tf_type(x): 223 return _tf_int(x, base) 224 return _py_int(x, base) 225 226 227def _tf_int(x, base): 228 if base not in (10, UNSPECIFIED): 229 raise NotImplementedError('base {} not supported for int'.format(base)) 230 231 # TODO(mdan): We shouldn't assume int32. 232 if x.dtype == dtypes.string: 233 return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32) 234 return math_ops.cast(x, dtype=dtypes.int32) 235 236 237def _py_int(x, base): 238 if base is UNSPECIFIED: 239 return int(x) 240 return int(x, base) 241 242 243def len_(s): 244 if tensors.is_tensor_array(s): 245 return _tf_tensor_array_len(s) 246 elif tensors.is_tensor_list(s): 247 return _tf_tensor_list_len(s) 248 elif tensor_util.is_tf_type(s): 249 return _tf_tensor_len(s) 250 if isinstance(s, dataset_ops.DatasetV2): 251 return _tf_dataset_len(s) 252 return _py_len(s) 253 254 255def _tf_tensor_array_len(s): 256 return s.size() 257 258 259def _tf_tensor_list_len(s): 260 return list_ops.tensor_list_length(s) 261 262 263def _tf_tensor_len(s): 264 """Overload of len_ for Tensor arguments.""" 265 # Statically shaped tensors: length is known ahead of time. 266 if s.shape.ndims and s.shape.dims[0].value is not None: 267 return s.shape.dims[0].value 268 269 # Static shape of unknown dimensions: use dynamic shape but statically 270 # check that it's a scalar. 271 shape = array_ops.shape(s) 272 273 assert shape.shape, 'shape tensor of zero size? {}'.format(shape) 274 275 if shape.shape[0] == 0: 276 raise ValueError( 277 'len requires a non-scalar tensor, got one of shape {}'.format(shape)) 278 279 if shape.shape.dims[0].value is not None: 280 return array_ops.shape(s)[0] 281 282 # Fully dynamic shape: use ops. 283 rank = array_ops.rank(s) 284 285 def raise_zero_rank_error(): 286 msg = gen_string_ops.string_join( 287 ['len requires non-zero rank, got ', 288 gen_string_ops.as_string(rank)]) 289 with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]): 290 return constant_op.constant(0, dtype=dtypes.int32) 291 292 return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0], 293 raise_zero_rank_error) 294 295 296def _tf_dataset_len(s): 297 l = cardinality.cardinality(s) 298 msg = gen_string_ops.string_join([ 299 'len requires dataset with definitive cardinality, got ', 300 gen_string_ops.as_string(l) 301 ]) 302 # TODO (yongtang): UNKNOWN is treated as an error. 303 # In case there are more UNKNOWN cases for dataset, we could 304 # use dataset.reduce() to find out the length (in an expensive way). 305 with ops.control_dependencies([ 306 control_flow_ops.Assert( 307 math_ops.logical_and( 308 math_ops.not_equal(l, cardinality.INFINITE), 309 math_ops.not_equal(l, cardinality.UNKNOWN)), [msg]) 310 ]): 311 l = array_ops.identity(l) 312 313 return l 314 315 316def _py_len(s): 317 return len(s) 318 319 320def print_(*objects, **kwargs): 321 """Overload of the print builtin.""" 322 # Note: Python 2.6 doesn't support explicit keywords after starargs. 323 unknown_kwargs = tuple( 324 set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush'))) 325 if unknown_kwargs: 326 raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs)) 327 328 # TODO(mdan): Use next.flatten(objects) instead? 329 if any(tensor_util.is_tf_type(o) for o in objects): 330 # TODO(mdan): use tf.print instead. 331 return _tf_py_func_print(objects, kwargs) 332 else: 333 _py_print(*objects, **kwargs) 334 335 336def _py_print(*objects, **kwargs): 337 print(*objects, **kwargs) 338 339 340def _tf_py_func_print(objects, kwargs): 341 """Overload of print_ as a py_func implementation.""" 342 override_kwargs = {k: v for k, v in kwargs.items() if v is not UNSPECIFIED} 343 if 'flush' not in override_kwargs: 344 # Defaulting to flushing the console in graph mode, which helps reduce 345 # garbled output in IPython. 346 override_kwargs['flush'] = True 347 348 def print_wrapper(*vals): 349 vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals) 350 if not six.PY2: 351 # TensorFlow doesn't seem to generate Unicode when passing strings to 352 # py_func. This causes the print to add a "b'" wrapper to the output, 353 # which is probably never what you want. 354 vals = tuple( 355 v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) 356 six.print_(*vals, **override_kwargs) 357 358 return py_func.wrap_py_func( 359 print_wrapper, None, objects, use_dummy_return=True) 360 361 362def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED): 363 if any(tensor_util.is_tf_type(s) for s in (start_or_stop, stop, step)): 364 return _tf_range(start_or_stop, stop, step) 365 return _py_range(start_or_stop, stop, step) 366 367 368def _tf_range(start_or_stop, stop, step): 369 """Overload of range_ that generates a TF range tensor.""" 370 # Note: for static inputs (e.g. constants), tf.range errors out at graph 371 # construction time, instead of returning an empty tensor. Preventing the 372 # graph construction error aligns the semantics with Python. 373 374 # TODO(mdan): We should optimize this when a full tensor is not required. 375 if step is not UNSPECIFIED: 376 # TODO(mdan): Add argument coercion similar to other cases. 377 return math_ops.range(start_or_stop, stop, step) 378 if stop is not UNSPECIFIED: 379 stop = math_ops.maximum(start_or_stop, stop) 380 return math_ops.range(start_or_stop, stop) 381 start_or_stop = math_ops.maximum(start_or_stop, 0) 382 return math_ops.range(start_or_stop) 383 384 385def _py_range(start_or_stop, stop, step): 386 if step is not UNSPECIFIED: 387 return range(start_or_stop, stop, step) 388 if stop is not UNSPECIFIED: 389 return range(start_or_stop, stop) 390 return range(start_or_stop) 391 392 393def enumerate_(s, start=0): 394 if isinstance(s, dataset_ops.DatasetV2): 395 return _tf_dataset_enumerate(s, start) 396 if isinstance( 397 s, (input_lib.DistributedIterator, input_lib.DistributedDataset)): 398 raise NotImplementedError( 399 'use a for loop over the dataset and keep a separate counter') 400 return _py_enumerate(s, start) 401 402 403def _tf_dataset_enumerate(s, start=0): 404 return s.enumerate(start) 405 406 407def _py_enumerate(s, start=0): 408 return enumerate(s, start) 409 410 411def zip_(*iterables): 412 if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables): 413 return _tf_dataset_zip(*iterables) 414 return _py_zip(*iterables) 415 416 417def _tf_dataset_zip(*iterables): 418 return dataset_ops.DatasetV2.zip(iterables) 419 420 421def _py_zip(*iterables): 422 return zip(*iterables) 423 424 425def map_(fn, *iterables): 426 if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables): 427 return _tf_dataset_map(fn, *iterables) 428 return _py_map(fn, *iterables) 429 430 431def _tf_dataset_map(fn, *iterables): 432 return dataset_ops.DatasetV2.zip(iterables).map(fn) 433 434 435def _py_map(fn, *iterables): 436 return map(fn, *iterables) 437 438 439def next_(iterator, default=UNSPECIFIED): 440 if isinstance(iterator, iterator_ops.OwnedIterator): 441 return next_tf_iterator(iterator, default) 442 return next_py(iterator, default) 443 444 445# TODO(mdan): These checks should be easier. Fix the nest API. 446def _verify_spec_compatible(input_name, spec_name, input_, spec): 447 """Verifies that a symbol has a type compatible vith a given spec. 448 449 Here, compatibility is viewed in the general TensorFlow sense: that the dtypes 450 are the same after implicit conversion, if both are tensors. 451 452 This verifier ensures consistent treatment of types across AutoGraph. 453 454 Args: 455 input_name: A name to use for `input_` in error messages. 456 spec_name: A name to use for `spec` in error messages. 457 input_: Any, value to verify. 458 spec: TypeSpec that `input_` must be compatible with. 459 460 Raises: 461 ValueError if the two types have been determined not to be compatible. 462 """ 463 assert isinstance(spec, tensor_spec.TensorSpec) 464 if input is None: 465 # TODO(mdan): raise from None when switching to Py3. 466 raise ValueError('{} cannot be None'.format(input_name)) 467 468 # TODO(mdan): Use TensorCompatible when ready. 469 if isinstance(input_, (bool, int, float, str, np.ndarray)): 470 input_ = ops.convert_to_tensor_v2(input_) 471 472 input_dtype = getattr(input_, 'dtype', None) 473 474 if input_dtype != spec.dtype: 475 input_dtype_str = 'no dtype' if input_dtype is None else str(input_dtype) 476 477 raise TypeError( 478 '{} must have the same dtype as {}. Expected {}, got {}'.format( 479 input_name, spec_name, spec.dtype, input_dtype_str)) 480 481 482def _verify_structure_compatible(input_name, spec_name, input_, spec): 483 """Verifies that possibly-structured symbol has types compatible vith another. 484 485 See _verify_spec_compatible for a more concrete meaning of "compatible". 486 Unspec _verify_spec_compatible, which handles singular Tensor-spec objects, 487 verify_structures_compatible can process structures recognized by tf.nest. 488 489 Args: 490 input_name: A name to use for `input_` in error messages. 491 spec_name: A name to use for `spec` in error messages. 492 input_: Any, value to verify. May, but doesn't need to, be a structure. 493 spec: Any, value that `input_` must be compatible with. May, but doesn't 494 need to, be a structure. 495 496 Raises: 497 ValueError if the two types have been determined not to be compatible. 498 """ 499 try: 500 nest.assert_same_structure(input_, spec, expand_composites=True) 501 except (ValueError, TypeError) as e: 502 raise TypeError( 503 '{} must have the same element structure as {}.\n\n{}'.format( 504 input_name, spec_name, str(e))) 505 506 nest.map_structure( 507 functools.partial(_verify_spec_compatible, input_name, spec_name), input_, 508 spec) 509 510 511def next_tf_iterator(iterator, default=UNSPECIFIED): 512 if default is UNSPECIFIED: 513 # Without a default, fall back to the "normal" behavior which raises 514 # a runtime exception. 515 return next(iterator) 516 opt_iterate = iterator.get_next_as_optional() 517 _verify_structure_compatible( 518 'the default argument', 'the iterate', default, iterator.element_spec) 519 return control_flow_ops.cond( 520 opt_iterate.has_value(), opt_iterate.get_value, lambda: default) 521 522 523def next_py(iterator, default=UNSPECIFIED): 524 if default is UNSPECIFIED: 525 return next(iterator) 526 return next(iterator, default) 527 528 529def filter_(function, iterable): 530 if isinstance(iterable, dataset_ops.DatasetV2): 531 return _tf_dataset_filter(function, iterable) 532 return _py_filter(function, iterable) 533 534 535def _tf_dataset_filter(function, iterable): 536 return iterable.filter(function) 537 538 539def _py_filter(function, iterable): 540 return filter(function, iterable) 541 542 543def any_(iterable): 544 if isinstance(iterable, dataset_ops.DatasetV2): 545 return _tf_dataset_any(iterable) 546 return _py_any(iterable) 547 548 549# any() operation is essentially a "if first True element exist". 550# For that it could be translated to `filter(True)` to filter out 551# only `True` element, and then `take(1)`. This works in tf.data 552# as tf.data's filter+take is done in pipeline so it will stop 553# as soon as `take(1)` returns. 554def _tf_dataset_any(iterable): 555 # check and make sure iterable.element_spec only consists of one 556 # element of tf.bool. 557 specs = nest.flatten(iterable.element_spec) 558 if len(specs) != 1 or specs[0].dtype != dtypes.bool: 559 raise ValueError('in graph mode, the "any" builtin only supports datasets ' 560 'that return bool scalars; got: {}'.format( 561 iterable.element_spec)) 562 ds = iterable.filter(lambda x: x) 563 ds = ds.take(1) 564 ds = ds.reduce(constant_op.constant(False, dtype=dtypes.bool), lambda _, y: y) 565 return ds 566 567 568def _py_any(iterable): 569 return any(iterable) 570 571 572def all_(iterable): 573 if isinstance(iterable, dataset_ops.DatasetV2): 574 return _tf_dataset_all(iterable) 575 return _py_all(iterable) 576 577 578# all() operation is similar to any() and could be translated 579# to `filter(False)` then `take(1)`, and check if `False` exists. 580def _tf_dataset_all(iterable): 581 # check and make sure iterable.element_spec only consists of one 582 # element of tf.bool. 583 specs = nest.flatten(iterable.element_spec) 584 if len(specs) != 1 or specs[0].dtype != dtypes.bool: 585 raise ValueError('in graph mode, the "all" builtin only supports datasets ' 586 'that return bool scalars; got: {}'.format( 587 iterable.element_spec)) 588 ds = iterable.filter(lambda x: math_ops.logical_not(x)) 589 ds = ds.take(1) 590 ds = ds.reduce(constant_op.constant(True, dtype=dtypes.bool), lambda _, y: y) 591 return ds 592 593 594def _py_all(iterable): 595 return all(iterable) 596 597 598def sorted_(iterable, key=UNSPECIFIED, reverse=UNSPECIFIED): 599 if tensor_util.is_tf_type(iterable): 600 return _tf_sorted(iterable, key, reverse) 601 return _py_sorted(iterable, key, reverse) 602 603 604def _tf_sorted(iterable, key, reverse): 605 """Overload of sorted_ for Tensor iterable.""" 606 if reverse is UNSPECIFIED: 607 direction = 'ASCENDING' 608 else: 609 direction = 'DESCENDING' 610 if key is not UNSPECIFIED: 611 mapped = parallel_ops.vectorized_map(key, iterable) 612 if mapped.shape.rank is not None and mapped.shape.rank != 1: 613 raise ValueError('sort only supports only 1D tensors') 614 with ops.control_dependencies([ 615 check_ops.assert_rank_v2(mapped, 1, 616 'sort only supports only 1D tensors') 617 ]): 618 order = sort_ops.argsort(mapped, direction=direction) 619 return array_ops.gather_v2(iterable, order) 620 if iterable.shape.rank is not None and iterable.shape.rank != 1: 621 raise ValueError('sort only supports only 1D tensors') 622 with ops.control_dependencies([ 623 check_ops.assert_rank_v2(iterable, 1, 624 'sort only supports only 1D tensors') 625 ]): 626 return sort_ops.sort(iterable, direction=direction) 627 628 629def _py_sorted(iterable, key, reverse): 630 if key is not UNSPECIFIED and reverse is UNSPECIFIED: 631 return sorted(iterable, key=key) 632 if key is UNSPECIFIED and reverse is not UNSPECIFIED: 633 return sorted(iterable, reverse=reverse) 634 if key is not UNSPECIFIED and reverse is not UNSPECIFIED: 635 return sorted(iterable, key=key, reverse=reverse) 636 return sorted(iterable) 637 638 639SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map, 640 filter, any, all, sorted) 641 642if six.PY2: 643 SUPPORTED_BUILTINS += (xrange,) 644 645BUILTIN_FUNCTIONS_MAP = { 646 'abs': abs_, 647 'any': any_, 648 'all': all_, 649 'enumerate': enumerate_, 650 'filter': filter_, 651 'float': float_, 652 'int': int_, 653 'len': len_, 654 'map': map_, 655 'next': next_, 656 'print': print_, 657 'range': range_, 658 'sorted': sorted_, 659 'xrange': range_, 660 'zip': zip_, 661} 662