1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Mathematical operations.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numbers 23import sys 24 25import numpy as np 26import six 27 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import bitwise_ops 34from tensorflow.python.ops import clip_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_math_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import nn_ops 39from tensorflow.python.ops import sort_ops 40from tensorflow.python.ops import special_math_ops 41from tensorflow.python.ops.numpy_ops import np_array_ops 42from tensorflow.python.ops.numpy_ops import np_arrays 43from tensorflow.python.ops.numpy_ops import np_dtypes 44from tensorflow.python.ops.numpy_ops import np_export 45from tensorflow.python.ops.numpy_ops import np_utils 46 47 48pi = np_export.np_export_constant(__name__, 'pi', np.pi) 49e = np_export.np_export_constant(__name__, 'e', np.e) 50inf = np_export.np_export_constant(__name__, 'inf', np.inf) 51 52 53@np_utils.np_doc_only('dot') 54def dot(a, b): # pylint: disable=missing-docstring 55 56 def f(a, b): # pylint: disable=missing-docstring 57 return np_utils.cond( 58 np_utils.logical_or( 59 math_ops.equal(array_ops.rank(a), 0), 60 math_ops.equal(array_ops.rank(b), 0)), 61 lambda: a * b, 62 lambda: np_utils.cond( # pylint: disable=g-long-lambda 63 math_ops.equal(array_ops.rank(b), 1), 64 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]), 65 lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]]))) 66 67 return _bin_op(f, a, b) 68 69 70# TODO(wangpeng): Make element-wise ops `ufunc`s 71def _bin_op(tf_fun, a, b, promote=True): 72 if promote: 73 a, b = np_array_ops._promote_dtype_binary(a, b) # pylint: disable=protected-access 74 else: 75 a = np_array_ops.array(a) 76 b = np_array_ops.array(b) 77 return tf_fun(a, b) 78 79 80@np_utils.np_doc('add') 81def add(x1, x2): 82 83 def add_or_or(x1, x2): 84 if x1.dtype == dtypes.bool: 85 assert x2.dtype == dtypes.bool 86 return math_ops.logical_or(x1, x2) 87 return math_ops.add(x1, x2) 88 89 return _bin_op(add_or_or, x1, x2) 90 91 92@np_utils.np_doc('subtract') 93def subtract(x1, x2): 94 return _bin_op(math_ops.subtract, x1, x2) 95 96 97@np_utils.np_doc('multiply') 98def multiply(x1, x2): 99 100 def mul_or_and(x1, x2): 101 if x1.dtype == dtypes.bool: 102 assert x2.dtype == dtypes.bool 103 return math_ops.logical_and(x1, x2) 104 return math_ops.multiply(x1, x2) 105 106 return _bin_op(mul_or_and, x1, x2) 107 108 109@np_utils.np_doc('true_divide') 110def true_divide(x1, x2): # pylint: disable=missing-function-docstring 111 112 def _avoid_float64(x1, x2): 113 if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64): 114 x1 = math_ops.cast(x1, dtype=dtypes.float32) 115 x2 = math_ops.cast(x2, dtype=dtypes.float32) 116 return x1, x2 117 118 def f(x1, x2): 119 if x1.dtype == dtypes.bool: 120 assert x2.dtype == dtypes.bool 121 float_ = np_dtypes.default_float_type() 122 x1 = math_ops.cast(x1, float_) 123 x2 = math_ops.cast(x2, float_) 124 if not np_dtypes.is_allow_float64(): 125 # math_ops.truediv in Python3 produces float64 when both inputs are int32 126 # or int64. We want to avoid that when is_allow_float64() is False. 127 x1, x2 = _avoid_float64(x1, x2) 128 return math_ops.truediv(x1, x2) 129 130 return _bin_op(f, x1, x2) 131 132 133@np_utils.np_doc('divide') 134def divide(x1, x2): # pylint: disable=missing-function-docstring 135 return true_divide(x1, x2) 136 137 138@np_utils.np_doc('floor_divide') 139def floor_divide(x1, x2): # pylint: disable=missing-function-docstring 140 141 def f(x1, x2): 142 if x1.dtype == dtypes.bool: 143 assert x2.dtype == dtypes.bool 144 x1 = math_ops.cast(x1, dtypes.int8) 145 x2 = math_ops.cast(x2, dtypes.int8) 146 return math_ops.floordiv(x1, x2) 147 148 return _bin_op(f, x1, x2) 149 150 151@np_utils.np_doc('mod') 152def mod(x1, x2): # pylint: disable=missing-function-docstring 153 154 def f(x1, x2): 155 if x1.dtype == dtypes.bool: 156 assert x2.dtype == dtypes.bool 157 x1 = math_ops.cast(x1, dtypes.int8) 158 x2 = math_ops.cast(x2, dtypes.int8) 159 return math_ops.mod(x1, x2) 160 161 return _bin_op(f, x1, x2) 162 163 164@np_utils.np_doc('remainder') 165def remainder(x1, x2): # pylint: disable=missing-function-docstring 166 return mod(x1, x2) 167 168 169@np_utils.np_doc('divmod') 170def divmod(x1, x2): # pylint: disable=redefined-builtin 171 return floor_divide(x1, x2), mod(x1, x2) 172 173 174@np_utils.np_doc('maximum') 175def maximum(x1, x2): # pylint: disable=missing-function-docstring 176 177 # Fast path for when maximum is used as relu. 178 if isinstance( 179 x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance( 180 x1, np_arrays.ndarray) and x1.dtype != dtypes.bool: 181 return nn_ops.relu(np_array_ops.asarray(x1)) 182 183 def max_or_or(x1, x2): 184 if x1.dtype == dtypes.bool: 185 assert x2.dtype == dtypes.bool 186 return math_ops.logical_or(x1, x2) 187 return math_ops.maximum(x1, x2) 188 189 return _bin_op(max_or_or, x1, x2) 190 191 192@np_utils.np_doc('minimum') 193def minimum(x1, x2): 194 195 def min_or_and(x1, x2): 196 if x1.dtype == dtypes.bool: 197 assert x2.dtype == dtypes.bool 198 return math_ops.logical_and(x1, x2) 199 return math_ops.minimum(x1, x2) 200 201 return _bin_op(min_or_and, x1, x2) 202 203 204@np_utils.np_doc('clip') 205def clip(a, a_min, a_max): # pylint: disable=missing-docstring 206 if a_min is None and a_max is None: 207 raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.') 208 if a_min is None: 209 return minimum(a, a_max) 210 elif a_max is None: 211 return maximum(a, a_min) 212 else: 213 a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access 214 return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max)) 215 216 217@np_utils.np_doc('matmul') 218def matmul(x1, x2): # pylint: disable=missing-docstring 219 def f(x1, x2): 220 try: 221 if x1._rank() == 2 and x2._rank() == 2: # pylint: disable=protected-access 222 # Fast path for known ranks. 223 return gen_math_ops.mat_mul(x1, x2) 224 return np_utils.cond( 225 math_ops.equal(np_utils.tf_rank(x2), 1), 226 lambda: math_ops.tensordot(x1, x2, axes=1), 227 lambda: np_utils.cond( # pylint: disable=g-long-lambda 228 math_ops.equal(np_utils.tf_rank(x1), 1), 229 lambda: math_ops.tensordot( # pylint: disable=g-long-lambda 230 x1, x2, axes=[[0], [-2]]), 231 lambda: math_ops.matmul(x1, x2))) 232 except errors.InvalidArgumentError as err: 233 six.reraise(ValueError, ValueError(str(err)), sys.exc_info()[2]) 234 235 return _bin_op(f, x1, x2) 236 237 238# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles 239# batched matmul as well, so simply including promotion in TF's current 240# __matmul__ implementation was not sufficient. 241setattr(np_arrays.ndarray, '_matmul', matmul) 242 243 244@np_utils.np_doc('tensordot') 245def tensordot(a, b, axes=2): 246 return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b) 247 248 249@np_utils.np_doc_only('inner') 250def inner(a, b): # pylint: disable=missing-function-docstring 251 252 def f(a, b): 253 return np_utils.cond( 254 np_utils.logical_or( 255 math_ops.equal(array_ops.rank(a), 0), 256 math_ops.equal(array_ops.rank(b), 0)), lambda: a * b, 257 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]])) 258 259 return _bin_op(f, a, b) 260 261 262@np_utils.np_doc('cross') 263def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=missing-docstring 264 265 def f(a, b): # pylint: disable=missing-docstring 266 # We can't assign to captured variable `axisa`, so make a new variable 267 if axis is None: 268 axis_a = axisa 269 axis_b = axisb 270 axis_c = axisc 271 else: 272 axis_a = axis 273 axis_b = axis 274 axis_c = axis 275 if axis_a < 0: 276 axis_a = np_utils.add(axis_a, array_ops.rank(a)) 277 if axis_b < 0: 278 axis_b = np_utils.add(axis_b, array_ops.rank(b)) 279 280 def maybe_move_axis_to_last(a, axis): 281 282 def move_axis_to_last(a, axis): 283 return array_ops.transpose( 284 a, 285 array_ops.concat([ 286 math_ops.range(axis), 287 math_ops.range(axis + 1, array_ops.rank(a)), [axis] 288 ], 289 axis=0)) 290 291 return np_utils.cond(axis == np_utils.subtract(array_ops.rank(a), 1), 292 lambda: a, lambda: move_axis_to_last(a, axis)) 293 294 a = maybe_move_axis_to_last(a, axis_a) 295 b = maybe_move_axis_to_last(b, axis_b) 296 a_dim = np_utils.getitem(array_ops.shape(a), -1) 297 b_dim = np_utils.getitem(array_ops.shape(b), -1) 298 299 def maybe_pad_0(a, size_of_last_dim): 300 301 def pad_0(a): 302 return array_ops.pad( 303 a, 304 array_ops.concat([ 305 array_ops.zeros([array_ops.rank(a) - 1, 2], dtypes.int32), 306 constant_op.constant([[0, 1]], dtypes.int32) 307 ], 308 axis=0)) 309 310 return np_utils.cond( 311 math_ops.equal(size_of_last_dim, 2), lambda: pad_0(a), lambda: a) 312 313 a = maybe_pad_0(a, a_dim) 314 b = maybe_pad_0(b, b_dim) 315 c = math_ops.cross(*np_utils.tf_broadcast(a, b)) 316 if axis_c < 0: 317 axis_c = np_utils.add(axis_c, array_ops.rank(c)) 318 319 def move_last_to_axis(a, axis): 320 r = array_ops.rank(a) 321 return array_ops.transpose( 322 a, 323 array_ops.concat( 324 [math_ops.range(axis), [r - 1], 325 math_ops.range(axis, r - 1)], 326 axis=0)) 327 328 c = np_utils.cond( 329 (a_dim == 2) & (b_dim == 2), 330 lambda: c[..., 2], 331 lambda: np_utils.cond( # pylint: disable=g-long-lambda 332 axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c, 333 lambda: move_last_to_axis(c, axis_c))) 334 return c 335 336 return _bin_op(f, a, b) 337 338 339@np_utils.np_doc_only('vdot') 340def vdot(a, b): # pylint: disable=missing-docstring 341 a, b = np_array_ops._promote_dtype(a, b) # pylint: disable=protected-access 342 a = np_array_ops.reshape(a, [-1]) 343 b = np_array_ops.reshape(b, [-1]) 344 if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64: 345 a = conj(a) 346 return dot(a, b) 347 348 349@np_utils.np_doc('power') 350def power(x1, x2): 351 return _bin_op(math_ops.pow, x1, x2) 352 353 354@np_utils.np_doc('float_power') 355def float_power(x1, x2): 356 return power(x1, x2) 357 358 359@np_utils.np_doc('arctan2') 360def arctan2(x1, x2): 361 return _bin_op(math_ops.atan2, x1, x2) 362 363 364@np_utils.np_doc('nextafter') 365def nextafter(x1, x2): 366 return _bin_op(math_ops.nextafter, x1, x2) 367 368 369@np_utils.np_doc('heaviside') 370def heaviside(x1, x2): # pylint: disable=missing-function-docstring 371 372 def f(x1, x2): 373 return array_ops.where_v2( 374 x1 < 0, constant_op.constant(0, dtype=x2.dtype), 375 array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2)) 376 377 y = _bin_op(f, x1, x2) 378 if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact): 379 y = y.astype(np_dtypes.default_float_type()) 380 return y 381 382 383@np_utils.np_doc('hypot') 384def hypot(x1, x2): 385 return sqrt(square(x1) + square(x2)) 386 387 388@np_utils.np_doc('kron') 389def kron(a, b): # pylint: disable=missing-function-docstring 390 # pylint: disable=protected-access,g-complex-comprehension 391 a, b = np_array_ops._promote_dtype(a, b) 392 t_a = np_utils.cond( 393 a.ndim < b.ndim, 394 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda 395 a, np_array_ops._pad_left_to(b.ndim, a.shape)), 396 lambda: a) 397 t_b = np_utils.cond( 398 b.ndim < a.ndim, 399 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda 400 b, np_array_ops._pad_left_to(a.ndim, b.shape)), 401 lambda: b) 402 403 def _make_shape(shape, prepend): 404 ones = array_ops.ones_like(shape) 405 if prepend: 406 shapes = [ones, shape] 407 else: 408 shapes = [shape, ones] 409 return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1]) 410 411 a_shape = array_ops.shape(t_a) 412 b_shape = array_ops.shape(t_b) 413 a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False)) 414 b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True)) 415 out_shape = a_shape * b_shape 416 return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape) 417 418 419@np_utils.np_doc('outer') 420def outer(a, b): 421 422 def f(a, b): 423 return array_ops.reshape(a, [-1, 1]) * array_ops.reshape(b, [-1]) 424 425 return _bin_op(f, a, b) 426 427 428# This can also be implemented via tf.reduce_logsumexp 429@np_utils.np_doc('logaddexp') 430def logaddexp(x1, x2): 431 amax = maximum(x1, x2) 432 delta = x1 - x2 433 return np_array_ops.where( 434 isnan(delta), 435 x1 + x2, # NaNs or infinities of the same sign. 436 amax + log1p(exp(-abs(delta)))) 437 438 439@np_utils.np_doc('logaddexp2') 440def logaddexp2(x1, x2): 441 amax = maximum(x1, x2) 442 delta = x1 - x2 443 return np_array_ops.where( 444 isnan(delta), 445 x1 + x2, # NaNs or infinities of the same sign. 446 amax + log1p(exp2(-abs(delta))) / np.log(2)) 447 448 449@np_utils.np_doc('polyval') 450def polyval(p, x): # pylint: disable=missing-function-docstring 451 452 def f(p, x): 453 if p.shape.rank == 0: 454 p = array_ops.reshape(p, [1]) 455 p = array_ops.unstack(p) 456 # TODO(wangpeng): Make tf version take a tensor for p instead of a list. 457 y = math_ops.polyval(p, x) 458 # If the polynomial is 0-order, numpy requires the result to be broadcast to 459 # `x`'s shape. 460 if len(p) == 1: 461 y = array_ops.broadcast_to(y, x.shape) 462 return y 463 464 return _bin_op(f, p, x) 465 466 467@np_utils.np_doc('isclose') 468def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): # pylint: disable=missing-docstring 469 470 def f(a, b): # pylint: disable=missing-docstring 471 dtype = a.dtype 472 if np.issubdtype(dtype.as_numpy_dtype, np.inexact): 473 rtol_ = ops.convert_to_tensor(rtol, dtype.real_dtype) 474 atol_ = ops.convert_to_tensor(atol, dtype.real_dtype) 475 result = (math_ops.abs(a - b) <= atol_ + rtol_ * math_ops.abs(b)) 476 if equal_nan: 477 result = result | (math_ops.is_nan(a) & math_ops.is_nan(b)) 478 return result 479 else: 480 return a == b 481 482 return _bin_op(f, a, b) 483 484 485@np_utils.np_doc('allclose') 486def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): 487 return np_array_ops.all( 488 isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) 489 490 491def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring 492 493 def _gcd_cond_fn(_, x2): 494 return math_ops.reduce_any(x2 != 0) 495 496 def _gcd_body_fn(x1, x2): 497 # math_ops.mod will raise an error when any element of x2 is 0. To avoid 498 # that, we change those zeros to ones. Their values don't matter because 499 # they won't be used. 500 x2_safe = array_ops.where_v2(x2 != 0, x2, constant_op.constant(1, x2.dtype)) 501 x1, x2 = (array_ops.where_v2(x2 != 0, x2, x1), 502 array_ops.where_v2(x2 != 0, math_ops.mod(x1, x2_safe), 503 constant_op.constant(0, x2.dtype))) 504 return (array_ops.where_v2(x1 < x2, x2, 505 x1), array_ops.where_v2(x1 < x2, x1, x2)) 506 507 if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or 508 not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)): 509 raise ValueError('Arguments to gcd must be integers.') 510 shape = array_ops.broadcast_dynamic_shape( 511 array_ops.shape(x1), array_ops.shape(x2)) 512 x1 = array_ops.broadcast_to(x1, shape) 513 x2 = array_ops.broadcast_to(x2, shape) 514 value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn, 515 (math_ops.abs(x1), math_ops.abs(x2))) 516 return value 517 518 519# Note that np.gcd may not be present in some supported versions of numpy. 520@np_utils.np_doc('gcd') 521def gcd(x1, x2): 522 return _bin_op(_tf_gcd, x1, x2) 523 524 525# Note that np.lcm may not be present in some supported versions of numpy. 526@np_utils.np_doc('lcm') 527def lcm(x1, x2): # pylint: disable=missing-function-docstring 528 529 def f(x1, x2): 530 d = _tf_gcd(x1, x2) 531 # Same as the `x2_safe` trick above 532 d_safe = array_ops.where_v2( 533 math_ops.equal(d, 0), constant_op.constant(1, d.dtype), d) 534 return array_ops.where_v2( 535 math_ops.equal(d, 0), constant_op.constant(0, d.dtype), 536 math_ops.abs(x1 * x2) // d_safe) 537 538 return _bin_op(f, x1, x2) 539 540 541def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring 542 543 def f(x1, x2): 544 is_bool = (x1.dtype == dtypes.bool) 545 if is_bool: 546 assert x2.dtype == dtypes.bool 547 x1 = math_ops.cast(x1, dtypes.int8) 548 x2 = math_ops.cast(x2, dtypes.int8) 549 r = tf_fn(x1, x2) 550 if is_bool: 551 r = math_ops.cast(r, dtypes.bool) 552 return r 553 554 return _bin_op(f, x1, x2) 555 556 557@np_utils.np_doc('bitwise_and') 558def bitwise_and(x1, x2): 559 return _bitwise_binary_op(bitwise_ops.bitwise_and, x1, x2) 560 561 562@np_utils.np_doc('bitwise_or') 563def bitwise_or(x1, x2): 564 return _bitwise_binary_op(bitwise_ops.bitwise_or, x1, x2) 565 566 567@np_utils.np_doc('bitwise_xor') 568def bitwise_xor(x1, x2): 569 return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2) 570 571 572@np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert')) 573def bitwise_not(x): 574 575 def f(x): 576 if x.dtype == dtypes.bool: 577 return math_ops.logical_not(x) 578 return bitwise_ops.invert(x) 579 580 return _scalar(f, x) 581 582 583def _scalar(tf_fn, x, promote_to_float=False): 584 """Computes the tf_fn(x) for each element in `x`. 585 586 Args: 587 tf_fn: function that takes a single Tensor argument. 588 x: array_like. Could be an ndarray, a Tensor or any object that can be 589 converted to a Tensor using `ops.convert_to_tensor`. 590 promote_to_float: whether to cast the argument to a float dtype 591 (`np_dtypes.default_float_type`) if it is not already. 592 593 Returns: 594 An ndarray with the same shape as `x`. The default output dtype is 595 determined by `np_dtypes.default_float_type`, unless x is an ndarray with a 596 floating point type, in which case the output type is same as x.dtype. 597 """ 598 x = np_array_ops.asarray(x) 599 if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact): 600 x = x.astype(np_dtypes.default_float_type()) 601 return tf_fn(x) 602 603 604@np_utils.np_doc('log') 605def log(x): 606 return _scalar(math_ops.log, x, True) 607 608 609@np_utils.np_doc('exp') 610def exp(x): 611 return _scalar(math_ops.exp, x, True) 612 613 614@np_utils.np_doc('sqrt') 615def sqrt(x): 616 return _scalar(math_ops.sqrt, x, True) 617 618 619@np_utils.np_doc('abs', link=np_utils.AliasOf('absolute')) 620def abs(x): # pylint: disable=redefined-builtin 621 return _scalar(math_ops.abs, x) 622 623 624@np_utils.np_doc('absolute') 625def absolute(x): 626 return abs(x) 627 628 629@np_utils.np_doc('fabs') 630def fabs(x): 631 return abs(x) 632 633 634@np_utils.np_doc('ceil') 635def ceil(x): 636 return _scalar(math_ops.ceil, x, True) 637 638 639@np_utils.np_doc('floor') 640def floor(x): 641 return _scalar(math_ops.floor, x, True) 642 643 644@np_utils.np_doc('conj') 645def conj(x): 646 return _scalar(math_ops.conj, x) 647 648 649@np_utils.np_doc('negative') 650def negative(x): 651 return _scalar(math_ops.negative, x) 652 653 654@np_utils.np_doc('reciprocal') 655def reciprocal(x): 656 return _scalar(math_ops.reciprocal, x) 657 658 659@np_utils.np_doc('signbit') 660def signbit(x): 661 662 def f(x): 663 if x.dtype == dtypes.bool: 664 return array_ops.fill(array_ops.shape(x), False) 665 return x < 0 666 667 return _scalar(f, x) 668 669 670@np_utils.np_doc('sin') 671def sin(x): 672 return _scalar(math_ops.sin, x, True) 673 674 675@np_utils.np_doc('cos') 676def cos(x): 677 return _scalar(math_ops.cos, x, True) 678 679 680@np_utils.np_doc('tan') 681def tan(x): 682 return _scalar(math_ops.tan, x, True) 683 684 685@np_utils.np_doc('sinh') 686def sinh(x): 687 return _scalar(math_ops.sinh, x, True) 688 689 690@np_utils.np_doc('cosh') 691def cosh(x): 692 return _scalar(math_ops.cosh, x, True) 693 694 695@np_utils.np_doc('tanh') 696def tanh(x): 697 return _scalar(math_ops.tanh, x, True) 698 699 700@np_utils.np_doc('arcsin') 701def arcsin(x): 702 return _scalar(math_ops.asin, x, True) 703 704 705@np_utils.np_doc('arccos') 706def arccos(x): 707 return _scalar(math_ops.acos, x, True) 708 709 710@np_utils.np_doc('arctan') 711def arctan(x): 712 return _scalar(math_ops.atan, x, True) 713 714 715@np_utils.np_doc('arcsinh') 716def arcsinh(x): 717 return _scalar(math_ops.asinh, x, True) 718 719 720@np_utils.np_doc('arccosh') 721def arccosh(x): 722 return _scalar(math_ops.acosh, x, True) 723 724 725@np_utils.np_doc('arctanh') 726def arctanh(x): 727 return _scalar(math_ops.atanh, x, True) 728 729 730@np_utils.np_doc('deg2rad') 731def deg2rad(x): 732 733 def f(x): 734 return x * (np.pi / 180.0) 735 736 return _scalar(f, x, True) 737 738 739@np_utils.np_doc('rad2deg') 740def rad2deg(x): 741 return x * (180.0 / np.pi) 742 743 744_tf_float_types = [ 745 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 746] 747 748 749@np_utils.np_doc('angle') 750def angle(z, deg=False): # pylint: disable=missing-function-docstring 751 752 def f(x): 753 if x.dtype in _tf_float_types: 754 # Workaround for b/147515503 755 return array_ops.where_v2(x < 0, np.pi, 0) 756 else: 757 return math_ops.angle(x) 758 759 y = _scalar(f, z, True) 760 if deg: 761 y = rad2deg(y) 762 return y 763 764 765@np_utils.np_doc('cbrt') 766def cbrt(x): 767 768 def f(x): 769 # __pow__ can't handle negative base, so we use `abs` here. 770 rt = math_ops.abs(x)**(1.0 / 3) 771 return array_ops.where_v2(x < 0, -rt, rt) 772 773 return _scalar(f, x, True) 774 775 776@np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj')) 777def conjugate(x): 778 return _scalar(math_ops.conj, x) 779 780 781@np_utils.np_doc('exp2') 782def exp2(x): 783 784 def f(x): 785 return 2**x 786 787 return _scalar(f, x, True) 788 789 790@np_utils.np_doc('expm1') 791def expm1(x): 792 return _scalar(math_ops.expm1, x, True) 793 794 795@np_utils.np_doc('fix') 796def fix(x): 797 798 def f(x): 799 return array_ops.where_v2(x < 0, math_ops.ceil(x), math_ops.floor(x)) 800 801 return _scalar(f, x, True) 802 803 804@np_utils.np_doc('iscomplex') 805def iscomplex(x): 806 return np_array_ops.imag(x) != 0 807 808 809@np_utils.np_doc('isreal') 810def isreal(x): 811 return np_array_ops.imag(x) == 0 812 813 814@np_utils.np_doc('iscomplexobj') 815def iscomplexobj(x): 816 x = np_array_ops.array(x) 817 return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating) 818 819 820@np_utils.np_doc('isrealobj') 821def isrealobj(x): 822 return not iscomplexobj(x) 823 824 825@np_utils.np_doc('isnan') 826def isnan(x): 827 return _scalar(math_ops.is_nan, x, True) 828 829 830def _make_nan_reduction(np_fun_name, reduction, init_val): 831 """Helper to generate nan* functions.""" 832 833 @np_utils.np_doc(np_fun_name) 834 def nan_reduction(a, axis=None, dtype=None, keepdims=False): 835 a = np_array_ops.array(a) 836 v = np_array_ops.array(init_val, dtype=a.dtype) 837 return reduction( 838 np_array_ops.where(isnan(a), v, a), 839 axis=axis, 840 dtype=dtype, 841 keepdims=keepdims) 842 843 return nan_reduction 844 845 846nansum = _make_nan_reduction('nansum', np_array_ops.sum, 0) 847nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1) 848 849 850@np_utils.np_doc('nanmean') 851def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring 852 a = np_array_ops.array(a) 853 if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype( 854 a.dtype.as_numpy_dtype, np.integer): 855 return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) 856 nan_mask = logical_not(isnan(a)) 857 if dtype is None: 858 dtype = a.dtype.as_numpy_dtype 859 normalizer = np_array_ops.sum( 860 nan_mask, axis=axis, dtype=dtype, keepdims=keepdims) 861 return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer 862 863 864@np_utils.np_doc('isfinite') 865def isfinite(x): 866 return _scalar(math_ops.is_finite, x, True) 867 868 869@np_utils.np_doc('isinf') 870def isinf(x): 871 return _scalar(math_ops.is_inf, x, True) 872 873 874@np_utils.np_doc('isneginf') 875def isneginf(x): 876 return x == np_array_ops.full_like(x, -np.inf) 877 878 879@np_utils.np_doc('isposinf') 880def isposinf(x): 881 return x == np_array_ops.full_like(x, np.inf) 882 883 884@np_utils.np_doc('log2') 885def log2(x): 886 return log(x) / np.log(2) 887 888 889@np_utils.np_doc('log10') 890def log10(x): 891 return log(x) / np.log(10) 892 893 894@np_utils.np_doc('log1p') 895def log1p(x): 896 return _scalar(math_ops.log1p, x, True) 897 898 899@np_utils.np_doc('positive') 900def positive(x): 901 return _scalar(lambda x: x, x) 902 903 904@np_utils.np_doc('sinc') 905def sinc(x): 906 907 def f(x): 908 pi_x = x * np.pi 909 return array_ops.where_v2(x == 0, array_ops.ones_like(x), 910 math_ops.sin(pi_x) / pi_x) 911 912 return _scalar(f, x, True) 913 914 915@np_utils.np_doc('square') 916def square(x): 917 return _scalar(math_ops.square, x) 918 919 920@np_utils.np_doc('diff') 921def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring 922 923 def f(a): 924 # TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution 925 # TODO(agarwal): avoid depending on static rank. 926 nd = a.shape.rank 927 if nd is None: 928 raise ValueError('diff currently requires known rank for input `a`') 929 if (axis + nd if axis < 0 else axis) >= nd: 930 raise ValueError('axis %s is out of bounds for array of dimension %s' % 931 (axis, nd)) 932 if n < 0: 933 raise ValueError('order must be non-negative but got %s' % n) 934 slice1 = [slice(None)] * nd 935 slice2 = [slice(None)] * nd 936 slice1[axis] = slice(1, None) 937 slice2[axis] = slice(None, -1) 938 slice1 = tuple(slice1) 939 slice2 = tuple(slice2) 940 op = math_ops.not_equal if a.dtype == dtypes.bool else math_ops.subtract 941 for _ in range(n): 942 a = op(a[slice1], a[slice2]) 943 return a 944 945 return _scalar(f, a) 946 947 948def _wrap(f, reverse=False): 949 """Wraps binary ops so they can be added as operator overloads on ndarray.""" 950 951 def _f(a, b): 952 if reverse: 953 a, b = b, a 954 955 if getattr(b, '__array_priority__', 956 0) > np_arrays.ndarray.__array_priority__: 957 return NotImplemented 958 959 return f(a, b) 960 961 return _f 962 963 964def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): 965 """Helper function for comparision.""" 966 dtype = np_utils.result_type(x1, x2) 967 # Cast x1 and x2 to the result_type if needed. 968 x1 = np_array_ops.array(x1, dtype=dtype) 969 x2 = np_array_ops.array(x2, dtype=dtype) 970 if cast_bool_to_int and x1.dtype == dtypes.bool: 971 x1 = math_ops.cast(x1, dtypes.int32) 972 x2 = math_ops.cast(x2, dtypes.int32) 973 return tf_fun(x1, x2) 974 975 976@np_utils.np_doc('equal') 977def equal(x1, x2): 978 return _comparison(math_ops.equal, x1, x2) 979 980 981@np_utils.np_doc('not_equal') 982def not_equal(x1, x2): 983 return _comparison(math_ops.not_equal, x1, x2) 984 985 986@np_utils.np_doc('greater') 987def greater(x1, x2): 988 return _comparison(math_ops.greater, x1, x2, True) 989 990 991@np_utils.np_doc('greater_equal') 992def greater_equal(x1, x2): 993 return _comparison(math_ops.greater_equal, x1, x2, True) 994 995 996@np_utils.np_doc('less') 997def less(x1, x2): 998 return _comparison(math_ops.less, x1, x2, True) 999 1000 1001@np_utils.np_doc('less_equal') 1002def less_equal(x1, x2): 1003 return _comparison(math_ops.less_equal, x1, x2, True) 1004 1005 1006@np_utils.np_doc('array_equal') 1007def array_equal(a1, a2): # pylint: disable=missing-function-docstring 1008 1009 def f(x1, x2): 1010 return np_utils.cond( 1011 math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)), 1012 lambda: np_utils.cond( # pylint: disable=g-long-lambda 1013 np_utils.reduce_all( 1014 math_ops.equal(array_ops.shape(x1), array_ops.shape(x2)) 1015 ), 1016 lambda: math_ops.reduce_all(math_ops.equal(x1, x2)), 1017 lambda: constant_op.constant(False)), 1018 lambda: constant_op.constant(False)) 1019 1020 return _comparison(f, a1, a2) 1021 1022 1023def _logical_binary_op(tf_fun, x1, x2): 1024 x1 = np_array_ops.array(x1, dtype=np.bool_) 1025 x2 = np_array_ops.array(x2, dtype=np.bool_) 1026 return tf_fun(x1, x2) 1027 1028 1029@np_utils.np_doc('logical_and') 1030def logical_and(x1, x2): 1031 return _logical_binary_op(math_ops.logical_and, x1, x2) 1032 1033 1034@np_utils.np_doc('logical_or') 1035def logical_or(x1, x2): 1036 return _logical_binary_op(math_ops.logical_or, x1, x2) 1037 1038 1039@np_utils.np_doc('logical_xor') 1040def logical_xor(x1, x2): 1041 return _logical_binary_op(math_ops.logical_xor, x1, x2) 1042 1043 1044@np_utils.np_doc('logical_not') 1045def logical_not(x): 1046 x = np_array_ops.array(x, dtype=np.bool_) 1047 return math_ops.logical_not(x) 1048 1049 1050@np_utils.np_doc('linspace') 1051def linspace( # pylint: disable=missing-docstring 1052 start, 1053 stop, 1054 num=50, 1055 endpoint=True, 1056 retstep=False, 1057 dtype=float, 1058 axis=0): 1059 if dtype: 1060 dtype = np_utils.result_type(dtype) 1061 start = np_array_ops.array(start, dtype=dtype) 1062 stop = np_array_ops.array(stop, dtype=dtype) 1063 if num < 0: 1064 raise ValueError('Number of samples {} must be non-negative.'.format(num)) 1065 step = ops.convert_to_tensor(np.nan) 1066 if endpoint: 1067 result = math_ops.linspace(start, stop, num, axis=axis) 1068 if num > 1: 1069 step = (stop - start) / (num - 1) 1070 else: 1071 # math_ops.linspace does not support endpoint=False so we manually handle it 1072 # here. 1073 if num > 0: 1074 step = ((stop - start) / num) 1075 if num > 1: 1076 new_stop = math_ops.cast(stop, step.dtype) - step 1077 start = math_ops.cast(start, new_stop.dtype) 1078 result = math_ops.linspace(start, new_stop, num, axis=axis) 1079 else: 1080 result = math_ops.linspace(start, stop, num, axis=axis) 1081 if dtype: 1082 result = math_ops.cast(result, dtype) 1083 if retstep: 1084 return (result, step) 1085 else: 1086 return result 1087 1088 1089@np_utils.np_doc('logspace') 1090def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): 1091 dtype = np_utils.result_type(start, stop, dtype) 1092 result = linspace( 1093 start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis) 1094 result = math_ops.pow(math_ops.cast(base, result.dtype), result) 1095 if dtype: 1096 result = math_ops.cast(result, dtype) 1097 return result 1098 1099 1100@np_utils.np_doc('geomspace') 1101def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring 1102 dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type( 1103 start, stop, float(num), np_array_ops.zeros((), dtype)) 1104 computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32) 1105 start = np_array_ops.asarray(start, dtype=computation_dtype) 1106 stop = np_array_ops.asarray(stop, dtype=computation_dtype) 1107 # follow the numpy geomspace convention for negative and complex endpoints 1108 start_sign = 1 - np_array_ops.sign(np_array_ops.real(start)) 1109 stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop)) 1110 signflip = 1 - start_sign * stop_sign // 2 1111 res = signflip * logspace( 1112 log10(signflip * start), 1113 log10(signflip * stop), 1114 num, 1115 endpoint=endpoint, 1116 base=10.0, 1117 dtype=computation_dtype, 1118 axis=0) 1119 if axis != 0: 1120 res = np_array_ops.moveaxis(res, 0, axis) 1121 return math_ops.cast(res, dtype) 1122 1123 1124@np_utils.np_doc('ptp') 1125def ptp(a, axis=None, keepdims=None): 1126 return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) - 1127 np_array_ops.amin(a, axis=axis, keepdims=keepdims)) 1128 1129 1130@np_utils.np_doc_only('concatenate') 1131def concatenate(arys, axis=0): 1132 if not isinstance(arys, (list, tuple)): 1133 arys = [arys] 1134 if not arys: 1135 raise ValueError('Need at least one array to concatenate.') 1136 dtype = np_utils.result_type(*arys) 1137 arys = [np_array_ops.array(array, dtype=dtype) for array in arys] 1138 return array_ops.concat(arys, axis) 1139 1140 1141@np_utils.np_doc_only('tile') 1142def tile(a, reps): # pylint: disable=missing-function-docstring 1143 a = np_array_ops.array(a) 1144 reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]) 1145 1146 a_rank = array_ops.rank(a) 1147 reps_size = array_ops.size(reps) 1148 reps = array_ops.pad( 1149 reps, [[math_ops.maximum(a_rank - reps_size, 0), 0]], constant_values=1) 1150 a_shape = array_ops.pad( 1151 array_ops.shape(a), [[math_ops.maximum(reps_size - a_rank, 0), 0]], 1152 constant_values=1) 1153 a = array_ops.reshape(a, a_shape) 1154 1155 return array_ops.tile(a, reps) 1156 1157 1158@np_utils.np_doc('count_nonzero') 1159def count_nonzero(a, axis=None): 1160 return math_ops.count_nonzero(np_array_ops.array(a), axis) 1161 1162 1163@np_utils.np_doc('argsort') 1164def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring 1165 # TODO(nareshmodi): make string tensors also work. 1166 if kind not in ('quicksort', 'stable'): 1167 raise ValueError("Only 'quicksort' and 'stable' arguments are supported.") 1168 if order is not None: 1169 raise ValueError("'order' argument to sort is not supported.") 1170 stable = (kind == 'stable') 1171 1172 a = np_array_ops.array(a) 1173 1174 def _argsort(a, axis, stable): 1175 if axis is None: 1176 a = array_ops.reshape(a, [-1]) 1177 axis = 0 1178 1179 return sort_ops.argsort(a, axis, stable=stable) 1180 1181 tf_ans = np_utils.cond( 1182 math_ops.equal(array_ops.rank(a), 0), lambda: constant_op.constant([0]), 1183 lambda: _argsort(a, axis, stable)) 1184 1185 return np_array_ops.array(tf_ans, dtype=np.intp) 1186 1187 1188@np_utils.np_doc('sort') 1189def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring 1190 if kind != 'quicksort': 1191 raise ValueError("Only 'quicksort' is supported.") 1192 if order is not None: 1193 raise ValueError("'order' argument to sort is not supported.") 1194 1195 a = np_array_ops.array(a) 1196 1197 if axis is None: 1198 return sort_ops.sort(array_ops.reshape(a, [-1]), 0) 1199 else: 1200 return sort_ops.sort(a, axis) 1201 1202 1203def _argminmax(fn, a, axis=None): 1204 a = np_array_ops.array(a) 1205 if axis is None: 1206 # When axis is None numpy flattens the array. 1207 a_t = array_ops.reshape(a, [-1]) 1208 else: 1209 a_t = np_array_ops.atleast_1d(a) 1210 return fn(input=a_t, axis=axis) 1211 1212 1213@np_utils.np_doc('argmax') 1214def argmax(a, axis=None): 1215 return _argminmax(math_ops.argmax, a, axis) 1216 1217 1218@np_utils.np_doc('argmin') 1219def argmin(a, axis=None): 1220 return _argminmax(math_ops.argmin, a, axis) 1221 1222 1223@np_utils.np_doc('append') 1224def append(arr, values, axis=None): 1225 if axis is None: 1226 return concatenate([np_array_ops.ravel(arr), np_array_ops.ravel(values)], 0) 1227 else: 1228 return concatenate([arr, values], axis=axis) 1229 1230 1231@np_utils.np_doc('average') 1232def average(a, axis=None, weights=None, returned=False): # pylint: disable=missing-docstring 1233 if axis is not None and not isinstance(axis, six.integer_types): 1234 # TODO(wangpeng): Support tuple of ints as `axis` 1235 raise ValueError('`axis` must be an integer. Tuple of ints is not ' 1236 'supported yet. Got type: %s' % type(axis)) 1237 a = np_array_ops.array(a) 1238 if weights is None: # Treat all weights as 1 1239 if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): 1240 a = a.astype( 1241 np_utils.result_type(a.dtype, np_dtypes.default_float_type())) 1242 avg = math_ops.reduce_mean(a, axis=axis) 1243 if returned: 1244 if axis is None: 1245 weights_sum = array_ops.size(a) 1246 else: 1247 weights_sum = array_ops.shape(a)[axis] 1248 weights_sum = math_ops.cast(weights_sum, a.dtype) 1249 else: 1250 if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): 1251 out_dtype = np_utils.result_type(a.dtype, weights) 1252 else: 1253 out_dtype = np_utils.result_type(a.dtype, weights, 1254 np_dtypes.default_float_type()) 1255 a = np_array_ops.array(a, out_dtype) 1256 weights = np_array_ops.array(weights, out_dtype) 1257 1258 def rank_equal_case(): 1259 control_flow_ops.Assert( 1260 math_ops.reduce_all(array_ops.shape(a) == array_ops.shape(weights)), 1261 [array_ops.shape(a), array_ops.shape(weights)]) 1262 weights_sum = math_ops.reduce_sum(weights, axis=axis) 1263 avg = math_ops.reduce_sum(a * weights, axis=axis) / weights_sum 1264 return avg, weights_sum 1265 1266 if axis is None: 1267 avg, weights_sum = rank_equal_case() 1268 else: 1269 1270 def rank_not_equal_case(): 1271 control_flow_ops.Assert( 1272 array_ops.rank(weights) == 1, [array_ops.rank(weights)]) 1273 weights_sum = math_ops.reduce_sum(weights) 1274 axes = ops.convert_to_tensor([[axis], [0]]) 1275 avg = math_ops.tensordot(a, weights, axes) / weights_sum 1276 return avg, weights_sum 1277 1278 # We condition on rank rather than shape equality, because if we do the 1279 # latter, when the shapes are partially unknown but the ranks are known 1280 # and different, np_utils.cond will run shape checking on the true branch, 1281 # which will raise a shape-checking error. 1282 avg, weights_sum = np_utils.cond( 1283 math_ops.equal(array_ops.rank(a), array_ops.rank(weights)), 1284 rank_equal_case, rank_not_equal_case) 1285 1286 avg = np_array_ops.array(avg) 1287 if returned: 1288 weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg)) 1289 return avg, weights_sum 1290 return avg 1291 1292 1293@np_utils.np_doc('trace') 1294def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring 1295 if dtype: 1296 dtype = np_utils.result_type(dtype) 1297 a = np_array_ops.asarray(a, dtype) 1298 1299 if offset == 0: 1300 a_shape = a.shape 1301 if a_shape.rank is not None: 1302 rank = len(a_shape) 1303 if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or 1304 axis2 == rank - 1): 1305 return math_ops.trace(a) 1306 1307 a = np_array_ops.diagonal(a, offset, axis1, axis2) 1308 return np_array_ops.sum(a, -1, dtype) 1309 1310 1311@np_utils.np_doc('meshgrid') 1312def meshgrid(*xi, **kwargs): 1313 """This currently requires copy=True and sparse=False.""" 1314 sparse = kwargs.get('sparse', False) 1315 if sparse: 1316 raise ValueError('meshgrid doesnt support returning sparse arrays yet') 1317 1318 copy = kwargs.get('copy', True) 1319 if not copy: 1320 raise ValueError('meshgrid only supports copy=True') 1321 1322 indexing = kwargs.get('indexing', 'xy') 1323 1324 xi = [np_array_ops.asarray(arg) for arg in xi] 1325 kwargs = {'indexing': indexing} 1326 1327 outputs = array_ops.meshgrid(*xi, **kwargs) 1328 1329 return outputs 1330 1331 1332# Uses np_doc_only here because np.einsum (in 1.16) doesn't have argument 1333# `subscripts`, even though the doc says it has. 1334@np_utils.np_doc_only('einsum') 1335def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstring 1336 casting = kwargs.get('casting', 'safe') 1337 optimize = kwargs.get('optimize', False) 1338 if casting == 'safe': 1339 operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access 1340 elif casting == 'no': 1341 operands = [np_array_ops.asarray(x) for x in operands] 1342 else: 1343 raise ValueError('casting policy not supported: %s' % casting) 1344 if not optimize: 1345 # TF doesn't have a "no optimization" option. 1346 # TODO(wangpeng): Print a warning that np and tf use different 1347 # optimizations. 1348 tf_optimize = 'greedy' 1349 elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison 1350 tf_optimize = 'greedy' 1351 elif optimize == 'greedy': 1352 tf_optimize = 'greedy' 1353 elif optimize == 'optimal': 1354 tf_optimize = 'optimal' 1355 else: 1356 raise ValueError('`optimize` method not supported: %s' % optimize) 1357 res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize) 1358 return res 1359 1360 1361def _tensor_t(self): 1362 """Returns a Tensor which is the transpose of this Tensor.""" 1363 return self.transpose() 1364 1365 1366def _tensor_ndim(self): 1367 """Returns the rank of the Tensor.""" 1368 return self.shape.ndims 1369 1370 1371def _tensor_pos(self): 1372 """Returns self, for unary operator `+`.""" 1373 return self 1374 1375 1376def _tensor_size(self): 1377 """Returns the number of elements in this Tensor, if fully known.""" 1378 if not self.shape.is_fully_defined(): 1379 return None 1380 return np.prod(self.shape.as_list()) 1381 1382 1383def _tensor_tolist(self): 1384 if isinstance(self, ops.EagerTensor): 1385 return self._numpy().tolist() # pylint: disable=protected-access 1386 1387 raise ValueError('Symbolic Tensors do not support the tolist API.') 1388 1389 1390def enable_numpy_methods_on_tensor(): 1391 """Adds additional NumPy methods on tf.Tensor class.""" 1392 t = property(_tensor_t) 1393 setattr(ops.Tensor, 'T', t) 1394 1395 ndim = property(_tensor_ndim) 1396 setattr(ops.Tensor, 'ndim', ndim) 1397 1398 size = property(_tensor_size) 1399 setattr(ops.Tensor, 'size', size) 1400 1401 setattr(ops.Tensor, '__pos__', _tensor_pos) 1402 setattr(ops.Tensor, 'tolist', _tensor_tolist) 1403 1404 # TODO(b/178540516): Make a custom `setattr` that changes the method's 1405 # docstring to the TF one. 1406 setattr(ops.Tensor, 'transpose', np_array_ops.transpose) 1407 setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access 1408 setattr(ops.Tensor, 'ravel', np_array_ops.ravel) 1409 setattr(ops.Tensor, 'clip', clip) 1410 setattr(ops.Tensor, 'astype', math_ops.cast) 1411 setattr(ops.Tensor, '__round__', np_array_ops.around) 1412 setattr(ops.Tensor, 'max', np_array_ops.amax) 1413 setattr(ops.Tensor, 'mean', np_array_ops.mean) 1414 setattr(ops.Tensor, 'min', np_array_ops.amin) 1415 1416 # TODO(wangpeng): Remove `data` when all uses of it are removed 1417 data = property(lambda self: self) 1418 setattr(ops.Tensor, 'data', data) 1419