1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for internal use.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import inspect 23import numbers 24import os 25import re 26import numpy as np 27 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import indexed_slices 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops.numpy_ops import np_arrays 35from tensorflow.python.ops.numpy_ops import np_dtypes 36from tensorflow.python.ops.numpy_ops import np_export 37from tensorflow.python.types import core 38from tensorflow.python.util import nest 39 40 41def _canonicalize_axis(axis, rank): 42 return _canonicalize_axes([axis], rank)[0] 43 44 45def _canonicalize_axes(axes, rank): 46 rank = _maybe_static(rank) 47 48 if isinstance(rank, core.Tensor): 49 canonicalizer = ( 50 lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis)) 51 else: 52 canonicalizer = lambda axis: axis + rank if axis < 0 else axis 53 54 return [canonicalizer(axis) for axis in axes] 55 56 57def _supports_signature(): 58 return hasattr(inspect, 'signature') 59 60 61def _to_tf_type(dtype): 62 """Converts a native python or numpy type to TF DType. 63 64 Args: 65 dtype: Could be a python type, a numpy type or a TF DType. 66 67 Returns: 68 A tensorflow `DType`. 69 """ 70 return dtypes.as_dtype(dtype) 71 72 73def _to_numpy_type(dtype): 74 """Converts a native python or TF DType to numpy type. 75 76 Args: 77 dtype: Could be a python type, a numpy type or a TF DType. 78 79 Returns: 80 A NumPy `dtype`. 81 """ 82 if isinstance(dtype, dtypes.DType): 83 return dtype.as_numpy_dtype 84 return np.dtype(dtype) 85 86 87def isscalar(val): 88 """Returns whether `val` is a scalar value or scalar Tensor.""" 89 if isinstance(val, np_arrays.ndarray): 90 val = val.data 91 if isinstance(val, core.Tensor): 92 ndims = val.shape.ndims 93 if ndims is not None: 94 return ndims == 0 95 else: 96 return math_ops.equal(array_ops.rank(val), 0) 97 else: 98 return np.isscalar(val) 99 100 101def _has_docstring(f): 102 return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and 103 f.__doc__) 104 105 106def _add_blank_line(s): 107 if s.endswith('\n'): 108 return s + '\n' 109 else: 110 return s + '\n\n' 111 112 113def _np_signature(f): 114 """An enhanced inspect.signature that can handle numpy.ufunc.""" 115 # TODO(wangpeng): consider migrating away from inspect.signature. 116 # inspect.signature is supported in Python 3.3. 117 if not hasattr(inspect, 'signature'): 118 return None 119 if f is None: 120 return None 121 if not isinstance(f, np.ufunc): 122 try: 123 return inspect.signature(f) 124 except ValueError: 125 return None 126 127 def names_from_num(prefix, n): 128 if n <= 0: 129 return [] 130 elif n == 1: 131 return [prefix] 132 else: 133 return [prefix + str(i + 1) for i in range(n)] 134 135 input_names = names_from_num('x', f.nin) 136 output_names = names_from_num('out', f.nout) 137 keyword_only_params = [('where', True), ('casting', 'same_kind'), 138 ('order', 'K'), ('dtype', None), ('subok', True), 139 ('signature', None), ('extobj', None)] 140 params = [] 141 params += [ 142 inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY) 143 for name in input_names 144 ] 145 if f.nout > 1: 146 params += [ 147 inspect.Parameter( 148 name, inspect.Parameter.POSITIONAL_ONLY, default=None) 149 for name in output_names 150 ] 151 params += [ 152 inspect.Parameter( 153 'out', 154 inspect.Parameter.POSITIONAL_OR_KEYWORD, 155 default=None if f.nout == 1 else (None,) * f.nout) 156 ] 157 params += [ 158 inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default) 159 for name, default in keyword_only_params 160 ] 161 return inspect.Signature(params) 162 163 164# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't 165# allow positional-only argument. So we conflate positional-only, keyword-only 166# and positional-or-keyword arguments here. 167def _is_compatible_param_kind(a, b): 168 169 def relax(k): 170 if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY): 171 return inspect.Parameter.POSITIONAL_OR_KEYWORD 172 return k 173 174 return relax(a) == relax(b) 175 176 177def _prepare_np_fun_name_and_fun(np_fun_name, np_fun): 178 """Mutually propagates information between `np_fun_name` and `np_fun`. 179 180 If one is None and the other is not, we'll try to make the former not None in 181 a best effort. 182 183 Args: 184 np_fun_name: name for the np_fun symbol. At least one of np_fun or 185 np_fun_name shoud be set. 186 np_fun: the numpy function whose docstring will be used. 187 188 Returns: 189 Processed `np_fun_name` and `np_fun`. 190 """ 191 if np_fun_name is not None: 192 assert isinstance(np_fun_name, str) 193 if np_fun is not None: 194 assert not isinstance(np_fun, str) 195 if np_fun is None: 196 assert np_fun_name is not None 197 try: 198 np_fun = getattr(np, str(np_fun_name)) 199 except AttributeError: 200 np_fun = None 201 if np_fun_name is None: 202 assert np_fun is not None 203 np_fun_name = np_fun.__name__ 204 return np_fun_name, np_fun 205 206 207def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None, 208 link=None): 209 """Helper to get docs.""" 210 assert np_f or np_fun_name 211 if not np_fun_name: 212 np_fun_name = np_f.__name__ 213 doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name 214 if unsupported_params: 215 doc += 'Unsupported arguments: ' + ', '.join( 216 '`' + name + '`' for name in unsupported_params) + '.\n\n' 217 if _has_docstring(f): 218 doc += f.__doc__ 219 doc = _add_blank_line(doc) 220 # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy 221 # doc according to some global switch. 222 doc = _add_np_doc(doc, np_fun_name, np_f, link=link) 223 return doc 224 225 226_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16') 227 228 229def get_np_doc_form(): 230 """Gets the form of the original numpy docstrings. 231 232 Returns: 233 See `set_np_doc_form` for the list of valid values. 234 """ 235 return _np_doc_form 236 237 238def set_np_doc_form(value): 239 r"""Selects the form of the original numpy docstrings. 240 241 This function sets a global variable that controls how a tf-numpy symbol's 242 docstring should refer to the original numpy docstring. If `value` is 243 `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy 244 docstring. Otherwise, a link to the original numpy docstring will be 245 added. Which numpy version the link points to depends on `value`: 246 * `'stable'`: the current stable version; 247 * `'dev'`: the current development version; 248 * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number, 249 e.g. '1.16'. 250 251 Args: 252 value: the value to set the global variable to. 253 """ 254 global _np_doc_form 255 _np_doc_form = value 256 257 258class Link: 259 260 def __init__(self, v): 261 self.value = v 262 263 264class AliasOf: 265 266 def __init__(self, v): 267 self.value = v 268 269 270class NoLink: 271 pass 272 273 274def generate_link(flag, np_fun_name): 275 """Generates link from numpy function name. 276 277 Args: 278 flag: the flag to control link form. See `set_np_doc_form`. 279 np_fun_name: the numpy function name. 280 281 Returns: 282 A string. 283 """ 284 # Only adds link in this case 285 if flag == 'dev': 286 template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html' 287 elif flag == 'stable': 288 template = ( 289 'https://numpy.org/doc/stable/reference/generated/numpy.%s.html') 290 elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag): 291 # `flag` is the version number 292 template = ('https://numpy.org/doc/' + flag + 293 '/reference/generated/numpy.%s.html') 294 else: 295 return None 296 return template % np_fun_name 297 298 299_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in 300 ('True', 'true', '1')) 301 302 303def is_check_link(): 304 return _is_check_link 305 306 307def set_check_link(value): 308 global _is_check_link 309 _is_check_link = value 310 311 312def _add_np_doc(doc, np_fun_name, np_f, link): 313 """Appends the numpy docstring to `doc`, according to `set_np_doc_form`. 314 315 See `set_np_doc_form` for how it controls the form of the numpy docstring. 316 317 Args: 318 doc: the docstring to be appended to. 319 np_fun_name: the name of the numpy function. 320 np_f: (optional) the numpy function. 321 link: (optional) which link to use. See `np_doc` for details. 322 323 Returns: 324 `doc` with numpy docstring appended. 325 """ 326 flag = get_np_doc_form() 327 if flag == 'inlined': 328 if _has_docstring(np_f): 329 doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name 330 # TODO(wangpeng): It looks like code snippets in numpy doc don't work 331 # correctly with doctest. Fix that and remove the reformatting of the np_f 332 # comment. 333 doc += np_f.__doc__.replace('>>>', '>') 334 elif isinstance(flag, str): 335 if link is None: 336 url = generate_link(flag, np_fun_name) 337 elif isinstance(link, AliasOf): 338 url = generate_link(flag, link.value) 339 elif isinstance(link, Link): 340 url = link.value 341 else: 342 url = None 343 if url is not None: 344 if is_check_link(): 345 # Imports locally because some builds may not have `requests` 346 import requests # pylint: disable=g-import-not-at-top 347 r = requests.head(url) 348 if r.status_code != 200: 349 raise ValueError( 350 f'Check link failed at [{url}] with status code {r.status_code}. ' 351 f'Argument `np_fun_name` is {np_fun_name}.') 352 doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % ( 353 np_fun_name, url) 354 return doc 355 356 357_is_sig_mismatch_an_error = ( 358 os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1')) 359 360 361def is_sig_mismatch_an_error(): 362 return _is_sig_mismatch_an_error 363 364 365def set_is_sig_mismatch_an_error(value): 366 global _is_sig_mismatch_an_error 367 _is_sig_mismatch_an_error = value 368 369 370def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None, 371 link=None): 372 """Attachs numpy docstring to a function. 373 374 Args: 375 np_fun_name: name for the np_fun symbol. At least one of np_fun or 376 np_fun_name shoud be set. 377 np_fun: (optional) the numpy function whose docstring will be used. 378 export: whether to export this symbol under module 379 `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be 380 a function directly under the `numpy` module, not under any submodule of 381 `numpy` (e.g. `numpy.random`). 382 unsupported_params: (optional) the list of parameters not supported 383 by tf.numpy. 384 link: (optional) which link to use. If `None`, a default link generated from 385 `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will 386 be used in place of `np_fun_name` for the link generation. If an instance 387 of `Link`, `link.value` will be used as the whole link. If an instance of 388 `NoLink`, no link will be added. 389 390 Returns: 391 A function decorator that attaches the docstring from `np_fun` to the 392 decorated function. 393 """ 394 np_fun_name_orig, np_fun_orig = np_fun_name, np_fun 395 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 396 np_sig = _np_signature(np_fun) 397 if unsupported_params is None: 398 unsupported_params = [] 399 400 def decorator(f): 401 """The decorator.""" 402 if hasattr(inspect, 'signature') and np_sig is not None: 403 try: 404 sig = inspect.signature(f) 405 except ValueError: 406 sig = None 407 if sig is not None: 408 for name, param in sig.parameters.items(): 409 np_param = np_sig.parameters.get(name) 410 if np_param is None: 411 if is_sig_mismatch_an_error(): 412 raise TypeError( 413 f'Cannot find parameter {name} in the numpy function\'s ' 414 f'signature (which has these parameters: ' 415 f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` ' 416 f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.') 417 else: 418 continue 419 if (is_sig_mismatch_an_error() and 420 not _is_compatible_param_kind(param.kind, np_param.kind)): 421 raise TypeError( 422 f'Parameter {name} is of kind {param.kind} while in numpy it ' 423 f'is of kind {np_param.kind}. Argument `np_fun_name` is ' 424 f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.') 425 has_default = (param.default != inspect.Parameter.empty) 426 np_has_default = (np_param.default != inspect.Parameter.empty) 427 if is_sig_mismatch_an_error() and has_default != np_has_default: 428 raise TypeError( 429 'Parameter {} should{} have a default value. Argument ' 430 '`np_fun_name` is {}. Argument `np_fun` is {}.'.format( 431 name, '' if np_has_default else ' not', np_fun_name_orig, 432 np_fun_orig)) 433 for name in np_sig.parameters: 434 if name not in sig.parameters: 435 unsupported_params.append(name) 436 f.__doc__ = _np_doc_helper( 437 f, np_fun, np_fun_name=np_fun_name, 438 unsupported_params=unsupported_params, link=link) 439 if export: 440 return np_export.np_export(np_fun_name)(f) 441 else: 442 return f 443 444 return decorator 445 446 447def np_doc_only(np_fun_name, np_fun=None, export=True): 448 """Attachs numpy docstring to a function. 449 450 This differs from np_doc in that it doesn't check for a match in signature. 451 452 Args: 453 np_fun_name: name for the np_fun symbol. At least one of np_fun or 454 np_fun_name shoud be set. 455 np_fun: (optional) the numpy function whose docstring will be used. 456 export: whether to export this symbol under module 457 `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a 458 function directly under the `numpy` module, not under any submodule of 459 `numpy` (e.g. `numpy.random`). 460 461 Returns: 462 A function decorator that attaches the docstring from `np_fun` to the 463 decorated function. 464 """ 465 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 466 467 def decorator(f): 468 f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name) 469 if export: 470 return np_export.np_export(np_fun_name)(f) 471 else: 472 return f 473 474 return decorator 475 476 477# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 478@np_doc('finfo') 479def finfo(dtype): 480 """Note that currently it just forwards to the numpy namesake, while 481 tensorflow and numpy dtypes may have different properties.""" 482 return np.finfo(_to_numpy_type(dtype)) 483# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 484 485 486def _maybe_get_dtype(x): 487 """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" 488 # Don't put np.ndarray in this list, because np.result_type looks at the 489 # value (not just dtype) of np.ndarray to decide the result type. 490 if isinstance(x, numbers.Real): 491 return x 492 if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)): 493 return _to_numpy_type(x.dtype) 494 if isinstance(x, dtypes.DType): 495 return x.as_numpy_dtype 496 if isinstance(x, (list, tuple)): 497 raise ValueError( 498 f'Cannot find dtype for type inference from argument `x` of a sequence ' 499 f'type {type(x)}. For sequences, please call this function on each ' 500 f'element individually.') 501 return x 502 503 504# Can't use np_doc because np.result_type is a builtin function. 505@np_doc_only('result_type') 506def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring 507 arrays_and_dtypes = [ 508 _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes) 509 ] 510 if not arrays_and_dtypes: 511 # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. 512 arrays_and_dtypes = [np.asarray([])] 513 return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access 514 515 516def _result_type_binary(t1, t2): # pylint: disable=missing-function-docstring 517 """A specialization of result_type for 2 arguments for performance reasons.""" 518 try: 519 return np_dtypes._result_type(_maybe_get_dtype(t1), # pylint: disable=protected-access 520 _maybe_get_dtype(t2)) # pylint: disable=protected-access 521 except ValueError: 522 return result_type(t1, t2) 523 524 525@np_doc('promote_types') 526def promote_types(type1, type2): # pylint: disable=missing-function-docstring 527 type1 = _to_numpy_type(type1) 528 type2 = _to_numpy_type(type2) 529 return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2)) 530 531 532def tf_broadcast(*args): 533 """Broadcast tensors. 534 535 Args: 536 *args: a list of tensors whose shapes are broadcastable against each other. 537 538 Returns: 539 Tensors broadcasted to the common shape. 540 """ 541 if len(args) <= 1: 542 return args 543 sh = array_ops.shape(args[0]) 544 for arg in args[1:]: 545 sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg)) 546 return [array_ops.broadcast_to(arg, sh) for arg in args] 547 548 549# TODO(wangpeng): Move the following functions to a separate file and check for 550# float dtypes in each of them. 551 552 553def get_static_value(x): 554 """A version of tf.get_static_value that returns None on float dtypes. 555 556 It returns None on float dtypes in order to avoid breaking gradients. 557 558 Args: 559 x: a tensor. 560 561 Returns: 562 Same as `tf.get_static_value`, except that it returns None when `x` has a 563 float dtype. 564 """ 565 if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex): 566 return None 567 return tensor_util.constant_value(x) 568 569 570def _maybe_static(x): 571 value = get_static_value(x) 572 if value is None: 573 return x 574 else: 575 return value 576 577 578# All the following functions exist becaues get_static_value can't handle 579# their TF counterparts. 580 581 582def cond(pred, true_fn, false_fn): 583 """A version of tf.cond that tries to evaluate the condition.""" 584 v = get_static_value(pred) 585 if v is None: 586 return control_flow_ops.cond(pred, true_fn, false_fn) 587 if v: 588 return true_fn() 589 else: 590 return false_fn() 591 592 593def add(a, b): 594 """A version of tf.add that eagerly evaluates if possible.""" 595 return _maybe_static(a) + _maybe_static(b) 596 597 598def subtract(a, b): 599 """A version of tf.subtract that eagerly evaluates if possible.""" 600 return _maybe_static(a) - _maybe_static(b) 601 602 603def greater(a, b): 604 """A version of tf.greater that eagerly evaluates if possible.""" 605 return _maybe_static(a) > _maybe_static(b) 606 607 608def greater_equal(a, b): 609 """A version of tf.greater_equal that eagerly evaluates if possible.""" 610 return _maybe_static(a) >= _maybe_static(b) 611 612 613def less_equal(a, b): 614 """A version of tf.less_equal that eagerly evaluates if possible.""" 615 return _maybe_static(a) <= _maybe_static(b) 616 617 618def logical_and(a, b): 619 """A version of tf.logical_and that eagerly evaluates if possible.""" 620 a_value = get_static_value(a) 621 if a_value is not None: 622 if np.isscalar(a_value): 623 if a_value: 624 return _maybe_static(b) 625 else: 626 return a_value 627 else: 628 return a_value & _maybe_static(b) 629 else: 630 return a & _maybe_static(b) 631 632 633def logical_or(a, b): 634 """A version of tf.logical_or that eagerly evaluates if possible.""" 635 a_value = get_static_value(a) 636 if a_value is not None: 637 if np.isscalar(a_value): 638 if a_value: 639 return a_value 640 else: 641 return _maybe_static(b) 642 else: 643 return a_value | _maybe_static(b) 644 else: 645 return a | _maybe_static(b) 646 647 648def getitem(a, slice_spec): 649 """A version of __getitem__ that eagerly evaluates if possible.""" 650 return _maybe_static(a)[slice_spec] 651 652 653def reduce_all(input_tensor, axis=None, keepdims=False): 654 """A version of tf.reduce_all that eagerly evaluates if possible.""" 655 v = get_static_value(input_tensor) 656 if v is None: 657 return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims) 658 else: 659 return v.all(axis=axis, keepdims=keepdims) 660 661 662def reduce_any(input_tensor, axis=None, keepdims=False): 663 """A version of tf.reduce_any that eagerly evaluates if possible.""" 664 v = get_static_value(input_tensor) 665 if v is None: 666 return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims) 667 else: 668 return v.any(axis=axis, keepdims=keepdims) 669 670 671def tf_rank(t): 672 r = t.shape.rank 673 if r is not None: 674 return r 675 return array_ops.rank(t) 676