• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for internal use."""
16# pylint: disable=g-direct-tensorflow-import
17
18import inspect
19import numbers
20import os
21import re
22import numpy as np
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import indexed_slices
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.numpy_ops import np_arrays
31from tensorflow.python.ops.numpy_ops import np_dtypes
32from tensorflow.python.ops.numpy_ops import np_export
33from tensorflow.python.types import core
34from tensorflow.python.util import nest
35
36
37def _canonicalize_axis(axis, rank):
38  return _canonicalize_axes([axis], rank)[0]
39
40
41def _canonicalize_axes(axes, rank):
42  rank = _maybe_static(rank)
43
44  if isinstance(rank, core.Tensor):
45    canonicalizer = (
46        lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
47  else:
48    canonicalizer = lambda axis: axis + rank if axis < 0 else axis
49
50  return [canonicalizer(axis) for axis in axes]
51
52
53def _supports_signature():
54  return hasattr(inspect, 'signature')
55
56
57def _to_tf_type(dtype):
58  """Converts a native python or numpy type to TF DType.
59
60  Args:
61    dtype: Could be a python type, a numpy type or a TF DType.
62
63  Returns:
64    A tensorflow `DType`.
65  """
66  return dtypes.as_dtype(dtype)
67
68
69def _to_numpy_type(dtype):
70  """Converts a native python or TF DType to numpy type.
71
72  Args:
73    dtype: Could be a python type, a numpy type or a TF DType.
74
75  Returns:
76    A NumPy `dtype`.
77  """
78  if isinstance(dtype, dtypes.DType):
79    return dtype.as_numpy_dtype
80  return np.dtype(dtype)
81
82
83def isscalar(val):
84  """Returns whether `val` is a scalar value or scalar Tensor."""
85  if isinstance(val, np_arrays.ndarray):
86    val = val.data
87  if isinstance(val, core.Tensor):
88    ndims = val.shape.ndims
89    if ndims is not None:
90      return ndims == 0
91    else:
92      return math_ops.equal(array_ops.rank(val), 0)
93  else:
94    return np.isscalar(val)
95
96
97def _has_docstring(f):
98  return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
99          f.__doc__)
100
101
102def _add_blank_line(s):
103  if s.endswith('\n'):
104    return s + '\n'
105  else:
106    return s + '\n\n'
107
108
109def _np_signature(f):
110  """An enhanced inspect.signature that can handle numpy.ufunc."""
111  # TODO(wangpeng): consider migrating away from inspect.signature.
112  # inspect.signature is supported in Python 3.3.
113  if not hasattr(inspect, 'signature'):
114    return None
115  if f is None:
116    return None
117  if not isinstance(f, np.ufunc):
118    try:
119      return inspect.signature(f)
120    except ValueError:
121      return None
122
123  def names_from_num(prefix, n):
124    if n <= 0:
125      return []
126    elif n == 1:
127      return [prefix]
128    else:
129      return [prefix + str(i + 1) for i in range(n)]
130
131  input_names = names_from_num('x', f.nin)
132  output_names = names_from_num('out', f.nout)
133  keyword_only_params = [('where', True), ('casting', 'same_kind'),
134                         ('order', 'K'), ('dtype', None), ('subok', True),
135                         ('signature', None), ('extobj', None)]
136  params = []
137  params += [
138      inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
139      for name in input_names
140  ]
141  if f.nout > 1:
142    params += [
143        inspect.Parameter(
144            name, inspect.Parameter.POSITIONAL_ONLY, default=None)
145        for name in output_names
146    ]
147  params += [
148      inspect.Parameter(
149          'out',
150          inspect.Parameter.POSITIONAL_OR_KEYWORD,
151          default=None if f.nout == 1 else (None,) * f.nout)
152  ]
153  params += [
154      inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
155      for name, default in keyword_only_params
156  ]
157  return inspect.Signature(params)
158
159
160# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
161# allow positional-only argument. So we conflate positional-only, keyword-only
162# and positional-or-keyword arguments here.
163def _is_compatible_param_kind(a, b):
164
165  def relax(k):
166    if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
167      return inspect.Parameter.POSITIONAL_OR_KEYWORD
168    return k
169
170  return relax(a) == relax(b)
171
172
173def _prepare_np_fun_name_and_fun(np_fun_name, np_fun):
174  """Mutually propagates information between `np_fun_name` and `np_fun`.
175
176  If one is None and the other is not, we'll try to make the former not None in
177  a best effort.
178
179  Args:
180    np_fun_name: name for the np_fun symbol. At least one of np_fun or
181      np_fun_name shoud be set.
182    np_fun: the numpy function whose docstring will be used.
183
184  Returns:
185    Processed `np_fun_name` and `np_fun`.
186  """
187  if np_fun_name is not None:
188    assert isinstance(np_fun_name, str)
189  if np_fun is not None:
190    assert not isinstance(np_fun, str)
191  if np_fun is None:
192    assert np_fun_name is not None
193    try:
194      np_fun = getattr(np, str(np_fun_name))
195    except AttributeError:
196      np_fun = None
197  if np_fun_name is None:
198    assert np_fun is not None
199    np_fun_name = np_fun.__name__
200  return np_fun_name, np_fun
201
202
203def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None,
204                   link=None):
205  """Helper to get docs."""
206  assert np_f or np_fun_name
207  if not np_fun_name:
208    np_fun_name = np_f.__name__
209  doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name
210  if unsupported_params:
211    doc += 'Unsupported arguments: ' + ', '.join(
212        '`' + name + '`' for name in unsupported_params) + '.\n\n'
213  if _has_docstring(f):
214    doc += f.__doc__
215    doc = _add_blank_line(doc)
216  # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy
217  #   doc according to some global switch.
218  doc = _add_np_doc(doc, np_fun_name, np_f, link=link)
219  return doc
220
221
222_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16')
223
224
225def get_np_doc_form():
226  """Gets the form of the original numpy docstrings.
227
228  Returns:
229    See `set_np_doc_form` for the list of valid values.
230  """
231  return _np_doc_form
232
233
234def set_np_doc_form(value):
235  r"""Selects the form of the original numpy docstrings.
236
237  This function sets a global variable that controls how a tf-numpy symbol's
238  docstring should refer to the original numpy docstring. If `value` is
239  `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy
240  docstring. Otherwise, a link to the original numpy docstring will be
241  added. Which numpy version the link points to depends on `value`:
242  * `'stable'`: the current stable version;
243  * `'dev'`: the current development version;
244  * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number,
245    e.g. '1.16'.
246
247  Args:
248    value: the value to set the global variable to.
249  """
250  global _np_doc_form
251  _np_doc_form = value
252
253
254class Link:
255
256  def __init__(self, v):
257    self.value = v
258
259
260class AliasOf:
261
262  def __init__(self, v):
263    self.value = v
264
265
266class NoLink:
267  pass
268
269
270def generate_link(flag, np_fun_name):
271  """Generates link from numpy function name.
272
273  Args:
274    flag: the flag to control link form. See `set_np_doc_form`.
275    np_fun_name: the numpy function name.
276
277  Returns:
278    A string.
279  """
280  # Only adds link in this case
281  if flag == 'dev':
282    template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html'
283  elif flag == 'stable':
284    template = (
285        'https://numpy.org/doc/stable/reference/generated/numpy.%s.html')
286  elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag):
287    # `flag` is the version number
288    template = ('https://numpy.org/doc/' + flag +
289                '/reference/generated/numpy.%s.html')
290  else:
291    return None
292  return template % np_fun_name
293
294
295_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in
296                  ('True', 'true', '1'))
297
298
299def is_check_link():
300  return _is_check_link
301
302
303def set_check_link(value):
304  global _is_check_link
305  _is_check_link = value
306
307
308def _add_np_doc(doc, np_fun_name, np_f, link):
309  """Appends the numpy docstring to `doc`, according to `set_np_doc_form`.
310
311  See `set_np_doc_form` for how it controls the form of the numpy docstring.
312
313  Args:
314    doc: the docstring to be appended to.
315    np_fun_name: the name of the numpy function.
316    np_f: (optional) the numpy function.
317    link: (optional) which link to use. See `np_doc` for details.
318
319  Returns:
320    `doc` with numpy docstring appended.
321  """
322  flag = get_np_doc_form()
323  if flag == 'inlined':
324    if _has_docstring(np_f):
325      doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name
326      # TODO(wangpeng): It looks like code snippets in numpy doc don't work
327      # correctly with doctest. Fix that and remove the reformatting of the np_f
328      # comment.
329      doc += np_f.__doc__.replace('>>>', '>')
330  elif isinstance(flag, str):
331    if link is None:
332      url = generate_link(flag, np_fun_name)
333    elif isinstance(link, AliasOf):
334      url = generate_link(flag, link.value)
335    elif isinstance(link, Link):
336      url = link.value
337    else:
338      url = None
339    if url is not None:
340      if is_check_link():
341        # Imports locally because some builds may not have `requests`
342        import requests  # pylint: disable=g-import-not-at-top
343        r = requests.head(url)
344        if r.status_code != 200:
345          raise ValueError(
346              f'Check link failed at [{url}] with status code {r.status_code}. '
347              f'Argument `np_fun_name` is {np_fun_name}.')
348      doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % (
349          np_fun_name, url)
350  return doc
351
352
353_is_sig_mismatch_an_error = (
354    os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1'))
355
356
357def is_sig_mismatch_an_error():
358  return _is_sig_mismatch_an_error
359
360
361def set_is_sig_mismatch_an_error(value):
362  global _is_sig_mismatch_an_error
363  _is_sig_mismatch_an_error = value
364
365
366def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None,
367           link=None):
368  """Attachs numpy docstring to a function.
369
370  Args:
371    np_fun_name: name for the np_fun symbol. At least one of np_fun or
372      np_fun_name shoud be set.
373    np_fun: (optional) the numpy function whose docstring will be used.
374    export: whether to export this symbol under module
375      `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be
376      a function directly under the `numpy` module, not under any submodule of
377      `numpy` (e.g. `numpy.random`).
378    unsupported_params: (optional) the list of parameters not supported
379      by tf.numpy.
380    link: (optional) which link to use. If `None`, a default link generated from
381      `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will
382      be used in place of `np_fun_name` for the link generation. If an instance
383      of `Link`, `link.value` will be used as the whole link. If an instance of
384      `NoLink`, no link will be added.
385
386  Returns:
387    A function decorator that attaches the docstring from `np_fun` to the
388    decorated function.
389  """
390  np_fun_name_orig, np_fun_orig = np_fun_name, np_fun
391  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
392  np_sig = _np_signature(np_fun)
393  if unsupported_params is None:
394    unsupported_params = []
395
396  def decorator(f):
397    """The decorator."""
398    if hasattr(inspect, 'signature') and np_sig is not None:
399      try:
400        sig = inspect.signature(f)
401      except ValueError:
402        sig = None
403      if sig is not None:
404        for name, param in sig.parameters.items():
405          np_param = np_sig.parameters.get(name)
406          if np_param is None:
407            if is_sig_mismatch_an_error():
408              raise TypeError(
409                  f'Cannot find parameter {name} in the numpy function\'s '
410                  f'signature (which has these parameters: '
411                  f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` '
412                  f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
413            else:
414              continue
415          if (is_sig_mismatch_an_error() and
416              not _is_compatible_param_kind(param.kind, np_param.kind)):
417            raise TypeError(
418                f'Parameter {name} is of kind {param.kind} while in numpy it '
419                f'is of kind {np_param.kind}. Argument `np_fun_name` is '
420                f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
421          has_default = (param.default != inspect.Parameter.empty)
422          np_has_default = (np_param.default != inspect.Parameter.empty)
423          if is_sig_mismatch_an_error() and has_default != np_has_default:
424            raise TypeError(
425                'Parameter {} should{} have a default value. Argument '
426                '`np_fun_name` is {}. Argument `np_fun` is {}.'.format(
427                    name, '' if np_has_default else ' not', np_fun_name_orig,
428                    np_fun_orig))
429        for name in np_sig.parameters:
430          if name not in sig.parameters:
431            unsupported_params.append(name)
432    f.__doc__ = _np_doc_helper(
433        f, np_fun, np_fun_name=np_fun_name,
434        unsupported_params=unsupported_params, link=link)
435    if export:
436      return np_export.np_export(np_fun_name)(f)
437    else:
438      return f
439
440  return decorator
441
442
443def np_doc_only(np_fun_name, np_fun=None, export=True):
444  """Attachs numpy docstring to a function.
445
446  This differs from np_doc in that it doesn't check for a match in signature.
447
448  Args:
449    np_fun_name: name for the np_fun symbol. At least one of np_fun or
450      np_fun_name shoud be set.
451    np_fun: (optional) the numpy function whose docstring will be used.
452    export: whether to export this symbol under module
453      `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a
454      function directly under the `numpy` module, not under any submodule of
455      `numpy` (e.g. `numpy.random`).
456
457  Returns:
458    A function decorator that attaches the docstring from `np_fun` to the
459    decorated function.
460  """
461  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
462
463  def decorator(f):
464    f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name)
465    if export:
466      return np_export.np_export(np_fun_name)(f)
467    else:
468      return f
469
470  return decorator
471
472
473# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
474@np_doc('finfo')
475def finfo(dtype):
476  """Note that currently it just forwards to the numpy namesake, while
477  tensorflow and numpy dtypes may have different properties."""
478  return np.finfo(_to_numpy_type(dtype))
479# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
480
481
482def _maybe_get_dtype(x):
483  """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
484  # Don't put np.ndarray in this list, because np.result_type looks at the
485  # value (not just dtype) of np.ndarray to decide the result type.
486  if isinstance(x, numbers.Real):
487    return x
488  if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)):
489    return _to_numpy_type(x.dtype)
490  if isinstance(x, dtypes.DType):
491    return x.as_numpy_dtype
492  if isinstance(x, (list, tuple)):
493    raise ValueError(
494        f'Cannot find dtype for type inference from argument `x` of a sequence '
495        f'type {type(x)}. For sequences, please call this function on each '
496        f'element individually.')
497  return x
498
499
500# Can't use np_doc because np.result_type is a builtin function.
501@np_doc_only('result_type')
502def result_type(*arrays_and_dtypes):  # pylint: disable=missing-function-docstring
503  arrays_and_dtypes = [
504      _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
505  ]
506  if not arrays_and_dtypes:
507    # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
508    arrays_and_dtypes = [np.asarray([])]
509  return np_dtypes._result_type(*arrays_and_dtypes)  # pylint: disable=protected-access
510
511
512def result_type_unary(a, dtype):  # pylint: disable=missing-function-docstring
513  """Find the result type from a single input and a dtype."""
514  if dtype:
515    # We need to let np_utils.result_type decide the dtype, not tf.zeros_like
516    return result_type(dtype)
517
518  # np_utils.result_type treats string inputs as dtype strings, not as strings.
519  # but for unary we want to treat it as a string input.
520  if isinstance(a, str):
521    return np.unicode_
522  elif isinstance(a, bytes):
523    return np.bytes_
524
525  # TF and numpy has different interpretations of Python types such as
526  # `float`, so we let `np_utils.result_type` decide.
527  return result_type(a)
528
529
530def _result_type_binary(t1, t2):  # pylint: disable=missing-function-docstring
531  """A specialization of result_type for 2 arguments for performance reasons."""
532  try:
533    return np_dtypes._result_type(_maybe_get_dtype(t1),  # pylint: disable=protected-access
534                                  _maybe_get_dtype(t2))  # pylint: disable=protected-access
535  except ValueError:
536    return result_type(t1, t2)
537
538
539@np_doc('promote_types')
540def promote_types(type1, type2):  # pylint: disable=missing-function-docstring
541  type1 = _to_numpy_type(type1)
542  type2 = _to_numpy_type(type2)
543  return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
544
545
546def tf_broadcast(*args):
547  """Broadcast tensors.
548
549  Args:
550    *args: a list of tensors whose shapes are broadcastable against each other.
551
552  Returns:
553    Tensors broadcasted to the common shape.
554  """
555  if len(args) <= 1:
556    return args
557  sh = array_ops.shape(args[0])
558  for arg in args[1:]:
559    sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
560  return [array_ops.broadcast_to(arg, sh) for arg in args]
561
562
563# TODO(wangpeng): Move the following functions to a separate file and check for
564#   float dtypes in each of them.
565
566
567def get_static_value(x):
568  """A version of tf.get_static_value that returns None on float dtypes.
569
570  It returns None on float dtypes in order to avoid breaking gradients.
571
572  Args:
573    x: a tensor.
574
575  Returns:
576    Same as `tf.get_static_value`, except that it returns None when `x` has a
577    float dtype.
578  """
579  if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
580    return None
581  return tensor_util.constant_value(x)
582
583
584def _maybe_static(x):
585  value = get_static_value(x)
586  if value is None:
587    return x
588  else:
589    return value
590
591
592# All the following functions exist becaues get_static_value can't handle
593# their TF counterparts.
594
595
596def cond(pred, true_fn, false_fn):
597  """A version of tf.cond that tries to evaluate the condition."""
598  v = get_static_value(pred)
599  if v is None:
600    return control_flow_ops.cond(pred, true_fn, false_fn)
601  if v:
602    return true_fn()
603  else:
604    return false_fn()
605
606
607def add(a, b):
608  """A version of tf.add that eagerly evaluates if possible."""
609  return _maybe_static(a) + _maybe_static(b)
610
611
612def subtract(a, b):
613  """A version of tf.subtract that eagerly evaluates if possible."""
614  return _maybe_static(a) - _maybe_static(b)
615
616
617def greater(a, b):
618  """A version of tf.greater that eagerly evaluates if possible."""
619  return _maybe_static(a) > _maybe_static(b)
620
621
622def greater_equal(a, b):
623  """A version of tf.greater_equal that eagerly evaluates if possible."""
624  return _maybe_static(a) >= _maybe_static(b)
625
626
627def less_equal(a, b):
628  """A version of tf.less_equal that eagerly evaluates if possible."""
629  return _maybe_static(a) <= _maybe_static(b)
630
631
632def logical_and(a, b):
633  """A version of tf.logical_and that eagerly evaluates if possible."""
634  a_value = get_static_value(a)
635  if a_value is not None:
636    if np.isscalar(a_value):
637      if a_value:
638        return _maybe_static(b)
639      else:
640        return a_value
641    else:
642      return a_value & _maybe_static(b)
643  else:
644    return a & _maybe_static(b)
645
646
647def logical_or(a, b):
648  """A version of tf.logical_or that eagerly evaluates if possible."""
649  a_value = get_static_value(a)
650  if a_value is not None:
651    if np.isscalar(a_value):
652      if a_value:
653        return a_value
654      else:
655        return _maybe_static(b)
656    else:
657      return a_value | _maybe_static(b)
658  else:
659    return a | _maybe_static(b)
660
661
662def getitem(a, slice_spec):
663  """A version of __getitem__ that eagerly evaluates if possible."""
664  return _maybe_static(a)[slice_spec]
665
666
667def reduce_all(input_tensor, axis=None, keepdims=False):
668  """A version of tf.reduce_all that eagerly evaluates if possible."""
669  v = get_static_value(input_tensor)
670  if v is None:
671    return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
672  else:
673    return v.all(axis=axis, keepdims=keepdims)
674
675
676def reduce_any(input_tensor, axis=None, keepdims=False):
677  """A version of tf.reduce_any that eagerly evaluates if possible."""
678  v = get_static_value(input_tensor)
679  if v is None:
680    return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
681  else:
682    return v.any(axis=axis, keepdims=keepdims)
683
684
685def tf_rank(t):
686  r = t.shape.rank
687  if r is not None:
688    return r
689  return array_ops.rank(t)
690