• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for internal use."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import inspect
23import numbers
24import os
25import re
26import numpy as np
27
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import indexed_slices
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops.numpy_ops import np_arrays
35from tensorflow.python.ops.numpy_ops import np_dtypes
36from tensorflow.python.ops.numpy_ops import np_export
37from tensorflow.python.types import core
38from tensorflow.python.util import nest
39
40
41def _canonicalize_axis(axis, rank):
42  return _canonicalize_axes([axis], rank)[0]
43
44
45def _canonicalize_axes(axes, rank):
46  rank = _maybe_static(rank)
47
48  if isinstance(rank, core.Tensor):
49    canonicalizer = (
50        lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
51  else:
52    canonicalizer = lambda axis: axis + rank if axis < 0 else axis
53
54  return [canonicalizer(axis) for axis in axes]
55
56
57def _supports_signature():
58  return hasattr(inspect, 'signature')
59
60
61def _to_tf_type(dtype):
62  """Converts a native python or numpy type to TF DType.
63
64  Args:
65    dtype: Could be a python type, a numpy type or a TF DType.
66
67  Returns:
68    A tensorflow `DType`.
69  """
70  return dtypes.as_dtype(dtype)
71
72
73def _to_numpy_type(dtype):
74  """Converts a native python or TF DType to numpy type.
75
76  Args:
77    dtype: Could be a python type, a numpy type or a TF DType.
78
79  Returns:
80    A NumPy `dtype`.
81  """
82  if isinstance(dtype, dtypes.DType):
83    return dtype.as_numpy_dtype
84  return np.dtype(dtype)
85
86
87def isscalar(val):
88  """Returns whether `val` is a scalar value or scalar Tensor."""
89  if isinstance(val, np_arrays.ndarray):
90    val = val.data
91  if isinstance(val, core.Tensor):
92    ndims = val.shape.ndims
93    if ndims is not None:
94      return ndims == 0
95    else:
96      return math_ops.equal(array_ops.rank(val), 0)
97  else:
98    return np.isscalar(val)
99
100
101def _has_docstring(f):
102  return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
103          f.__doc__)
104
105
106def _add_blank_line(s):
107  if s.endswith('\n'):
108    return s + '\n'
109  else:
110    return s + '\n\n'
111
112
113def _np_signature(f):
114  """An enhanced inspect.signature that can handle numpy.ufunc."""
115  # TODO(wangpeng): consider migrating away from inspect.signature.
116  # inspect.signature is supported in Python 3.3.
117  if not hasattr(inspect, 'signature'):
118    return None
119  if f is None:
120    return None
121  if not isinstance(f, np.ufunc):
122    try:
123      return inspect.signature(f)
124    except ValueError:
125      return None
126
127  def names_from_num(prefix, n):
128    if n <= 0:
129      return []
130    elif n == 1:
131      return [prefix]
132    else:
133      return [prefix + str(i + 1) for i in range(n)]
134
135  input_names = names_from_num('x', f.nin)
136  output_names = names_from_num('out', f.nout)
137  keyword_only_params = [('where', True), ('casting', 'same_kind'),
138                         ('order', 'K'), ('dtype', None), ('subok', True),
139                         ('signature', None), ('extobj', None)]
140  params = []
141  params += [
142      inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
143      for name in input_names
144  ]
145  if f.nout > 1:
146    params += [
147        inspect.Parameter(
148            name, inspect.Parameter.POSITIONAL_ONLY, default=None)
149        for name in output_names
150    ]
151  params += [
152      inspect.Parameter(
153          'out',
154          inspect.Parameter.POSITIONAL_OR_KEYWORD,
155          default=None if f.nout == 1 else (None,) * f.nout)
156  ]
157  params += [
158      inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
159      for name, default in keyword_only_params
160  ]
161  return inspect.Signature(params)
162
163
164# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
165# allow positional-only argument. So we conflate positional-only, keyword-only
166# and positional-or-keyword arguments here.
167def _is_compatible_param_kind(a, b):
168
169  def relax(k):
170    if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
171      return inspect.Parameter.POSITIONAL_OR_KEYWORD
172    return k
173
174  return relax(a) == relax(b)
175
176
177def _prepare_np_fun_name_and_fun(np_fun_name, np_fun):
178  """Mutually propagates information between `np_fun_name` and `np_fun`.
179
180  If one is None and the other is not, we'll try to make the former not None in
181  a best effort.
182
183  Args:
184    np_fun_name: name for the np_fun symbol. At least one of np_fun or
185      np_fun_name shoud be set.
186    np_fun: the numpy function whose docstring will be used.
187
188  Returns:
189    Processed `np_fun_name` and `np_fun`.
190  """
191  if np_fun_name is not None:
192    assert isinstance(np_fun_name, str)
193  if np_fun is not None:
194    assert not isinstance(np_fun, str)
195  if np_fun is None:
196    assert np_fun_name is not None
197    try:
198      np_fun = getattr(np, str(np_fun_name))
199    except AttributeError:
200      np_fun = None
201  if np_fun_name is None:
202    assert np_fun is not None
203    np_fun_name = np_fun.__name__
204  return np_fun_name, np_fun
205
206
207def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None,
208                   link=None):
209  """Helper to get docs."""
210  assert np_f or np_fun_name
211  if not np_fun_name:
212    np_fun_name = np_f.__name__
213  doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name
214  if unsupported_params:
215    doc += 'Unsupported arguments: ' + ', '.join(
216        '`' + name + '`' for name in unsupported_params) + '.\n\n'
217  if _has_docstring(f):
218    doc += f.__doc__
219    doc = _add_blank_line(doc)
220  # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy
221  #   doc according to some global switch.
222  doc = _add_np_doc(doc, np_fun_name, np_f, link=link)
223  return doc
224
225
226_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16')
227
228
229def get_np_doc_form():
230  """Gets the form of the original numpy docstrings.
231
232  Returns:
233    See `set_np_doc_form` for the list of valid values.
234  """
235  return _np_doc_form
236
237
238def set_np_doc_form(value):
239  r"""Selects the form of the original numpy docstrings.
240
241  This function sets a global variable that controls how a tf-numpy symbol's
242  docstring should refer to the original numpy docstring. If `value` is
243  `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy
244  docstring. Otherwise, a link to the original numpy docstring will be
245  added. Which numpy version the link points to depends on `value`:
246  * `'stable'`: the current stable version;
247  * `'dev'`: the current development version;
248  * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number,
249    e.g. '1.16'.
250
251  Args:
252    value: the value to set the global variable to.
253  """
254  global _np_doc_form
255  _np_doc_form = value
256
257
258class Link:
259
260  def __init__(self, v):
261    self.value = v
262
263
264class AliasOf:
265
266  def __init__(self, v):
267    self.value = v
268
269
270class NoLink:
271  pass
272
273
274def generate_link(flag, np_fun_name):
275  """Generates link from numpy function name.
276
277  Args:
278    flag: the flag to control link form. See `set_np_doc_form`.
279    np_fun_name: the numpy function name.
280
281  Returns:
282    A string.
283  """
284  # Only adds link in this case
285  if flag == 'dev':
286    template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html'
287  elif flag == 'stable':
288    template = (
289        'https://numpy.org/doc/stable/reference/generated/numpy.%s.html')
290  elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag):
291    # `flag` is the version number
292    template = ('https://numpy.org/doc/' + flag +
293                '/reference/generated/numpy.%s.html')
294  else:
295    return None
296  return template % np_fun_name
297
298
299_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in
300                  ('True', 'true', '1'))
301
302
303def is_check_link():
304  return _is_check_link
305
306
307def set_check_link(value):
308  global _is_check_link
309  _is_check_link = value
310
311
312def _add_np_doc(doc, np_fun_name, np_f, link):
313  """Appends the numpy docstring to `doc`, according to `set_np_doc_form`.
314
315  See `set_np_doc_form` for how it controls the form of the numpy docstring.
316
317  Args:
318    doc: the docstring to be appended to.
319    np_fun_name: the name of the numpy function.
320    np_f: (optional) the numpy function.
321    link: (optional) which link to use. See `np_doc` for details.
322
323  Returns:
324    `doc` with numpy docstring appended.
325  """
326  flag = get_np_doc_form()
327  if flag == 'inlined':
328    if _has_docstring(np_f):
329      doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name
330      # TODO(wangpeng): It looks like code snippets in numpy doc don't work
331      # correctly with doctest. Fix that and remove the reformatting of the np_f
332      # comment.
333      doc += np_f.__doc__.replace('>>>', '>')
334  elif isinstance(flag, str):
335    if link is None:
336      url = generate_link(flag, np_fun_name)
337    elif isinstance(link, AliasOf):
338      url = generate_link(flag, link.value)
339    elif isinstance(link, Link):
340      url = link.value
341    else:
342      url = None
343    if url is not None:
344      if is_check_link():
345        # Imports locally because some builds may not have `requests`
346        import requests  # pylint: disable=g-import-not-at-top
347        r = requests.head(url)
348        if r.status_code != 200:
349          raise ValueError(
350              f'Check link failed at [{url}] with status code {r.status_code}. '
351              f'Argument `np_fun_name` is {np_fun_name}.')
352      doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % (
353          np_fun_name, url)
354  return doc
355
356
357_is_sig_mismatch_an_error = (
358    os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1'))
359
360
361def is_sig_mismatch_an_error():
362  return _is_sig_mismatch_an_error
363
364
365def set_is_sig_mismatch_an_error(value):
366  global _is_sig_mismatch_an_error
367  _is_sig_mismatch_an_error = value
368
369
370def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None,
371           link=None):
372  """Attachs numpy docstring to a function.
373
374  Args:
375    np_fun_name: name for the np_fun symbol. At least one of np_fun or
376      np_fun_name shoud be set.
377    np_fun: (optional) the numpy function whose docstring will be used.
378    export: whether to export this symbol under module
379      `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be
380      a function directly under the `numpy` module, not under any submodule of
381      `numpy` (e.g. `numpy.random`).
382    unsupported_params: (optional) the list of parameters not supported
383      by tf.numpy.
384    link: (optional) which link to use. If `None`, a default link generated from
385      `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will
386      be used in place of `np_fun_name` for the link generation. If an instance
387      of `Link`, `link.value` will be used as the whole link. If an instance of
388      `NoLink`, no link will be added.
389
390  Returns:
391    A function decorator that attaches the docstring from `np_fun` to the
392    decorated function.
393  """
394  np_fun_name_orig, np_fun_orig = np_fun_name, np_fun
395  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
396  np_sig = _np_signature(np_fun)
397  if unsupported_params is None:
398    unsupported_params = []
399
400  def decorator(f):
401    """The decorator."""
402    if hasattr(inspect, 'signature') and np_sig is not None:
403      try:
404        sig = inspect.signature(f)
405      except ValueError:
406        sig = None
407      if sig is not None:
408        for name, param in sig.parameters.items():
409          np_param = np_sig.parameters.get(name)
410          if np_param is None:
411            if is_sig_mismatch_an_error():
412              raise TypeError(
413                  f'Cannot find parameter {name} in the numpy function\'s '
414                  f'signature (which has these parameters: '
415                  f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` '
416                  f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
417            else:
418              continue
419          if (is_sig_mismatch_an_error() and
420              not _is_compatible_param_kind(param.kind, np_param.kind)):
421            raise TypeError(
422                f'Parameter {name} is of kind {param.kind} while in numpy it '
423                f'is of kind {np_param.kind}. Argument `np_fun_name` is '
424                f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
425          has_default = (param.default != inspect.Parameter.empty)
426          np_has_default = (np_param.default != inspect.Parameter.empty)
427          if is_sig_mismatch_an_error() and has_default != np_has_default:
428            raise TypeError(
429                'Parameter {} should{} have a default value. Argument '
430                '`np_fun_name` is {}. Argument `np_fun` is {}.'.format(
431                    name, '' if np_has_default else ' not', np_fun_name_orig,
432                    np_fun_orig))
433        for name in np_sig.parameters:
434          if name not in sig.parameters:
435            unsupported_params.append(name)
436    f.__doc__ = _np_doc_helper(
437        f, np_fun, np_fun_name=np_fun_name,
438        unsupported_params=unsupported_params, link=link)
439    if export:
440      return np_export.np_export(np_fun_name)(f)
441    else:
442      return f
443
444  return decorator
445
446
447def np_doc_only(np_fun_name, np_fun=None, export=True):
448  """Attachs numpy docstring to a function.
449
450  This differs from np_doc in that it doesn't check for a match in signature.
451
452  Args:
453    np_fun_name: name for the np_fun symbol. At least one of np_fun or
454      np_fun_name shoud be set.
455    np_fun: (optional) the numpy function whose docstring will be used.
456    export: whether to export this symbol under module
457      `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a
458      function directly under the `numpy` module, not under any submodule of
459      `numpy` (e.g. `numpy.random`).
460
461  Returns:
462    A function decorator that attaches the docstring from `np_fun` to the
463    decorated function.
464  """
465  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
466
467  def decorator(f):
468    f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name)
469    if export:
470      return np_export.np_export(np_fun_name)(f)
471    else:
472      return f
473
474  return decorator
475
476
477# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
478@np_doc('finfo')
479def finfo(dtype):
480  """Note that currently it just forwards to the numpy namesake, while
481  tensorflow and numpy dtypes may have different properties."""
482  return np.finfo(_to_numpy_type(dtype))
483# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
484
485
486def _maybe_get_dtype(x):
487  """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
488  # Don't put np.ndarray in this list, because np.result_type looks at the
489  # value (not just dtype) of np.ndarray to decide the result type.
490  if isinstance(x, numbers.Real):
491    return x
492  if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)):
493    return _to_numpy_type(x.dtype)
494  if isinstance(x, dtypes.DType):
495    return x.as_numpy_dtype
496  if isinstance(x, (list, tuple)):
497    raise ValueError(
498        f'Cannot find dtype for type inference from argument `x` of a sequence '
499        f'type {type(x)}. For sequences, please call this function on each '
500        f'element individually.')
501  return x
502
503
504# Can't use np_doc because np.result_type is a builtin function.
505@np_doc_only('result_type')
506def result_type(*arrays_and_dtypes):  # pylint: disable=missing-function-docstring
507  arrays_and_dtypes = [
508      _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
509  ]
510  if not arrays_and_dtypes:
511    # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
512    arrays_and_dtypes = [np.asarray([])]
513  return np_dtypes._result_type(*arrays_and_dtypes)  # pylint: disable=protected-access
514
515
516def _result_type_binary(t1, t2):  # pylint: disable=missing-function-docstring
517  """A specialization of result_type for 2 arguments for performance reasons."""
518  try:
519    return np_dtypes._result_type(_maybe_get_dtype(t1),  # pylint: disable=protected-access
520                                  _maybe_get_dtype(t2))  # pylint: disable=protected-access
521  except ValueError:
522    return result_type(t1, t2)
523
524
525@np_doc('promote_types')
526def promote_types(type1, type2):  # pylint: disable=missing-function-docstring
527  type1 = _to_numpy_type(type1)
528  type2 = _to_numpy_type(type2)
529  return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
530
531
532def tf_broadcast(*args):
533  """Broadcast tensors.
534
535  Args:
536    *args: a list of tensors whose shapes are broadcastable against each other.
537
538  Returns:
539    Tensors broadcasted to the common shape.
540  """
541  if len(args) <= 1:
542    return args
543  sh = array_ops.shape(args[0])
544  for arg in args[1:]:
545    sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
546  return [array_ops.broadcast_to(arg, sh) for arg in args]
547
548
549# TODO(wangpeng): Move the following functions to a separate file and check for
550#   float dtypes in each of them.
551
552
553def get_static_value(x):
554  """A version of tf.get_static_value that returns None on float dtypes.
555
556  It returns None on float dtypes in order to avoid breaking gradients.
557
558  Args:
559    x: a tensor.
560
561  Returns:
562    Same as `tf.get_static_value`, except that it returns None when `x` has a
563    float dtype.
564  """
565  if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
566    return None
567  return tensor_util.constant_value(x)
568
569
570def _maybe_static(x):
571  value = get_static_value(x)
572  if value is None:
573    return x
574  else:
575    return value
576
577
578# All the following functions exist becaues get_static_value can't handle
579# their TF counterparts.
580
581
582def cond(pred, true_fn, false_fn):
583  """A version of tf.cond that tries to evaluate the condition."""
584  v = get_static_value(pred)
585  if v is None:
586    return control_flow_ops.cond(pred, true_fn, false_fn)
587  if v:
588    return true_fn()
589  else:
590    return false_fn()
591
592
593def add(a, b):
594  """A version of tf.add that eagerly evaluates if possible."""
595  return _maybe_static(a) + _maybe_static(b)
596
597
598def subtract(a, b):
599  """A version of tf.subtract that eagerly evaluates if possible."""
600  return _maybe_static(a) - _maybe_static(b)
601
602
603def greater(a, b):
604  """A version of tf.greater that eagerly evaluates if possible."""
605  return _maybe_static(a) > _maybe_static(b)
606
607
608def greater_equal(a, b):
609  """A version of tf.greater_equal that eagerly evaluates if possible."""
610  return _maybe_static(a) >= _maybe_static(b)
611
612
613def less_equal(a, b):
614  """A version of tf.less_equal that eagerly evaluates if possible."""
615  return _maybe_static(a) <= _maybe_static(b)
616
617
618def logical_and(a, b):
619  """A version of tf.logical_and that eagerly evaluates if possible."""
620  a_value = get_static_value(a)
621  if a_value is not None:
622    if np.isscalar(a_value):
623      if a_value:
624        return _maybe_static(b)
625      else:
626        return a_value
627    else:
628      return a_value & _maybe_static(b)
629  else:
630    return a & _maybe_static(b)
631
632
633def logical_or(a, b):
634  """A version of tf.logical_or that eagerly evaluates if possible."""
635  a_value = get_static_value(a)
636  if a_value is not None:
637    if np.isscalar(a_value):
638      if a_value:
639        return a_value
640      else:
641        return _maybe_static(b)
642    else:
643      return a_value | _maybe_static(b)
644  else:
645    return a | _maybe_static(b)
646
647
648def getitem(a, slice_spec):
649  """A version of __getitem__ that eagerly evaluates if possible."""
650  return _maybe_static(a)[slice_spec]
651
652
653def reduce_all(input_tensor, axis=None, keepdims=False):
654  """A version of tf.reduce_all that eagerly evaluates if possible."""
655  v = get_static_value(input_tensor)
656  if v is None:
657    return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
658  else:
659    return v.all(axis=axis, keepdims=keepdims)
660
661
662def reduce_any(input_tensor, axis=None, keepdims=False):
663  """A version of tf.reduce_any that eagerly evaluates if possible."""
664  v = get_static_value(input_tensor)
665  if v is None:
666    return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
667  else:
668    return v.any(axis=axis, keepdims=keepdims)
669
670
671def tf_rank(t):
672  r = t.shape.rank
673  if r is not None:
674    return r
675  return array_ops.rank(t)
676