1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for internal use.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import inspect 23import numbers 24import os 25import re 26import numpy as np 27 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import indexed_slices 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops.numpy_ops import np_arrays 35from tensorflow.python.ops.numpy_ops import np_dtypes 36from tensorflow.python.ops.numpy_ops import np_export 37from tensorflow.python.types import core 38from tensorflow.python.util import nest 39 40 41def _canonicalize_axis(axis, rank): 42 return _canonicalize_axes([axis], rank)[0] 43 44 45def _canonicalize_axes(axes, rank): 46 rank = _maybe_static(rank) 47 48 if isinstance(rank, core.Tensor): 49 canonicalizer = ( 50 lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis)) 51 else: 52 canonicalizer = lambda axis: axis + rank if axis < 0 else axis 53 54 return [canonicalizer(axis) for axis in axes] 55 56 57def _supports_signature(): 58 return hasattr(inspect, 'signature') 59 60 61def _to_tf_type(dtype): 62 """Converts a native python or numpy type to TF DType. 63 64 Args: 65 dtype: Could be a python type, a numpy type or a TF DType. 66 67 Returns: 68 A tensorflow `DType`. 69 """ 70 return dtypes.as_dtype(dtype) 71 72 73def _to_numpy_type(dtype): 74 """Converts a native python or TF DType to numpy type. 75 76 Args: 77 dtype: Could be a python type, a numpy type or a TF DType. 78 79 Returns: 80 A NumPy `dtype`. 81 """ 82 if isinstance(dtype, dtypes.DType): 83 return dtype.as_numpy_dtype 84 return np.dtype(dtype) 85 86 87def isscalar(val): 88 """Returns whether `val` is a scalar value or scalar Tensor.""" 89 if isinstance(val, np_arrays.ndarray): 90 val = val.data 91 if isinstance(val, core.Tensor): 92 ndims = val.shape.ndims 93 if ndims is not None: 94 return ndims == 0 95 else: 96 return math_ops.equal(array_ops.rank(val), 0) 97 else: 98 return np.isscalar(val) 99 100 101def _has_docstring(f): 102 return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and 103 f.__doc__) 104 105 106def _add_blank_line(s): 107 if s.endswith('\n'): 108 return s + '\n' 109 else: 110 return s + '\n\n' 111 112 113def _np_signature(f): 114 """An enhanced inspect.signature that can handle numpy.ufunc.""" 115 # TODO(wangpeng): consider migrating away from inspect.signature. 116 # inspect.signature is supported in Python 3.3. 117 if not hasattr(inspect, 'signature'): 118 return None 119 if f is None: 120 return None 121 if not isinstance(f, np.ufunc): 122 try: 123 return inspect.signature(f) 124 except ValueError: 125 return None 126 127 def names_from_num(prefix, n): 128 if n <= 0: 129 return [] 130 elif n == 1: 131 return [prefix] 132 else: 133 return [prefix + str(i + 1) for i in range(n)] 134 135 input_names = names_from_num('x', f.nin) 136 output_names = names_from_num('out', f.nout) 137 keyword_only_params = [('where', True), ('casting', 'same_kind'), 138 ('order', 'K'), ('dtype', None), ('subok', True), 139 ('signature', None), ('extobj', None)] 140 params = [] 141 params += [ 142 inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY) 143 for name in input_names 144 ] 145 if f.nout > 1: 146 params += [ 147 inspect.Parameter( 148 name, inspect.Parameter.POSITIONAL_ONLY, default=None) 149 for name in output_names 150 ] 151 params += [ 152 inspect.Parameter( 153 'out', 154 inspect.Parameter.POSITIONAL_OR_KEYWORD, 155 default=None if f.nout == 1 else (None,) * f.nout) 156 ] 157 params += [ 158 inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default) 159 for name, default in keyword_only_params 160 ] 161 return inspect.Signature(params) 162 163 164# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't 165# allow positional-only argument. So we conflate positional-only, keyword-only 166# and positional-or-keyword arguments here. 167def _is_compatible_param_kind(a, b): 168 169 def relax(k): 170 if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY): 171 return inspect.Parameter.POSITIONAL_OR_KEYWORD 172 return k 173 174 return relax(a) == relax(b) 175 176 177def _prepare_np_fun_name_and_fun(np_fun_name, np_fun): 178 """Mutually propagates information between `np_fun_name` and `np_fun`. 179 180 If one is None and the other is not, we'll try to make the former not None in 181 a best effort. 182 183 Args: 184 np_fun_name: name for the np_fun symbol. At least one of np_fun or 185 np_fun_name shoud be set. 186 np_fun: the numpy function whose docstring will be used. 187 188 Returns: 189 Processed `np_fun_name` and `np_fun`. 190 """ 191 if np_fun_name is not None: 192 assert isinstance(np_fun_name, str) 193 if np_fun is not None: 194 assert not isinstance(np_fun, str) 195 if np_fun is None: 196 assert np_fun_name is not None 197 try: 198 np_fun = getattr(np, str(np_fun_name)) 199 except AttributeError: 200 np_fun = None 201 if np_fun_name is None: 202 assert np_fun is not None 203 np_fun_name = np_fun.__name__ 204 return np_fun_name, np_fun 205 206 207def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None, 208 link=None): 209 """Helper to get docs.""" 210 assert np_f or np_fun_name 211 if not np_fun_name: 212 np_fun_name = np_f.__name__ 213 doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name 214 if unsupported_params: 215 doc += 'Unsupported arguments: ' + ', '.join( 216 '`' + name + '`' for name in unsupported_params) + '.\n\n' 217 if _has_docstring(f): 218 doc += f.__doc__ 219 doc = _add_blank_line(doc) 220 # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy 221 # doc according to some global switch. 222 doc = _add_np_doc(doc, np_fun_name, np_f, link=link) 223 return doc 224 225 226_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16') 227 228 229def get_np_doc_form(): 230 """Gets the form of the original numpy docstrings. 231 232 Returns: 233 See `set_np_doc_form` for the list of valid values. 234 """ 235 return _np_doc_form 236 237 238def set_np_doc_form(value): 239 r"""Selects the form of the original numpy docstrings. 240 241 This function sets a global variable that controls how a tf-numpy symbol's 242 docstring should refer to the original numpy docstring. If `value` is 243 `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy 244 docstring. Otherwise, a link to the original numpy docstring will be 245 added. Which numpy version the link points to depends on `value`: 246 * `'stable'`: the current stable version; 247 * `'dev'`: the current development version; 248 * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number, 249 e.g. '1.16'. 250 251 Args: 252 value: the value to set the global variable to. 253 """ 254 global _np_doc_form 255 _np_doc_form = value 256 257 258class Link: 259 260 def __init__(self, v): 261 self.value = v 262 263 264class AliasOf: 265 266 def __init__(self, v): 267 self.value = v 268 269 270class NoLink: 271 pass 272 273 274def generate_link(flag, np_fun_name): 275 """Generates link from numpy function name. 276 277 Args: 278 flag: the flag to control link form. See `set_np_doc_form`. 279 np_fun_name: the numpy function name. 280 281 Returns: 282 A string. 283 """ 284 # Only adds link in this case 285 if flag == 'dev': 286 template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html' 287 elif flag == 'stable': 288 template = ( 289 'https://numpy.org/doc/stable/reference/generated/numpy.%s.html') 290 elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag): 291 # `flag` is the version number 292 template = ('https://numpy.org/doc/' + flag + 293 '/reference/generated/numpy.%s.html') 294 else: 295 return None 296 return template % np_fun_name 297 298 299_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in 300 ('True', 'true', '1')) 301 302 303def is_check_link(): 304 return _is_check_link 305 306 307def set_check_link(value): 308 global _is_check_link 309 _is_check_link = value 310 311 312def _add_np_doc(doc, np_fun_name, np_f, link): 313 """Appends the numpy docstring to `doc`, according to `set_np_doc_form`. 314 315 See `set_np_doc_form` for how it controls the form of the numpy docstring. 316 317 Args: 318 doc: the docstring to be appended to. 319 np_fun_name: the name of the numpy function. 320 np_f: (optional) the numpy function. 321 link: (optional) which link to use. See `np_doc` for details. 322 323 Returns: 324 `doc` with numpy docstring appended. 325 """ 326 flag = get_np_doc_form() 327 if flag == 'inlined': 328 if _has_docstring(np_f): 329 doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name 330 # TODO(wangpeng): It looks like code snippets in numpy doc don't work 331 # correctly with doctest. Fix that and remove the reformatting of the np_f 332 # comment. 333 doc += np_f.__doc__.replace('>>>', '>') 334 elif isinstance(flag, str): 335 if link is None: 336 url = generate_link(flag, np_fun_name) 337 elif isinstance(link, AliasOf): 338 url = generate_link(flag, link.value) 339 elif isinstance(link, Link): 340 url = link.value 341 else: 342 url = None 343 if url is not None: 344 if is_check_link(): 345 # Imports locally because some builds may not have `requests` 346 import requests # pylint: disable=g-import-not-at-top 347 r = requests.head(url) 348 if r.status_code != 200: 349 raise ValueError("Can't open link for %s: %s" % (np_fun_name, url)) 350 doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % ( 351 np_fun_name, url) 352 return doc 353 354 355_is_sig_mismatch_an_error = ( 356 os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1')) 357 358 359def is_sig_mismatch_an_error(): 360 return _is_sig_mismatch_an_error 361 362 363def set_is_sig_mismatch_an_error(value): 364 global _is_sig_mismatch_an_error 365 _is_sig_mismatch_an_error = value 366 367 368def np_doc(np_fun_name, np_fun=None, export=True, link=None): 369 """Attachs numpy docstring to a function. 370 371 Args: 372 np_fun_name: name for the np_fun symbol. At least one of np_fun or 373 np_fun_name shoud be set. 374 np_fun: (optional) the numpy function whose docstring will be used. 375 export: whether to export this symbol under module 376 `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be 377 a function directly under the `numpy` module, not under any submodule of 378 `numpy` (e.g. `numpy.random`). 379 link: (optional) which link to use. If `None`, a default link generated from 380 `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will 381 be used in place of `np_fun_name` for the link generation. If an instance 382 of `Link`, `link.value` will be used as the whole link. If an instance of 383 `NoLink`, no link will be added. 384 385 Returns: 386 A function decorator that attaches the docstring from `np_fun` to the 387 decorated function. 388 """ 389 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 390 np_sig = _np_signature(np_fun) 391 392 def decorator(f): 393 """The decorator.""" 394 unsupported_params = [] 395 if hasattr(inspect, 'signature') and np_sig is not None: 396 try: 397 sig = inspect.signature(f) 398 except ValueError: 399 sig = None 400 if sig is not None: 401 for name, param in sig.parameters.items(): 402 np_param = np_sig.parameters.get(name) 403 if np_param is None: 404 if is_sig_mismatch_an_error(): 405 raise TypeError( 406 'Cannot find parameter "%s" in the numpy function\'s ' 407 'signature (which has these parameters: %s)' % 408 (name, list(np_sig.parameters.keys()))) 409 else: 410 continue 411 if (is_sig_mismatch_an_error() and 412 not _is_compatible_param_kind(param.kind, np_param.kind)): 413 raise TypeError( 414 'Parameter "%s" is of kind %s while in numpy it is of ' 415 'kind %s' % (name, param.kind, np_param.kind)) 416 has_default = (param.default != inspect.Parameter.empty) 417 np_has_default = (np_param.default != inspect.Parameter.empty) 418 if is_sig_mismatch_an_error() and has_default != np_has_default: 419 raise TypeError('Parameter "%s" should%s have a default value' % 420 (name, '' if np_has_default else ' not')) 421 for name in np_sig.parameters: 422 if name not in sig.parameters: 423 unsupported_params.append(name) 424 f.__doc__ = _np_doc_helper( 425 f, np_fun, np_fun_name=np_fun_name, 426 unsupported_params=unsupported_params, link=link) 427 if export: 428 return np_export.np_export(np_fun_name)(f) 429 else: 430 return f 431 432 return decorator 433 434 435def np_doc_only(np_fun_name, np_fun=None, export=True): 436 """Attachs numpy docstring to a function. 437 438 This differs from np_doc in that it doesn't check for a match in signature. 439 440 Args: 441 np_fun_name: name for the np_fun symbol. At least one of np_fun or 442 np_fun_name shoud be set. 443 np_fun: (optional) the numpy function whose docstring will be used. 444 export: whether to export this symbol under module 445 `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a 446 function directly under the `numpy` module, not under any submodule of 447 `numpy` (e.g. `numpy.random`). 448 449 Returns: 450 A function decorator that attaches the docstring from `np_fun` to the 451 decorated function. 452 """ 453 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 454 455 def decorator(f): 456 f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name) 457 if export: 458 return np_export.np_export(np_fun_name)(f) 459 else: 460 return f 461 462 return decorator 463 464 465# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 466@np_doc('finfo') 467def finfo(dtype): 468 """Note that currently it just forwards to the numpy namesake, while 469 tensorflow and numpy dtypes may have different properties.""" 470 return np.finfo(_to_numpy_type(dtype)) 471# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 472 473 474def _maybe_get_dtype(x): 475 """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" 476 # Don't put np.ndarray in this list, because np.result_type looks at the 477 # value (not just dtype) of np.ndarray to decide the result type. 478 if isinstance(x, numbers.Real): 479 return x 480 if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)): 481 return _to_numpy_type(x.dtype) 482 if isinstance(x, dtypes.DType): 483 return x.as_numpy_dtype 484 if isinstance(x, (list, tuple)): 485 raise ValueError('Got sequence') 486 return x 487 488 489# Can't use np_doc because np.result_type is a builtin function. 490@np_doc_only('result_type') 491def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring 492 arrays_and_dtypes = [ 493 _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes) 494 ] 495 if not arrays_and_dtypes: 496 # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. 497 arrays_and_dtypes = [np.asarray([])] 498 return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access 499 500 501def _result_type_binary(t1, t2): # pylint: disable=missing-function-docstring 502 """A specialization of result_type for 2 arguments for performance reasons.""" 503 try: 504 return np_dtypes._result_type(_maybe_get_dtype(t1), # pylint: disable=protected-access 505 _maybe_get_dtype(t2)) # pylint: disable=protected-access 506 except ValueError: 507 return result_type(t1, t2) 508 509 510@np_doc('promote_types') 511def promote_types(type1, type2): # pylint: disable=missing-function-docstring 512 type1 = _to_numpy_type(type1) 513 type2 = _to_numpy_type(type2) 514 return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2)) 515 516 517def tf_broadcast(*args): 518 """Broadcast tensors. 519 520 Args: 521 *args: a list of tensors whose shapes are broadcastable against each other. 522 523 Returns: 524 Tensors broadcasted to the common shape. 525 """ 526 if len(args) <= 1: 527 return args 528 sh = array_ops.shape(args[0]) 529 for arg in args[1:]: 530 sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg)) 531 return [array_ops.broadcast_to(arg, sh) for arg in args] 532 533 534# TODO(wangpeng): Move the following functions to a separate file and check for 535# float dtypes in each of them. 536 537 538def get_static_value(x): 539 """A version of tf.get_static_value that returns None on float dtypes. 540 541 It returns None on float dtypes in order to avoid breaking gradients. 542 543 Args: 544 x: a tensor. 545 546 Returns: 547 Same as `tf.get_static_value`, except that it returns None when `x` has a 548 float dtype. 549 """ 550 if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex): 551 return None 552 return tensor_util.constant_value(x) 553 554 555def _maybe_static(x): 556 value = get_static_value(x) 557 if value is None: 558 return x 559 else: 560 return value 561 562 563# All the following functions exist becaues get_static_value can't handle 564# their TF counterparts. 565 566 567def cond(pred, true_fn, false_fn): 568 """A version of tf.cond that tries to evaluate the condition.""" 569 v = get_static_value(pred) 570 if v is None: 571 return control_flow_ops.cond(pred, true_fn, false_fn) 572 if v: 573 return true_fn() 574 else: 575 return false_fn() 576 577 578def add(a, b): 579 """A version of tf.add that eagerly evaluates if possible.""" 580 return _maybe_static(a) + _maybe_static(b) 581 582 583def subtract(a, b): 584 """A version of tf.subtract that eagerly evaluates if possible.""" 585 return _maybe_static(a) - _maybe_static(b) 586 587 588def greater(a, b): 589 """A version of tf.greater that eagerly evaluates if possible.""" 590 return _maybe_static(a) > _maybe_static(b) 591 592 593def greater_equal(a, b): 594 """A version of tf.greater_equal that eagerly evaluates if possible.""" 595 return _maybe_static(a) >= _maybe_static(b) 596 597 598def less_equal(a, b): 599 """A version of tf.less_equal that eagerly evaluates if possible.""" 600 return _maybe_static(a) <= _maybe_static(b) 601 602 603def logical_and(a, b): 604 """A version of tf.logical_and that eagerly evaluates if possible.""" 605 a_value = get_static_value(a) 606 if a_value is not None: 607 if np.isscalar(a_value): 608 if a_value: 609 return _maybe_static(b) 610 else: 611 return a_value 612 else: 613 return a_value & _maybe_static(b) 614 else: 615 return a & _maybe_static(b) 616 617 618def logical_or(a, b): 619 """A version of tf.logical_or that eagerly evaluates if possible.""" 620 a_value = get_static_value(a) 621 if a_value is not None: 622 if np.isscalar(a_value): 623 if a_value: 624 return a_value 625 else: 626 return _maybe_static(b) 627 else: 628 return a_value | _maybe_static(b) 629 else: 630 return a | _maybe_static(b) 631 632 633def getitem(a, slice_spec): 634 """A version of __getitem__ that eagerly evaluates if possible.""" 635 return _maybe_static(a)[slice_spec] 636 637 638def reduce_all(input_tensor, axis=None, keepdims=False): 639 """A version of tf.reduce_all that eagerly evaluates if possible.""" 640 v = get_static_value(input_tensor) 641 if v is None: 642 return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims) 643 else: 644 return v.all(axis=axis, keepdims=keepdims) 645 646 647def reduce_any(input_tensor, axis=None, keepdims=False): 648 """A version of tf.reduce_any that eagerly evaluates if possible.""" 649 v = get_static_value(input_tensor) 650 if v is None: 651 return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims) 652 else: 653 return v.any(axis=axis, keepdims=keepdims) 654 655 656def tf_rank(t): 657 r = t.shape.rank 658 if r is not None: 659 return r 660 return array_ops.rank(t) 661