• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common array methods."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import enum
23import functools
24import math
25import numbers
26import numpy as np
27import six
28
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import clip_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import linalg_ops
37from tensorflow.python.ops import manip_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import sort_ops
40from tensorflow.python.ops.numpy_ops import np_arrays
41from tensorflow.python.ops.numpy_ops import np_dtypes
42from tensorflow.python.ops.numpy_ops import np_export
43from tensorflow.python.ops.numpy_ops import np_utils
44from tensorflow.python.util import nest
45
46
47newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis)
48
49
50@np_utils.np_doc('empty')
51def empty(shape, dtype=float):  # pylint: disable=redefined-outer-name
52  return zeros(shape, dtype)
53
54
55@np_utils.np_doc('empty_like')
56def empty_like(a, dtype=None):
57  return zeros_like(a, dtype)
58
59
60@np_utils.np_doc('zeros')
61def zeros(shape, dtype=float):  # pylint: disable=redefined-outer-name
62  dtype = (
63      np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
64  return array_ops.zeros(shape, dtype=dtype)
65
66
67@np_utils.np_doc('zeros_like')
68def zeros_like(a, dtype=None):  # pylint: disable=missing-docstring
69  if dtype is None:
70    # We need to let np_utils.result_type decide the dtype, not tf.zeros_like
71    dtype = np_utils.result_type(a)
72  else:
73    # TF and numpy has different interpretations of Python types such as
74    # `float`, so we let `np_utils.result_type` decide.
75    dtype = np_utils.result_type(dtype)
76  dtype = dtypes.as_dtype(dtype)  # Work around b/149877262
77  return array_ops.zeros_like(a, dtype)
78
79
80@np_utils.np_doc('ones')
81def ones(shape, dtype=float):  # pylint: disable=redefined-outer-name
82  if dtype:
83    dtype = np_utils.result_type(dtype)
84  return array_ops.ones(shape, dtype=dtype)
85
86
87@np_utils.np_doc('ones_like')
88def ones_like(a, dtype=None):
89  if dtype is None:
90    dtype = np_utils.result_type(a)
91  else:
92    dtype = np_utils.result_type(dtype)
93  return array_ops.ones_like(a, dtype)
94
95
96@np_utils.np_doc('eye')
97def eye(N, M=None, k=0, dtype=float):  # pylint: disable=invalid-name,missing-docstring
98  if dtype:
99    dtype = np_utils.result_type(dtype)
100  if not M:
101    M = N
102  # Making sure N, M and k are `int`
103  N = int(N)
104  M = int(M)
105  k = int(k)
106  if k >= M or -k >= N:
107    # tf.linalg.diag will raise an error in this case
108    return zeros([N, M], dtype=dtype)
109  if k == 0:
110    return linalg_ops.eye(N, M, dtype=dtype)
111  # We need the precise length, otherwise tf.linalg.diag will raise an error
112  diag_len = min(N, M)
113  if k > 0:
114    if N >= M:
115      diag_len -= k
116    elif N + k > M:
117      diag_len = M - k
118  elif k <= 0:
119    if M >= N:
120      diag_len += k
121    elif M - k > N:
122      diag_len = N + k
123  diagonal_ = array_ops.ones([diag_len], dtype=dtype)
124  return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)
125
126
127@np_utils.np_doc('identity')
128def identity(n, dtype=float):
129  return eye(N=n, M=n, dtype=dtype)
130
131
132@np_utils.np_doc('full')
133def full(shape, fill_value, dtype=None):  # pylint: disable=redefined-outer-name
134  if not isinstance(shape, np_arrays.ndarray):
135    shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
136  shape = atleast_1d(shape)
137  fill_value = asarray(fill_value, dtype=dtype)
138  return array_ops.broadcast_to(fill_value, shape)
139
140
141# Using doc only here since np full_like signature doesn't seem to have the
142# shape argument (even though it exists in the documentation online).
143@np_utils.np_doc_only('full_like')
144def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):  # pylint: disable=missing-docstring,redefined-outer-name
145  """order, subok and shape arguments mustn't be changed."""
146  if order != 'K':
147    raise ValueError('Non-standard orders are not supported.')
148  if not subok:
149    raise ValueError('subok being False is not supported.')
150  if shape:
151    raise ValueError('Overriding the shape is not supported.')
152
153  a = asarray(a)
154  dtype = dtype or np_utils.result_type(a)
155  fill_value = asarray(fill_value, dtype=dtype)
156  return array_ops.broadcast_to(fill_value, array_ops.shape(a))
157
158
159def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
160  """Main implementation of np.array()."""
161  result_t = val
162
163  if not isinstance(result_t, ops.Tensor):
164    if not dtype:
165      dtype = np_utils.result_type(result_t)
166    # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
167    # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
168    # while np.array allows them. We need to convert-then-cast.
169
170    # EagerTensor conversion complains about "mixed types" when converting
171    # tensors with no dtype information. This is because it infers types based
172    # on one selected item in the list. So e.g. when converting [2., 2j]
173    # to a tensor, it will select float32 as the inferred type and not be able
174    # to convert the list to a float 32 tensor.
175    # Since we have some information about the final dtype we care about, we
176    # supply that information so that convert_to_tensor will do best-effort
177    # conversion to that dtype first.
178    result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype)
179    result_t = math_ops.cast(result_t, dtype=dtype)
180  elif dtype:
181    result_t = math_ops.cast(result_t, dtype)
182
183  if copy:
184    result_t = array_ops.identity(result_t)
185
186  if ndmin == 0:
187    return result_t
188
189  ndims = array_ops.rank(result_t)
190
191  def true_fn():
192    old_shape = array_ops.shape(result_t)
193    new_shape = array_ops.concat(
194        [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0)
195    return array_ops.reshape(result_t, new_shape)
196
197  result_t = np_utils.cond(
198      np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
199  return result_t
200
201
202# TODO(wangpeng): investigate whether we can make `copy` default to False.
203# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
204@np_utils.np_doc_only('array')
205def array(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
206  """Since Tensors are immutable, a copy is made only if val is placed on a
207
208  different device than the current one. Even if `copy` is False, a new Tensor
209  may need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
210  is an ndarray or a Tensor.
211  """  # pylint:disable=g-docstring-missing-newline
212  if dtype:
213    dtype = np_utils.result_type(dtype)
214  return _array_internal(val, dtype, copy, ndmin)
215
216
217# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
218
219
220@np_utils.np_doc('asarray')
221def asarray(a, dtype=None):
222  if dtype:
223    dtype = np_utils.result_type(dtype)
224  if isinstance(a, np_arrays.ndarray) and (
225      not dtype or dtype == a.dtype.as_numpy_dtype):
226    return a
227  return array(a, dtype, copy=False)
228
229
230@np_utils.np_doc('asanyarray')
231def asanyarray(a, dtype=None):
232  return asarray(a, dtype)
233
234
235@np_utils.np_doc('ascontiguousarray')
236def ascontiguousarray(a, dtype=None):
237  return array(a, dtype, ndmin=1)
238
239
240# Numerical ranges.
241@np_utils.np_doc('arange')
242def arange(start, stop=None, step=1, dtype=None):
243  """Returns `step`-separated values in the range [start, stop).
244
245  Args:
246    start: Start of the interval. Included in the range.
247    stop: End of the interval. If not specified, `start` is treated as 0 and
248      `start` value is used as `stop`. If specified, it is not included in the
249      range if `step` is integer. When `step` is floating point, it may or may
250      not be included.
251    step: The difference between 2 consecutive values in the output range. It is
252      recommended to use `linspace` instead of using non-integer values for
253      `step`.
254    dtype: Optional. Type of the resulting ndarray. Could be a python type, a
255      NumPy type or a TensorFlow `DType`. If not provided, the largest type of
256      `start`, `stop`, `step` is used.
257
258  Raises:
259    ValueError: If step is zero.
260  """
261  if not step:
262    raise ValueError('step must be non-zero.')
263  if dtype:
264    dtype = np_utils.result_type(dtype)
265  else:
266    if stop is None:
267      dtype = np_utils.result_type(start, step)
268    else:
269      dtype = np_utils.result_type(start, step, stop)
270  if step > 0 and ((stop is not None and start > stop) or
271                   (stop is None and start < 0)):
272    return array([], dtype=dtype)
273  if step < 0 and ((stop is not None and start < stop) or
274                   (stop is None and start > 0)):
275    return array([], dtype=dtype)
276  # TODO(srbs): There are some bugs when start or stop is float type and dtype
277  # is integer type.
278  return math_ops.cast(
279      math_ops.range(start, limit=stop, delta=step), dtype=dtype)
280
281
282# Building matrices.
283@np_utils.np_doc('diag')
284def diag(v, k=0):  # pylint: disable=missing-docstring
285  """Raises an error if input is not 1- or 2-d."""
286  v = asarray(v)
287  v_rank = array_ops.rank(v)
288
289  v.shape.with_rank_at_most(2)
290
291  # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during
292  # tracing time if the shape is known.
293  control_flow_ops.Assert(
294      np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)),
295      [v_rank])
296
297  def _diag(v, k):
298    return np_utils.cond(
299        math_ops.equal(array_ops.size(v), 0),
300        lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype),
301        lambda: array_ops.matrix_diag(v, k=k))
302
303  def _diag_part(v, k):
304    v_shape = array_ops.shape(v)
305    v, k = np_utils.cond(
306        np_utils.logical_or(
307            np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
308            np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
309        ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
310    result = array_ops.matrix_diag_part(v, k=k)
311    return result
312
313  result = np_utils.cond(
314      math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
315  return result
316
317
318@np_utils.np_doc('diagonal')
319def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
320  a = asarray(a)
321
322  maybe_rank = a.shape.rank
323  if maybe_rank is not None and offset == 0 and (
324      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
325                                                   axis2 == -1):
326    return array_ops.matrix_diag_part(a)
327
328  a = moveaxis(a, (axis1, axis2), (-2, -1))
329
330  a_shape = array_ops.shape(a)
331
332  def _zeros():  # pylint: disable=missing-docstring
333    return (array_ops.zeros(
334        array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)
335
336  # All zeros since diag_part doesn't handle all possible k (aka offset).
337  # Written this way since cond will run shape inference on both branches,
338  # and diag_part shape inference will fail when offset is out of bounds.
339  a, offset = np_utils.cond(
340      np_utils.logical_or(
341          np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)),
342          np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
343      ), _zeros, lambda: (a, offset))
344
345  a = array_ops.matrix_diag_part(a, k=offset)
346  return a
347
348
349@np_utils.np_doc('diagflat')
350def diagflat(v, k=0):
351  v = asarray(v)
352  return diag(array_ops.reshape(v, [-1]), k)
353
354
355def _promote_dtype(*arrays):
356  dtype = np_utils.result_type(*arrays)
357  def _fast_asarray(a):
358    if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype:
359      return a
360    return _array_internal(a, dtype=dtype, copy=False)
361  return [_fast_asarray(a) for a in arrays]
362
363
364def _promote_dtype_binary(t1, t2):
365  dtype = np_utils._result_type_binary(t1, t2)  # pylint: disable=protected-access
366  if not(
367      isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype):
368    t1 = _array_internal(t1, dtype=dtype, copy=False)
369  if not(
370      isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype):
371    t2 = _array_internal(t2, dtype=dtype, copy=False)
372  return t1, t2
373
374
375@np_utils.np_doc('all')
376def all(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
377  a = asarray(a, dtype=bool)
378  return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims)
379
380
381@np_utils.np_doc('any')
382def any(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
383  a = asarray(a, dtype=bool)
384  return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims)
385
386
387@np_utils.np_doc('compress')
388def compress(condition, a, axis=None):  # pylint: disable=redefined-outer-name,missing-function-docstring
389  condition = asarray(condition, dtype=bool)
390  a = asarray(a)
391
392  if condition.ndim != 1:
393    raise ValueError('condition must be a 1-d array.')
394  # `np.compress` treats scalars as 1-d arrays.
395  if a.ndim == 0:
396    a = ravel(a)
397
398  if axis is None:
399    a = ravel(a)
400    axis = 0
401
402  if axis < 0:
403    axis += a.ndim
404
405  assert axis >= 0 and axis < a.ndim
406
407  # `tf.boolean_mask` requires the first dimensions of array and condition to
408  # match. `np.compress` pads condition with False when it is shorter.
409  condition_t = condition
410  a_t = a
411  if condition.shape[0] < a.shape[axis]:
412    padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False)
413    condition_t = array_ops.concat([condition_t, padding], axis=0)
414  return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)
415
416
417@np_utils.np_doc('copy')
418def copy(a):
419  return array(a, copy=True)
420
421
422def _maybe_promote_to_int(a):
423  if dtypes.as_dtype(a.dtype).is_integer:
424    # If a is an integer type and its precision is less than that of `int`,
425    # the output type will be `int`.
426    a_numpy_dtype = a.dtype.as_numpy_dtype
427    output_type = np.promote_types(a_numpy_dtype, int)
428    if output_type != a_numpy_dtype:
429      a = asarray(a, dtype=output_type)
430
431  return a
432
433
434@np_utils.np_doc('cumprod')
435def cumprod(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
436  a = asarray(a, dtype=dtype)
437
438  if dtype is None:
439    a = _maybe_promote_to_int(a)
440
441  # If axis is None, the input is flattened.
442  if axis is None:
443    a = ravel(a)
444    axis = 0
445  elif axis < 0:
446    axis += array_ops.rank(a)
447  return math_ops.cumprod(a, axis)
448
449
450@np_utils.np_doc('cumsum')
451def cumsum(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
452  a = asarray(a, dtype=dtype)
453
454  if dtype is None:
455    a = _maybe_promote_to_int(a)
456
457  # If axis is None, the input is flattened.
458  if axis is None:
459    a = ravel(a)
460    axis = 0
461  elif axis < 0:
462    axis += array_ops.rank(a)
463  return math_ops.cumsum(a, axis)
464
465
466@np_utils.np_doc('imag')
467def imag(val):
468  val = asarray(val)
469  # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always
470  # return an ndarray.
471  return math_ops.imag(val)
472
473
474_TO_INT_ = 0
475_TO_FLOAT = 1
476
477
478def _reduce(tf_fn,
479            a,
480            axis=None,
481            dtype=None,
482            keepdims=None,
483            promote_int=_TO_INT_,
484            tf_bool_fn=None,
485            preserve_bool=False):
486  """A general reduction function.
487
488  Args:
489    tf_fn: the TF reduction function.
490    a: the array to be reduced.
491    axis: (optional) the axis along which to do the reduction. If None, all
492      dimensions are reduced.
493    dtype: (optional) the dtype of the result.
494    keepdims: (optional) whether to keep the reduced dimension(s).
495    promote_int: how to promote integer and bool inputs. There are three
496      choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2)
497      `_TO_FLOAT` always promotes them to a float type (determined by
498      dtypes.default_float_type); (3) None: don't promote.
499    tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
500      only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
501      is `np.bool_` and `preserve_bool` is True.
502    preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype
503      is `np.bool_` (some reductions such as np.sum convert bools to integers,
504      while others such as np.max preserve bools.
505
506  Returns:
507    An ndarray.
508  """
509  if dtype:
510    dtype = np_utils.result_type(dtype)
511  if keepdims is None:
512    keepdims = False
513  a = asarray(a, dtype=dtype)
514  if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and
515      tf_bool_fn is not None):
516    return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims)
517  if dtype is None:
518    dtype = a.dtype.as_numpy_dtype
519    if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
520      if promote_int == _TO_INT_:
521        # If a is an integer/bool type and whose bit width is less than np.int_,
522        # numpy up-casts it to np.int_ based on the documentation at
523        # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html
524        if dtype == np.bool_:
525          is_signed = True
526          width = 8  # We can use any number here that is less than 64
527        else:
528          is_signed = np.issubdtype(dtype, np.signedinteger)
529          width = np.iinfo(dtype).bits
530        # Numpy int_ and uint are defined as 'long' and 'unsigned long', so
531        # should have the same bit width.
532        if width < np.iinfo(np.int_).bits:
533          if is_signed:
534            dtype = np.int_
535          else:
536            dtype = np.uint
537          a = math_ops.cast(a, dtype)
538      elif promote_int == _TO_FLOAT:
539        a = math_ops.cast(a, np_dtypes.default_float_type())
540
541  if isinstance(axis, ops.Tensor) and axis.dtype not in (
542      dtypes.int32, dtypes.int64):
543    axis = math_ops.cast(axis, dtypes.int64)
544
545  return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims)
546
547
548# TODO (DarrenZhang01): Add `axis` support to the `size` API.
549@np_utils.np_doc('size')
550def size(x, axis=None):  # pylint: disable=missing-docstring
551  if axis is not None:
552    raise NotImplementedError('axis argument is not supported in the current '
553                              '`np.size` implementation')
554  if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
555    return 1
556  x = asarray(x)
557  if x.shape.is_fully_defined():
558    return np.prod(x.shape.as_list(), dtype=int)
559  else:
560    return array_ops.size_v2(x)
561
562
563@np_utils.np_doc('sum')
564def sum(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=redefined-builtin
565  return _reduce(
566      math_ops.reduce_sum,
567      a,
568      axis=axis,
569      dtype=dtype,
570      keepdims=keepdims,
571      tf_bool_fn=math_ops.reduce_any)
572
573
574@np_utils.np_doc('prod')
575def prod(a, axis=None, dtype=None, keepdims=None):
576  return _reduce(
577      math_ops.reduce_prod,
578      a,
579      axis=axis,
580      dtype=dtype,
581      keepdims=keepdims,
582      tf_bool_fn=math_ops.reduce_all)
583
584
585@np_utils.np_doc('mean', unsupported_params=['out'])
586def mean(a, axis=None, dtype=None, out=None, keepdims=None):
587  if out is not None:
588    raise ValueError('Setting out is not supported.')
589  return _reduce(
590      math_ops.reduce_mean,
591      a,
592      axis=axis,
593      dtype=dtype,
594      keepdims=keepdims,
595      promote_int=_TO_FLOAT)
596
597
598@np_utils.np_doc('amax', unsupported_params=['out'])
599def amax(a, axis=None, out=None, keepdims=None):
600  if out is not None:
601    raise ValueError('Setting out is not supported.')
602  return _reduce(
603      math_ops.reduce_max,
604      a,
605      axis=axis,
606      dtype=None,
607      keepdims=keepdims,
608      promote_int=None,
609      tf_bool_fn=math_ops.reduce_any,
610      preserve_bool=True)
611
612
613@np_utils.np_doc('amin', unsupported_params=['out'])
614def amin(a, axis=None, out=None, keepdims=None):
615  if out is not None:
616    raise ValueError('Setting out is not supported.')
617  return _reduce(
618      math_ops.reduce_min,
619      a,
620      axis=axis,
621      dtype=None,
622      keepdims=keepdims,
623      promote_int=None,
624      tf_bool_fn=math_ops.reduce_all,
625      preserve_bool=True)
626
627
628@np_utils.np_doc('var')
629def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):  # pylint: disable=missing-docstring
630  if dtype:
631    working_dtype = np_utils.result_type(a, dtype)
632  else:
633    working_dtype = None
634  if out is not None:
635    raise ValueError('Setting out is not supported.')
636  if ddof != 0:
637    # TF reduce_variance doesn't support ddof, so calculate it using raw ops.
638    def reduce_fn(input_tensor, axis, keepdims):
639      means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True)
640      centered = input_tensor - means
641      if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
642        centered = math_ops.cast(
643            math_ops.real(centered * math_ops.conj(centered)),
644            input_tensor.dtype)
645      else:
646        centered = math_ops.square(centered)
647      squared_deviations = math_ops.reduce_sum(
648          centered, axis=axis, keepdims=keepdims)
649
650      if axis is None:
651        n = array_ops.size(input_tensor)
652      else:
653        if axis < 0:
654          axis += array_ops.rank(input_tensor)
655        n = math_ops.reduce_prod(
656            array_ops.gather(array_ops.shape(input_tensor), axis))
657      n = math_ops.cast(n - ddof, input_tensor.dtype)
658
659      return math_ops.cast(math_ops.divide(squared_deviations, n), dtype)
660  else:
661    reduce_fn = math_ops.reduce_variance
662
663  result = _reduce(
664      reduce_fn,
665      a,
666      axis=axis,
667      dtype=working_dtype,
668      keepdims=keepdims,
669      promote_int=_TO_FLOAT)
670  if dtype:
671    result = math_ops.cast(result, dtype)
672  return result
673
674
675@np_utils.np_doc('std')
676def std(a, axis=None, keepdims=None):  # pylint: disable=missing-function-docstring
677  return _reduce(
678      math_ops.reduce_std,
679      a,
680      axis=axis,
681      dtype=None,
682      keepdims=keepdims,
683      promote_int=_TO_FLOAT)
684
685
686@np_utils.np_doc('ravel')
687def ravel(a):  # pylint: disable=missing-docstring
688  a = asarray(a)
689  return array_ops.reshape(a, [-1])
690
691
692@np_utils.np_doc('real')
693def real(val):
694  val = asarray(val)
695  # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
696  # return an ndarray.
697  return math_ops.real(val)
698
699
700@np_utils.np_doc('repeat')
701def repeat(a, repeats, axis=None):  # pylint: disable=missing-docstring
702  a = asarray(a)
703  original_shape = a._shape_as_list()  # pylint: disable=protected-access
704  # Best effort recovery of the shape.
705  known_shape = original_shape is not None and None not in original_shape
706  if known_shape:
707    if not original_shape:
708      original_shape = (repeats,)
709    else:
710      repeats_np = np.ravel(np.array(repeats))
711      if repeats_np.size == 1:
712        repeats_np = repeats_np.item()
713        if axis is None:
714          original_shape = (repeats_np * np.prod(original_shape),)
715        else:
716          original_shape[axis] = repeats_np * original_shape[axis]
717      else:
718        if axis is None:
719          original_shape = (repeats_np.sum(),)
720        else:
721          original_shape[axis] = repeats_np.sum()
722
723  repeats = asarray(repeats)
724  result = array_ops.repeat(a, repeats, axis)
725  if known_shape:
726    result.set_shape(original_shape)
727
728  return result
729
730
731@np_utils.np_doc('around')
732def around(a, decimals=0):  # pylint: disable=missing-docstring
733  a = asarray(a)
734  dtype = a.dtype.as_numpy_dtype
735  factor = math.pow(10, decimals)
736  if np.issubdtype(dtype, np.inexact):
737    factor = math_ops.cast(factor, dtype)
738  else:
739    # Use float as the working dtype when a.dtype is exact (e.g. integer),
740    # because `decimals` can be negative.
741    float_dtype = np_dtypes.default_float_type()
742    a = a.astype(float_dtype)
743    factor = math_ops.cast(factor, float_dtype)
744  a = math_ops.multiply(a, factor)
745  a = math_ops.round(a)
746  a = math_ops.divide(a, factor)
747  return a.astype(dtype)
748
749
750setattr(np_arrays.ndarray, '__round__', around)
751
752
753@np_utils.np_doc('reshape')
754def reshape(a, newshape, order='C'):
755  """order argument can only b 'C' or 'F'."""
756  if order not in {'C', 'F'}:
757    raise ValueError('Unsupported order argument {}'.format(order))
758
759  a = asarray(a)
760  if isinstance(newshape, int):
761    newshape = [newshape]
762
763  if order == 'F':
764    r = array_ops.transpose(
765        array_ops.reshape(array_ops.transpose(a), newshape[::-1]))
766  else:
767    r = array_ops.reshape(a, newshape)
768
769  return r
770
771
772def _reshape_method_wrapper(a, *newshape, **kwargs):
773  order = kwargs.pop('order', 'C')
774  if kwargs:
775    raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))
776
777  if len(newshape) == 1 and not isinstance(newshape[0], int):
778    newshape = newshape[0]
779
780  return reshape(a, newshape, order=order)
781
782
783@np_utils.np_doc('expand_dims')
784def expand_dims(a, axis):
785  a = asarray(a)
786  return array_ops.expand_dims(a, axis=axis)
787
788
789@np_utils.np_doc('squeeze')
790def squeeze(a, axis=None):
791  a = asarray(a)
792  return array_ops.squeeze(a, axis)
793
794
795@np_utils.np_doc('transpose')
796def transpose(a, axes=None):
797  a = asarray(a)
798  if axes is not None:
799    axes = asarray(axes)
800  return array_ops.transpose(a=a, perm=axes)
801
802
803@np_utils.np_doc('swapaxes')
804def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
805  a = asarray(a)
806  def adjust_axes(axes, rank):
807    def f(x):
808      if isinstance(x, int):
809        if x < 0:
810          x = x + rank
811      else:
812        x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x)
813      return x
814    return nest.map_structure(f, axes)
815
816  if (a.shape.rank is not None and
817      isinstance(axis1, int) and isinstance(axis2, int)):
818    # This branch makes sure `perm` is statically known, to avoid a
819    # not-compile-time-constant XLA error.
820    a_rank = a.shape.rank
821    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
822    perm = list(range(a_rank))
823    perm[axis1] = axis2
824    perm[axis2] = axis1
825  else:
826    a_rank = array_ops.rank(a)
827    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
828    perm = math_ops.range(a_rank)
829    perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
830                                           [axis2, axis1])
831  a = array_ops.transpose(a, perm)
832  return a
833
834
835@np_utils.np_doc('moveaxis')
836def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
837  """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
838  if not source and not destination:
839    return a
840
841  a = asarray(a)
842
843  if isinstance(source, int):
844    source = (source,)
845  if isinstance(destination, int):
846    destination = (destination,)
847  if len(source) != len(destination):
848    raise ValueError('The lengths of source and destination must equal')
849
850  a_rank = np_utils._maybe_static(array_ops.rank(a))  # pylint: disable=protected-access
851
852  def _correct_axis(axis, rank):
853    if axis < 0:
854      return axis + rank
855    return axis
856
857  source = tuple(_correct_axis(axis, a_rank) for axis in source)
858  destination = tuple(_correct_axis(axis, a_rank) for axis in destination)
859
860  if a.shape.rank is not None:
861    perm = [i for i in range(a_rank) if i not in source]
862    for dest, src in sorted(zip(destination, source)):
863      assert dest <= len(perm)
864      perm.insert(dest, src)
865  else:
866    r = math_ops.range(a_rank)
867
868    def _remove_indices(a, b):
869      """Remove indices (`b`) from `a`."""
870      items = array_ops.unstack(sort_ops.sort(array_ops.stack(b)), num=len(b))
871
872      i = 0
873      result = []
874
875      for item in items:
876        result.append(a[i:item])
877        i = item + 1
878
879      result.append(a[i:])
880
881      return array_ops.concat(result, 0)
882
883    minus_sources = _remove_indices(r, source)
884    minus_dest = _remove_indices(r, destination)
885
886    perm = array_ops.scatter_nd(
887        array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank])
888    perm = array_ops.tensor_scatter_update(
889        perm, array_ops.expand_dims(destination, 1), source)
890  a = array_ops.transpose(a, perm)
891
892  return a
893
894
895@np_utils.np_doc('pad')
896def pad(array, pad_width, mode, **kwargs):  # pylint: disable=redefined-outer-name
897  """Only supports modes 'constant', 'reflect' and 'symmetric' currently."""
898  constant_values = kwargs.get('constant_values', 0)
899  if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'):
900    raise ValueError('Unsupported padding mode: ' + mode)
901  mode = mode.upper()
902  array = asarray(array)
903  pad_width = asarray(pad_width, dtype=dtypes.int32)
904  return array_ops.pad(
905      tensor=array,
906      paddings=pad_width,
907      mode=mode,
908      constant_values=constant_values)
909
910
911@np_utils.np_doc('take')
912def take(a, indices, axis=None, out=None, mode='clip'):
913  """out argument is not supported, and default mode is clip."""
914  if out is not None:
915    raise ValueError('out argument is not supported in take.')
916
917  if mode not in {'raise', 'clip', 'wrap'}:
918    raise ValueError("Invalid mode '{}' for take".format(mode))
919
920  a = asarray(a)
921  indices = asarray(indices)
922
923  if axis is None:
924    a = array_ops.reshape(a, [-1])
925    axis = 0
926
927  axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
928  if mode == 'clip':
929    indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
930  elif mode == 'wrap':
931    indices = math_ops.floormod(indices, axis_size)
932  else:
933    raise ValueError("The 'raise' mode to take is not supported.")
934
935  return array_ops.gather(a, indices, axis=axis)
936
937
938@np_utils.np_doc_only('where')
939def where(condition, x=None, y=None):
940  """Raises ValueError if exactly one of x or y is not None."""
941  condition = asarray(condition, dtype=np.bool_)
942  if x is None and y is None:
943    return nonzero(condition)
944  elif x is not None and y is not None:
945    x, y = _promote_dtype(x, y)
946    return array_ops.where_v2(condition, x, y)
947  raise ValueError('Both x and y must be ndarrays, or both must be None.')
948
949
950@np_utils.np_doc('select')
951def select(condlist, choicelist, default=0):  # pylint: disable=missing-docstring
952  if len(condlist) != len(choicelist):
953    msg = 'condlist must have length equal to choicelist ({} vs {})'
954    raise ValueError(msg.format(len(condlist), len(choicelist)))
955  if not condlist:
956    raise ValueError('condlist must be non-empty')
957  choices = _promote_dtype(default, *choicelist)
958  choicelist = choices[1:]
959  output = choices[0]
960  # The traversal is in reverse order so we can return the first value in
961  # choicelist where condlist is True.
962  for cond, choice in zip(condlist[::-1], choicelist[::-1]):
963    output = where(cond, choice, output)
964  return output
965
966
967@np_utils.np_doc('shape', link=np_utils.Link(
968    'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html'))
969def shape(a):
970  a = asarray(a)
971  return a.shape
972
973
974@np_utils.np_doc('ndim', link=np_utils.NoLink())
975def ndim(a):
976  a = asarray(a)
977  return a.ndim
978
979
980@np_utils.np_doc('isscalar')
981def isscalar(num):
982  return ndim(num) == 0
983
984
985def _boundaries_to_sizes(a, boundaries, axis):
986  """Converting boundaries of splits to sizes of splits.
987
988  Args:
989    a: the array to be split.
990    boundaries: the boundaries, as in np.split.
991    axis: the axis along which to split.
992
993  Returns:
994    A list of sizes of the splits, as in tf.split.
995  """
996  if axis >= len(a.shape):
997    raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape))
998  total_size = a.shape[axis]
999  sizes = []
1000  sizes_sum = 0
1001  prev = 0
1002  for i, b in enumerate(boundaries):
1003    size = b - prev
1004    if size < 0:
1005      raise ValueError('The %s-th boundary %s is smaller than the previous '
1006                       'boundary %s' % (i, b, prev))
1007    size = min(size, max(0, total_size - sizes_sum))
1008    sizes.append(size)
1009    sizes_sum += size
1010    prev = b
1011  sizes.append(max(0, total_size - sizes_sum))
1012  return sizes
1013
1014
1015@np_utils.np_doc('split')
1016def split(ary, indices_or_sections, axis=0):
1017  ary = asarray(ary)
1018  if not isinstance(indices_or_sections, six.integer_types):
1019    indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis)
1020  return array_ops.split(ary, indices_or_sections, axis=axis)
1021
1022
1023def _split_on_axis(np_fun_name, axis):
1024
1025  @np_utils.np_doc(np_fun_name)
1026  def f(ary, indices_or_sections):
1027    return split(ary, indices_or_sections, axis=axis)
1028
1029  return f
1030
1031
1032vsplit = _split_on_axis('vsplit', axis=0)
1033hsplit = _split_on_axis('hsplit', axis=1)
1034dsplit = _split_on_axis('dsplit', axis=2)
1035
1036
1037@np_utils.np_doc('broadcast_to')
1038def broadcast_to(array, shape):  # pylint: disable=redefined-outer-name
1039  return full(shape, array)
1040
1041
1042@np_utils.np_doc('stack')
1043def stack(arrays, axis=0):  # pylint: disable=missing-function-docstring
1044  if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
1045    arrays = asarray(arrays)
1046    if axis == 0:
1047      return arrays
1048    else:
1049      return swapaxes(arrays, 0, axis)
1050  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1051  unwrapped_arrays = [
1052      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1053  ]
1054  return asarray(array_ops.stack(unwrapped_arrays, axis))
1055
1056
1057@np_utils.np_doc('hstack')
1058def hstack(tup):
1059  arrays = [atleast_1d(a) for a in tup]
1060  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1061  unwrapped_arrays = [
1062      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1063  ]
1064  rank = array_ops.rank(unwrapped_arrays[0])
1065  return np_utils.cond(
1066      math_ops.equal(rank,
1067                     1), lambda: array_ops.concat(unwrapped_arrays, axis=0),
1068      lambda: array_ops.concat(unwrapped_arrays, axis=1))
1069
1070
1071@np_utils.np_doc('vstack')
1072def vstack(tup):
1073  arrays = [atleast_2d(a) for a in tup]
1074  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1075  unwrapped_arrays = [
1076      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1077  ]
1078  return array_ops.concat(unwrapped_arrays, axis=0)
1079
1080
1081@np_utils.np_doc('dstack')
1082def dstack(tup):
1083  arrays = [atleast_3d(a) for a in tup]
1084  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1085  unwrapped_arrays = [
1086      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1087  ]
1088  return array_ops.concat(unwrapped_arrays, axis=2)
1089
1090
1091def _pad_left_to(n, old_shape):
1092  old_shape = asarray(old_shape, dtype=np.int32)
1093  new_shape = array_ops.pad(
1094      old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]],
1095      constant_values=1)
1096  return asarray(new_shape)
1097
1098
1099def _atleast_nd(n, new_shape, *arys):
1100  """Reshape arrays to be at least `n`-dimensional.
1101
1102  Args:
1103    n: The minimal rank.
1104    new_shape: a function that takes `n` and the old shape and returns the
1105      desired new shape.
1106    *arys: ndarray(s) to be reshaped.
1107
1108  Returns:
1109    The reshaped array(s).
1110  """
1111
1112  def f(x):
1113    # pylint: disable=g-long-lambda
1114    x = asarray(x)
1115    return asarray(
1116        np_utils.cond(
1117            np_utils.greater(n, array_ops.rank(x)),
1118            lambda: reshape(x, new_shape(n, array_ops.shape(x))),
1119            lambda: x))
1120
1121  arys = list(map(f, arys))
1122  if len(arys) == 1:
1123    return arys[0]
1124  else:
1125    return arys
1126
1127
1128@np_utils.np_doc('atleast_1d')
1129def atleast_1d(*arys):
1130  return _atleast_nd(1, _pad_left_to, *arys)
1131
1132
1133@np_utils.np_doc('atleast_2d')
1134def atleast_2d(*arys):
1135  return _atleast_nd(2, _pad_left_to, *arys)
1136
1137
1138@np_utils.np_doc('atleast_3d')
1139def atleast_3d(*arys):  # pylint: disable=missing-docstring
1140
1141  def new_shape(_, old_shape):
1142    # pylint: disable=g-long-lambda
1143    ndim_ = array_ops.size(old_shape)
1144    return np_utils.cond(
1145        math_ops.equal(ndim_, 0),
1146        lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32),
1147        lambda: np_utils.cond(
1148            math_ops.equal(ndim_, 1), lambda: array_ops.pad(
1149                old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad(
1150                    old_shape, [[0, 1]], constant_values=1)))
1151
1152  return _atleast_nd(3, new_shape, *arys)
1153
1154
1155@np_utils.np_doc('nonzero')
1156def nonzero(a):
1157  a = atleast_1d(a)
1158  if a.shape.rank is None:
1159    raise ValueError("The rank of `a` is unknown, so we can't decide how many "
1160                     'arrays to return.')
1161  return array_ops.unstack(
1162            array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
1163            a.shape.rank,
1164            axis=1)
1165
1166
1167@np_utils.np_doc('diag_indices')
1168def diag_indices(n, ndim=2):  # pylint: disable=missing-docstring,redefined-outer-name
1169  if n < 0:
1170    raise ValueError(
1171        'n argument to diag_indices must be nonnegative, got {}'.format(n))
1172  if ndim < 0:
1173    raise ValueError(
1174        'ndim argument to diag_indices must be nonnegative, got {}'.format(
1175            ndim))
1176
1177  return (math_ops.range(n),) * ndim
1178
1179
1180@np_utils.np_doc('tri')
1181def tri(N, M=None, k=0, dtype=None):  # pylint: disable=invalid-name,missing-docstring
1182  M = M if M is not None else N
1183  if dtype is not None:
1184    dtype = np_utils.result_type(dtype)
1185  else:
1186    dtype = np_dtypes.default_float_type()
1187
1188  if k < 0:
1189    lower = -k - 1
1190    if lower > N:
1191      r = array_ops.zeros([N, M], dtype)
1192    else:
1193      # Keep as tf bool, since we create an upper triangular matrix and invert
1194      # it.
1195      o = array_ops.ones([N, M], dtype=dtypes.bool)
1196      r = math_ops.cast(
1197          math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype)
1198  else:
1199    o = array_ops.ones([N, M], dtype)
1200    if k > M:
1201      r = o
1202    else:
1203      r = array_ops.matrix_band_part(o, -1, k)
1204  return r
1205
1206
1207@np_utils.np_doc('tril')
1208def tril(m, k=0):  # pylint: disable=missing-docstring
1209  m = asarray(m)
1210  if m.shape.ndims is None:
1211    raise ValueError('Argument to tril should have known rank')
1212  m_shape = m.shape.as_list()
1213
1214  if len(m_shape) < 2:
1215    raise ValueError('Argument to tril must have rank at least 2')
1216
1217  if m_shape[-1] is None or m_shape[-2] is None:
1218    raise ValueError('Currently, the last two dimensions of the input array '
1219                     'need to be known.')
1220
1221  z = constant_op.constant(0, m.dtype)
1222
1223  mask = tri(*m_shape[-2:], k=k, dtype=bool)
1224  return array_ops.where_v2(
1225      array_ops.broadcast_to(mask, array_ops.shape(m)), m, z)
1226
1227
1228@np_utils.np_doc('triu')
1229def triu(m, k=0):  # pylint: disable=missing-docstring
1230  m = asarray(m)
1231  if m.shape.ndims is None:
1232    raise ValueError('Argument to triu should have known rank')
1233  m_shape = m.shape.as_list()
1234
1235  if len(m_shape) < 2:
1236    raise ValueError('Argument to triu must have rank at least 2')
1237
1238  if m_shape[-1] is None or m_shape[-2] is None:
1239    raise ValueError('Currently, the last two dimensions of the input array '
1240                     'need to be known.')
1241
1242  z = constant_op.constant(0, m.dtype)
1243
1244  mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
1245  return array_ops.where_v2(
1246      array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)
1247
1248
1249@np_utils.np_doc('flip')
1250def flip(m, axis=None):  # pylint: disable=missing-docstring
1251  m = asarray(m)
1252
1253  if axis is None:
1254    return array_ops.reverse(m, math_ops.range(array_ops.rank(m)))
1255
1256  axis = np_utils._canonicalize_axis(axis, array_ops.rank(m))  # pylint: disable=protected-access
1257
1258  return array_ops.reverse(m, [axis])
1259
1260
1261@np_utils.np_doc('flipud')
1262def flipud(m):  # pylint: disable=missing-docstring
1263  return flip(m, 0)
1264
1265
1266@np_utils.np_doc('fliplr')
1267def fliplr(m):  # pylint: disable=missing-docstring
1268  return flip(m, 1)
1269
1270
1271@np_utils.np_doc('roll')
1272def roll(a, shift, axis=None):  # pylint: disable=missing-docstring
1273  a = asarray(a)
1274
1275  if axis is not None:
1276    return manip_ops.roll(a, shift, axis)
1277
1278  # If axis is None, the roll happens as a 1-d tensor.
1279  original_shape = array_ops.shape(a)
1280  a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
1281  return array_ops.reshape(a, original_shape)
1282
1283
1284@np_utils.np_doc('rot90')
1285def rot90(m, k=1, axes=(0, 1)):  # pylint: disable=missing-docstring
1286  m_rank = array_ops.rank(m)
1287  ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank)  # pylint: disable=protected-access
1288
1289  k = k % 4
1290  if k == 0:
1291    return m
1292  elif k == 2:
1293    return flip(flip(m, ax1), ax2)
1294  else:
1295    perm = math_ops.range(m_rank)
1296    perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1])
1297
1298    if k == 1:
1299      return transpose(flip(m, ax2), perm)
1300    else:
1301      return flip(transpose(m, perm), ax2)
1302
1303
1304@np_utils.np_doc('vander')
1305def vander(x, N=None, increasing=False):  # pylint: disable=missing-docstring,invalid-name
1306  x = asarray(x)
1307
1308  x_shape = array_ops.shape(x)
1309  N = N or x_shape[0]
1310
1311  N_temp = np_utils.get_static_value(N)  # pylint: disable=invalid-name
1312  if N_temp is not None:
1313    N = N_temp
1314    if N < 0:
1315      raise ValueError('N must be nonnegative')
1316  else:
1317    control_flow_ops.Assert(N >= 0, [N])
1318
1319  rank = array_ops.rank(x)
1320  rank_temp = np_utils.get_static_value(rank)
1321  if rank_temp is not None:
1322    rank = rank_temp
1323    if rank != 1:
1324      raise ValueError('x must be a one-dimensional array')
1325  else:
1326    control_flow_ops.Assert(math_ops.equal(rank, 1), [rank])
1327
1328  if increasing:
1329    start = 0
1330    limit = N
1331    delta = 1
1332  else:
1333    start = N - 1
1334    limit = -1
1335    delta = -1
1336
1337  x = array_ops.expand_dims(x, -1)
1338  return math_ops.pow(
1339      x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))
1340
1341
1342@np_utils.np_doc('ix_')
1343def ix_(*args):  # pylint: disable=missing-docstring
1344  n = len(args)
1345  output = []
1346  for i, a in enumerate(args):
1347    a = asarray(a)
1348    a_rank = array_ops.rank(a)
1349    a_rank_temp = np_utils.get_static_value(a_rank)
1350    if a_rank_temp is not None:
1351      a_rank = a_rank_temp
1352      if a_rank != 1:
1353        raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format(
1354            i, a_rank))
1355    else:
1356      control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank])
1357
1358    new_shape = [1] * n
1359    new_shape[i] = -1
1360    dtype = a.dtype
1361    if dtype == dtypes.bool:
1362      output.append(array_ops.reshape(nonzero(a)[0], new_shape))
1363    elif dtype.is_integer:
1364      output.append(array_ops.reshape(a, new_shape))
1365    else:
1366      raise ValueError(
1367          'Only integer and bool dtypes are supported, got {}'.format(dtype))
1368
1369  return output
1370
1371
1372@np_utils.np_doc('broadcast_arrays')
1373def broadcast_arrays(*args, **kwargs):  # pylint: disable=missing-docstring
1374  subok = kwargs.pop('subok', False)
1375  if subok:
1376    raise ValueError('subok=True is not supported.')
1377  if kwargs:
1378    raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
1379
1380  args = [asarray(arg) for arg in args]
1381  return np_utils.tf_broadcast(*args)
1382
1383
1384@np_utils.np_doc_only('sign')
1385def sign(x, out=None, where=None, **kwargs):  # pylint: disable=missing-docstring,redefined-outer-name
1386  if out:
1387    raise ValueError('tf.numpy doesnt support setting out.')
1388  if where:
1389    raise ValueError('tf.numpy doesnt support setting where.')
1390  if kwargs:
1391    raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
1392
1393  x = asarray(x)
1394  dtype = x.dtype.as_numpy_dtype
1395  if np.issubdtype(dtype, np.complexfloating):
1396    result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype)
1397  else:
1398    result = math_ops.sign(x)
1399
1400  return result
1401
1402
1403# Note that np.take_along_axis may not be present in some supported versions of
1404# numpy.
1405@np_utils.np_doc('take_along_axis')
1406def take_along_axis(arr, indices, axis):  # pylint: disable=missing-docstring
1407  arr = asarray(arr)
1408  indices = asarray(indices)
1409
1410  if axis is None:
1411    return take_along_axis(arr.ravel(), indices, 0)
1412
1413  rank = array_ops.rank(arr)
1414  axis = axis + rank if axis < 0 else axis
1415
1416  # Broadcast shapes to match, ensure that the axis of interest is not
1417  # broadcast.
1418  arr_shape_original = array_ops.shape(arr)
1419  indices_shape_original = array_ops.shape(indices)
1420  arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1])
1421  indices_shape = array_ops.tensor_scatter_update(indices_shape_original,
1422                                                  [[axis]], [1])
1423  broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape,
1424                                                        indices_shape)
1425  arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]],
1426                                              [arr_shape_original[axis]])
1427  indices_shape = array_ops.tensor_scatter_update(
1428      broadcasted_shape, [[axis]], [indices_shape_original[axis]])
1429  arr = array_ops.broadcast_to(arr, arr_shape)
1430  indices = array_ops.broadcast_to(indices, indices_shape)
1431
1432  # Save indices shape so we can restore it later.
1433  possible_result_shape = indices.shape
1434
1435  # Correct indices since gather doesn't correctly handle negative indices.
1436  indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
1437
1438  swapaxes_ = lambda t: swapaxes(t, axis, -1)
1439
1440  dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1))
1441  arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
1442                      lambda: swapaxes_(arr))
1443  indices = np_utils.cond(dont_move_axis_to_end, lambda: indices,
1444                          lambda: swapaxes_(indices))
1445
1446  arr_shape = array_ops.shape(arr)
1447  arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
1448
1449  indices_shape = array_ops.shape(indices)
1450  indices = array_ops.reshape(indices, [-1, indices_shape[-1]])
1451
1452  result = array_ops.gather(arr, indices, batch_dims=1)
1453  result = array_ops.reshape(result, indices_shape)
1454  result = np_utils.cond(dont_move_axis_to_end, lambda: result,
1455                         lambda: swapaxes_(result))
1456  result.set_shape(possible_result_shape)
1457
1458  return result
1459
1460
1461_SLICE_ERORR = (
1462    'only integers, slices (`:`), ellipsis (`...`), '
1463    'numpy.newaxis (`None`) and integer or boolean arrays are valid indices')
1464
1465
1466def _as_index(idx, need_scalar=True):
1467  """Helper function to parse idx as an index.
1468
1469  Args:
1470    idx: index
1471    need_scalar: If idx needs to be a scalar value.
1472
1473  Returns:
1474    A pair, (indx, bool). First one is the parsed index and can be a tensor,
1475    or scalar integer / Dimension. Second one is True if rank is known to be 0.
1476
1477  Raises:
1478    IndexError: For incorrect indices.
1479  """
1480  if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
1481    return idx, True
1482  data = asarray(idx)
1483  if data.dtype == dtypes.bool:
1484    if data.shape.ndims != 1:
1485      # TODO(agarwal): handle higher rank boolean masks.
1486      raise NotImplementedError('Need rank 1 for bool index %s' % idx)
1487    data = array_ops.where_v2(data)
1488    data = array_ops.reshape(data, [-1])
1489  if need_scalar and data.shape.rank not in (None, 0):
1490    raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1491  np_dtype = data.dtype.as_numpy_dtype
1492  if not np.issubdtype(np_dtype, np.integer):
1493    raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1494  if data.dtype not in (dtypes.int64, dtypes.int32):
1495    # TF slicing can only handle int32/int64. So we need to cast.
1496    promoted_dtype = np.promote_types(np.int32, np_dtype)
1497    if promoted_dtype == np.int32:
1498      data = math_ops.cast(data, dtypes.int32)
1499    elif promoted_dtype == np.int64:
1500      data = math_ops.cast(data, dtypes.int64)
1501    else:
1502      raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1503  return data, data.shape.rank == 0
1504
1505
1506class _UpdateMethod(enum.Enum):
1507  UPDATE = 0
1508  ADD = 1
1509  MIN = 2
1510  MAX = 3
1511
1512
1513def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
1514  """Helper function for __getitem__ and _with_index_update_helper.
1515
1516  This function collects the indices in `slice_spec` into two buckets, which we
1517  can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2
1518  `gather`.  They also correspond to "basic indices" and "advanced indices" in
1519  numpy.  This function supports both reading and writing at the indices. The
1520  reading path can be summarized as `gather(stride_slice(tensor, idx1),
1521  idx2)`. The writing path can be summarized as `strided_slice_update(tensor,
1522  idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`.  (`gather` here
1523  means `tf.gather` or `tf.gather_nd`; `scatter` here means
1524  `tf.tensor_scatter_update`.)  The writing path is inefficient because it needs
1525  to first read out a portion (probably much larger than `updates`) of `tensor`
1526  using `strided_slice`, update it, and then write the portion back. An
1527  alternative approach is to only use `scatter`, which amounts to using the
1528  indexing mechanism of gather/scatter to implement
1529  strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter
1530  because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but
1531  not TF gather/scatter because they don't support spans (except those that
1532  cover entire dimensions, i.e. `:`).  If we materialize spans into individual
1533  indices, the size of the index tensor would explode.  (Note that XLA
1534  Gather/Scatter have a similar problem for stride > 1 because they don't
1535  support strides.  Indices such as `1:2:8` will need to be materialized into
1536  individual indices such as [1, 3, 5, 7].)
1537
1538  Args:
1539    tensor: the tensor to be read from or write into.
1540    slice_spec: the indices.
1541    update_method: (optional) a member of `_UpdateMethod`, indicating how to
1542      update the values (replacement, add, etc.). `None` indicates just reading.
1543    updates: (optional) the new values to write into `tensor`. It must have the
1544      same dtype as `tensor`.
1545
1546  Returns:
1547    The result of reading (if `update_method` is `None`) or the updated `tensor`
1548    after writing.
1549  """
1550  begin, end, strides = [], [], []
1551  new_axis_mask, shrink_axis_mask = 0, 0
1552  begin_mask, end_mask = 0, 0
1553  ellipsis_mask = 0
1554  advanced_indices = []
1555  shrink_indices = []
1556  for index, s in enumerate(slice_spec):
1557    if isinstance(s, slice):
1558      if s.start is not None:
1559        begin.append(_as_index(s.start)[0])
1560      else:
1561        begin.append(0)
1562        begin_mask |= (1 << index)
1563      if s.stop is not None:
1564        end.append(_as_index(s.stop)[0])
1565      else:
1566        end.append(0)
1567        end_mask |= (1 << index)
1568      if s.step is not None:
1569        strides.append(_as_index(s.step)[0])
1570      else:
1571        strides.append(1)
1572    elif s is Ellipsis:
1573      begin.append(0)
1574      end.append(0)
1575      strides.append(1)
1576      ellipsis_mask |= (1 << index)
1577    elif s is array_ops.newaxis:
1578      begin.append(0)
1579      end.append(0)
1580      strides.append(1)
1581      new_axis_mask |= (1 << index)
1582    else:
1583      s, is_scalar = _as_index(s, False)
1584      if is_scalar:
1585        begin.append(s)
1586        end.append(s + 1)
1587        strides.append(1)
1588        shrink_axis_mask |= (1 << index)
1589        shrink_indices.append(index)
1590      else:
1591        begin.append(0)
1592        end.append(0)
1593        strides.append(1)
1594        begin_mask |= (1 << index)
1595        end_mask |= (1 << index)
1596        advanced_indices.append((index, s, ellipsis_mask != 0))
1597
1598  # stack possibly involves no tensors, so we must use op_scope correct graph.
1599  with ops.name_scope(
1600      None,
1601      'strided_slice', [tensor] + begin + end + strides,
1602      skip_on_eager=False) as name:
1603    if begin:
1604      packed_begin, packed_end, packed_strides = (array_ops.stack(begin),
1605                                                  array_ops.stack(end),
1606                                                  array_ops.stack(strides))
1607      if (packed_begin.dtype == dtypes.int64 or
1608          packed_end.dtype == dtypes.int64 or
1609          packed_strides.dtype == dtypes.int64):
1610        if packed_begin.dtype != dtypes.int64:
1611          packed_begin = math_ops.cast(packed_begin, dtypes.int64)
1612        if packed_end.dtype != dtypes.int64:
1613          packed_end = math_ops.cast(packed_end, dtypes.int64)
1614        if packed_strides.dtype != dtypes.int64:
1615          packed_strides = math_ops.cast(packed_strides, dtypes.int64)
1616    else:
1617      var_empty = constant_op.constant([], dtype=dtypes.int32)
1618      packed_begin = packed_end = packed_strides = var_empty
1619    if update_method == _UpdateMethod.UPDATE and not advanced_indices:
1620      return array_ops.tensor_strided_slice_update(
1621          tensor,
1622          packed_begin,
1623          packed_end,
1624          packed_strides,
1625          updates,
1626          begin_mask=begin_mask,
1627          end_mask=end_mask,
1628          shrink_axis_mask=shrink_axis_mask,
1629          new_axis_mask=new_axis_mask,
1630          ellipsis_mask=ellipsis_mask,
1631          name=name)
1632    else:
1633      # TODO(b/164251540): Find a better way to support update that does not
1634      #   involve one read + two writes.
1635      if updates is not None:
1636        original_tensor = tensor
1637      # TODO(agarwal): set_shape on tensor to set rank.
1638      tensor = array_ops.strided_slice(
1639          tensor,
1640          packed_begin,
1641          packed_end,
1642          packed_strides,
1643          begin_mask=begin_mask,
1644          end_mask=end_mask,
1645          shrink_axis_mask=shrink_axis_mask,
1646          new_axis_mask=new_axis_mask,
1647          ellipsis_mask=ellipsis_mask,
1648          name=name)
1649    if not advanced_indices:
1650      if update_method is None:
1651        return tensor
1652      assert update_method != _UpdateMethod.UPDATE
1653      # TF lacks TensorStridedSliceAdd and alike, so we need to do
1654      # read+add+update.
1655      if update_method == _UpdateMethod.ADD:
1656        update_op = math_ops.add
1657      elif update_method == _UpdateMethod.MIN:
1658        update_op = math_ops.minimum
1659      elif update_method == _UpdateMethod.MAX:
1660        update_op = math_ops.maximum
1661      return array_ops.tensor_strided_slice_update(
1662          original_tensor,
1663          packed_begin,
1664          packed_end,
1665          packed_strides,
1666          update_op(tensor, updates),
1667          begin_mask=begin_mask,
1668          end_mask=end_mask,
1669          shrink_axis_mask=shrink_axis_mask,
1670          new_axis_mask=new_axis_mask,
1671          ellipsis_mask=ellipsis_mask,
1672          name=name + '_2')
1673    advanced_indices_map = {}
1674    for index, data, had_ellipsis in advanced_indices:
1675      if had_ellipsis:
1676        num_shrink = len([x for x in shrink_indices if x > index])
1677        dim = index - len(slice_spec) + num_shrink
1678      else:
1679        num_shrink = len([x for x in shrink_indices if x < index])
1680        dim = index - num_shrink
1681      advanced_indices_map[dim] = data
1682    dims = sorted(advanced_indices_map.keys())
1683    dims_contiguous = True
1684    if len(dims) > 1:
1685      if dims[0] < 0 and dims[-1] >= 0:  # not all same sign
1686        dims_contiguous = False
1687      else:
1688        for i in range(len(dims) - 1):
1689          if dims[i] + 1 != dims[i + 1]:
1690            dims_contiguous = False
1691            break
1692    indices = [advanced_indices_map[x] for x in dims]
1693    indices = _promote_dtype(*indices)
1694    indices = np_utils.tf_broadcast(*indices)
1695    stacked_indices = array_ops.stack(indices, axis=-1)
1696    # Skip the contiguous-dims optimization for update because there is no
1697    # tf.*scatter* op that supports the `axis` argument.
1698    if not dims_contiguous or updates is not None:
1699      if range(len(dims)) != dims:
1700        tensor = moveaxis(tensor, dims, range(len(dims)))
1701      tensor_shape_prefix = array_ops.shape(
1702          tensor, out_type=stacked_indices.dtype)[:len(dims)]
1703      stacked_indices = array_ops.where_v2(
1704          stacked_indices < 0, stacked_indices + tensor_shape_prefix,
1705          stacked_indices)
1706      if updates is None:
1707        return array_ops.gather_nd(tensor, stacked_indices)
1708      else:
1709        # We only need to move-axis `updates` in the contiguous case becausce
1710        # only in this case the result dimensions of advanced indexing are in
1711        # the middle of `updates`. In the non-contiguous case, those dimensions
1712        # are always at the front.
1713        if dims_contiguous:
1714          # TODO(wangpeng): Support unknown rank (e.g. by partially flattening
1715          #   `updates`)
1716          if stacked_indices.shape.rank is None:
1717            raise NotImplementedError(
1718                'Rank of the advanced indices must currently be known')
1719          batch_size = stacked_indices.shape.rank - 1
1720          batch_start = dims[0]
1721          if batch_start < 0:
1722            batch_start += len(dims) - batch_size
1723          def range_(start, length):
1724            return range(start, start + length)
1725          updates = moveaxis(updates, range_(batch_start, batch_size),
1726                             range(batch_size))
1727        if update_method == _UpdateMethod.UPDATE:
1728          update_op = array_ops.tensor_scatter_update
1729        elif update_method == _UpdateMethod.ADD:
1730          update_op = array_ops.tensor_scatter_add
1731        elif update_method == _UpdateMethod.MIN:
1732          update_op = array_ops.tensor_scatter_min
1733        elif update_method == _UpdateMethod.MAX:
1734          update_op = array_ops.tensor_scatter_max
1735        tensor = update_op(
1736            tensor, stacked_indices, updates)
1737        if range(len(dims)) != dims:
1738          tensor = moveaxis(tensor, range(len(dims)), dims)
1739        return array_ops.tensor_strided_slice_update(
1740            original_tensor,
1741            packed_begin,
1742            packed_end,
1743            packed_strides,
1744            tensor,
1745            begin_mask=begin_mask,
1746            end_mask=end_mask,
1747            shrink_axis_mask=shrink_axis_mask,
1748            new_axis_mask=new_axis_mask,
1749            ellipsis_mask=ellipsis_mask,
1750            name=name + '_2')
1751    # Note that gather_nd does not support gathering from inside the array.
1752    # To avoid shuffling data back and forth, we transform the indices and
1753    # do a gather instead.
1754    rank = np_utils._maybe_static(array_ops.rank(tensor))  # pylint: disable=protected-access
1755    dims = [(x + rank if x < 0 else x) for x in dims]
1756    shape_tensor = array_ops.shape(tensor)
1757    dim_sizes = array_ops.gather(shape_tensor, dims)
1758    if len(dims) == 1:
1759      stacked_indices = indices[0]
1760    stacked_indices = math_ops.cast(stacked_indices, dtypes.int32)
1761    stacked_indices = array_ops.where_v2(stacked_indices < 0,
1762                                         stacked_indices + dim_sizes,
1763                                         stacked_indices)
1764    axis = dims[0]
1765    if len(dims) > 1:
1766      index_scaling = math_ops.cumprod(
1767          dim_sizes, reverse=True, exclusive=True)
1768      def _tensordot(a, b):
1769        # TODO(b/168657656): This function should be replaced by
1770        # tensordot(axis=1) once MatMul has int32 XLA kernel.
1771        b = array_ops.broadcast_to(b, array_ops.shape(a))
1772        return math_ops.reduce_sum(a * b, axis=-1)
1773      stacked_indices = _tensordot(stacked_indices, index_scaling)
1774      flat_shape = array_ops.concat(
1775          [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]],
1776          axis=0)
1777      tensor = array_ops.reshape(tensor, flat_shape)
1778
1779    return array_ops.gather(tensor, stacked_indices, axis=axis)
1780
1781
1782def _as_spec_tuple(slice_spec):
1783  """Convert slice_spec to tuple."""
1784  if isinstance(slice_spec,
1785                (list, tuple)) and not isinstance(slice_spec, np.ndarray):
1786    is_index = True
1787    for s in slice_spec:
1788      if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)):
1789        is_index = False
1790        break
1791      elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0:
1792        is_index = False
1793        break
1794    if not is_index:
1795      return tuple(slice_spec)
1796  return (slice_spec,)
1797
1798
1799def _getitem(self, slice_spec):
1800  """Implementation of ndarray.__getitem__."""
1801  if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1802                                       slice_spec.dtype == dtypes.bool) or
1803      (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1804       slice_spec.dtype == np.bool_)):
1805    return array_ops.boolean_mask(tensor=self, mask=slice_spec)
1806
1807  if not isinstance(slice_spec, tuple):
1808    slice_spec = _as_spec_tuple(slice_spec)
1809
1810  result_t = _slice_helper(self, slice_spec)
1811  return result_t
1812
1813
1814def _with_index_update_helper(update_method, a, slice_spec, updates):
1815  """Implementation of ndarray._with_index_*."""
1816  if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1817                                       slice_spec.dtype == dtypes.bool) or
1818      (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1819       slice_spec.dtype == np.bool_)):
1820    slice_spec = nonzero(slice_spec)
1821
1822  if not isinstance(slice_spec, tuple):
1823    slice_spec = _as_spec_tuple(slice_spec)
1824
1825  a_dtype = a.dtype
1826  a, updates = _promote_dtype_binary(a, updates)
1827  result_t = _slice_helper(a, slice_spec, update_method, updates)
1828  return result_t.astype(a_dtype)
1829
1830
1831setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem)
1832setattr(np_arrays.ndarray, '_with_index_update',
1833        functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE))
1834setattr(np_arrays.ndarray, '_with_index_add',
1835        functools.partial(_with_index_update_helper, _UpdateMethod.ADD))
1836setattr(np_arrays.ndarray, '_with_index_min',
1837        functools.partial(_with_index_update_helper, _UpdateMethod.MIN))
1838setattr(np_arrays.ndarray, '_with_index_max',
1839        functools.partial(_with_index_update_helper, _UpdateMethod.MAX))
1840