1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Common array methods.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import enum 23import functools 24import math 25import numbers 26import numpy as np 27import six 28 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import clip_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import linalg_ops 37from tensorflow.python.ops import manip_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import sort_ops 40from tensorflow.python.ops.numpy_ops import np_arrays 41from tensorflow.python.ops.numpy_ops import np_dtypes 42from tensorflow.python.ops.numpy_ops import np_export 43from tensorflow.python.ops.numpy_ops import np_utils 44from tensorflow.python.util import nest 45 46 47newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis) 48 49 50@np_utils.np_doc('empty') 51def empty(shape, dtype=float): # pylint: disable=redefined-outer-name 52 return zeros(shape, dtype) 53 54 55@np_utils.np_doc('empty_like') 56def empty_like(a, dtype=None): 57 return zeros_like(a, dtype) 58 59 60@np_utils.np_doc('zeros') 61def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name 62 dtype = ( 63 np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type()) 64 return array_ops.zeros(shape, dtype=dtype) 65 66 67@np_utils.np_doc('zeros_like') 68def zeros_like(a, dtype=None): # pylint: disable=missing-docstring 69 if dtype is None: 70 # We need to let np_utils.result_type decide the dtype, not tf.zeros_like 71 dtype = np_utils.result_type(a) 72 else: 73 # TF and numpy has different interpretations of Python types such as 74 # `float`, so we let `np_utils.result_type` decide. 75 dtype = np_utils.result_type(dtype) 76 dtype = dtypes.as_dtype(dtype) # Work around b/149877262 77 return array_ops.zeros_like(a, dtype) 78 79 80@np_utils.np_doc('ones') 81def ones(shape, dtype=float): # pylint: disable=redefined-outer-name 82 if dtype: 83 dtype = np_utils.result_type(dtype) 84 return array_ops.ones(shape, dtype=dtype) 85 86 87@np_utils.np_doc('ones_like') 88def ones_like(a, dtype=None): 89 if dtype is None: 90 dtype = np_utils.result_type(a) 91 else: 92 dtype = np_utils.result_type(dtype) 93 return array_ops.ones_like(a, dtype) 94 95 96@np_utils.np_doc('eye') 97def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-docstring 98 if dtype: 99 dtype = np_utils.result_type(dtype) 100 if not M: 101 M = N 102 # Making sure N, M and k are `int` 103 N = int(N) 104 M = int(M) 105 k = int(k) 106 if k >= M or -k >= N: 107 # tf.linalg.diag will raise an error in this case 108 return zeros([N, M], dtype=dtype) 109 if k == 0: 110 return linalg_ops.eye(N, M, dtype=dtype) 111 # We need the precise length, otherwise tf.linalg.diag will raise an error 112 diag_len = min(N, M) 113 if k > 0: 114 if N >= M: 115 diag_len -= k 116 elif N + k > M: 117 diag_len = M - k 118 elif k <= 0: 119 if M >= N: 120 diag_len += k 121 elif M - k > N: 122 diag_len = N + k 123 diagonal_ = array_ops.ones([diag_len], dtype=dtype) 124 return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k) 125 126 127@np_utils.np_doc('identity') 128def identity(n, dtype=float): 129 return eye(N=n, M=n, dtype=dtype) 130 131 132@np_utils.np_doc('full') 133def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name 134 if not isinstance(shape, np_arrays.ndarray): 135 shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32)) 136 shape = atleast_1d(shape) 137 fill_value = asarray(fill_value, dtype=dtype) 138 return array_ops.broadcast_to(fill_value, shape) 139 140 141# Using doc only here since np full_like signature doesn't seem to have the 142# shape argument (even though it exists in the documentation online). 143@np_utils.np_doc_only('full_like') 144def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # pylint: disable=missing-docstring,redefined-outer-name 145 """order, subok and shape arguments mustn't be changed.""" 146 if order != 'K': 147 raise ValueError('Non-standard orders are not supported.') 148 if not subok: 149 raise ValueError('subok being False is not supported.') 150 if shape: 151 raise ValueError('Overriding the shape is not supported.') 152 153 a = asarray(a) 154 dtype = dtype or np_utils.result_type(a) 155 fill_value = asarray(fill_value, dtype=dtype) 156 return array_ops.broadcast_to(fill_value, array_ops.shape(a)) 157 158 159def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 160 """Main implementation of np.array().""" 161 result_t = val 162 163 if not isinstance(result_t, ops.Tensor): 164 if not dtype: 165 dtype = np_utils.result_type(result_t) 166 # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because 167 # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) 168 # while np.array allows them. We need to convert-then-cast. 169 170 # EagerTensor conversion complains about "mixed types" when converting 171 # tensors with no dtype information. This is because it infers types based 172 # on one selected item in the list. So e.g. when converting [2., 2j] 173 # to a tensor, it will select float32 as the inferred type and not be able 174 # to convert the list to a float 32 tensor. 175 # Since we have some information about the final dtype we care about, we 176 # supply that information so that convert_to_tensor will do best-effort 177 # conversion to that dtype first. 178 result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype) 179 result_t = math_ops.cast(result_t, dtype=dtype) 180 elif dtype: 181 result_t = math_ops.cast(result_t, dtype) 182 183 if copy: 184 result_t = array_ops.identity(result_t) 185 186 if ndmin == 0: 187 return result_t 188 189 ndims = array_ops.rank(result_t) 190 191 def true_fn(): 192 old_shape = array_ops.shape(result_t) 193 new_shape = array_ops.concat( 194 [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0) 195 return array_ops.reshape(result_t, new_shape) 196 197 result_t = np_utils.cond( 198 np_utils.greater(ndmin, ndims), true_fn, lambda: result_t) 199 return result_t 200 201 202# TODO(wangpeng): investigate whether we can make `copy` default to False. 203# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 204@np_utils.np_doc_only('array') 205def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 206 """Since Tensors are immutable, a copy is made only if val is placed on a 207 208 different device than the current one. Even if `copy` is False, a new Tensor 209 may need to be built to satisfy `dtype` and `ndim`. This is used only if `val` 210 is an ndarray or a Tensor. 211 """ # pylint:disable=g-docstring-missing-newline 212 if dtype: 213 dtype = np_utils.result_type(dtype) 214 return _array_internal(val, dtype, copy, ndmin) 215 216 217# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 218 219 220@np_utils.np_doc('asarray') 221def asarray(a, dtype=None): 222 if dtype: 223 dtype = np_utils.result_type(dtype) 224 if isinstance(a, np_arrays.ndarray) and ( 225 not dtype or dtype == a.dtype.as_numpy_dtype): 226 return a 227 return array(a, dtype, copy=False) 228 229 230@np_utils.np_doc('asanyarray') 231def asanyarray(a, dtype=None): 232 return asarray(a, dtype) 233 234 235@np_utils.np_doc('ascontiguousarray') 236def ascontiguousarray(a, dtype=None): 237 return array(a, dtype, ndmin=1) 238 239 240# Numerical ranges. 241@np_utils.np_doc('arange') 242def arange(start, stop=None, step=1, dtype=None): 243 """Returns `step`-separated values in the range [start, stop). 244 245 Args: 246 start: Start of the interval. Included in the range. 247 stop: End of the interval. If not specified, `start` is treated as 0 and 248 `start` value is used as `stop`. If specified, it is not included in the 249 range if `step` is integer. When `step` is floating point, it may or may 250 not be included. 251 step: The difference between 2 consecutive values in the output range. It is 252 recommended to use `linspace` instead of using non-integer values for 253 `step`. 254 dtype: Optional. Type of the resulting ndarray. Could be a python type, a 255 NumPy type or a TensorFlow `DType`. If not provided, the largest type of 256 `start`, `stop`, `step` is used. 257 258 Raises: 259 ValueError: If step is zero. 260 """ 261 if not step: 262 raise ValueError('step must be non-zero.') 263 if dtype: 264 dtype = np_utils.result_type(dtype) 265 else: 266 if stop is None: 267 dtype = np_utils.result_type(start, step) 268 else: 269 dtype = np_utils.result_type(start, step, stop) 270 if step > 0 and ((stop is not None and start > stop) or 271 (stop is None and start < 0)): 272 return array([], dtype=dtype) 273 if step < 0 and ((stop is not None and start < stop) or 274 (stop is None and start > 0)): 275 return array([], dtype=dtype) 276 # TODO(srbs): There are some bugs when start or stop is float type and dtype 277 # is integer type. 278 return math_ops.cast( 279 math_ops.range(start, limit=stop, delta=step), dtype=dtype) 280 281 282# Building matrices. 283@np_utils.np_doc('diag') 284def diag(v, k=0): # pylint: disable=missing-docstring 285 """Raises an error if input is not 1- or 2-d.""" 286 v = asarray(v) 287 v_rank = array_ops.rank(v) 288 289 v.shape.with_rank_at_most(2) 290 291 # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during 292 # tracing time if the shape is known. 293 control_flow_ops.Assert( 294 np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)), 295 [v_rank]) 296 297 def _diag(v, k): 298 return np_utils.cond( 299 math_ops.equal(array_ops.size(v), 0), 300 lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype), 301 lambda: array_ops.matrix_diag(v, k=k)) 302 303 def _diag_part(v, k): 304 v_shape = array_ops.shape(v) 305 v, k = np_utils.cond( 306 np_utils.logical_or( 307 np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)), 308 np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)), 309 ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) 310 result = array_ops.matrix_diag_part(v, k=k) 311 return result 312 313 result = np_utils.cond( 314 math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) 315 return result 316 317 318@np_utils.np_doc('diagonal') 319def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring 320 a = asarray(a) 321 322 maybe_rank = a.shape.rank 323 if maybe_rank is not None and offset == 0 and ( 324 axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or 325 axis2 == -1): 326 return array_ops.matrix_diag_part(a) 327 328 a = moveaxis(a, (axis1, axis2), (-2, -1)) 329 330 a_shape = array_ops.shape(a) 331 332 def _zeros(): # pylint: disable=missing-docstring 333 return (array_ops.zeros( 334 array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) 335 336 # All zeros since diag_part doesn't handle all possible k (aka offset). 337 # Written this way since cond will run shape inference on both branches, 338 # and diag_part shape inference will fail when offset is out of bounds. 339 a, offset = np_utils.cond( 340 np_utils.logical_or( 341 np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)), 342 np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)), 343 ), _zeros, lambda: (a, offset)) 344 345 a = array_ops.matrix_diag_part(a, k=offset) 346 return a 347 348 349@np_utils.np_doc('diagflat') 350def diagflat(v, k=0): 351 v = asarray(v) 352 return diag(array_ops.reshape(v, [-1]), k) 353 354 355def _promote_dtype(*arrays): 356 dtype = np_utils.result_type(*arrays) 357 def _fast_asarray(a): 358 if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype: 359 return a 360 return _array_internal(a, dtype=dtype, copy=False) 361 return [_fast_asarray(a) for a in arrays] 362 363 364def _promote_dtype_binary(t1, t2): 365 dtype = np_utils._result_type_binary(t1, t2) # pylint: disable=protected-access 366 if not( 367 isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype): 368 t1 = _array_internal(t1, dtype=dtype, copy=False) 369 if not( 370 isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype): 371 t2 = _array_internal(t2, dtype=dtype, copy=False) 372 return t1, t2 373 374 375@np_utils.np_doc('all') 376def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 377 a = asarray(a, dtype=bool) 378 return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims) 379 380 381@np_utils.np_doc('any') 382def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 383 a = asarray(a, dtype=bool) 384 return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims) 385 386 387@np_utils.np_doc('compress') 388def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,missing-function-docstring 389 condition = asarray(condition, dtype=bool) 390 a = asarray(a) 391 392 if condition.ndim != 1: 393 raise ValueError('condition must be a 1-d array.') 394 # `np.compress` treats scalars as 1-d arrays. 395 if a.ndim == 0: 396 a = ravel(a) 397 398 if axis is None: 399 a = ravel(a) 400 axis = 0 401 402 if axis < 0: 403 axis += a.ndim 404 405 assert axis >= 0 and axis < a.ndim 406 407 # `tf.boolean_mask` requires the first dimensions of array and condition to 408 # match. `np.compress` pads condition with False when it is shorter. 409 condition_t = condition 410 a_t = a 411 if condition.shape[0] < a.shape[axis]: 412 padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False) 413 condition_t = array_ops.concat([condition_t, padding], axis=0) 414 return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) 415 416 417@np_utils.np_doc('copy') 418def copy(a): 419 return array(a, copy=True) 420 421 422def _maybe_promote_to_int(a): 423 if dtypes.as_dtype(a.dtype).is_integer: 424 # If a is an integer type and its precision is less than that of `int`, 425 # the output type will be `int`. 426 a_numpy_dtype = a.dtype.as_numpy_dtype 427 output_type = np.promote_types(a_numpy_dtype, int) 428 if output_type != a_numpy_dtype: 429 a = asarray(a, dtype=output_type) 430 431 return a 432 433 434@np_utils.np_doc('cumprod') 435def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring 436 a = asarray(a, dtype=dtype) 437 438 if dtype is None: 439 a = _maybe_promote_to_int(a) 440 441 # If axis is None, the input is flattened. 442 if axis is None: 443 a = ravel(a) 444 axis = 0 445 elif axis < 0: 446 axis += array_ops.rank(a) 447 return math_ops.cumprod(a, axis) 448 449 450@np_utils.np_doc('cumsum') 451def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring 452 a = asarray(a, dtype=dtype) 453 454 if dtype is None: 455 a = _maybe_promote_to_int(a) 456 457 # If axis is None, the input is flattened. 458 if axis is None: 459 a = ravel(a) 460 axis = 0 461 elif axis < 0: 462 axis += array_ops.rank(a) 463 return math_ops.cumsum(a, axis) 464 465 466@np_utils.np_doc('imag') 467def imag(val): 468 val = asarray(val) 469 # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always 470 # return an ndarray. 471 return math_ops.imag(val) 472 473 474_TO_INT_ = 0 475_TO_FLOAT = 1 476 477 478def _reduce(tf_fn, 479 a, 480 axis=None, 481 dtype=None, 482 keepdims=None, 483 promote_int=_TO_INT_, 484 tf_bool_fn=None, 485 preserve_bool=False): 486 """A general reduction function. 487 488 Args: 489 tf_fn: the TF reduction function. 490 a: the array to be reduced. 491 axis: (optional) the axis along which to do the reduction. If None, all 492 dimensions are reduced. 493 dtype: (optional) the dtype of the result. 494 keepdims: (optional) whether to keep the reduced dimension(s). 495 promote_int: how to promote integer and bool inputs. There are three 496 choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2) 497 `_TO_FLOAT` always promotes them to a float type (determined by 498 dtypes.default_float_type); (3) None: don't promote. 499 tf_bool_fn: (optional) the TF reduction function for bool inputs. It will 500 only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype 501 is `np.bool_` and `preserve_bool` is True. 502 preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype 503 is `np.bool_` (some reductions such as np.sum convert bools to integers, 504 while others such as np.max preserve bools. 505 506 Returns: 507 An ndarray. 508 """ 509 if dtype: 510 dtype = np_utils.result_type(dtype) 511 if keepdims is None: 512 keepdims = False 513 a = asarray(a, dtype=dtype) 514 if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and 515 tf_bool_fn is not None): 516 return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims) 517 if dtype is None: 518 dtype = a.dtype.as_numpy_dtype 519 if np.issubdtype(dtype, np.integer) or dtype == np.bool_: 520 if promote_int == _TO_INT_: 521 # If a is an integer/bool type and whose bit width is less than np.int_, 522 # numpy up-casts it to np.int_ based on the documentation at 523 # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html 524 if dtype == np.bool_: 525 is_signed = True 526 width = 8 # We can use any number here that is less than 64 527 else: 528 is_signed = np.issubdtype(dtype, np.signedinteger) 529 width = np.iinfo(dtype).bits 530 # Numpy int_ and uint are defined as 'long' and 'unsigned long', so 531 # should have the same bit width. 532 if width < np.iinfo(np.int_).bits: 533 if is_signed: 534 dtype = np.int_ 535 else: 536 dtype = np.uint 537 a = math_ops.cast(a, dtype) 538 elif promote_int == _TO_FLOAT: 539 a = math_ops.cast(a, np_dtypes.default_float_type()) 540 541 if isinstance(axis, ops.Tensor) and axis.dtype not in ( 542 dtypes.int32, dtypes.int64): 543 axis = math_ops.cast(axis, dtypes.int64) 544 545 return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims) 546 547 548# TODO (DarrenZhang01): Add `axis` support to the `size` API. 549@np_utils.np_doc('size') 550def size(x, axis=None): # pylint: disable=missing-docstring 551 if axis is not None: 552 raise NotImplementedError('axis argument is not supported in the current ' 553 '`np.size` implementation') 554 if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)): 555 return 1 556 x = asarray(x) 557 if x.shape.is_fully_defined(): 558 return np.prod(x.shape.as_list(), dtype=int) 559 else: 560 return array_ops.size_v2(x) 561 562 563@np_utils.np_doc('sum') 564def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin 565 return _reduce( 566 math_ops.reduce_sum, 567 a, 568 axis=axis, 569 dtype=dtype, 570 keepdims=keepdims, 571 tf_bool_fn=math_ops.reduce_any) 572 573 574@np_utils.np_doc('prod') 575def prod(a, axis=None, dtype=None, keepdims=None): 576 return _reduce( 577 math_ops.reduce_prod, 578 a, 579 axis=axis, 580 dtype=dtype, 581 keepdims=keepdims, 582 tf_bool_fn=math_ops.reduce_all) 583 584 585@np_utils.np_doc('mean') 586def mean(a, axis=None, dtype=None, keepdims=None): 587 return _reduce( 588 math_ops.reduce_mean, 589 a, 590 axis=axis, 591 dtype=dtype, 592 keepdims=keepdims, 593 promote_int=_TO_FLOAT) 594 595 596@np_utils.np_doc('amax') 597def amax(a, axis=None, keepdims=None): 598 return _reduce( 599 math_ops.reduce_max, 600 a, 601 axis=axis, 602 dtype=None, 603 keepdims=keepdims, 604 promote_int=None, 605 tf_bool_fn=math_ops.reduce_any, 606 preserve_bool=True) 607 608 609@np_utils.np_doc('amin') 610def amin(a, axis=None, keepdims=None): 611 return _reduce( 612 math_ops.reduce_min, 613 a, 614 axis=axis, 615 dtype=None, 616 keepdims=keepdims, 617 promote_int=None, 618 tf_bool_fn=math_ops.reduce_all, 619 preserve_bool=True) 620 621 622@np_utils.np_doc('var') 623def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=missing-docstring 624 if dtype: 625 working_dtype = np_utils.result_type(a, dtype) 626 else: 627 working_dtype = None 628 if out is not None: 629 raise ValueError('Setting out is not supported.') 630 if ddof != 0: 631 # TF reduce_variance doesn't support ddof, so calculate it using raw ops. 632 def reduce_fn(input_tensor, axis, keepdims): 633 means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True) 634 centered = input_tensor - means 635 if input_tensor.dtype in (dtypes.complex64, dtypes.complex128): 636 centered = math_ops.cast( 637 math_ops.real(centered * math_ops.conj(centered)), 638 input_tensor.dtype) 639 else: 640 centered = math_ops.square(centered) 641 squared_deviations = math_ops.reduce_sum( 642 centered, axis=axis, keepdims=keepdims) 643 644 if axis is None: 645 n = array_ops.size(input_tensor) 646 else: 647 if axis < 0: 648 axis += array_ops.rank(input_tensor) 649 n = math_ops.reduce_prod( 650 array_ops.gather(array_ops.shape(input_tensor), axis)) 651 n = math_ops.cast(n - ddof, input_tensor.dtype) 652 653 return math_ops.cast(math_ops.divide(squared_deviations, n), dtype) 654 else: 655 reduce_fn = math_ops.reduce_variance 656 657 result = _reduce( 658 reduce_fn, 659 a, 660 axis=axis, 661 dtype=working_dtype, 662 keepdims=keepdims, 663 promote_int=_TO_FLOAT) 664 if dtype: 665 result = math_ops.cast(result, dtype) 666 return result 667 668 669@np_utils.np_doc('std') 670def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring 671 return _reduce( 672 math_ops.reduce_std, 673 a, 674 axis=axis, 675 dtype=None, 676 keepdims=keepdims, 677 promote_int=_TO_FLOAT) 678 679 680@np_utils.np_doc('ravel') 681def ravel(a): # pylint: disable=missing-docstring 682 a = asarray(a) 683 return array_ops.reshape(a, [-1]) 684 685 686@np_utils.np_doc('real') 687def real(val): 688 val = asarray(val) 689 # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always 690 # return an ndarray. 691 return math_ops.real(val) 692 693 694@np_utils.np_doc('repeat') 695def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring 696 a = asarray(a) 697 original_shape = a._shape_as_list() # pylint: disable=protected-access 698 # Best effort recovery of the shape. 699 known_shape = original_shape is not None and None not in original_shape 700 if known_shape: 701 if not original_shape: 702 original_shape = (repeats,) 703 else: 704 repeats_np = np.ravel(np.array(repeats)) 705 if repeats_np.size == 1: 706 repeats_np = repeats_np.item() 707 if axis is None: 708 original_shape = (repeats_np * np.prod(original_shape),) 709 else: 710 original_shape[axis] = repeats_np * original_shape[axis] 711 else: 712 if axis is None: 713 original_shape = (repeats_np.sum(),) 714 else: 715 original_shape[axis] = repeats_np.sum() 716 717 repeats = asarray(repeats) 718 result = array_ops.repeat(a, repeats, axis) 719 if known_shape: 720 result.set_shape(original_shape) 721 722 return result 723 724 725@np_utils.np_doc('around') 726def around(a, decimals=0): # pylint: disable=missing-docstring 727 a = asarray(a) 728 dtype = a.dtype.as_numpy_dtype 729 factor = math.pow(10, decimals) 730 if np.issubdtype(dtype, np.inexact): 731 factor = math_ops.cast(factor, dtype) 732 else: 733 # Use float as the working dtype when a.dtype is exact (e.g. integer), 734 # because `decimals` can be negative. 735 float_dtype = np_dtypes.default_float_type() 736 a = a.astype(float_dtype) 737 factor = math_ops.cast(factor, float_dtype) 738 a = math_ops.multiply(a, factor) 739 a = math_ops.round(a) 740 a = math_ops.divide(a, factor) 741 return a.astype(dtype) 742 743 744setattr(np_arrays.ndarray, '__round__', around) 745 746 747@np_utils.np_doc('reshape') 748def reshape(a, newshape, order='C'): 749 """order argument can only b 'C' or 'F'.""" 750 if order not in {'C', 'F'}: 751 raise ValueError('Unsupported order argument {}'.format(order)) 752 753 a = asarray(a) 754 if isinstance(newshape, int): 755 newshape = [newshape] 756 757 if order == 'F': 758 r = array_ops.transpose( 759 array_ops.reshape(array_ops.transpose(a), newshape[::-1])) 760 else: 761 r = array_ops.reshape(a, newshape) 762 763 return r 764 765 766def _reshape_method_wrapper(a, *newshape, **kwargs): 767 order = kwargs.pop('order', 'C') 768 if kwargs: 769 raise ValueError('Unsupported arguments: {}'.format(kwargs.keys())) 770 771 if len(newshape) == 1 and not isinstance(newshape[0], int): 772 newshape = newshape[0] 773 774 return reshape(a, newshape, order=order) 775 776 777@np_utils.np_doc('expand_dims') 778def expand_dims(a, axis): 779 a = asarray(a) 780 return array_ops.expand_dims(a, axis=axis) 781 782 783@np_utils.np_doc('squeeze') 784def squeeze(a, axis=None): 785 a = asarray(a) 786 return array_ops.squeeze(a, axis) 787 788 789@np_utils.np_doc('transpose') 790def transpose(a, axes=None): 791 a = asarray(a) 792 if axes is not None: 793 axes = asarray(axes) 794 return array_ops.transpose(a=a, perm=axes) 795 796 797@np_utils.np_doc('swapaxes') 798def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring 799 a = asarray(a) 800 def adjust_axes(axes, rank): 801 def f(x): 802 if isinstance(x, int): 803 if x < 0: 804 x = x + rank 805 else: 806 x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x) 807 return x 808 return nest.map_structure(f, axes) 809 810 if (a.shape.rank is not None and 811 isinstance(axis1, int) and isinstance(axis2, int)): 812 # This branch makes sure `perm` is statically known, to avoid a 813 # not-compile-time-constant XLA error. 814 a_rank = a.shape.rank 815 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 816 perm = list(range(a_rank)) 817 perm[axis1] = axis2 818 perm[axis2] = axis1 819 else: 820 a_rank = array_ops.rank(a) 821 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 822 perm = math_ops.range(a_rank) 823 perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]], 824 [axis2, axis1]) 825 a = array_ops.transpose(a, perm) 826 return a 827 828 829@np_utils.np_doc('moveaxis') 830def moveaxis(a, source, destination): # pylint: disable=missing-docstring 831 """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" 832 if not source and not destination: 833 return a 834 835 a = asarray(a) 836 837 if isinstance(source, int): 838 source = (source,) 839 if isinstance(destination, int): 840 destination = (destination,) 841 if len(source) != len(destination): 842 raise ValueError('The lengths of source and destination must equal') 843 844 a_rank = np_utils._maybe_static(array_ops.rank(a)) # pylint: disable=protected-access 845 846 def _correct_axis(axis, rank): 847 if axis < 0: 848 return axis + rank 849 return axis 850 851 source = tuple(_correct_axis(axis, a_rank) for axis in source) 852 destination = tuple(_correct_axis(axis, a_rank) for axis in destination) 853 854 if a.shape.rank is not None: 855 perm = [i for i in range(a_rank) if i not in source] 856 for dest, src in sorted(zip(destination, source)): 857 assert dest <= len(perm) 858 perm.insert(dest, src) 859 else: 860 r = math_ops.range(a_rank) 861 862 def _remove_indices(a, b): 863 """Remove indices (`b`) from `a`.""" 864 items = array_ops.unstack(sort_ops.sort(array_ops.stack(b)), num=len(b)) 865 866 i = 0 867 result = [] 868 869 for item in items: 870 result.append(a[i:item]) 871 i = item + 1 872 873 result.append(a[i:]) 874 875 return array_ops.concat(result, 0) 876 877 minus_sources = _remove_indices(r, source) 878 minus_dest = _remove_indices(r, destination) 879 880 perm = array_ops.scatter_nd( 881 array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank]) 882 perm = array_ops.tensor_scatter_update( 883 perm, array_ops.expand_dims(destination, 1), source) 884 a = array_ops.transpose(a, perm) 885 886 return a 887 888 889@np_utils.np_doc('pad') 890def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-name 891 """Only supports modes 'constant', 'reflect' and 'symmetric' currently.""" 892 constant_values = kwargs.get('constant_values', 0) 893 if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'): 894 raise ValueError('Unsupported padding mode: ' + mode) 895 mode = mode.upper() 896 array = asarray(array) 897 pad_width = asarray(pad_width, dtype=dtypes.int32) 898 return array_ops.pad( 899 tensor=array, 900 paddings=pad_width, 901 mode=mode, 902 constant_values=constant_values) 903 904 905@np_utils.np_doc('take') 906def take(a, indices, axis=None, out=None, mode='clip'): 907 """out argument is not supported, and default mode is clip.""" 908 if out is not None: 909 raise ValueError('out argument is not supported in take.') 910 911 if mode not in {'raise', 'clip', 'wrap'}: 912 raise ValueError("Invalid mode '{}' for take".format(mode)) 913 914 a = asarray(a) 915 indices = asarray(indices) 916 917 if axis is None: 918 a = array_ops.reshape(a, [-1]) 919 axis = 0 920 921 axis_size = array_ops.shape(a, out_type=indices.dtype)[axis] 922 if mode == 'clip': 923 indices = clip_ops.clip_by_value(indices, 0, axis_size - 1) 924 elif mode == 'wrap': 925 indices = math_ops.floormod(indices, axis_size) 926 else: 927 raise ValueError("The 'raise' mode to take is not supported.") 928 929 return array_ops.gather(a, indices, axis=axis) 930 931 932@np_utils.np_doc_only('where') 933def where(condition, x=None, y=None): 934 """Raises ValueError if exactly one of x or y is not None.""" 935 condition = asarray(condition, dtype=np.bool_) 936 if x is None and y is None: 937 return nonzero(condition) 938 elif x is not None and y is not None: 939 x, y = _promote_dtype(x, y) 940 return array_ops.where_v2(condition, x, y) 941 raise ValueError('Both x and y must be ndarrays, or both must be None.') 942 943 944@np_utils.np_doc('select') 945def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring 946 if len(condlist) != len(choicelist): 947 msg = 'condlist must have length equal to choicelist ({} vs {})' 948 raise ValueError(msg.format(len(condlist), len(choicelist))) 949 if not condlist: 950 raise ValueError('condlist must be non-empty') 951 choices = _promote_dtype(default, *choicelist) 952 choicelist = choices[1:] 953 output = choices[0] 954 # The traversal is in reverse order so we can return the first value in 955 # choicelist where condlist is True. 956 for cond, choice in zip(condlist[::-1], choicelist[::-1]): 957 output = where(cond, choice, output) 958 return output 959 960 961@np_utils.np_doc('shape', link=np_utils.Link( 962 'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html')) 963def shape(a): 964 a = asarray(a) 965 return a.shape 966 967 968@np_utils.np_doc('ndim', link=np_utils.NoLink()) 969def ndim(a): 970 a = asarray(a) 971 return a.ndim 972 973 974@np_utils.np_doc('isscalar') 975def isscalar(num): 976 return ndim(num) == 0 977 978 979def _boundaries_to_sizes(a, boundaries, axis): 980 """Converting boundaries of splits to sizes of splits. 981 982 Args: 983 a: the array to be split. 984 boundaries: the boundaries, as in np.split. 985 axis: the axis along which to split. 986 987 Returns: 988 A list of sizes of the splits, as in tf.split. 989 """ 990 if axis >= len(a.shape): 991 raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape)) 992 total_size = a.shape[axis] 993 sizes = [] 994 sizes_sum = 0 995 prev = 0 996 for i, b in enumerate(boundaries): 997 size = b - prev 998 if size < 0: 999 raise ValueError('The %s-th boundary %s is smaller than the previous ' 1000 'boundary %s' % (i, b, prev)) 1001 size = min(size, max(0, total_size - sizes_sum)) 1002 sizes.append(size) 1003 sizes_sum += size 1004 prev = b 1005 sizes.append(max(0, total_size - sizes_sum)) 1006 return sizes 1007 1008 1009@np_utils.np_doc('split') 1010def split(ary, indices_or_sections, axis=0): 1011 ary = asarray(ary) 1012 if not isinstance(indices_or_sections, six.integer_types): 1013 indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) 1014 return array_ops.split(ary, indices_or_sections, axis=axis) 1015 1016 1017def _split_on_axis(np_fun_name, axis): 1018 1019 @np_utils.np_doc(np_fun_name) 1020 def f(ary, indices_or_sections): 1021 return split(ary, indices_or_sections, axis=axis) 1022 1023 return f 1024 1025 1026vsplit = _split_on_axis('vsplit', axis=0) 1027hsplit = _split_on_axis('hsplit', axis=1) 1028dsplit = _split_on_axis('dsplit', axis=2) 1029 1030 1031@np_utils.np_doc('broadcast_to') 1032def broadcast_to(array, shape): # pylint: disable=redefined-outer-name 1033 return full(shape, array) 1034 1035 1036@np_utils.np_doc('stack') 1037def stack(arrays, axis=0): # pylint: disable=missing-function-docstring 1038 if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)): 1039 arrays = asarray(arrays) 1040 if axis == 0: 1041 return arrays 1042 else: 1043 return swapaxes(arrays, 0, axis) 1044 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1045 unwrapped_arrays = [ 1046 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1047 ] 1048 return asarray(array_ops.stack(unwrapped_arrays, axis)) 1049 1050 1051@np_utils.np_doc('hstack') 1052def hstack(tup): 1053 arrays = [atleast_1d(a) for a in tup] 1054 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1055 unwrapped_arrays = [ 1056 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1057 ] 1058 rank = array_ops.rank(unwrapped_arrays[0]) 1059 return np_utils.cond( 1060 math_ops.equal(rank, 1061 1), lambda: array_ops.concat(unwrapped_arrays, axis=0), 1062 lambda: array_ops.concat(unwrapped_arrays, axis=1)) 1063 1064 1065@np_utils.np_doc('vstack') 1066def vstack(tup): 1067 arrays = [atleast_2d(a) for a in tup] 1068 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1069 unwrapped_arrays = [ 1070 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1071 ] 1072 return array_ops.concat(unwrapped_arrays, axis=0) 1073 1074 1075@np_utils.np_doc('dstack') 1076def dstack(tup): 1077 arrays = [atleast_3d(a) for a in tup] 1078 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1079 unwrapped_arrays = [ 1080 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1081 ] 1082 return array_ops.concat(unwrapped_arrays, axis=2) 1083 1084 1085def _pad_left_to(n, old_shape): 1086 old_shape = asarray(old_shape, dtype=np.int32) 1087 new_shape = array_ops.pad( 1088 old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]], 1089 constant_values=1) 1090 return asarray(new_shape) 1091 1092 1093def _atleast_nd(n, new_shape, *arys): 1094 """Reshape arrays to be at least `n`-dimensional. 1095 1096 Args: 1097 n: The minimal rank. 1098 new_shape: a function that takes `n` and the old shape and returns the 1099 desired new shape. 1100 *arys: ndarray(s) to be reshaped. 1101 1102 Returns: 1103 The reshaped array(s). 1104 """ 1105 1106 def f(x): 1107 # pylint: disable=g-long-lambda 1108 x = asarray(x) 1109 return asarray( 1110 np_utils.cond( 1111 np_utils.greater(n, array_ops.rank(x)), 1112 lambda: reshape(x, new_shape(n, array_ops.shape(x))), 1113 lambda: x)) 1114 1115 arys = list(map(f, arys)) 1116 if len(arys) == 1: 1117 return arys[0] 1118 else: 1119 return arys 1120 1121 1122@np_utils.np_doc('atleast_1d') 1123def atleast_1d(*arys): 1124 return _atleast_nd(1, _pad_left_to, *arys) 1125 1126 1127@np_utils.np_doc('atleast_2d') 1128def atleast_2d(*arys): 1129 return _atleast_nd(2, _pad_left_to, *arys) 1130 1131 1132@np_utils.np_doc('atleast_3d') 1133def atleast_3d(*arys): # pylint: disable=missing-docstring 1134 1135 def new_shape(_, old_shape): 1136 # pylint: disable=g-long-lambda 1137 ndim_ = array_ops.size(old_shape) 1138 return np_utils.cond( 1139 math_ops.equal(ndim_, 0), 1140 lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32), 1141 lambda: np_utils.cond( 1142 math_ops.equal(ndim_, 1), lambda: array_ops.pad( 1143 old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad( 1144 old_shape, [[0, 1]], constant_values=1))) 1145 1146 return _atleast_nd(3, new_shape, *arys) 1147 1148 1149@np_utils.np_doc('nonzero') 1150def nonzero(a): 1151 a = atleast_1d(a) 1152 if a.shape.rank is None: 1153 raise ValueError("The rank of `a` is unknown, so we can't decide how many " 1154 'arrays to return.') 1155 return array_ops.unstack( 1156 array_ops.where_v2(math_ops.cast(a, dtypes.bool)), 1157 a.shape.rank, 1158 axis=1) 1159 1160 1161@np_utils.np_doc('diag_indices') 1162def diag_indices(n, ndim=2): # pylint: disable=missing-docstring,redefined-outer-name 1163 if n < 0: 1164 raise ValueError( 1165 'n argument to diag_indices must be nonnegative, got {}'.format(n)) 1166 if ndim < 0: 1167 raise ValueError( 1168 'ndim argument to diag_indices must be nonnegative, got {}'.format( 1169 ndim)) 1170 1171 return (math_ops.range(n),) * ndim 1172 1173 1174@np_utils.np_doc('tri') 1175def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring 1176 M = M if M is not None else N 1177 if dtype is not None: 1178 dtype = np_utils.result_type(dtype) 1179 else: 1180 dtype = np_dtypes.default_float_type() 1181 1182 if k < 0: 1183 lower = -k - 1 1184 if lower > N: 1185 r = array_ops.zeros([N, M], dtype) 1186 else: 1187 # Keep as tf bool, since we create an upper triangular matrix and invert 1188 # it. 1189 o = array_ops.ones([N, M], dtype=dtypes.bool) 1190 r = math_ops.cast( 1191 math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype) 1192 else: 1193 o = array_ops.ones([N, M], dtype) 1194 if k > M: 1195 r = o 1196 else: 1197 r = array_ops.matrix_band_part(o, -1, k) 1198 return r 1199 1200 1201@np_utils.np_doc('tril') 1202def tril(m, k=0): # pylint: disable=missing-docstring 1203 m = asarray(m) 1204 if m.shape.ndims is None: 1205 raise ValueError('Argument to tril should have known rank') 1206 m_shape = m.shape.as_list() 1207 1208 if len(m_shape) < 2: 1209 raise ValueError('Argument to tril must have rank at least 2') 1210 1211 if m_shape[-1] is None or m_shape[-2] is None: 1212 raise ValueError('Currently, the last two dimensions of the input array ' 1213 'need to be known.') 1214 1215 z = constant_op.constant(0, m.dtype) 1216 1217 mask = tri(*m_shape[-2:], k=k, dtype=bool) 1218 return array_ops.where_v2( 1219 array_ops.broadcast_to(mask, array_ops.shape(m)), m, z) 1220 1221 1222@np_utils.np_doc('triu') 1223def triu(m, k=0): # pylint: disable=missing-docstring 1224 m = asarray(m) 1225 if m.shape.ndims is None: 1226 raise ValueError('Argument to triu should have known rank') 1227 m_shape = m.shape.as_list() 1228 1229 if len(m_shape) < 2: 1230 raise ValueError('Argument to triu must have rank at least 2') 1231 1232 if m_shape[-1] is None or m_shape[-2] is None: 1233 raise ValueError('Currently, the last two dimensions of the input array ' 1234 'need to be known.') 1235 1236 z = constant_op.constant(0, m.dtype) 1237 1238 mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) 1239 return array_ops.where_v2( 1240 array_ops.broadcast_to(mask, array_ops.shape(m)), z, m) 1241 1242 1243@np_utils.np_doc('flip') 1244def flip(m, axis=None): # pylint: disable=missing-docstring 1245 m = asarray(m) 1246 1247 if axis is None: 1248 return array_ops.reverse(m, math_ops.range(array_ops.rank(m))) 1249 1250 axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access 1251 1252 return array_ops.reverse(m, [axis]) 1253 1254 1255@np_utils.np_doc('flipud') 1256def flipud(m): # pylint: disable=missing-docstring 1257 return flip(m, 0) 1258 1259 1260@np_utils.np_doc('fliplr') 1261def fliplr(m): # pylint: disable=missing-docstring 1262 return flip(m, 1) 1263 1264 1265@np_utils.np_doc('roll') 1266def roll(a, shift, axis=None): # pylint: disable=missing-docstring 1267 a = asarray(a) 1268 1269 if axis is not None: 1270 return manip_ops.roll(a, shift, axis) 1271 1272 # If axis is None, the roll happens as a 1-d tensor. 1273 original_shape = array_ops.shape(a) 1274 a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0) 1275 return array_ops.reshape(a, original_shape) 1276 1277 1278@np_utils.np_doc('rot90') 1279def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring 1280 m_rank = array_ops.rank(m) 1281 ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access 1282 1283 k = k % 4 1284 if k == 0: 1285 return m 1286 elif k == 2: 1287 return flip(flip(m, ax1), ax2) 1288 else: 1289 perm = math_ops.range(m_rank) 1290 perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1]) 1291 1292 if k == 1: 1293 return transpose(flip(m, ax2), perm) 1294 else: 1295 return flip(transpose(m, perm), ax2) 1296 1297 1298@np_utils.np_doc('vander') 1299def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name 1300 x = asarray(x) 1301 1302 x_shape = array_ops.shape(x) 1303 N = N or x_shape[0] 1304 1305 N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name 1306 if N_temp is not None: 1307 N = N_temp 1308 if N < 0: 1309 raise ValueError('N must be nonnegative') 1310 else: 1311 control_flow_ops.Assert(N >= 0, [N]) 1312 1313 rank = array_ops.rank(x) 1314 rank_temp = np_utils.get_static_value(rank) 1315 if rank_temp is not None: 1316 rank = rank_temp 1317 if rank != 1: 1318 raise ValueError('x must be a one-dimensional array') 1319 else: 1320 control_flow_ops.Assert(math_ops.equal(rank, 1), [rank]) 1321 1322 if increasing: 1323 start = 0 1324 limit = N 1325 delta = 1 1326 else: 1327 start = N - 1 1328 limit = -1 1329 delta = -1 1330 1331 x = array_ops.expand_dims(x, -1) 1332 return math_ops.pow( 1333 x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)) 1334 1335 1336@np_utils.np_doc('ix_') 1337def ix_(*args): # pylint: disable=missing-docstring 1338 n = len(args) 1339 output = [] 1340 for i, a in enumerate(args): 1341 a = asarray(a) 1342 a_rank = array_ops.rank(a) 1343 a_rank_temp = np_utils.get_static_value(a_rank) 1344 if a_rank_temp is not None: 1345 a_rank = a_rank_temp 1346 if a_rank != 1: 1347 raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format( 1348 i, a_rank)) 1349 else: 1350 control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank]) 1351 1352 new_shape = [1] * n 1353 new_shape[i] = -1 1354 dtype = a.dtype 1355 if dtype == dtypes.bool: 1356 output.append(array_ops.reshape(nonzero(a)[0], new_shape)) 1357 elif dtype.is_integer: 1358 output.append(array_ops.reshape(a, new_shape)) 1359 else: 1360 raise ValueError( 1361 'Only integer and bool dtypes are supported, got {}'.format(dtype)) 1362 1363 return output 1364 1365 1366@np_utils.np_doc('broadcast_arrays') 1367def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring 1368 subok = kwargs.pop('subok', False) 1369 if subok: 1370 raise ValueError('subok=True is not supported.') 1371 if kwargs: 1372 raise ValueError('Received unsupported arguments {}'.format(kwargs.keys())) 1373 1374 args = [asarray(arg) for arg in args] 1375 return np_utils.tf_broadcast(*args) 1376 1377 1378@np_utils.np_doc_only('sign') 1379def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstring,redefined-outer-name 1380 if out: 1381 raise ValueError('tf.numpy doesnt support setting out.') 1382 if where: 1383 raise ValueError('tf.numpy doesnt support setting where.') 1384 if kwargs: 1385 raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys())) 1386 1387 x = asarray(x) 1388 dtype = x.dtype.as_numpy_dtype 1389 if np.issubdtype(dtype, np.complex): 1390 result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype) 1391 else: 1392 result = math_ops.sign(x) 1393 1394 return result 1395 1396 1397# Note that np.take_along_axis may not be present in some supported versions of 1398# numpy. 1399@np_utils.np_doc('take_along_axis') 1400def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring 1401 arr = asarray(arr) 1402 indices = asarray(indices) 1403 1404 if axis is None: 1405 return take_along_axis(arr.ravel(), indices, 0) 1406 1407 rank = array_ops.rank(arr) 1408 axis = axis + rank if axis < 0 else axis 1409 1410 # Broadcast shapes to match, ensure that the axis of interest is not 1411 # broadcast. 1412 arr_shape_original = array_ops.shape(arr) 1413 indices_shape_original = array_ops.shape(indices) 1414 arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1]) 1415 indices_shape = array_ops.tensor_scatter_update(indices_shape_original, 1416 [[axis]], [1]) 1417 broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape, 1418 indices_shape) 1419 arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]], 1420 [arr_shape_original[axis]]) 1421 indices_shape = array_ops.tensor_scatter_update( 1422 broadcasted_shape, [[axis]], [indices_shape_original[axis]]) 1423 arr = array_ops.broadcast_to(arr, arr_shape) 1424 indices = array_ops.broadcast_to(indices, indices_shape) 1425 1426 # Save indices shape so we can restore it later. 1427 possible_result_shape = indices.shape 1428 1429 # Correct indices since gather doesn't correctly handle negative indices. 1430 indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices) 1431 1432 swapaxes_ = lambda t: swapaxes(t, axis, -1) 1433 1434 dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1)) 1435 arr = np_utils.cond(dont_move_axis_to_end, lambda: arr, 1436 lambda: swapaxes_(arr)) 1437 indices = np_utils.cond(dont_move_axis_to_end, lambda: indices, 1438 lambda: swapaxes_(indices)) 1439 1440 arr_shape = array_ops.shape(arr) 1441 arr = array_ops.reshape(arr, [-1, arr_shape[-1]]) 1442 1443 indices_shape = array_ops.shape(indices) 1444 indices = array_ops.reshape(indices, [-1, indices_shape[-1]]) 1445 1446 result = array_ops.gather(arr, indices, batch_dims=1) 1447 result = array_ops.reshape(result, indices_shape) 1448 result = np_utils.cond(dont_move_axis_to_end, lambda: result, 1449 lambda: swapaxes_(result)) 1450 result.set_shape(possible_result_shape) 1451 1452 return result 1453 1454 1455_SLICE_ERORR = ( 1456 'only integers, slices (`:`), ellipsis (`...`), ' 1457 'numpy.newaxis (`None`) and integer or boolean arrays are valid indices') 1458 1459 1460def _as_index(idx, need_scalar=True): 1461 """Helper function to parse idx as an index. 1462 1463 Args: 1464 idx: index 1465 need_scalar: If idx needs to be a scalar value. 1466 1467 Returns: 1468 A pair, (indx, bool). First one is the parsed index and can be a tensor, 1469 or scalar integer / Dimension. Second one is True if rank is known to be 0. 1470 1471 Raises: 1472 IndexError: For incorrect indices. 1473 """ 1474 if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): 1475 return idx, True 1476 data = asarray(idx) 1477 if data.dtype == dtypes.bool: 1478 if data.shape.ndims != 1: 1479 # TODO(agarwal): handle higher rank boolean masks. 1480 raise NotImplementedError('Need rank 1 for bool index %s' % idx) 1481 data = array_ops.where_v2(data) 1482 data = array_ops.reshape(data, [-1]) 1483 if need_scalar and data.shape.rank not in (None, 0): 1484 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1485 np_dtype = data.dtype.as_numpy_dtype 1486 if not np.issubdtype(np_dtype, np.integer): 1487 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1488 if data.dtype not in (dtypes.int64, dtypes.int32): 1489 # TF slicing can only handle int32/int64. So we need to cast. 1490 promoted_dtype = np.promote_types(np.int32, np_dtype) 1491 if promoted_dtype == np.int32: 1492 data = math_ops.cast(data, dtypes.int32) 1493 elif promoted_dtype == np.int64: 1494 data = math_ops.cast(data, dtypes.int64) 1495 else: 1496 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1497 return data, data.shape.rank == 0 1498 1499 1500class _UpdateMethod(enum.Enum): 1501 UPDATE = 0 1502 ADD = 1 1503 MIN = 2 1504 MAX = 3 1505 1506 1507def _slice_helper(tensor, slice_spec, update_method=None, updates=None): 1508 """Helper function for __getitem__ and _with_index_update_helper. 1509 1510 This function collects the indices in `slice_spec` into two buckets, which we 1511 can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2 1512 `gather`. They also correspond to "basic indices" and "advanced indices" in 1513 numpy. This function supports both reading and writing at the indices. The 1514 reading path can be summarized as `gather(stride_slice(tensor, idx1), 1515 idx2)`. The writing path can be summarized as `strided_slice_update(tensor, 1516 idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`. (`gather` here 1517 means `tf.gather` or `tf.gather_nd`; `scatter` here means 1518 `tf.tensor_scatter_update`.) The writing path is inefficient because it needs 1519 to first read out a portion (probably much larger than `updates`) of `tensor` 1520 using `strided_slice`, update it, and then write the portion back. An 1521 alternative approach is to only use `scatter`, which amounts to using the 1522 indexing mechanism of gather/scatter to implement 1523 strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter 1524 because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but 1525 not TF gather/scatter because they don't support spans (except those that 1526 cover entire dimensions, i.e. `:`). If we materialize spans into individual 1527 indices, the size of the index tensor would explode. (Note that XLA 1528 Gather/Scatter have a similar problem for stride > 1 because they don't 1529 support strides. Indices such as `1:2:8` will need to be materialized into 1530 individual indices such as [1, 3, 5, 7].) 1531 1532 Args: 1533 tensor: the tensor to be read from or write into. 1534 slice_spec: the indices. 1535 update_method: (optional) a member of `_UpdateMethod`, indicating how to 1536 update the values (replacement, add, etc.). `None` indicates just reading. 1537 updates: (optional) the new values to write into `tensor`. It must have the 1538 same dtype as `tensor`. 1539 1540 Returns: 1541 The result of reading (if `update_method` is `None`) or the updated `tensor` 1542 after writing. 1543 """ 1544 begin, end, strides = [], [], [] 1545 new_axis_mask, shrink_axis_mask = 0, 0 1546 begin_mask, end_mask = 0, 0 1547 ellipsis_mask = 0 1548 advanced_indices = [] 1549 shrink_indices = [] 1550 for index, s in enumerate(slice_spec): 1551 if isinstance(s, slice): 1552 if s.start is not None: 1553 begin.append(_as_index(s.start)[0]) 1554 else: 1555 begin.append(0) 1556 begin_mask |= (1 << index) 1557 if s.stop is not None: 1558 end.append(_as_index(s.stop)[0]) 1559 else: 1560 end.append(0) 1561 end_mask |= (1 << index) 1562 if s.step is not None: 1563 strides.append(_as_index(s.step)[0]) 1564 else: 1565 strides.append(1) 1566 elif s is Ellipsis: 1567 begin.append(0) 1568 end.append(0) 1569 strides.append(1) 1570 ellipsis_mask |= (1 << index) 1571 elif s is array_ops.newaxis: 1572 begin.append(0) 1573 end.append(0) 1574 strides.append(1) 1575 new_axis_mask |= (1 << index) 1576 else: 1577 s, is_scalar = _as_index(s, False) 1578 if is_scalar: 1579 begin.append(s) 1580 end.append(s + 1) 1581 strides.append(1) 1582 shrink_axis_mask |= (1 << index) 1583 shrink_indices.append(index) 1584 else: 1585 begin.append(0) 1586 end.append(0) 1587 strides.append(1) 1588 begin_mask |= (1 << index) 1589 end_mask |= (1 << index) 1590 advanced_indices.append((index, s, ellipsis_mask != 0)) 1591 1592 # stack possibly involves no tensors, so we must use op_scope correct graph. 1593 with ops.name_scope( 1594 None, 1595 'strided_slice', [tensor] + begin + end + strides, 1596 skip_on_eager=False) as name: 1597 if begin: 1598 packed_begin, packed_end, packed_strides = (array_ops.stack(begin), 1599 array_ops.stack(end), 1600 array_ops.stack(strides)) 1601 if (packed_begin.dtype == dtypes.int64 or 1602 packed_end.dtype == dtypes.int64 or 1603 packed_strides.dtype == dtypes.int64): 1604 if packed_begin.dtype != dtypes.int64: 1605 packed_begin = math_ops.cast(packed_begin, dtypes.int64) 1606 if packed_end.dtype != dtypes.int64: 1607 packed_end = math_ops.cast(packed_end, dtypes.int64) 1608 if packed_strides.dtype != dtypes.int64: 1609 packed_strides = math_ops.cast(packed_strides, dtypes.int64) 1610 else: 1611 var_empty = constant_op.constant([], dtype=dtypes.int32) 1612 packed_begin = packed_end = packed_strides = var_empty 1613 if update_method == _UpdateMethod.UPDATE and not advanced_indices: 1614 return array_ops.tensor_strided_slice_update( 1615 tensor, 1616 packed_begin, 1617 packed_end, 1618 packed_strides, 1619 updates, 1620 begin_mask=begin_mask, 1621 end_mask=end_mask, 1622 shrink_axis_mask=shrink_axis_mask, 1623 new_axis_mask=new_axis_mask, 1624 ellipsis_mask=ellipsis_mask, 1625 name=name) 1626 else: 1627 # TODO(b/164251540): Find a better way to support update that does not 1628 # involve one read + two writes. 1629 if updates is not None: 1630 original_tensor = tensor 1631 # TODO(agarwal): set_shape on tensor to set rank. 1632 tensor = array_ops.strided_slice( 1633 tensor, 1634 packed_begin, 1635 packed_end, 1636 packed_strides, 1637 begin_mask=begin_mask, 1638 end_mask=end_mask, 1639 shrink_axis_mask=shrink_axis_mask, 1640 new_axis_mask=new_axis_mask, 1641 ellipsis_mask=ellipsis_mask, 1642 name=name) 1643 if not advanced_indices: 1644 if update_method is None: 1645 return tensor 1646 assert update_method != _UpdateMethod.UPDATE 1647 # TF lacks TensorStridedSliceAdd and alike, so we need to do 1648 # read+add+update. 1649 if update_method == _UpdateMethod.ADD: 1650 update_op = math_ops.add 1651 elif update_method == _UpdateMethod.MIN: 1652 update_op = math_ops.minimum 1653 elif update_method == _UpdateMethod.MAX: 1654 update_op = math_ops.maximum 1655 return array_ops.tensor_strided_slice_update( 1656 original_tensor, 1657 packed_begin, 1658 packed_end, 1659 packed_strides, 1660 update_op(tensor, updates), 1661 begin_mask=begin_mask, 1662 end_mask=end_mask, 1663 shrink_axis_mask=shrink_axis_mask, 1664 new_axis_mask=new_axis_mask, 1665 ellipsis_mask=ellipsis_mask, 1666 name=name + '_2') 1667 advanced_indices_map = {} 1668 for index, data, had_ellipsis in advanced_indices: 1669 if had_ellipsis: 1670 num_shrink = len([x for x in shrink_indices if x > index]) 1671 dim = index - len(slice_spec) + num_shrink 1672 else: 1673 num_shrink = len([x for x in shrink_indices if x < index]) 1674 dim = index - num_shrink 1675 advanced_indices_map[dim] = data 1676 dims = sorted(advanced_indices_map.keys()) 1677 dims_contiguous = True 1678 if len(dims) > 1: 1679 if dims[0] < 0 and dims[-1] >= 0: # not all same sign 1680 dims_contiguous = False 1681 else: 1682 for i in range(len(dims) - 1): 1683 if dims[i] + 1 != dims[i + 1]: 1684 dims_contiguous = False 1685 break 1686 indices = [advanced_indices_map[x] for x in dims] 1687 indices = _promote_dtype(*indices) 1688 indices = np_utils.tf_broadcast(*indices) 1689 stacked_indices = array_ops.stack(indices, axis=-1) 1690 # Skip the contiguous-dims optimization for update because there is no 1691 # tf.*scatter* op that supports the `axis` argument. 1692 if not dims_contiguous or updates is not None: 1693 if range(len(dims)) != dims: 1694 tensor = moveaxis(tensor, dims, range(len(dims))) 1695 tensor_shape_prefix = array_ops.shape( 1696 tensor, out_type=stacked_indices.dtype)[:len(dims)] 1697 stacked_indices = array_ops.where_v2( 1698 stacked_indices < 0, stacked_indices + tensor_shape_prefix, 1699 stacked_indices) 1700 if updates is None: 1701 return array_ops.gather_nd(tensor, stacked_indices) 1702 else: 1703 # We only need to move-axis `updates` in the contiguous case becausce 1704 # only in this case the result dimensions of advanced indexing are in 1705 # the middle of `updates`. In the non-contiguous case, those dimensions 1706 # are always at the front. 1707 if dims_contiguous: 1708 # TODO(wangpeng): Support unknown rank (e.g. by partially flattening 1709 # `updates`) 1710 if stacked_indices.shape.rank is None: 1711 raise NotImplementedError( 1712 'Rank of the advanced indices must currently be known') 1713 batch_size = stacked_indices.shape.rank - 1 1714 batch_start = dims[0] 1715 if batch_start < 0: 1716 batch_start += len(dims) - batch_size 1717 def range_(start, length): 1718 return range(start, start + length) 1719 updates = moveaxis(updates, range_(batch_start, batch_size), 1720 range(batch_size)) 1721 if update_method == _UpdateMethod.UPDATE: 1722 update_op = array_ops.tensor_scatter_update 1723 elif update_method == _UpdateMethod.ADD: 1724 update_op = array_ops.tensor_scatter_add 1725 elif update_method == _UpdateMethod.MIN: 1726 update_op = array_ops.tensor_scatter_min 1727 elif update_method == _UpdateMethod.MAX: 1728 update_op = array_ops.tensor_scatter_max 1729 tensor = update_op( 1730 tensor, stacked_indices, updates) 1731 if range(len(dims)) != dims: 1732 tensor = moveaxis(tensor, range(len(dims)), dims) 1733 return array_ops.tensor_strided_slice_update( 1734 original_tensor, 1735 packed_begin, 1736 packed_end, 1737 packed_strides, 1738 tensor, 1739 begin_mask=begin_mask, 1740 end_mask=end_mask, 1741 shrink_axis_mask=shrink_axis_mask, 1742 new_axis_mask=new_axis_mask, 1743 ellipsis_mask=ellipsis_mask, 1744 name=name + '_2') 1745 # Note that gather_nd does not support gathering from inside the array. 1746 # To avoid shuffling data back and forth, we transform the indices and 1747 # do a gather instead. 1748 rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access 1749 dims = [(x + rank if x < 0 else x) for x in dims] 1750 shape_tensor = array_ops.shape(tensor) 1751 dim_sizes = array_ops.gather(shape_tensor, dims) 1752 if len(dims) == 1: 1753 stacked_indices = indices[0] 1754 stacked_indices = math_ops.cast(stacked_indices, dtypes.int32) 1755 stacked_indices = array_ops.where_v2(stacked_indices < 0, 1756 stacked_indices + dim_sizes, 1757 stacked_indices) 1758 axis = dims[0] 1759 if len(dims) > 1: 1760 index_scaling = math_ops.cumprod( 1761 dim_sizes, reverse=True, exclusive=True) 1762 def _tensordot(a, b): 1763 # TODO(b/168657656): This function should be replaced by 1764 # tensordot(axis=1) once MatMul has int32 XLA kernel. 1765 b = array_ops.broadcast_to(b, array_ops.shape(a)) 1766 return math_ops.reduce_sum(a * b, axis=-1) 1767 stacked_indices = _tensordot(stacked_indices, index_scaling) 1768 flat_shape = array_ops.concat( 1769 [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]], 1770 axis=0) 1771 tensor = array_ops.reshape(tensor, flat_shape) 1772 1773 return array_ops.gather(tensor, stacked_indices, axis=axis) 1774 1775 1776def _as_spec_tuple(slice_spec): 1777 """Convert slice_spec to tuple.""" 1778 if isinstance(slice_spec, 1779 (list, tuple)) and not isinstance(slice_spec, np.ndarray): 1780 is_index = True 1781 for s in slice_spec: 1782 if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)): 1783 is_index = False 1784 break 1785 elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0: 1786 is_index = False 1787 break 1788 if not is_index: 1789 return tuple(slice_spec) 1790 return (slice_spec,) 1791 1792 1793def _getitem(self, slice_spec): 1794 """Implementation of ndarray.__getitem__.""" 1795 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 1796 slice_spec.dtype == dtypes.bool) or 1797 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 1798 slice_spec.dtype == np.bool)): 1799 return array_ops.boolean_mask(tensor=self, mask=slice_spec) 1800 1801 if not isinstance(slice_spec, tuple): 1802 slice_spec = _as_spec_tuple(slice_spec) 1803 1804 result_t = _slice_helper(self, slice_spec) 1805 return result_t 1806 1807 1808def _with_index_update_helper(update_method, a, slice_spec, updates): 1809 """Implementation of ndarray._with_index_*.""" 1810 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 1811 slice_spec.dtype == dtypes.bool) or 1812 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 1813 slice_spec.dtype == np.bool)): 1814 slice_spec = nonzero(slice_spec) 1815 1816 if not isinstance(slice_spec, tuple): 1817 slice_spec = _as_spec_tuple(slice_spec) 1818 1819 a_dtype = a.dtype 1820 a, updates = _promote_dtype_binary(a, updates) 1821 result_t = _slice_helper(a, slice_spec, update_method, updates) 1822 return result_t.astype(a_dtype) 1823 1824 1825setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem) 1826setattr(np_arrays.ndarray, '_with_index_update', 1827 functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE)) 1828setattr(np_arrays.ndarray, '_with_index_add', 1829 functools.partial(_with_index_update_helper, _UpdateMethod.ADD)) 1830setattr(np_arrays.ndarray, '_with_index_min', 1831 functools.partial(_with_index_update_helper, _UpdateMethod.MIN)) 1832setattr(np_arrays.ndarray, '_with_index_max', 1833 functools.partial(_with_index_update_helper, _UpdateMethod.MAX)) 1834