1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Common array methods.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import enum 23import functools 24import math 25import numbers 26import numpy as np 27import six 28 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import clip_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import linalg_ops 37from tensorflow.python.ops import manip_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import sort_ops 40from tensorflow.python.ops.numpy_ops import np_arrays 41from tensorflow.python.ops.numpy_ops import np_dtypes 42from tensorflow.python.ops.numpy_ops import np_export 43from tensorflow.python.ops.numpy_ops import np_utils 44from tensorflow.python.util import nest 45 46 47newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis) 48 49 50@np_utils.np_doc('empty') 51def empty(shape, dtype=float): # pylint: disable=redefined-outer-name 52 return zeros(shape, dtype) 53 54 55@np_utils.np_doc('empty_like') 56def empty_like(a, dtype=None): 57 return zeros_like(a, dtype) 58 59 60@np_utils.np_doc('zeros') 61def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name 62 dtype = ( 63 np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type()) 64 return array_ops.zeros(shape, dtype=dtype) 65 66 67@np_utils.np_doc('zeros_like') 68def zeros_like(a, dtype=None): # pylint: disable=missing-docstring 69 if dtype is None: 70 # We need to let np_utils.result_type decide the dtype, not tf.zeros_like 71 dtype = np_utils.result_type(a) 72 else: 73 # TF and numpy has different interpretations of Python types such as 74 # `float`, so we let `np_utils.result_type` decide. 75 dtype = np_utils.result_type(dtype) 76 dtype = dtypes.as_dtype(dtype) # Work around b/149877262 77 return array_ops.zeros_like(a, dtype) 78 79 80@np_utils.np_doc('ones') 81def ones(shape, dtype=float): # pylint: disable=redefined-outer-name 82 if dtype: 83 dtype = np_utils.result_type(dtype) 84 return array_ops.ones(shape, dtype=dtype) 85 86 87@np_utils.np_doc('ones_like') 88def ones_like(a, dtype=None): 89 if dtype is None: 90 dtype = np_utils.result_type(a) 91 else: 92 dtype = np_utils.result_type(dtype) 93 return array_ops.ones_like(a, dtype) 94 95 96@np_utils.np_doc('eye') 97def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-docstring 98 if dtype: 99 dtype = np_utils.result_type(dtype) 100 if not M: 101 M = N 102 # Making sure N, M and k are `int` 103 N = int(N) 104 M = int(M) 105 k = int(k) 106 if k >= M or -k >= N: 107 # tf.linalg.diag will raise an error in this case 108 return zeros([N, M], dtype=dtype) 109 if k == 0: 110 return linalg_ops.eye(N, M, dtype=dtype) 111 # We need the precise length, otherwise tf.linalg.diag will raise an error 112 diag_len = min(N, M) 113 if k > 0: 114 if N >= M: 115 diag_len -= k 116 elif N + k > M: 117 diag_len = M - k 118 elif k <= 0: 119 if M >= N: 120 diag_len += k 121 elif M - k > N: 122 diag_len = N + k 123 diagonal_ = array_ops.ones([diag_len], dtype=dtype) 124 return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k) 125 126 127@np_utils.np_doc('identity') 128def identity(n, dtype=float): 129 return eye(N=n, M=n, dtype=dtype) 130 131 132@np_utils.np_doc('full') 133def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name 134 if not isinstance(shape, np_arrays.ndarray): 135 shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32)) 136 shape = atleast_1d(shape) 137 fill_value = asarray(fill_value, dtype=dtype) 138 return array_ops.broadcast_to(fill_value, shape) 139 140 141# Using doc only here since np full_like signature doesn't seem to have the 142# shape argument (even though it exists in the documentation online). 143@np_utils.np_doc_only('full_like') 144def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # pylint: disable=missing-docstring,redefined-outer-name 145 """order, subok and shape arguments mustn't be changed.""" 146 if order != 'K': 147 raise ValueError('Non-standard orders are not supported.') 148 if not subok: 149 raise ValueError('subok being False is not supported.') 150 if shape: 151 raise ValueError('Overriding the shape is not supported.') 152 153 a = asarray(a) 154 dtype = dtype or np_utils.result_type(a) 155 fill_value = asarray(fill_value, dtype=dtype) 156 return array_ops.broadcast_to(fill_value, array_ops.shape(a)) 157 158 159def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 160 """Main implementation of np.array().""" 161 result_t = val 162 163 if not isinstance(result_t, ops.Tensor): 164 if not dtype: 165 dtype = np_utils.result_type(result_t) 166 # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because 167 # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) 168 # while np.array allows them. We need to convert-then-cast. 169 170 # EagerTensor conversion complains about "mixed types" when converting 171 # tensors with no dtype information. This is because it infers types based 172 # on one selected item in the list. So e.g. when converting [2., 2j] 173 # to a tensor, it will select float32 as the inferred type and not be able 174 # to convert the list to a float 32 tensor. 175 # Since we have some information about the final dtype we care about, we 176 # supply that information so that convert_to_tensor will do best-effort 177 # conversion to that dtype first. 178 result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype) 179 result_t = math_ops.cast(result_t, dtype=dtype) 180 elif dtype: 181 result_t = math_ops.cast(result_t, dtype) 182 183 if copy: 184 result_t = array_ops.identity(result_t) 185 186 if ndmin == 0: 187 return result_t 188 189 ndims = array_ops.rank(result_t) 190 191 def true_fn(): 192 old_shape = array_ops.shape(result_t) 193 new_shape = array_ops.concat( 194 [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0) 195 return array_ops.reshape(result_t, new_shape) 196 197 result_t = np_utils.cond( 198 np_utils.greater(ndmin, ndims), true_fn, lambda: result_t) 199 return result_t 200 201 202# TODO(wangpeng): investigate whether we can make `copy` default to False. 203# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 204@np_utils.np_doc_only('array') 205def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 206 """Since Tensors are immutable, a copy is made only if val is placed on a 207 208 different device than the current one. Even if `copy` is False, a new Tensor 209 may need to be built to satisfy `dtype` and `ndim`. This is used only if `val` 210 is an ndarray or a Tensor. 211 """ # pylint:disable=g-docstring-missing-newline 212 if dtype: 213 dtype = np_utils.result_type(dtype) 214 return _array_internal(val, dtype, copy, ndmin) 215 216 217# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 218 219 220@np_utils.np_doc('asarray') 221def asarray(a, dtype=None): 222 if dtype: 223 dtype = np_utils.result_type(dtype) 224 if isinstance(a, np_arrays.ndarray) and ( 225 not dtype or dtype == a.dtype.as_numpy_dtype): 226 return a 227 return array(a, dtype, copy=False) 228 229 230@np_utils.np_doc('asanyarray') 231def asanyarray(a, dtype=None): 232 return asarray(a, dtype) 233 234 235@np_utils.np_doc('ascontiguousarray') 236def ascontiguousarray(a, dtype=None): 237 return array(a, dtype, ndmin=1) 238 239 240# Numerical ranges. 241@np_utils.np_doc('arange') 242def arange(start, stop=None, step=1, dtype=None): 243 """Returns `step`-separated values in the range [start, stop). 244 245 Args: 246 start: Start of the interval. Included in the range. 247 stop: End of the interval. If not specified, `start` is treated as 0 and 248 `start` value is used as `stop`. If specified, it is not included in the 249 range if `step` is integer. When `step` is floating point, it may or may 250 not be included. 251 step: The difference between 2 consecutive values in the output range. It is 252 recommended to use `linspace` instead of using non-integer values for 253 `step`. 254 dtype: Optional. Type of the resulting ndarray. Could be a python type, a 255 NumPy type or a TensorFlow `DType`. If not provided, the largest type of 256 `start`, `stop`, `step` is used. 257 258 Raises: 259 ValueError: If step is zero. 260 """ 261 if not step: 262 raise ValueError('step must be non-zero.') 263 if dtype: 264 dtype = np_utils.result_type(dtype) 265 else: 266 if stop is None: 267 dtype = np_utils.result_type(start, step) 268 else: 269 dtype = np_utils.result_type(start, step, stop) 270 if step > 0 and ((stop is not None and start > stop) or 271 (stop is None and start < 0)): 272 return array([], dtype=dtype) 273 if step < 0 and ((stop is not None and start < stop) or 274 (stop is None and start > 0)): 275 return array([], dtype=dtype) 276 # TODO(srbs): There are some bugs when start or stop is float type and dtype 277 # is integer type. 278 return math_ops.cast( 279 math_ops.range(start, limit=stop, delta=step), dtype=dtype) 280 281 282# Building matrices. 283@np_utils.np_doc('diag') 284def diag(v, k=0): # pylint: disable=missing-docstring 285 """Raises an error if input is not 1- or 2-d.""" 286 v = asarray(v) 287 v_rank = array_ops.rank(v) 288 289 v.shape.with_rank_at_most(2) 290 291 # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during 292 # tracing time if the shape is known. 293 control_flow_ops.Assert( 294 np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)), 295 [v_rank]) 296 297 def _diag(v, k): 298 return np_utils.cond( 299 math_ops.equal(array_ops.size(v), 0), 300 lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype), 301 lambda: array_ops.matrix_diag(v, k=k)) 302 303 def _diag_part(v, k): 304 v_shape = array_ops.shape(v) 305 v, k = np_utils.cond( 306 np_utils.logical_or( 307 np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)), 308 np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)), 309 ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) 310 result = array_ops.matrix_diag_part(v, k=k) 311 return result 312 313 result = np_utils.cond( 314 math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) 315 return result 316 317 318@np_utils.np_doc('diagonal') 319def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring 320 a = asarray(a) 321 322 maybe_rank = a.shape.rank 323 if maybe_rank is not None and offset == 0 and ( 324 axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or 325 axis2 == -1): 326 return array_ops.matrix_diag_part(a) 327 328 a = moveaxis(a, (axis1, axis2), (-2, -1)) 329 330 a_shape = array_ops.shape(a) 331 332 def _zeros(): # pylint: disable=missing-docstring 333 return (array_ops.zeros( 334 array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) 335 336 # All zeros since diag_part doesn't handle all possible k (aka offset). 337 # Written this way since cond will run shape inference on both branches, 338 # and diag_part shape inference will fail when offset is out of bounds. 339 a, offset = np_utils.cond( 340 np_utils.logical_or( 341 np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)), 342 np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)), 343 ), _zeros, lambda: (a, offset)) 344 345 a = array_ops.matrix_diag_part(a, k=offset) 346 return a 347 348 349@np_utils.np_doc('diagflat') 350def diagflat(v, k=0): 351 v = asarray(v) 352 return diag(array_ops.reshape(v, [-1]), k) 353 354 355def _promote_dtype(*arrays): 356 dtype = np_utils.result_type(*arrays) 357 def _fast_asarray(a): 358 if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype: 359 return a 360 return _array_internal(a, dtype=dtype, copy=False) 361 return [_fast_asarray(a) for a in arrays] 362 363 364def _promote_dtype_binary(t1, t2): 365 dtype = np_utils._result_type_binary(t1, t2) # pylint: disable=protected-access 366 if not( 367 isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype): 368 t1 = _array_internal(t1, dtype=dtype, copy=False) 369 if not( 370 isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype): 371 t2 = _array_internal(t2, dtype=dtype, copy=False) 372 return t1, t2 373 374 375@np_utils.np_doc('all') 376def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 377 a = asarray(a, dtype=bool) 378 return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims) 379 380 381@np_utils.np_doc('any') 382def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 383 a = asarray(a, dtype=bool) 384 return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims) 385 386 387@np_utils.np_doc('compress') 388def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,missing-function-docstring 389 condition = asarray(condition, dtype=bool) 390 a = asarray(a) 391 392 if condition.ndim != 1: 393 raise ValueError('condition must be a 1-d array.') 394 # `np.compress` treats scalars as 1-d arrays. 395 if a.ndim == 0: 396 a = ravel(a) 397 398 if axis is None: 399 a = ravel(a) 400 axis = 0 401 402 if axis < 0: 403 axis += a.ndim 404 405 assert axis >= 0 and axis < a.ndim 406 407 # `tf.boolean_mask` requires the first dimensions of array and condition to 408 # match. `np.compress` pads condition with False when it is shorter. 409 condition_t = condition 410 a_t = a 411 if condition.shape[0] < a.shape[axis]: 412 padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False) 413 condition_t = array_ops.concat([condition_t, padding], axis=0) 414 return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) 415 416 417@np_utils.np_doc('copy') 418def copy(a): 419 return array(a, copy=True) 420 421 422def _maybe_promote_to_int(a): 423 if dtypes.as_dtype(a.dtype).is_integer: 424 # If a is an integer type and its precision is less than that of `int`, 425 # the output type will be `int`. 426 a_numpy_dtype = a.dtype.as_numpy_dtype 427 output_type = np.promote_types(a_numpy_dtype, int) 428 if output_type != a_numpy_dtype: 429 a = asarray(a, dtype=output_type) 430 431 return a 432 433 434@np_utils.np_doc('cumprod') 435def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring 436 a = asarray(a, dtype=dtype) 437 438 if dtype is None: 439 a = _maybe_promote_to_int(a) 440 441 # If axis is None, the input is flattened. 442 if axis is None: 443 a = ravel(a) 444 axis = 0 445 elif axis < 0: 446 axis += array_ops.rank(a) 447 return math_ops.cumprod(a, axis) 448 449 450@np_utils.np_doc('cumsum') 451def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring 452 a = asarray(a, dtype=dtype) 453 454 if dtype is None: 455 a = _maybe_promote_to_int(a) 456 457 # If axis is None, the input is flattened. 458 if axis is None: 459 a = ravel(a) 460 axis = 0 461 elif axis < 0: 462 axis += array_ops.rank(a) 463 return math_ops.cumsum(a, axis) 464 465 466@np_utils.np_doc('imag') 467def imag(val): 468 val = asarray(val) 469 # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always 470 # return an ndarray. 471 return math_ops.imag(val) 472 473 474_TO_INT_ = 0 475_TO_FLOAT = 1 476 477 478def _reduce(tf_fn, 479 a, 480 axis=None, 481 dtype=None, 482 keepdims=None, 483 promote_int=_TO_INT_, 484 tf_bool_fn=None, 485 preserve_bool=False): 486 """A general reduction function. 487 488 Args: 489 tf_fn: the TF reduction function. 490 a: the array to be reduced. 491 axis: (optional) the axis along which to do the reduction. If None, all 492 dimensions are reduced. 493 dtype: (optional) the dtype of the result. 494 keepdims: (optional) whether to keep the reduced dimension(s). 495 promote_int: how to promote integer and bool inputs. There are three 496 choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2) 497 `_TO_FLOAT` always promotes them to a float type (determined by 498 dtypes.default_float_type); (3) None: don't promote. 499 tf_bool_fn: (optional) the TF reduction function for bool inputs. It will 500 only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype 501 is `np.bool_` and `preserve_bool` is True. 502 preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype 503 is `np.bool_` (some reductions such as np.sum convert bools to integers, 504 while others such as np.max preserve bools. 505 506 Returns: 507 An ndarray. 508 """ 509 if dtype: 510 dtype = np_utils.result_type(dtype) 511 if keepdims is None: 512 keepdims = False 513 a = asarray(a, dtype=dtype) 514 if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and 515 tf_bool_fn is not None): 516 return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims) 517 if dtype is None: 518 dtype = a.dtype.as_numpy_dtype 519 if np.issubdtype(dtype, np.integer) or dtype == np.bool_: 520 if promote_int == _TO_INT_: 521 # If a is an integer/bool type and whose bit width is less than np.int_, 522 # numpy up-casts it to np.int_ based on the documentation at 523 # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html 524 if dtype == np.bool_: 525 is_signed = True 526 width = 8 # We can use any number here that is less than 64 527 else: 528 is_signed = np.issubdtype(dtype, np.signedinteger) 529 width = np.iinfo(dtype).bits 530 # Numpy int_ and uint are defined as 'long' and 'unsigned long', so 531 # should have the same bit width. 532 if width < np.iinfo(np.int_).bits: 533 if is_signed: 534 dtype = np.int_ 535 else: 536 dtype = np.uint 537 a = math_ops.cast(a, dtype) 538 elif promote_int == _TO_FLOAT: 539 a = math_ops.cast(a, np_dtypes.default_float_type()) 540 541 if isinstance(axis, ops.Tensor) and axis.dtype not in ( 542 dtypes.int32, dtypes.int64): 543 axis = math_ops.cast(axis, dtypes.int64) 544 545 return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims) 546 547 548# TODO (DarrenZhang01): Add `axis` support to the `size` API. 549@np_utils.np_doc('size') 550def size(x, axis=None): # pylint: disable=missing-docstring 551 if axis is not None: 552 raise NotImplementedError('axis argument is not supported in the current ' 553 '`np.size` implementation') 554 if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)): 555 return 1 556 x = asarray(x) 557 if x.shape.is_fully_defined(): 558 return np.prod(x.shape.as_list(), dtype=int) 559 else: 560 return array_ops.size_v2(x) 561 562 563@np_utils.np_doc('sum') 564def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin 565 return _reduce( 566 math_ops.reduce_sum, 567 a, 568 axis=axis, 569 dtype=dtype, 570 keepdims=keepdims, 571 tf_bool_fn=math_ops.reduce_any) 572 573 574@np_utils.np_doc('prod') 575def prod(a, axis=None, dtype=None, keepdims=None): 576 return _reduce( 577 math_ops.reduce_prod, 578 a, 579 axis=axis, 580 dtype=dtype, 581 keepdims=keepdims, 582 tf_bool_fn=math_ops.reduce_all) 583 584 585@np_utils.np_doc('mean', unsupported_params=['out']) 586def mean(a, axis=None, dtype=None, out=None, keepdims=None): 587 if out is not None: 588 raise ValueError('Setting out is not supported.') 589 return _reduce( 590 math_ops.reduce_mean, 591 a, 592 axis=axis, 593 dtype=dtype, 594 keepdims=keepdims, 595 promote_int=_TO_FLOAT) 596 597 598@np_utils.np_doc('amax', unsupported_params=['out']) 599def amax(a, axis=None, out=None, keepdims=None): 600 if out is not None: 601 raise ValueError('Setting out is not supported.') 602 return _reduce( 603 math_ops.reduce_max, 604 a, 605 axis=axis, 606 dtype=None, 607 keepdims=keepdims, 608 promote_int=None, 609 tf_bool_fn=math_ops.reduce_any, 610 preserve_bool=True) 611 612 613@np_utils.np_doc('amin', unsupported_params=['out']) 614def amin(a, axis=None, out=None, keepdims=None): 615 if out is not None: 616 raise ValueError('Setting out is not supported.') 617 return _reduce( 618 math_ops.reduce_min, 619 a, 620 axis=axis, 621 dtype=None, 622 keepdims=keepdims, 623 promote_int=None, 624 tf_bool_fn=math_ops.reduce_all, 625 preserve_bool=True) 626 627 628@np_utils.np_doc('var') 629def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=missing-docstring 630 if dtype: 631 working_dtype = np_utils.result_type(a, dtype) 632 else: 633 working_dtype = None 634 if out is not None: 635 raise ValueError('Setting out is not supported.') 636 if ddof != 0: 637 # TF reduce_variance doesn't support ddof, so calculate it using raw ops. 638 def reduce_fn(input_tensor, axis, keepdims): 639 means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True) 640 centered = input_tensor - means 641 if input_tensor.dtype in (dtypes.complex64, dtypes.complex128): 642 centered = math_ops.cast( 643 math_ops.real(centered * math_ops.conj(centered)), 644 input_tensor.dtype) 645 else: 646 centered = math_ops.square(centered) 647 squared_deviations = math_ops.reduce_sum( 648 centered, axis=axis, keepdims=keepdims) 649 650 if axis is None: 651 n = array_ops.size(input_tensor) 652 else: 653 if axis < 0: 654 axis += array_ops.rank(input_tensor) 655 n = math_ops.reduce_prod( 656 array_ops.gather(array_ops.shape(input_tensor), axis)) 657 n = math_ops.cast(n - ddof, input_tensor.dtype) 658 659 return math_ops.cast(math_ops.divide(squared_deviations, n), dtype) 660 else: 661 reduce_fn = math_ops.reduce_variance 662 663 result = _reduce( 664 reduce_fn, 665 a, 666 axis=axis, 667 dtype=working_dtype, 668 keepdims=keepdims, 669 promote_int=_TO_FLOAT) 670 if dtype: 671 result = math_ops.cast(result, dtype) 672 return result 673 674 675@np_utils.np_doc('std') 676def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring 677 return _reduce( 678 math_ops.reduce_std, 679 a, 680 axis=axis, 681 dtype=None, 682 keepdims=keepdims, 683 promote_int=_TO_FLOAT) 684 685 686@np_utils.np_doc('ravel') 687def ravel(a): # pylint: disable=missing-docstring 688 a = asarray(a) 689 return array_ops.reshape(a, [-1]) 690 691 692@np_utils.np_doc('real') 693def real(val): 694 val = asarray(val) 695 # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always 696 # return an ndarray. 697 return math_ops.real(val) 698 699 700@np_utils.np_doc('repeat') 701def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring 702 a = asarray(a) 703 original_shape = a._shape_as_list() # pylint: disable=protected-access 704 # Best effort recovery of the shape. 705 known_shape = original_shape is not None and None not in original_shape 706 if known_shape: 707 if not original_shape: 708 original_shape = (repeats,) 709 else: 710 repeats_np = np.ravel(np.array(repeats)) 711 if repeats_np.size == 1: 712 repeats_np = repeats_np.item() 713 if axis is None: 714 original_shape = (repeats_np * np.prod(original_shape),) 715 else: 716 original_shape[axis] = repeats_np * original_shape[axis] 717 else: 718 if axis is None: 719 original_shape = (repeats_np.sum(),) 720 else: 721 original_shape[axis] = repeats_np.sum() 722 723 repeats = asarray(repeats) 724 result = array_ops.repeat(a, repeats, axis) 725 if known_shape: 726 result.set_shape(original_shape) 727 728 return result 729 730 731@np_utils.np_doc('around') 732def around(a, decimals=0): # pylint: disable=missing-docstring 733 a = asarray(a) 734 dtype = a.dtype.as_numpy_dtype 735 factor = math.pow(10, decimals) 736 if np.issubdtype(dtype, np.inexact): 737 factor = math_ops.cast(factor, dtype) 738 else: 739 # Use float as the working dtype when a.dtype is exact (e.g. integer), 740 # because `decimals` can be negative. 741 float_dtype = np_dtypes.default_float_type() 742 a = a.astype(float_dtype) 743 factor = math_ops.cast(factor, float_dtype) 744 a = math_ops.multiply(a, factor) 745 a = math_ops.round(a) 746 a = math_ops.divide(a, factor) 747 return a.astype(dtype) 748 749 750setattr(np_arrays.ndarray, '__round__', around) 751 752 753@np_utils.np_doc('reshape') 754def reshape(a, newshape, order='C'): 755 """order argument can only b 'C' or 'F'.""" 756 if order not in {'C', 'F'}: 757 raise ValueError('Unsupported order argument {}'.format(order)) 758 759 a = asarray(a) 760 if isinstance(newshape, int): 761 newshape = [newshape] 762 763 if order == 'F': 764 r = array_ops.transpose( 765 array_ops.reshape(array_ops.transpose(a), newshape[::-1])) 766 else: 767 r = array_ops.reshape(a, newshape) 768 769 return r 770 771 772def _reshape_method_wrapper(a, *newshape, **kwargs): 773 order = kwargs.pop('order', 'C') 774 if kwargs: 775 raise ValueError('Unsupported arguments: {}'.format(kwargs.keys())) 776 777 if len(newshape) == 1 and not isinstance(newshape[0], int): 778 newshape = newshape[0] 779 780 return reshape(a, newshape, order=order) 781 782 783@np_utils.np_doc('expand_dims') 784def expand_dims(a, axis): 785 a = asarray(a) 786 return array_ops.expand_dims(a, axis=axis) 787 788 789@np_utils.np_doc('squeeze') 790def squeeze(a, axis=None): 791 a = asarray(a) 792 return array_ops.squeeze(a, axis) 793 794 795@np_utils.np_doc('transpose') 796def transpose(a, axes=None): 797 a = asarray(a) 798 if axes is not None: 799 axes = asarray(axes) 800 return array_ops.transpose(a=a, perm=axes) 801 802 803@np_utils.np_doc('swapaxes') 804def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring 805 a = asarray(a) 806 def adjust_axes(axes, rank): 807 def f(x): 808 if isinstance(x, int): 809 if x < 0: 810 x = x + rank 811 else: 812 x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x) 813 return x 814 return nest.map_structure(f, axes) 815 816 if (a.shape.rank is not None and 817 isinstance(axis1, int) and isinstance(axis2, int)): 818 # This branch makes sure `perm` is statically known, to avoid a 819 # not-compile-time-constant XLA error. 820 a_rank = a.shape.rank 821 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 822 perm = list(range(a_rank)) 823 perm[axis1] = axis2 824 perm[axis2] = axis1 825 else: 826 a_rank = array_ops.rank(a) 827 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 828 perm = math_ops.range(a_rank) 829 perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]], 830 [axis2, axis1]) 831 a = array_ops.transpose(a, perm) 832 return a 833 834 835@np_utils.np_doc('moveaxis') 836def moveaxis(a, source, destination): # pylint: disable=missing-docstring 837 """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" 838 if not source and not destination: 839 return a 840 841 a = asarray(a) 842 843 if isinstance(source, int): 844 source = (source,) 845 if isinstance(destination, int): 846 destination = (destination,) 847 if len(source) != len(destination): 848 raise ValueError('The lengths of source and destination must equal') 849 850 a_rank = np_utils._maybe_static(array_ops.rank(a)) # pylint: disable=protected-access 851 852 def _correct_axis(axis, rank): 853 if axis < 0: 854 return axis + rank 855 return axis 856 857 source = tuple(_correct_axis(axis, a_rank) for axis in source) 858 destination = tuple(_correct_axis(axis, a_rank) for axis in destination) 859 860 if a.shape.rank is not None: 861 perm = [i for i in range(a_rank) if i not in source] 862 for dest, src in sorted(zip(destination, source)): 863 assert dest <= len(perm) 864 perm.insert(dest, src) 865 else: 866 r = math_ops.range(a_rank) 867 868 def _remove_indices(a, b): 869 """Remove indices (`b`) from `a`.""" 870 items = array_ops.unstack(sort_ops.sort(array_ops.stack(b)), num=len(b)) 871 872 i = 0 873 result = [] 874 875 for item in items: 876 result.append(a[i:item]) 877 i = item + 1 878 879 result.append(a[i:]) 880 881 return array_ops.concat(result, 0) 882 883 minus_sources = _remove_indices(r, source) 884 minus_dest = _remove_indices(r, destination) 885 886 perm = array_ops.scatter_nd( 887 array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank]) 888 perm = array_ops.tensor_scatter_update( 889 perm, array_ops.expand_dims(destination, 1), source) 890 a = array_ops.transpose(a, perm) 891 892 return a 893 894 895@np_utils.np_doc('pad') 896def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-name 897 """Only supports modes 'constant', 'reflect' and 'symmetric' currently.""" 898 constant_values = kwargs.get('constant_values', 0) 899 if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'): 900 raise ValueError('Unsupported padding mode: ' + mode) 901 mode = mode.upper() 902 array = asarray(array) 903 pad_width = asarray(pad_width, dtype=dtypes.int32) 904 return array_ops.pad( 905 tensor=array, 906 paddings=pad_width, 907 mode=mode, 908 constant_values=constant_values) 909 910 911@np_utils.np_doc('take') 912def take(a, indices, axis=None, out=None, mode='clip'): 913 """out argument is not supported, and default mode is clip.""" 914 if out is not None: 915 raise ValueError('out argument is not supported in take.') 916 917 if mode not in {'raise', 'clip', 'wrap'}: 918 raise ValueError("Invalid mode '{}' for take".format(mode)) 919 920 a = asarray(a) 921 indices = asarray(indices) 922 923 if axis is None: 924 a = array_ops.reshape(a, [-1]) 925 axis = 0 926 927 axis_size = array_ops.shape(a, out_type=indices.dtype)[axis] 928 if mode == 'clip': 929 indices = clip_ops.clip_by_value(indices, 0, axis_size - 1) 930 elif mode == 'wrap': 931 indices = math_ops.floormod(indices, axis_size) 932 else: 933 raise ValueError("The 'raise' mode to take is not supported.") 934 935 return array_ops.gather(a, indices, axis=axis) 936 937 938@np_utils.np_doc_only('where') 939def where(condition, x=None, y=None): 940 """Raises ValueError if exactly one of x or y is not None.""" 941 condition = asarray(condition, dtype=np.bool_) 942 if x is None and y is None: 943 return nonzero(condition) 944 elif x is not None and y is not None: 945 x, y = _promote_dtype(x, y) 946 return array_ops.where_v2(condition, x, y) 947 raise ValueError('Both x and y must be ndarrays, or both must be None.') 948 949 950@np_utils.np_doc('select') 951def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring 952 if len(condlist) != len(choicelist): 953 msg = 'condlist must have length equal to choicelist ({} vs {})' 954 raise ValueError(msg.format(len(condlist), len(choicelist))) 955 if not condlist: 956 raise ValueError('condlist must be non-empty') 957 choices = _promote_dtype(default, *choicelist) 958 choicelist = choices[1:] 959 output = choices[0] 960 # The traversal is in reverse order so we can return the first value in 961 # choicelist where condlist is True. 962 for cond, choice in zip(condlist[::-1], choicelist[::-1]): 963 output = where(cond, choice, output) 964 return output 965 966 967@np_utils.np_doc('shape', link=np_utils.Link( 968 'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html')) 969def shape(a): 970 a = asarray(a) 971 return a.shape 972 973 974@np_utils.np_doc('ndim', link=np_utils.NoLink()) 975def ndim(a): 976 a = asarray(a) 977 return a.ndim 978 979 980@np_utils.np_doc('isscalar') 981def isscalar(num): 982 return ndim(num) == 0 983 984 985def _boundaries_to_sizes(a, boundaries, axis): 986 """Converting boundaries of splits to sizes of splits. 987 988 Args: 989 a: the array to be split. 990 boundaries: the boundaries, as in np.split. 991 axis: the axis along which to split. 992 993 Returns: 994 A list of sizes of the splits, as in tf.split. 995 """ 996 if axis >= len(a.shape): 997 raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape)) 998 total_size = a.shape[axis] 999 sizes = [] 1000 sizes_sum = 0 1001 prev = 0 1002 for i, b in enumerate(boundaries): 1003 size = b - prev 1004 if size < 0: 1005 raise ValueError('The %s-th boundary %s is smaller than the previous ' 1006 'boundary %s' % (i, b, prev)) 1007 size = min(size, max(0, total_size - sizes_sum)) 1008 sizes.append(size) 1009 sizes_sum += size 1010 prev = b 1011 sizes.append(max(0, total_size - sizes_sum)) 1012 return sizes 1013 1014 1015@np_utils.np_doc('split') 1016def split(ary, indices_or_sections, axis=0): 1017 ary = asarray(ary) 1018 if not isinstance(indices_or_sections, six.integer_types): 1019 indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) 1020 return array_ops.split(ary, indices_or_sections, axis=axis) 1021 1022 1023def _split_on_axis(np_fun_name, axis): 1024 1025 @np_utils.np_doc(np_fun_name) 1026 def f(ary, indices_or_sections): 1027 return split(ary, indices_or_sections, axis=axis) 1028 1029 return f 1030 1031 1032vsplit = _split_on_axis('vsplit', axis=0) 1033hsplit = _split_on_axis('hsplit', axis=1) 1034dsplit = _split_on_axis('dsplit', axis=2) 1035 1036 1037@np_utils.np_doc('broadcast_to') 1038def broadcast_to(array, shape): # pylint: disable=redefined-outer-name 1039 return full(shape, array) 1040 1041 1042@np_utils.np_doc('stack') 1043def stack(arrays, axis=0): # pylint: disable=missing-function-docstring 1044 if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)): 1045 arrays = asarray(arrays) 1046 if axis == 0: 1047 return arrays 1048 else: 1049 return swapaxes(arrays, 0, axis) 1050 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1051 unwrapped_arrays = [ 1052 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1053 ] 1054 return asarray(array_ops.stack(unwrapped_arrays, axis)) 1055 1056 1057@np_utils.np_doc('hstack') 1058def hstack(tup): 1059 arrays = [atleast_1d(a) for a in tup] 1060 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1061 unwrapped_arrays = [ 1062 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1063 ] 1064 rank = array_ops.rank(unwrapped_arrays[0]) 1065 return np_utils.cond( 1066 math_ops.equal(rank, 1067 1), lambda: array_ops.concat(unwrapped_arrays, axis=0), 1068 lambda: array_ops.concat(unwrapped_arrays, axis=1)) 1069 1070 1071@np_utils.np_doc('vstack') 1072def vstack(tup): 1073 arrays = [atleast_2d(a) for a in tup] 1074 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1075 unwrapped_arrays = [ 1076 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1077 ] 1078 return array_ops.concat(unwrapped_arrays, axis=0) 1079 1080 1081@np_utils.np_doc('dstack') 1082def dstack(tup): 1083 arrays = [atleast_3d(a) for a in tup] 1084 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1085 unwrapped_arrays = [ 1086 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1087 ] 1088 return array_ops.concat(unwrapped_arrays, axis=2) 1089 1090 1091def _pad_left_to(n, old_shape): 1092 old_shape = asarray(old_shape, dtype=np.int32) 1093 new_shape = array_ops.pad( 1094 old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]], 1095 constant_values=1) 1096 return asarray(new_shape) 1097 1098 1099def _atleast_nd(n, new_shape, *arys): 1100 """Reshape arrays to be at least `n`-dimensional. 1101 1102 Args: 1103 n: The minimal rank. 1104 new_shape: a function that takes `n` and the old shape and returns the 1105 desired new shape. 1106 *arys: ndarray(s) to be reshaped. 1107 1108 Returns: 1109 The reshaped array(s). 1110 """ 1111 1112 def f(x): 1113 # pylint: disable=g-long-lambda 1114 x = asarray(x) 1115 return asarray( 1116 np_utils.cond( 1117 np_utils.greater(n, array_ops.rank(x)), 1118 lambda: reshape(x, new_shape(n, array_ops.shape(x))), 1119 lambda: x)) 1120 1121 arys = list(map(f, arys)) 1122 if len(arys) == 1: 1123 return arys[0] 1124 else: 1125 return arys 1126 1127 1128@np_utils.np_doc('atleast_1d') 1129def atleast_1d(*arys): 1130 return _atleast_nd(1, _pad_left_to, *arys) 1131 1132 1133@np_utils.np_doc('atleast_2d') 1134def atleast_2d(*arys): 1135 return _atleast_nd(2, _pad_left_to, *arys) 1136 1137 1138@np_utils.np_doc('atleast_3d') 1139def atleast_3d(*arys): # pylint: disable=missing-docstring 1140 1141 def new_shape(_, old_shape): 1142 # pylint: disable=g-long-lambda 1143 ndim_ = array_ops.size(old_shape) 1144 return np_utils.cond( 1145 math_ops.equal(ndim_, 0), 1146 lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32), 1147 lambda: np_utils.cond( 1148 math_ops.equal(ndim_, 1), lambda: array_ops.pad( 1149 old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad( 1150 old_shape, [[0, 1]], constant_values=1))) 1151 1152 return _atleast_nd(3, new_shape, *arys) 1153 1154 1155@np_utils.np_doc('nonzero') 1156def nonzero(a): 1157 a = atleast_1d(a) 1158 if a.shape.rank is None: 1159 raise ValueError("The rank of `a` is unknown, so we can't decide how many " 1160 'arrays to return.') 1161 return array_ops.unstack( 1162 array_ops.where_v2(math_ops.cast(a, dtypes.bool)), 1163 a.shape.rank, 1164 axis=1) 1165 1166 1167@np_utils.np_doc('diag_indices') 1168def diag_indices(n, ndim=2): # pylint: disable=missing-docstring,redefined-outer-name 1169 if n < 0: 1170 raise ValueError( 1171 'n argument to diag_indices must be nonnegative, got {}'.format(n)) 1172 if ndim < 0: 1173 raise ValueError( 1174 'ndim argument to diag_indices must be nonnegative, got {}'.format( 1175 ndim)) 1176 1177 return (math_ops.range(n),) * ndim 1178 1179 1180@np_utils.np_doc('tri') 1181def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring 1182 M = M if M is not None else N 1183 if dtype is not None: 1184 dtype = np_utils.result_type(dtype) 1185 else: 1186 dtype = np_dtypes.default_float_type() 1187 1188 if k < 0: 1189 lower = -k - 1 1190 if lower > N: 1191 r = array_ops.zeros([N, M], dtype) 1192 else: 1193 # Keep as tf bool, since we create an upper triangular matrix and invert 1194 # it. 1195 o = array_ops.ones([N, M], dtype=dtypes.bool) 1196 r = math_ops.cast( 1197 math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype) 1198 else: 1199 o = array_ops.ones([N, M], dtype) 1200 if k > M: 1201 r = o 1202 else: 1203 r = array_ops.matrix_band_part(o, -1, k) 1204 return r 1205 1206 1207@np_utils.np_doc('tril') 1208def tril(m, k=0): # pylint: disable=missing-docstring 1209 m = asarray(m) 1210 if m.shape.ndims is None: 1211 raise ValueError('Argument to tril should have known rank') 1212 m_shape = m.shape.as_list() 1213 1214 if len(m_shape) < 2: 1215 raise ValueError('Argument to tril must have rank at least 2') 1216 1217 if m_shape[-1] is None or m_shape[-2] is None: 1218 raise ValueError('Currently, the last two dimensions of the input array ' 1219 'need to be known.') 1220 1221 z = constant_op.constant(0, m.dtype) 1222 1223 mask = tri(*m_shape[-2:], k=k, dtype=bool) 1224 return array_ops.where_v2( 1225 array_ops.broadcast_to(mask, array_ops.shape(m)), m, z) 1226 1227 1228@np_utils.np_doc('triu') 1229def triu(m, k=0): # pylint: disable=missing-docstring 1230 m = asarray(m) 1231 if m.shape.ndims is None: 1232 raise ValueError('Argument to triu should have known rank') 1233 m_shape = m.shape.as_list() 1234 1235 if len(m_shape) < 2: 1236 raise ValueError('Argument to triu must have rank at least 2') 1237 1238 if m_shape[-1] is None or m_shape[-2] is None: 1239 raise ValueError('Currently, the last two dimensions of the input array ' 1240 'need to be known.') 1241 1242 z = constant_op.constant(0, m.dtype) 1243 1244 mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) 1245 return array_ops.where_v2( 1246 array_ops.broadcast_to(mask, array_ops.shape(m)), z, m) 1247 1248 1249@np_utils.np_doc('flip') 1250def flip(m, axis=None): # pylint: disable=missing-docstring 1251 m = asarray(m) 1252 1253 if axis is None: 1254 return array_ops.reverse(m, math_ops.range(array_ops.rank(m))) 1255 1256 axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access 1257 1258 return array_ops.reverse(m, [axis]) 1259 1260 1261@np_utils.np_doc('flipud') 1262def flipud(m): # pylint: disable=missing-docstring 1263 return flip(m, 0) 1264 1265 1266@np_utils.np_doc('fliplr') 1267def fliplr(m): # pylint: disable=missing-docstring 1268 return flip(m, 1) 1269 1270 1271@np_utils.np_doc('roll') 1272def roll(a, shift, axis=None): # pylint: disable=missing-docstring 1273 a = asarray(a) 1274 1275 if axis is not None: 1276 return manip_ops.roll(a, shift, axis) 1277 1278 # If axis is None, the roll happens as a 1-d tensor. 1279 original_shape = array_ops.shape(a) 1280 a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0) 1281 return array_ops.reshape(a, original_shape) 1282 1283 1284@np_utils.np_doc('rot90') 1285def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring 1286 m_rank = array_ops.rank(m) 1287 ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access 1288 1289 k = k % 4 1290 if k == 0: 1291 return m 1292 elif k == 2: 1293 return flip(flip(m, ax1), ax2) 1294 else: 1295 perm = math_ops.range(m_rank) 1296 perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1]) 1297 1298 if k == 1: 1299 return transpose(flip(m, ax2), perm) 1300 else: 1301 return flip(transpose(m, perm), ax2) 1302 1303 1304@np_utils.np_doc('vander') 1305def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name 1306 x = asarray(x) 1307 1308 x_shape = array_ops.shape(x) 1309 N = N or x_shape[0] 1310 1311 N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name 1312 if N_temp is not None: 1313 N = N_temp 1314 if N < 0: 1315 raise ValueError('N must be nonnegative') 1316 else: 1317 control_flow_ops.Assert(N >= 0, [N]) 1318 1319 rank = array_ops.rank(x) 1320 rank_temp = np_utils.get_static_value(rank) 1321 if rank_temp is not None: 1322 rank = rank_temp 1323 if rank != 1: 1324 raise ValueError('x must be a one-dimensional array') 1325 else: 1326 control_flow_ops.Assert(math_ops.equal(rank, 1), [rank]) 1327 1328 if increasing: 1329 start = 0 1330 limit = N 1331 delta = 1 1332 else: 1333 start = N - 1 1334 limit = -1 1335 delta = -1 1336 1337 x = array_ops.expand_dims(x, -1) 1338 return math_ops.pow( 1339 x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)) 1340 1341 1342@np_utils.np_doc('ix_') 1343def ix_(*args): # pylint: disable=missing-docstring 1344 n = len(args) 1345 output = [] 1346 for i, a in enumerate(args): 1347 a = asarray(a) 1348 a_rank = array_ops.rank(a) 1349 a_rank_temp = np_utils.get_static_value(a_rank) 1350 if a_rank_temp is not None: 1351 a_rank = a_rank_temp 1352 if a_rank != 1: 1353 raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format( 1354 i, a_rank)) 1355 else: 1356 control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank]) 1357 1358 new_shape = [1] * n 1359 new_shape[i] = -1 1360 dtype = a.dtype 1361 if dtype == dtypes.bool: 1362 output.append(array_ops.reshape(nonzero(a)[0], new_shape)) 1363 elif dtype.is_integer: 1364 output.append(array_ops.reshape(a, new_shape)) 1365 else: 1366 raise ValueError( 1367 'Only integer and bool dtypes are supported, got {}'.format(dtype)) 1368 1369 return output 1370 1371 1372@np_utils.np_doc('broadcast_arrays') 1373def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring 1374 subok = kwargs.pop('subok', False) 1375 if subok: 1376 raise ValueError('subok=True is not supported.') 1377 if kwargs: 1378 raise ValueError('Received unsupported arguments {}'.format(kwargs.keys())) 1379 1380 args = [asarray(arg) for arg in args] 1381 return np_utils.tf_broadcast(*args) 1382 1383 1384@np_utils.np_doc_only('sign') 1385def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstring,redefined-outer-name 1386 if out: 1387 raise ValueError('tf.numpy doesnt support setting out.') 1388 if where: 1389 raise ValueError('tf.numpy doesnt support setting where.') 1390 if kwargs: 1391 raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys())) 1392 1393 x = asarray(x) 1394 dtype = x.dtype.as_numpy_dtype 1395 if np.issubdtype(dtype, np.complexfloating): 1396 result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype) 1397 else: 1398 result = math_ops.sign(x) 1399 1400 return result 1401 1402 1403# Note that np.take_along_axis may not be present in some supported versions of 1404# numpy. 1405@np_utils.np_doc('take_along_axis') 1406def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring 1407 arr = asarray(arr) 1408 indices = asarray(indices) 1409 1410 if axis is None: 1411 return take_along_axis(arr.ravel(), indices, 0) 1412 1413 rank = array_ops.rank(arr) 1414 axis = axis + rank if axis < 0 else axis 1415 1416 # Broadcast shapes to match, ensure that the axis of interest is not 1417 # broadcast. 1418 arr_shape_original = array_ops.shape(arr) 1419 indices_shape_original = array_ops.shape(indices) 1420 arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1]) 1421 indices_shape = array_ops.tensor_scatter_update(indices_shape_original, 1422 [[axis]], [1]) 1423 broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape, 1424 indices_shape) 1425 arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]], 1426 [arr_shape_original[axis]]) 1427 indices_shape = array_ops.tensor_scatter_update( 1428 broadcasted_shape, [[axis]], [indices_shape_original[axis]]) 1429 arr = array_ops.broadcast_to(arr, arr_shape) 1430 indices = array_ops.broadcast_to(indices, indices_shape) 1431 1432 # Save indices shape so we can restore it later. 1433 possible_result_shape = indices.shape 1434 1435 # Correct indices since gather doesn't correctly handle negative indices. 1436 indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices) 1437 1438 swapaxes_ = lambda t: swapaxes(t, axis, -1) 1439 1440 dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1)) 1441 arr = np_utils.cond(dont_move_axis_to_end, lambda: arr, 1442 lambda: swapaxes_(arr)) 1443 indices = np_utils.cond(dont_move_axis_to_end, lambda: indices, 1444 lambda: swapaxes_(indices)) 1445 1446 arr_shape = array_ops.shape(arr) 1447 arr = array_ops.reshape(arr, [-1, arr_shape[-1]]) 1448 1449 indices_shape = array_ops.shape(indices) 1450 indices = array_ops.reshape(indices, [-1, indices_shape[-1]]) 1451 1452 result = array_ops.gather(arr, indices, batch_dims=1) 1453 result = array_ops.reshape(result, indices_shape) 1454 result = np_utils.cond(dont_move_axis_to_end, lambda: result, 1455 lambda: swapaxes_(result)) 1456 result.set_shape(possible_result_shape) 1457 1458 return result 1459 1460 1461_SLICE_ERORR = ( 1462 'only integers, slices (`:`), ellipsis (`...`), ' 1463 'numpy.newaxis (`None`) and integer or boolean arrays are valid indices') 1464 1465 1466def _as_index(idx, need_scalar=True): 1467 """Helper function to parse idx as an index. 1468 1469 Args: 1470 idx: index 1471 need_scalar: If idx needs to be a scalar value. 1472 1473 Returns: 1474 A pair, (indx, bool). First one is the parsed index and can be a tensor, 1475 or scalar integer / Dimension. Second one is True if rank is known to be 0. 1476 1477 Raises: 1478 IndexError: For incorrect indices. 1479 """ 1480 if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): 1481 return idx, True 1482 data = asarray(idx) 1483 if data.dtype == dtypes.bool: 1484 if data.shape.ndims != 1: 1485 # TODO(agarwal): handle higher rank boolean masks. 1486 raise NotImplementedError('Need rank 1 for bool index %s' % idx) 1487 data = array_ops.where_v2(data) 1488 data = array_ops.reshape(data, [-1]) 1489 if need_scalar and data.shape.rank not in (None, 0): 1490 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1491 np_dtype = data.dtype.as_numpy_dtype 1492 if not np.issubdtype(np_dtype, np.integer): 1493 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1494 if data.dtype not in (dtypes.int64, dtypes.int32): 1495 # TF slicing can only handle int32/int64. So we need to cast. 1496 promoted_dtype = np.promote_types(np.int32, np_dtype) 1497 if promoted_dtype == np.int32: 1498 data = math_ops.cast(data, dtypes.int32) 1499 elif promoted_dtype == np.int64: 1500 data = math_ops.cast(data, dtypes.int64) 1501 else: 1502 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1503 return data, data.shape.rank == 0 1504 1505 1506class _UpdateMethod(enum.Enum): 1507 UPDATE = 0 1508 ADD = 1 1509 MIN = 2 1510 MAX = 3 1511 1512 1513def _slice_helper(tensor, slice_spec, update_method=None, updates=None): 1514 """Helper function for __getitem__ and _with_index_update_helper. 1515 1516 This function collects the indices in `slice_spec` into two buckets, which we 1517 can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2 1518 `gather`. They also correspond to "basic indices" and "advanced indices" in 1519 numpy. This function supports both reading and writing at the indices. The 1520 reading path can be summarized as `gather(stride_slice(tensor, idx1), 1521 idx2)`. The writing path can be summarized as `strided_slice_update(tensor, 1522 idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`. (`gather` here 1523 means `tf.gather` or `tf.gather_nd`; `scatter` here means 1524 `tf.tensor_scatter_update`.) The writing path is inefficient because it needs 1525 to first read out a portion (probably much larger than `updates`) of `tensor` 1526 using `strided_slice`, update it, and then write the portion back. An 1527 alternative approach is to only use `scatter`, which amounts to using the 1528 indexing mechanism of gather/scatter to implement 1529 strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter 1530 because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but 1531 not TF gather/scatter because they don't support spans (except those that 1532 cover entire dimensions, i.e. `:`). If we materialize spans into individual 1533 indices, the size of the index tensor would explode. (Note that XLA 1534 Gather/Scatter have a similar problem for stride > 1 because they don't 1535 support strides. Indices such as `1:2:8` will need to be materialized into 1536 individual indices such as [1, 3, 5, 7].) 1537 1538 Args: 1539 tensor: the tensor to be read from or write into. 1540 slice_spec: the indices. 1541 update_method: (optional) a member of `_UpdateMethod`, indicating how to 1542 update the values (replacement, add, etc.). `None` indicates just reading. 1543 updates: (optional) the new values to write into `tensor`. It must have the 1544 same dtype as `tensor`. 1545 1546 Returns: 1547 The result of reading (if `update_method` is `None`) or the updated `tensor` 1548 after writing. 1549 """ 1550 begin, end, strides = [], [], [] 1551 new_axis_mask, shrink_axis_mask = 0, 0 1552 begin_mask, end_mask = 0, 0 1553 ellipsis_mask = 0 1554 advanced_indices = [] 1555 shrink_indices = [] 1556 for index, s in enumerate(slice_spec): 1557 if isinstance(s, slice): 1558 if s.start is not None: 1559 begin.append(_as_index(s.start)[0]) 1560 else: 1561 begin.append(0) 1562 begin_mask |= (1 << index) 1563 if s.stop is not None: 1564 end.append(_as_index(s.stop)[0]) 1565 else: 1566 end.append(0) 1567 end_mask |= (1 << index) 1568 if s.step is not None: 1569 strides.append(_as_index(s.step)[0]) 1570 else: 1571 strides.append(1) 1572 elif s is Ellipsis: 1573 begin.append(0) 1574 end.append(0) 1575 strides.append(1) 1576 ellipsis_mask |= (1 << index) 1577 elif s is array_ops.newaxis: 1578 begin.append(0) 1579 end.append(0) 1580 strides.append(1) 1581 new_axis_mask |= (1 << index) 1582 else: 1583 s, is_scalar = _as_index(s, False) 1584 if is_scalar: 1585 begin.append(s) 1586 end.append(s + 1) 1587 strides.append(1) 1588 shrink_axis_mask |= (1 << index) 1589 shrink_indices.append(index) 1590 else: 1591 begin.append(0) 1592 end.append(0) 1593 strides.append(1) 1594 begin_mask |= (1 << index) 1595 end_mask |= (1 << index) 1596 advanced_indices.append((index, s, ellipsis_mask != 0)) 1597 1598 # stack possibly involves no tensors, so we must use op_scope correct graph. 1599 with ops.name_scope( 1600 None, 1601 'strided_slice', [tensor] + begin + end + strides, 1602 skip_on_eager=False) as name: 1603 if begin: 1604 packed_begin, packed_end, packed_strides = (array_ops.stack(begin), 1605 array_ops.stack(end), 1606 array_ops.stack(strides)) 1607 if (packed_begin.dtype == dtypes.int64 or 1608 packed_end.dtype == dtypes.int64 or 1609 packed_strides.dtype == dtypes.int64): 1610 if packed_begin.dtype != dtypes.int64: 1611 packed_begin = math_ops.cast(packed_begin, dtypes.int64) 1612 if packed_end.dtype != dtypes.int64: 1613 packed_end = math_ops.cast(packed_end, dtypes.int64) 1614 if packed_strides.dtype != dtypes.int64: 1615 packed_strides = math_ops.cast(packed_strides, dtypes.int64) 1616 else: 1617 var_empty = constant_op.constant([], dtype=dtypes.int32) 1618 packed_begin = packed_end = packed_strides = var_empty 1619 if update_method == _UpdateMethod.UPDATE and not advanced_indices: 1620 return array_ops.tensor_strided_slice_update( 1621 tensor, 1622 packed_begin, 1623 packed_end, 1624 packed_strides, 1625 updates, 1626 begin_mask=begin_mask, 1627 end_mask=end_mask, 1628 shrink_axis_mask=shrink_axis_mask, 1629 new_axis_mask=new_axis_mask, 1630 ellipsis_mask=ellipsis_mask, 1631 name=name) 1632 else: 1633 # TODO(b/164251540): Find a better way to support update that does not 1634 # involve one read + two writes. 1635 if updates is not None: 1636 original_tensor = tensor 1637 # TODO(agarwal): set_shape on tensor to set rank. 1638 tensor = array_ops.strided_slice( 1639 tensor, 1640 packed_begin, 1641 packed_end, 1642 packed_strides, 1643 begin_mask=begin_mask, 1644 end_mask=end_mask, 1645 shrink_axis_mask=shrink_axis_mask, 1646 new_axis_mask=new_axis_mask, 1647 ellipsis_mask=ellipsis_mask, 1648 name=name) 1649 if not advanced_indices: 1650 if update_method is None: 1651 return tensor 1652 assert update_method != _UpdateMethod.UPDATE 1653 # TF lacks TensorStridedSliceAdd and alike, so we need to do 1654 # read+add+update. 1655 if update_method == _UpdateMethod.ADD: 1656 update_op = math_ops.add 1657 elif update_method == _UpdateMethod.MIN: 1658 update_op = math_ops.minimum 1659 elif update_method == _UpdateMethod.MAX: 1660 update_op = math_ops.maximum 1661 return array_ops.tensor_strided_slice_update( 1662 original_tensor, 1663 packed_begin, 1664 packed_end, 1665 packed_strides, 1666 update_op(tensor, updates), 1667 begin_mask=begin_mask, 1668 end_mask=end_mask, 1669 shrink_axis_mask=shrink_axis_mask, 1670 new_axis_mask=new_axis_mask, 1671 ellipsis_mask=ellipsis_mask, 1672 name=name + '_2') 1673 advanced_indices_map = {} 1674 for index, data, had_ellipsis in advanced_indices: 1675 if had_ellipsis: 1676 num_shrink = len([x for x in shrink_indices if x > index]) 1677 dim = index - len(slice_spec) + num_shrink 1678 else: 1679 num_shrink = len([x for x in shrink_indices if x < index]) 1680 dim = index - num_shrink 1681 advanced_indices_map[dim] = data 1682 dims = sorted(advanced_indices_map.keys()) 1683 dims_contiguous = True 1684 if len(dims) > 1: 1685 if dims[0] < 0 and dims[-1] >= 0: # not all same sign 1686 dims_contiguous = False 1687 else: 1688 for i in range(len(dims) - 1): 1689 if dims[i] + 1 != dims[i + 1]: 1690 dims_contiguous = False 1691 break 1692 indices = [advanced_indices_map[x] for x in dims] 1693 indices = _promote_dtype(*indices) 1694 indices = np_utils.tf_broadcast(*indices) 1695 stacked_indices = array_ops.stack(indices, axis=-1) 1696 # Skip the contiguous-dims optimization for update because there is no 1697 # tf.*scatter* op that supports the `axis` argument. 1698 if not dims_contiguous or updates is not None: 1699 if range(len(dims)) != dims: 1700 tensor = moveaxis(tensor, dims, range(len(dims))) 1701 tensor_shape_prefix = array_ops.shape( 1702 tensor, out_type=stacked_indices.dtype)[:len(dims)] 1703 stacked_indices = array_ops.where_v2( 1704 stacked_indices < 0, stacked_indices + tensor_shape_prefix, 1705 stacked_indices) 1706 if updates is None: 1707 return array_ops.gather_nd(tensor, stacked_indices) 1708 else: 1709 # We only need to move-axis `updates` in the contiguous case becausce 1710 # only in this case the result dimensions of advanced indexing are in 1711 # the middle of `updates`. In the non-contiguous case, those dimensions 1712 # are always at the front. 1713 if dims_contiguous: 1714 # TODO(wangpeng): Support unknown rank (e.g. by partially flattening 1715 # `updates`) 1716 if stacked_indices.shape.rank is None: 1717 raise NotImplementedError( 1718 'Rank of the advanced indices must currently be known') 1719 batch_size = stacked_indices.shape.rank - 1 1720 batch_start = dims[0] 1721 if batch_start < 0: 1722 batch_start += len(dims) - batch_size 1723 def range_(start, length): 1724 return range(start, start + length) 1725 updates = moveaxis(updates, range_(batch_start, batch_size), 1726 range(batch_size)) 1727 if update_method == _UpdateMethod.UPDATE: 1728 update_op = array_ops.tensor_scatter_update 1729 elif update_method == _UpdateMethod.ADD: 1730 update_op = array_ops.tensor_scatter_add 1731 elif update_method == _UpdateMethod.MIN: 1732 update_op = array_ops.tensor_scatter_min 1733 elif update_method == _UpdateMethod.MAX: 1734 update_op = array_ops.tensor_scatter_max 1735 tensor = update_op( 1736 tensor, stacked_indices, updates) 1737 if range(len(dims)) != dims: 1738 tensor = moveaxis(tensor, range(len(dims)), dims) 1739 return array_ops.tensor_strided_slice_update( 1740 original_tensor, 1741 packed_begin, 1742 packed_end, 1743 packed_strides, 1744 tensor, 1745 begin_mask=begin_mask, 1746 end_mask=end_mask, 1747 shrink_axis_mask=shrink_axis_mask, 1748 new_axis_mask=new_axis_mask, 1749 ellipsis_mask=ellipsis_mask, 1750 name=name + '_2') 1751 # Note that gather_nd does not support gathering from inside the array. 1752 # To avoid shuffling data back and forth, we transform the indices and 1753 # do a gather instead. 1754 rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access 1755 dims = [(x + rank if x < 0 else x) for x in dims] 1756 shape_tensor = array_ops.shape(tensor) 1757 dim_sizes = array_ops.gather(shape_tensor, dims) 1758 if len(dims) == 1: 1759 stacked_indices = indices[0] 1760 stacked_indices = math_ops.cast(stacked_indices, dtypes.int32) 1761 stacked_indices = array_ops.where_v2(stacked_indices < 0, 1762 stacked_indices + dim_sizes, 1763 stacked_indices) 1764 axis = dims[0] 1765 if len(dims) > 1: 1766 index_scaling = math_ops.cumprod( 1767 dim_sizes, reverse=True, exclusive=True) 1768 def _tensordot(a, b): 1769 # TODO(b/168657656): This function should be replaced by 1770 # tensordot(axis=1) once MatMul has int32 XLA kernel. 1771 b = array_ops.broadcast_to(b, array_ops.shape(a)) 1772 return math_ops.reduce_sum(a * b, axis=-1) 1773 stacked_indices = _tensordot(stacked_indices, index_scaling) 1774 flat_shape = array_ops.concat( 1775 [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]], 1776 axis=0) 1777 tensor = array_ops.reshape(tensor, flat_shape) 1778 1779 return array_ops.gather(tensor, stacked_indices, axis=axis) 1780 1781 1782def _as_spec_tuple(slice_spec): 1783 """Convert slice_spec to tuple.""" 1784 if isinstance(slice_spec, 1785 (list, tuple)) and not isinstance(slice_spec, np.ndarray): 1786 is_index = True 1787 for s in slice_spec: 1788 if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)): 1789 is_index = False 1790 break 1791 elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0: 1792 is_index = False 1793 break 1794 if not is_index: 1795 return tuple(slice_spec) 1796 return (slice_spec,) 1797 1798 1799def _getitem(self, slice_spec): 1800 """Implementation of ndarray.__getitem__.""" 1801 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 1802 slice_spec.dtype == dtypes.bool) or 1803 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 1804 slice_spec.dtype == np.bool_)): 1805 return array_ops.boolean_mask(tensor=self, mask=slice_spec) 1806 1807 if not isinstance(slice_spec, tuple): 1808 slice_spec = _as_spec_tuple(slice_spec) 1809 1810 result_t = _slice_helper(self, slice_spec) 1811 return result_t 1812 1813 1814def _with_index_update_helper(update_method, a, slice_spec, updates): 1815 """Implementation of ndarray._with_index_*.""" 1816 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 1817 slice_spec.dtype == dtypes.bool) or 1818 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 1819 slice_spec.dtype == np.bool_)): 1820 slice_spec = nonzero(slice_spec) 1821 1822 if not isinstance(slice_spec, tuple): 1823 slice_spec = _as_spec_tuple(slice_spec) 1824 1825 a_dtype = a.dtype 1826 a, updates = _promote_dtype_binary(a, updates) 1827 result_t = _slice_helper(a, slice_spec, update_method, updates) 1828 return result_t.astype(a_dtype) 1829 1830 1831setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem) 1832setattr(np_arrays.ndarray, '_with_index_update', 1833 functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE)) 1834setattr(np_arrays.ndarray, '_with_index_add', 1835 functools.partial(_with_index_update_helper, _UpdateMethod.ADD)) 1836setattr(np_arrays.ndarray, '_with_index_min', 1837 functools.partial(_with_index_update_helper, _UpdateMethod.MIN)) 1838setattr(np_arrays.ndarray, '_with_index_max', 1839 functools.partial(_with_index_update_helper, _UpdateMethod.MAX)) 1840