1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for internal use.""" 16# pylint: disable=g-direct-tensorflow-import 17 18import inspect 19import numbers 20import os 21import re 22import numpy as np 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import indexed_slices 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.numpy_ops import np_arrays 31from tensorflow.python.ops.numpy_ops import np_dtypes 32from tensorflow.python.ops.numpy_ops import np_export 33from tensorflow.python.types import core 34from tensorflow.python.util import nest 35 36 37def _canonicalize_axis(axis, rank): 38 return _canonicalize_axes([axis], rank)[0] 39 40 41def _canonicalize_axes(axes, rank): 42 rank = _maybe_static(rank) 43 44 if isinstance(rank, core.Tensor): 45 canonicalizer = ( 46 lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis)) 47 else: 48 canonicalizer = lambda axis: axis + rank if axis < 0 else axis 49 50 return [canonicalizer(axis) for axis in axes] 51 52 53def _supports_signature(): 54 return hasattr(inspect, 'signature') 55 56 57def _to_tf_type(dtype): 58 """Converts a native python or numpy type to TF DType. 59 60 Args: 61 dtype: Could be a python type, a numpy type or a TF DType. 62 63 Returns: 64 A tensorflow `DType`. 65 """ 66 return dtypes.as_dtype(dtype) 67 68 69def _to_numpy_type(dtype): 70 """Converts a native python or TF DType to numpy type. 71 72 Args: 73 dtype: Could be a python type, a numpy type or a TF DType. 74 75 Returns: 76 A NumPy `dtype`. 77 """ 78 if isinstance(dtype, dtypes.DType): 79 return dtype.as_numpy_dtype 80 return np.dtype(dtype) 81 82 83def isscalar(val): 84 """Returns whether `val` is a scalar value or scalar Tensor.""" 85 if isinstance(val, np_arrays.ndarray): 86 val = val.data 87 if isinstance(val, core.Tensor): 88 ndims = val.shape.ndims 89 if ndims is not None: 90 return ndims == 0 91 else: 92 return math_ops.equal(array_ops.rank(val), 0) 93 else: 94 return np.isscalar(val) 95 96 97def _has_docstring(f): 98 return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and 99 f.__doc__) 100 101 102def _add_blank_line(s): 103 if s.endswith('\n'): 104 return s + '\n' 105 else: 106 return s + '\n\n' 107 108 109def _np_signature(f): 110 """An enhanced inspect.signature that can handle numpy.ufunc.""" 111 # TODO(wangpeng): consider migrating away from inspect.signature. 112 # inspect.signature is supported in Python 3.3. 113 if not hasattr(inspect, 'signature'): 114 return None 115 if f is None: 116 return None 117 if not isinstance(f, np.ufunc): 118 try: 119 return inspect.signature(f) 120 except ValueError: 121 return None 122 123 def names_from_num(prefix, n): 124 if n <= 0: 125 return [] 126 elif n == 1: 127 return [prefix] 128 else: 129 return [prefix + str(i + 1) for i in range(n)] 130 131 input_names = names_from_num('x', f.nin) 132 output_names = names_from_num('out', f.nout) 133 keyword_only_params = [('where', True), ('casting', 'same_kind'), 134 ('order', 'K'), ('dtype', None), ('subok', True), 135 ('signature', None), ('extobj', None)] 136 params = [] 137 params += [ 138 inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY) 139 for name in input_names 140 ] 141 if f.nout > 1: 142 params += [ 143 inspect.Parameter( 144 name, inspect.Parameter.POSITIONAL_ONLY, default=None) 145 for name in output_names 146 ] 147 params += [ 148 inspect.Parameter( 149 'out', 150 inspect.Parameter.POSITIONAL_OR_KEYWORD, 151 default=None if f.nout == 1 else (None,) * f.nout) 152 ] 153 params += [ 154 inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default) 155 for name, default in keyword_only_params 156 ] 157 return inspect.Signature(params) 158 159 160# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't 161# allow positional-only argument. So we conflate positional-only, keyword-only 162# and positional-or-keyword arguments here. 163def _is_compatible_param_kind(a, b): 164 165 def relax(k): 166 if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY): 167 return inspect.Parameter.POSITIONAL_OR_KEYWORD 168 return k 169 170 return relax(a) == relax(b) 171 172 173def _prepare_np_fun_name_and_fun(np_fun_name, np_fun): 174 """Mutually propagates information between `np_fun_name` and `np_fun`. 175 176 If one is None and the other is not, we'll try to make the former not None in 177 a best effort. 178 179 Args: 180 np_fun_name: name for the np_fun symbol. At least one of np_fun or 181 np_fun_name shoud be set. 182 np_fun: the numpy function whose docstring will be used. 183 184 Returns: 185 Processed `np_fun_name` and `np_fun`. 186 """ 187 if np_fun_name is not None: 188 assert isinstance(np_fun_name, str) 189 if np_fun is not None: 190 assert not isinstance(np_fun, str) 191 if np_fun is None: 192 assert np_fun_name is not None 193 try: 194 np_fun = getattr(np, str(np_fun_name)) 195 except AttributeError: 196 np_fun = None 197 if np_fun_name is None: 198 assert np_fun is not None 199 np_fun_name = np_fun.__name__ 200 return np_fun_name, np_fun 201 202 203def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None, 204 link=None): 205 """Helper to get docs.""" 206 assert np_f or np_fun_name 207 if not np_fun_name: 208 np_fun_name = np_f.__name__ 209 doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name 210 if unsupported_params: 211 doc += 'Unsupported arguments: ' + ', '.join( 212 '`' + name + '`' for name in unsupported_params) + '.\n\n' 213 if _has_docstring(f): 214 doc += f.__doc__ 215 doc = _add_blank_line(doc) 216 # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy 217 # doc according to some global switch. 218 doc = _add_np_doc(doc, np_fun_name, np_f, link=link) 219 return doc 220 221 222_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16') 223 224 225def get_np_doc_form(): 226 """Gets the form of the original numpy docstrings. 227 228 Returns: 229 See `set_np_doc_form` for the list of valid values. 230 """ 231 return _np_doc_form 232 233 234def set_np_doc_form(value): 235 r"""Selects the form of the original numpy docstrings. 236 237 This function sets a global variable that controls how a tf-numpy symbol's 238 docstring should refer to the original numpy docstring. If `value` is 239 `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy 240 docstring. Otherwise, a link to the original numpy docstring will be 241 added. Which numpy version the link points to depends on `value`: 242 * `'stable'`: the current stable version; 243 * `'dev'`: the current development version; 244 * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number, 245 e.g. '1.16'. 246 247 Args: 248 value: the value to set the global variable to. 249 """ 250 global _np_doc_form 251 _np_doc_form = value 252 253 254class Link: 255 256 def __init__(self, v): 257 self.value = v 258 259 260class AliasOf: 261 262 def __init__(self, v): 263 self.value = v 264 265 266class NoLink: 267 pass 268 269 270def generate_link(flag, np_fun_name): 271 """Generates link from numpy function name. 272 273 Args: 274 flag: the flag to control link form. See `set_np_doc_form`. 275 np_fun_name: the numpy function name. 276 277 Returns: 278 A string. 279 """ 280 # Only adds link in this case 281 if flag == 'dev': 282 template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html' 283 elif flag == 'stable': 284 template = ( 285 'https://numpy.org/doc/stable/reference/generated/numpy.%s.html') 286 elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag): 287 # `flag` is the version number 288 template = ('https://numpy.org/doc/' + flag + 289 '/reference/generated/numpy.%s.html') 290 else: 291 return None 292 return template % np_fun_name 293 294 295_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in 296 ('True', 'true', '1')) 297 298 299def is_check_link(): 300 return _is_check_link 301 302 303def set_check_link(value): 304 global _is_check_link 305 _is_check_link = value 306 307 308def _add_np_doc(doc, np_fun_name, np_f, link): 309 """Appends the numpy docstring to `doc`, according to `set_np_doc_form`. 310 311 See `set_np_doc_form` for how it controls the form of the numpy docstring. 312 313 Args: 314 doc: the docstring to be appended to. 315 np_fun_name: the name of the numpy function. 316 np_f: (optional) the numpy function. 317 link: (optional) which link to use. See `np_doc` for details. 318 319 Returns: 320 `doc` with numpy docstring appended. 321 """ 322 flag = get_np_doc_form() 323 if flag == 'inlined': 324 if _has_docstring(np_f): 325 doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name 326 # TODO(wangpeng): It looks like code snippets in numpy doc don't work 327 # correctly with doctest. Fix that and remove the reformatting of the np_f 328 # comment. 329 doc += np_f.__doc__.replace('>>>', '>') 330 elif isinstance(flag, str): 331 if link is None: 332 url = generate_link(flag, np_fun_name) 333 elif isinstance(link, AliasOf): 334 url = generate_link(flag, link.value) 335 elif isinstance(link, Link): 336 url = link.value 337 else: 338 url = None 339 if url is not None: 340 if is_check_link(): 341 # Imports locally because some builds may not have `requests` 342 import requests # pylint: disable=g-import-not-at-top 343 r = requests.head(url) 344 if r.status_code != 200: 345 raise ValueError( 346 f'Check link failed at [{url}] with status code {r.status_code}. ' 347 f'Argument `np_fun_name` is {np_fun_name}.') 348 doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % ( 349 np_fun_name, url) 350 return doc 351 352 353_is_sig_mismatch_an_error = ( 354 os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1')) 355 356 357def is_sig_mismatch_an_error(): 358 return _is_sig_mismatch_an_error 359 360 361def set_is_sig_mismatch_an_error(value): 362 global _is_sig_mismatch_an_error 363 _is_sig_mismatch_an_error = value 364 365 366def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None, 367 link=None): 368 """Attachs numpy docstring to a function. 369 370 Args: 371 np_fun_name: name for the np_fun symbol. At least one of np_fun or 372 np_fun_name shoud be set. 373 np_fun: (optional) the numpy function whose docstring will be used. 374 export: whether to export this symbol under module 375 `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be 376 a function directly under the `numpy` module, not under any submodule of 377 `numpy` (e.g. `numpy.random`). 378 unsupported_params: (optional) the list of parameters not supported 379 by tf.numpy. 380 link: (optional) which link to use. If `None`, a default link generated from 381 `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will 382 be used in place of `np_fun_name` for the link generation. If an instance 383 of `Link`, `link.value` will be used as the whole link. If an instance of 384 `NoLink`, no link will be added. 385 386 Returns: 387 A function decorator that attaches the docstring from `np_fun` to the 388 decorated function. 389 """ 390 np_fun_name_orig, np_fun_orig = np_fun_name, np_fun 391 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 392 np_sig = _np_signature(np_fun) 393 if unsupported_params is None: 394 unsupported_params = [] 395 396 def decorator(f): 397 """The decorator.""" 398 if hasattr(inspect, 'signature') and np_sig is not None: 399 try: 400 sig = inspect.signature(f) 401 except ValueError: 402 sig = None 403 if sig is not None: 404 for name, param in sig.parameters.items(): 405 np_param = np_sig.parameters.get(name) 406 if np_param is None: 407 if is_sig_mismatch_an_error(): 408 raise TypeError( 409 f'Cannot find parameter {name} in the numpy function\'s ' 410 f'signature (which has these parameters: ' 411 f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` ' 412 f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.') 413 else: 414 continue 415 if (is_sig_mismatch_an_error() and 416 not _is_compatible_param_kind(param.kind, np_param.kind)): 417 raise TypeError( 418 f'Parameter {name} is of kind {param.kind} while in numpy it ' 419 f'is of kind {np_param.kind}. Argument `np_fun_name` is ' 420 f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.') 421 has_default = (param.default != inspect.Parameter.empty) 422 np_has_default = (np_param.default != inspect.Parameter.empty) 423 if is_sig_mismatch_an_error() and has_default != np_has_default: 424 raise TypeError( 425 'Parameter {} should{} have a default value. Argument ' 426 '`np_fun_name` is {}. Argument `np_fun` is {}.'.format( 427 name, '' if np_has_default else ' not', np_fun_name_orig, 428 np_fun_orig)) 429 for name in np_sig.parameters: 430 if name not in sig.parameters: 431 unsupported_params.append(name) 432 f.__doc__ = _np_doc_helper( 433 f, np_fun, np_fun_name=np_fun_name, 434 unsupported_params=unsupported_params, link=link) 435 if export: 436 return np_export.np_export(np_fun_name)(f) 437 else: 438 return f 439 440 return decorator 441 442 443def np_doc_only(np_fun_name, np_fun=None, export=True): 444 """Attachs numpy docstring to a function. 445 446 This differs from np_doc in that it doesn't check for a match in signature. 447 448 Args: 449 np_fun_name: name for the np_fun symbol. At least one of np_fun or 450 np_fun_name shoud be set. 451 np_fun: (optional) the numpy function whose docstring will be used. 452 export: whether to export this symbol under module 453 `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a 454 function directly under the `numpy` module, not under any submodule of 455 `numpy` (e.g. `numpy.random`). 456 457 Returns: 458 A function decorator that attaches the docstring from `np_fun` to the 459 decorated function. 460 """ 461 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 462 463 def decorator(f): 464 f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name) 465 if export: 466 return np_export.np_export(np_fun_name)(f) 467 else: 468 return f 469 470 return decorator 471 472 473# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 474@np_doc('finfo') 475def finfo(dtype): 476 """Note that currently it just forwards to the numpy namesake, while 477 tensorflow and numpy dtypes may have different properties.""" 478 return np.finfo(_to_numpy_type(dtype)) 479# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 480 481 482def _maybe_get_dtype(x): 483 """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" 484 # Don't put np.ndarray in this list, because np.result_type looks at the 485 # value (not just dtype) of np.ndarray to decide the result type. 486 if isinstance(x, numbers.Real): 487 return x 488 if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)): 489 return _to_numpy_type(x.dtype) 490 if isinstance(x, dtypes.DType): 491 return x.as_numpy_dtype 492 if isinstance(x, (list, tuple)): 493 raise ValueError( 494 f'Cannot find dtype for type inference from argument `x` of a sequence ' 495 f'type {type(x)}. For sequences, please call this function on each ' 496 f'element individually.') 497 return x 498 499 500# Can't use np_doc because np.result_type is a builtin function. 501@np_doc_only('result_type') 502def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring 503 arrays_and_dtypes = [ 504 _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes) 505 ] 506 if not arrays_and_dtypes: 507 # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. 508 arrays_and_dtypes = [np.asarray([])] 509 return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access 510 511 512def result_type_unary(a, dtype): # pylint: disable=missing-function-docstring 513 """Find the result type from a single input and a dtype.""" 514 if dtype: 515 # We need to let np_utils.result_type decide the dtype, not tf.zeros_like 516 return result_type(dtype) 517 518 # np_utils.result_type treats string inputs as dtype strings, not as strings. 519 # but for unary we want to treat it as a string input. 520 if isinstance(a, str): 521 return np.unicode_ 522 elif isinstance(a, bytes): 523 return np.bytes_ 524 525 # TF and numpy has different interpretations of Python types such as 526 # `float`, so we let `np_utils.result_type` decide. 527 return result_type(a) 528 529 530def _result_type_binary(t1, t2): # pylint: disable=missing-function-docstring 531 """A specialization of result_type for 2 arguments for performance reasons.""" 532 try: 533 return np_dtypes._result_type(_maybe_get_dtype(t1), # pylint: disable=protected-access 534 _maybe_get_dtype(t2)) # pylint: disable=protected-access 535 except ValueError: 536 return result_type(t1, t2) 537 538 539@np_doc('promote_types') 540def promote_types(type1, type2): # pylint: disable=missing-function-docstring 541 type1 = _to_numpy_type(type1) 542 type2 = _to_numpy_type(type2) 543 return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2)) 544 545 546def tf_broadcast(*args): 547 """Broadcast tensors. 548 549 Args: 550 *args: a list of tensors whose shapes are broadcastable against each other. 551 552 Returns: 553 Tensors broadcasted to the common shape. 554 """ 555 if len(args) <= 1: 556 return args 557 sh = array_ops.shape(args[0]) 558 for arg in args[1:]: 559 sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg)) 560 return [array_ops.broadcast_to(arg, sh) for arg in args] 561 562 563# TODO(wangpeng): Move the following functions to a separate file and check for 564# float dtypes in each of them. 565 566 567def get_static_value(x): 568 """A version of tf.get_static_value that returns None on float dtypes. 569 570 It returns None on float dtypes in order to avoid breaking gradients. 571 572 Args: 573 x: a tensor. 574 575 Returns: 576 Same as `tf.get_static_value`, except that it returns None when `x` has a 577 float dtype. 578 """ 579 if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex): 580 return None 581 return tensor_util.constant_value(x) 582 583 584def _maybe_static(x): 585 value = get_static_value(x) 586 if value is None: 587 return x 588 else: 589 return value 590 591 592# All the following functions exist becaues get_static_value can't handle 593# their TF counterparts. 594 595 596def cond(pred, true_fn, false_fn): 597 """A version of tf.cond that tries to evaluate the condition.""" 598 v = get_static_value(pred) 599 if v is None: 600 return control_flow_ops.cond(pred, true_fn, false_fn) 601 if v: 602 return true_fn() 603 else: 604 return false_fn() 605 606 607def add(a, b): 608 """A version of tf.add that eagerly evaluates if possible.""" 609 return _maybe_static(a) + _maybe_static(b) 610 611 612def subtract(a, b): 613 """A version of tf.subtract that eagerly evaluates if possible.""" 614 return _maybe_static(a) - _maybe_static(b) 615 616 617def greater(a, b): 618 """A version of tf.greater that eagerly evaluates if possible.""" 619 return _maybe_static(a) > _maybe_static(b) 620 621 622def greater_equal(a, b): 623 """A version of tf.greater_equal that eagerly evaluates if possible.""" 624 return _maybe_static(a) >= _maybe_static(b) 625 626 627def less_equal(a, b): 628 """A version of tf.less_equal that eagerly evaluates if possible.""" 629 return _maybe_static(a) <= _maybe_static(b) 630 631 632def logical_and(a, b): 633 """A version of tf.logical_and that eagerly evaluates if possible.""" 634 a_value = get_static_value(a) 635 if a_value is not None: 636 if np.isscalar(a_value): 637 if a_value: 638 return _maybe_static(b) 639 else: 640 return a_value 641 else: 642 return a_value & _maybe_static(b) 643 else: 644 return a & _maybe_static(b) 645 646 647def logical_or(a, b): 648 """A version of tf.logical_or that eagerly evaluates if possible.""" 649 a_value = get_static_value(a) 650 if a_value is not None: 651 if np.isscalar(a_value): 652 if a_value: 653 return a_value 654 else: 655 return _maybe_static(b) 656 else: 657 return a_value | _maybe_static(b) 658 else: 659 return a | _maybe_static(b) 660 661 662def getitem(a, slice_spec): 663 """A version of __getitem__ that eagerly evaluates if possible.""" 664 return _maybe_static(a)[slice_spec] 665 666 667def reduce_all(input_tensor, axis=None, keepdims=False): 668 """A version of tf.reduce_all that eagerly evaluates if possible.""" 669 v = get_static_value(input_tensor) 670 if v is None: 671 return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims) 672 else: 673 return v.all(axis=axis, keepdims=keepdims) 674 675 676def reduce_any(input_tensor, axis=None, keepdims=False): 677 """A version of tf.reduce_any that eagerly evaluates if possible.""" 678 v = get_static_value(input_tensor) 679 if v is None: 680 return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims) 681 else: 682 return v.any(axis=axis, keepdims=keepdims) 683 684 685def tf_rank(t): 686 r = t.shape.rank 687 if r is not None: 688 return r 689 return array_ops.rank(t) 690