• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Math Operations.
16
17Note: Functions taking `Tensor` arguments can also take anything accepted by
18`tf.convert_to_tensor`.
19
20Note: Elementwise binary operations in TensorFlow follow [numpy-style
21broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
22
23TensorFlow provides a variety of math functions including:
24
25* Basic arithmetic operators and trigonometric functions.
26* Special math functions (like: `tf.math.igamma` and `tf.math.zeta`)
27* Complex number functions (like: `tf.math.imag` and `tf.math.angle`)
28* Reductions and scans (like: `tf.math.reduce_mean` and `tf.math.cumsum`)
29* Segment functions (like: `tf.math.segment_sum`)
30
31See: `tf.linalg` for matrix and tensor functions.
32
33<a id=Segmentation></a>
34
35## About Segmentation
36
37TensorFlow provides several operations that you can use to perform common
38math computations on tensor segments.
39Here a segmentation is a partitioning of a tensor along
40the first dimension, i.e. it  defines a mapping from the first dimension onto
41`segment_ids`. The `segment_ids` tensor should be the size of
42the first dimension, `d0`, with consecutive IDs in the range `0` to `k`,
43where `k<d0`.
44In particular, a segmentation of a matrix tensor is a mapping of rows to
45segments.
46
47For example:
48
49```python
50c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
51tf.segment_sum(c, tf.constant([0, 0, 1]))
52#  ==>  [[0 0 0 0]
53#        [5 6 7 8]]
54```
55
56The standard `segment_*` functions assert that the segment indices are sorted.
57If you have unsorted indices use the equivalent `unsorted_segment_` function.
58Thses functions take an additional argument `num_segments` so that the output
59tensor can be efficiently allocated.
60
61``` python
62c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
63tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2)
64# ==> [[ 6,  8, 10, 12],
65#       [-1, -2, -3, -4]]
66```
67
68"""
69from __future__ import absolute_import
70from __future__ import division
71from __future__ import print_function
72
73import numpy as np
74from six.moves import xrange  # pylint: disable=redefined-builtin
75
76from tensorflow.python.eager import context
77from tensorflow.python.framework import common_shapes
78from tensorflow.python.framework import constant_op
79from tensorflow.python.framework import dtypes
80from tensorflow.python.framework import graph_util
81from tensorflow.python.framework import ops
82from tensorflow.python.framework import sparse_tensor
83from tensorflow.python.framework import tensor_shape
84from tensorflow.python.ops import array_ops
85from tensorflow.python.ops import gen_data_flow_ops
86from tensorflow.python.ops import gen_math_ops
87from tensorflow.python.ops import gen_nn_ops
88from tensorflow.python.ops import gen_sparse_ops
89# go/tf-wildcard-import
90# pylint: disable=wildcard-import
91from tensorflow.python.ops.gen_math_ops import *
92# pylint: enable=wildcard-import
93from tensorflow.python.platform import tf_logging as logging
94from tensorflow.python.util import compat
95from tensorflow.python.util import deprecation
96from tensorflow.python.util import dispatch
97from tensorflow.python.util import nest
98from tensorflow.python.util.tf_export import tf_export
99
100# Aliases for some automatically-generated names.
101linspace = gen_math_ops.lin_space
102nextafter = gen_math_ops.next_after
103
104arg_max = deprecation.deprecated(None, "Use `tf.math.argmax` instead")(arg_max)  # pylint: disable=used-before-assignment
105arg_min = deprecation.deprecated(None, "Use `tf.math.argmin` instead")(arg_min)  # pylint: disable=used-before-assignment
106tf_export(v1=["arg_max"])(arg_max)
107tf_export(v1=["arg_min"])(arg_min)
108
109# This is set by resource_variable_ops.py. It is included in this way since
110# there is a circular dependency between math_ops and resource_variable_ops
111_resource_variable_type = None
112
113
114def _set_doc(doc):
115
116  def _decorator(func):
117    func.__doc__ = doc
118    return func
119
120  return _decorator
121
122
123# pylint: disable=redefined-builtin
124@tf_export(v1=["math.argmax", "argmax"])
125@deprecation.deprecated_args(None, "Use the `axis` argument instead",
126                             "dimension")
127@_set_doc(
128    gen_math_ops.arg_max.__doc__.replace("dimensions", "axes").replace(
129        "dimension", "axis"))
130def argmax(input,
131           axis=None,
132           name=None,
133           dimension=None,
134           output_type=dtypes.int64):
135  axis = deprecation.deprecated_argument_lookup(
136      "axis", axis, "dimension", dimension)
137  return argmax_v2(input, axis, output_type, name)
138
139
140@tf_export("math.argmax", "argmax", v1=[])
141def argmax_v2(input,
142              axis=None,
143              output_type=dtypes.int64,
144              name=None):
145  """Returns the index with the largest value across axes of a tensor.
146
147  Note that in case of ties the identity of the return value is not guaranteed.
148
149  Args:
150    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
151    `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`,
152    `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`.
153    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`.
154      int32 or int64, must be in the range `-rank(input), rank(input))`.
155      Describes which axis of the input Tensor to reduce across. For vectors,
156      use axis = 0.
157    output_type: An optional `tf.DType` from: `tf.int32, tf.int64`.
158      Defaults to `tf.int64`.
159    name: A name for the operation (optional).
160
161  Returns:
162    A `Tensor` of type `output_type`.
163  """
164  if axis is None:
165    axis = 0
166  return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
167
168
169@tf_export(v1=["math.argmin", "argmin"])
170@deprecation.deprecated_args(None, "Use the `axis` argument instead",
171                             "dimension")
172@_set_doc(
173    gen_math_ops.arg_min.__doc__.replace("dimensions", "axes").replace(
174        "dimension", "axis"))
175def argmin(input,
176           axis=None,
177           name=None,
178           dimension=None,
179           output_type=dtypes.int64):
180  axis = deprecation.deprecated_argument_lookup(
181      "axis", axis, "dimension", dimension)
182  return argmin_v2(input, axis, output_type, name)
183
184
185@tf_export("math.argmin", "argmin", v1=[])
186def argmin_v2(input,
187              axis=None,
188              output_type=dtypes.int64,
189              name=None):
190  """Returns the index with the smallest value across axes of a tensor.
191
192  Note that in case of ties the identity of the return value is not guaranteed.
193
194  Args:
195    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
196    `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`,
197    `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`.
198    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`.
199      int32 or int64, must be in the range `-rank(input), rank(input))`.
200      Describes which axis of the input Tensor to reduce across. For vectors,
201      use axis = 0.
202    output_type: An optional `tf.DType` from: `tf.int32, tf.int64`.
203      Defaults to `tf.int64`.
204    name: A name for the operation (optional).
205
206  Returns:
207    A `Tensor` of type `output_type`.
208  """
209  if axis is None:
210    axis = 0
211  return gen_math_ops.arg_min(input, axis, name=name, output_type=output_type)
212
213
214# pylint: enable=redefined-builtin
215
216
217# pylint: disable=anomalous-backslash-in-string,protected-access
218# pylint: disable=g-docstring-has-escape
219@tf_export("math.abs", "abs")
220@dispatch.add_dispatch_support
221def abs(x, name=None):  # pylint: disable=redefined-builtin
222  r"""Computes the absolute value of a tensor.
223
224  Given a tensor `x` of complex numbers, this operation returns a tensor of type
225  `float32` or `float64` that is the absolute value of each element in `x`. All
226  elements in `x` must be complex numbers of the form \\(a + bj\\). The
227  absolute value is computed as \\( \sqrt{a^2 + b^2}\\).  For example:
228  ```python
229  x = tf.constant([[-2.25 + 4.75j], [-3.25 + 5.75j]])
230  tf.abs(x)  # [5.25594902, 6.60492229]
231  ```
232
233  Args:
234    x: A `Tensor` or `SparseTensor` of type `float16`, `float32`, `float64`,
235      `int32`, `int64`, `complex64` or `complex128`.
236    name: A name for the operation (optional).
237
238  Returns:
239    A `Tensor` or `SparseTensor` the same size and type as `x` with absolute
240      values.
241    Note, for `complex64` or `complex128` input, the returned `Tensor` will be
242      of type `float32` or `float64`, respectively.
243  """
244  with ops.name_scope(name, "Abs", [x]) as name:
245    x = ops.convert_to_tensor(x, name="x")
246    if x.dtype.is_complex:
247      return gen_math_ops.complex_abs(x, Tout=x.dtype.real_dtype, name=name)
248    return gen_math_ops._abs(x, name=name)
249# pylint: enable=g-docstring-has-escape
250
251
252# pylint: disable=redefined-builtin
253def _bucketize(input, boundaries, name=None):
254  return gen_math_ops.bucketize(input=input, boundaries=boundaries, name=name)
255
256
257# pylint: enable=redefined-builtin
258
259
260class DivideDelegateWithName(object):
261  """Use Python2/Python3 division delegation to implement divide for tensors."""
262
263  def __init__(self, x, name):
264    """Construct DivideDelegateWithName.
265
266    Args:
267      x: Tensor to use as left operand in operator overloads
268      name: The name that is preferred for the op created.
269    """
270    self.x = x
271    self.name = name
272
273  def __truediv__(self, y):
274    return _truediv_python3(self.x, y, self.name)
275
276  def __floordiv__(self, y):
277    return floordiv(self.x, y, self.name)
278
279  def __div__(self, y):
280    return _div_python2(self.x, y, self.name)
281
282
283@tf_export("math.divide", "divide")
284@dispatch.add_dispatch_support
285def divide(x, y, name=None):
286  """Computes Python style division of `x` by `y`."""
287
288  if name is not None:
289    # Cannot use tensors operator overload, because it has no way to track
290    # override names. Use a dummy class to track the runtime division behavior
291    return DivideDelegateWithName(x, name) / y
292  else:
293    return x / y
294
295
296@tf_export("math.multiply", "multiply")
297@dispatch.add_dispatch_support
298def multiply(x, y, name=None):
299  return gen_math_ops.mul(x, y, name)
300
301
302multiply.__doc__ = gen_math_ops.mul.__doc__.replace("Multiply", "`tf.multiply`")
303
304
305# TODO(aselle): put deprecation in after another round of global code changes
306@deprecation.deprecated(
307    "2016-12-30",
308    "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`")
309def _mul(x, y, name=None):
310  return gen_math_ops.mul(x, y, name)
311
312
313_mul.__doc__ = (
314    gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
315
316
317@tf_export("math.subtract", "subtract")
318@dispatch.add_dispatch_support
319def subtract(x, y, name=None):
320  return gen_math_ops.sub(x, y, name)
321
322
323subtract.__doc__ = gen_math_ops.sub.__doc__.replace("`Sub`", "`tf.subtract`")
324
325
326# TODO(aselle): put deprecation in after another round of global code changes
327@deprecation.deprecated(
328    "2016-12-30",
329    "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
330def _sub(x, y, name=None):
331  return gen_math_ops.sub(x, y, name)
332
333
334_sub.__doc__ = (
335    gen_math_ops.sub.__doc__ + ("" if _sub.__doc__ is None else _sub.__doc__))
336
337
338negative = gen_math_ops.neg
339
340
341# pylint: disable=g-docstring-has-escape
342@deprecation.deprecated(
343    "2016-12-30",
344    "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
345def _neg(x, name=None):
346  """Computes numerical negative value element-wise.
347
348  I.e., \\(y = -x\\).
349
350  Args:
351    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
352      `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
353    name: A name for the operation (optional).
354
355  Returns:
356    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
357  """
358  return negative(x, name)
359
360
361# pylint: enable=g-docstring-has-escape
362
363
364@tf_export(v1=["math.scalar_mul", "scalar_mul"])
365def scalar_mul(scalar, x, name=None):
366  """Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
367
368  Intended for use in gradient code which might deal with `IndexedSlices`
369  objects, which are easy to multiply by a scalar but more expensive to
370  multiply with arbitrary tensors.
371
372  Args:
373    scalar: A 0-D scalar `Tensor`. Must have known shape.
374    x: A `Tensor` or `IndexedSlices` to be scaled.
375    name: A name for the operation (optional).
376
377  Returns:
378    `scalar * x` of the same type (`Tensor` or `IndexedSlices`) as `x`.
379
380  Raises:
381    ValueError: if scalar is not a 0-D `scalar`.
382  """
383  scalar = ops.convert_to_tensor(
384      scalar, dtype=x.dtype.base_dtype, name="scalar")
385  shape = scalar.get_shape()
386  if shape.ndims == 0:
387    if isinstance(x, ops.IndexedSlices):
388      return ops.IndexedSlices(gen_math_ops.mul(scalar, x.values, name),
389                               x.indices, x.dense_shape)
390    else:
391      return gen_math_ops.mul(scalar, x, name)
392  else:
393    raise ValueError("Only scalar multiply works, got shape %s" % shape)
394
395
396@tf_export("math.scalar_mul", "scalar_mul", v1=[])
397@_set_doc(scalar_mul.__doc__)
398def scalar_mul_v2(scalar, x, name=None):
399  with ops.name_scope(name, "scalar_mul", [x]) as name:
400    return scalar_mul(scalar, x, name)
401
402
403@tf_export("math.pow", "pow")
404@dispatch.add_dispatch_support
405def pow(x, y, name=None):  # pylint: disable=redefined-builtin
406  r"""Computes the power of one value to another.
407
408  Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
409  corresponding elements in `x` and `y`. For example:
410
411  ```python
412  x = tf.constant([[2, 2], [3, 3]])
413  y = tf.constant([[8, 16], [2, 3]])
414  tf.pow(x, y)  # [[256, 65536], [9, 27]]
415  ```
416
417  Args:
418    x: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, `int64`,
419     `complex64`, or `complex128`.
420    y: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, `int64`,
421     `complex64`, or `complex128`.
422    name: A name for the operation (optional).
423
424  Returns:
425    A `Tensor`.
426  """
427  with ops.name_scope(name, "Pow", [x]) as name:
428    return gen_math_ops._pow(x, y, name=name)
429
430
431# pylint: disable=redefined-builtin,redefined-outer-name
432@tf_export("dtypes.complex", "complex")
433@dispatch.add_dispatch_support
434def complex(real, imag, name=None):
435  r"""Converts two real numbers to a complex number.
436
437  Given a tensor `real` representing the real part of a complex number, and a
438  tensor `imag` representing the imaginary part of a complex number, this
439  operation returns complex numbers elementwise of the form \\(a + bj\\), where
440  *a* represents the `real` part and *b* represents the `imag` part.
441
442  The input tensors `real` and `imag` must have the same shape.
443
444  For example:
445
446  ```python
447  real = tf.constant([2.25, 3.25])
448  imag = tf.constant([4.75, 5.75])
449  tf.complex(real, imag)  # [[2.25 + 4.75j], [3.25 + 5.75j]]
450  ```
451
452  Args:
453    real: A `Tensor`. Must be one of the following types: `float32`,
454      `float64`.
455    imag: A `Tensor`. Must have the same type as `real`.
456    name: A name for the operation (optional).
457
458  Returns:
459    A `Tensor` of type `complex64` or `complex128`.
460  """
461  real = ops.convert_to_tensor(real, name="real")
462  imag = ops.convert_to_tensor(imag, name="imag")
463  with ops.name_scope(name, "Complex", [real, imag]) as name:
464    input_types = (real.dtype, imag.dtype)
465    if input_types == (dtypes.float64, dtypes.float64):
466      Tout = dtypes.complex128
467    elif input_types == (dtypes.float32, dtypes.float32):
468      Tout = dtypes.complex64
469    else:
470      raise TypeError("real and imag have incorrect types: "
471                      "{} {}".format(real.dtype.name, imag.dtype.name))
472    return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
473
474
475@tf_export("math.real", v1=["math.real", "real"])
476@deprecation.deprecated_endpoints("real")
477@dispatch.add_dispatch_support
478def real(input, name=None):
479  r"""Returns the real part of a complex (or real) tensor.
480
481  Given a tensor `input`, this operation returns a tensor of type `float` that
482  is the real part of each element in `input` considered as a complex number.
483
484  For example:
485
486  ```python
487  x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
488  tf.real(x)  # [-2.25, 3.25]
489  ```
490
491  If `input` is already real, it is returned unchanged.
492
493  Args:
494    input: A `Tensor`. Must have numeric type.
495    name: A name for the operation (optional).
496
497  Returns:
498    A `Tensor` of type `float32` or `float64`.
499  """
500  with ops.name_scope(name, "Real", [input]) as name:
501    if input.dtype.is_complex:
502      real_dtype = input.dtype.real_dtype
503      return gen_math_ops.real(input, Tout=real_dtype, name=name)
504    else:
505      return input
506
507
508@tf_export("math.imag", v1=["math.imag", "imag"])
509@deprecation.deprecated_endpoints("imag")
510@dispatch.add_dispatch_support
511def imag(input, name=None):
512  r"""Returns the imaginary part of a complex (or real) tensor.
513
514  Given a tensor `input`, this operation returns a tensor of type `float` that
515  is the imaginary part of each element in `input` considered as a complex
516  number. If `input` is real, a tensor of all zeros is returned.
517
518  For example:
519
520  ```python
521  x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
522  tf.imag(x)  # [4.75, 5.75]
523  ```
524
525  Args:
526    input: A `Tensor`. Must be one of the following types: `float`, `double`,
527      `complex64`, `complex128`.
528    name: A name for the operation (optional).
529
530  Returns:
531    A `Tensor` of type `float32` or `float64`.
532  """
533  with ops.name_scope(name, "Imag", [input]) as name:
534    if input.dtype.is_complex:
535      return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
536    else:
537      return array_ops.zeros_like(input)
538
539
540@tf_export("math.angle", v1=["math.angle", "angle"])
541@deprecation.deprecated_endpoints("angle")
542@dispatch.add_dispatch_support
543def angle(input, name=None):
544  r"""Returns the element-wise argument of a complex (or real) tensor.
545
546  Given a tensor `input`, this operation returns a tensor of type `float` that
547  is the argument of each element in `input` considered as a complex number.
548
549  The elements in `input` are considered to be complex numbers of the form
550  \\(a + bj\\), where *a* is the real part and *b* is the imaginary part.
551  If `input` is real then *b* is zero by definition.
552
553  The argument returned by this function is of the form \\(atan2(b, a)\\).
554  If `input` is real, a tensor of all zeros is returned.
555
556  For example:
557
558  ```
559  # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
560  tf.angle(input) ==> [2.0132, 1.056]
561  ```
562
563  Args:
564    input: A `Tensor`. Must be one of the following types: `float`, `double`,
565      `complex64`, `complex128`.
566    name: A name for the operation (optional).
567
568  Returns:
569    A `Tensor` of type `float32` or `float64`.
570  """
571  with ops.name_scope(name, "Angle", [input]) as name:
572    if input.dtype.is_complex:
573      return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
574    else:
575      return array_ops.zeros_like(input)
576
577
578# pylint: enable=redefined-outer-name,redefined-builtin
579
580
581@tf_export("math.round", "round")
582@dispatch.add_dispatch_support
583def round(x, name=None):  # pylint: disable=redefined-builtin
584  """Rounds the values of a tensor to the nearest integer, element-wise.
585
586  Rounds half to even.  Also known as bankers rounding. If you want to round
587  according to the current system rounding mode use tf::cint.
588  For example:
589
590  ```python
591  x = tf.constant([0.9, 2.5, 2.3, 1.5, -4.5])
592  tf.round(x)  # [ 1.0, 2.0, 2.0, 2.0, -4.0 ]
593  ```
594
595  Args:
596    x: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, or `int64`.
597    name: A name for the operation (optional).
598
599  Returns:
600    A `Tensor` of same shape and type as `x`.
601  """
602  x = ops.convert_to_tensor(x, name="x")
603  if x.dtype.is_integer:
604    return x
605  else:
606    return gen_math_ops.round(x, name=name)
607
608
609@tf_export("dtypes.cast", "cast")
610@dispatch.add_dispatch_support
611def cast(x, dtype, name=None):
612  """Casts a tensor to a new type.
613
614  The operation casts `x` (in case of `Tensor`) or `x.values`
615  (in case of `SparseTensor` or `IndexedSlices`) to `dtype`.
616
617  For example:
618
619  ```python
620  x = tf.constant([1.8, 2.2], dtype=tf.float32)
621  tf.cast(x, tf.int32)  # [1, 2], dtype=tf.int32
622  ```
623
624  The operation supports data types (for `x` and `dtype`) of
625  `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
626  `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
627  In case of casting from complex types (`complex64`, `complex128`) to real
628  types, only the real part of `x` is returned. In case of casting from real
629  types to complex types (`complex64`, `complex128`), the imaginary part of the
630  returned value is set to `0`. The handling of complex types here matches the
631  behavior of numpy.
632
633  Args:
634    x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could
635      be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`,
636      `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`,
637      `bfloat16`.
638    dtype: The destination type. The list of supported dtypes is the same as
639      `x`.
640    name: A name for the operation (optional).
641
642  Returns:
643    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and
644      same type as `dtype`.
645
646  Raises:
647    TypeError: If `x` cannot be cast to the `dtype`.
648  """
649  base_type = dtypes.as_dtype(dtype).base_dtype
650  if isinstance(x,
651                (ops.Tensor, _resource_variable_type)) and base_type == x.dtype:
652    return x
653  with ops.name_scope(name, "Cast", [x]) as name:
654    if isinstance(x, sparse_tensor.SparseTensor):
655      values_cast = cast(x.values, base_type, name=name)
656      x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
657    elif isinstance(x, ops.IndexedSlices):
658      values_cast = cast(x.values, base_type, name=name)
659      x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape)
660    else:
661      # TODO(josh11b): If x is not already a Tensor, we could return
662      # ops.convert_to_tensor(x, dtype=dtype, ...)  here, but that
663      # allows some conversions that cast() can't do, e.g. casting numbers to
664      # strings.
665      x = ops.convert_to_tensor(x, name="x")
666      if x.dtype.base_dtype != base_type:
667        x = gen_math_ops.cast(x, base_type, name=name)
668    if x.dtype.is_complex and base_type.is_floating:
669      logging.warn("Casting complex to real discards imaginary part.")
670    return x
671
672
673@tf_export("dtypes.saturate_cast", "saturate_cast")
674@dispatch.add_dispatch_support
675def saturate_cast(value, dtype, name=None):
676  """Performs a safe saturating cast of `value` to `dtype`.
677
678  This function casts the input to `dtype` without applying any scaling.  If
679  there is a danger that values would over or underflow in the cast, this op
680  applies the appropriate clamping before the cast.
681
682  Args:
683    value: A `Tensor`.
684    dtype: The desired output `DType`.
685    name: A name for the operation (optional).
686
687  Returns:
688    `value` safely cast to `dtype`.
689  """
690  # When casting to a type with smaller representable range, clamp.
691  # Note that this covers casting to unsigned types as well.
692  with ops.name_scope(name, "saturate_cast", [value]) as name:
693    value = ops.convert_to_tensor(value, name="value")
694    dtype = dtypes.as_dtype(dtype).base_dtype
695    if value.dtype.min < dtype.min:
696      value = gen_math_ops.maximum(value,
697                                   ops.convert_to_tensor(
698                                       dtype.min, dtype=value.dtype,
699                                       name="min"))
700    if value.dtype.max > dtype.max:
701      value = gen_math_ops.minimum(value,
702                                   ops.convert_to_tensor(
703                                       dtype.max, dtype=value.dtype,
704                                       name="max"))
705    return cast(value, dtype, name=name)
706
707@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
708@tf_export(v1=["to_float"])
709def to_float(x, name="ToFloat"):
710  """Casts a tensor to type `float32`.
711
712  Args:
713    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
714    name: A name for the operation (optional).
715
716  Returns:
717    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
718    type `float32`.
719
720  Raises:
721    TypeError: If `x` cannot be cast to the `float32`.
722  """
723  return cast(x, dtypes.float32, name=name)
724
725
726@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
727@tf_export(v1=["to_double"])
728def to_double(x, name="ToDouble"):
729  """Casts a tensor to type `float64`.
730
731  Args:
732    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
733    name: A name for the operation (optional).
734
735  Returns:
736    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
737    type `float64`.
738
739  Raises:
740    TypeError: If `x` cannot be cast to the `float64`.
741  """
742  return cast(x, dtypes.float64, name=name)
743
744
745@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
746@tf_export(v1=["to_int32"])
747def to_int32(x, name="ToInt32"):
748  """Casts a tensor to type `int32`.
749
750  Args:
751    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
752    name: A name for the operation (optional).
753
754  Returns:
755    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
756    type `int32`.
757
758  Raises:
759    TypeError: If `x` cannot be cast to the `int32`.
760  """
761  return cast(x, dtypes.int32, name=name)
762
763
764@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
765@tf_export(v1=["to_int64"])
766def to_int64(x, name="ToInt64"):
767  """Casts a tensor to type `int64`.
768
769  Args:
770    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
771    name: A name for the operation (optional).
772
773  Returns:
774    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
775    type `int64`.
776
777  Raises:
778    TypeError: If `x` cannot be cast to the `int64`.
779  """
780  return cast(x, dtypes.int64, name=name)
781
782
783@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
784@tf_export(v1=["to_bfloat16"])
785def to_bfloat16(x, name="ToBFloat16"):
786  """Casts a tensor to type `bfloat16`.
787
788  Args:
789    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
790    name: A name for the operation (optional).
791
792  Returns:
793    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
794    type `bfloat16`.
795
796  Raises:
797    TypeError: If `x` cannot be cast to the `bfloat16`.
798  """
799  return cast(x, dtypes.bfloat16, name=name)
800
801
802@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
803@tf_export(v1=["to_complex64"])
804def to_complex64(x, name="ToComplex64"):
805  """Casts a tensor to type `complex64`.
806
807  Args:
808    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
809    name: A name for the operation (optional).
810
811  Returns:
812    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
813    type `complex64`.
814
815  Raises:
816    TypeError: If `x` cannot be cast to the `complex64`.
817  """
818  return cast(x, dtypes.complex64, name=name)
819
820
821@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
822@tf_export(v1=["to_complex128"])
823def to_complex128(x, name="ToComplex128"):
824  """Casts a tensor to type `complex128`.
825
826  Args:
827    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
828    name: A name for the operation (optional).
829
830  Returns:
831    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
832    type `complex128`.
833
834  Raises:
835    TypeError: If `x` cannot be cast to the `complex128`.
836  """
837  return cast(x, dtypes.complex128, name=name)
838
839
840ops.Tensor._override_operator("__neg__", gen_math_ops.neg)
841ops.Tensor._override_operator("__abs__", abs)
842# __invert__ corresponds to the ~ operator.  Here we follow the numpy convention
843# ~ marks an elementwise bit-wise inverse.  This is only implemented for boolean
844# tensors and will throw a TypeError if used on nonboolean arrays
845ops.Tensor._override_operator("__invert__", gen_math_ops.logical_not)
846
847
848def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
849  """Register operators with different tensor and scalar versions.
850
851  If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices,
852  sp_values, sp_shape, dense)` and outputs `(new_sp_values)`.
853
854  Args:
855    func: the operator
856    op_name: name of the operator being overridden
857    clazz_object: class to override for.  Either `Tensor` or `SparseTensor`.
858  """
859
860  def binary_op_wrapper(x, y):
861    with ops.name_scope(None, op_name, [x, y]) as name:
862      if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
863        return func(x, y, name=name)
864      elif not isinstance(y, sparse_tensor.SparseTensor):
865        try:
866          y = ops.convert_to_tensor_v2(y, dtype_hint=x.dtype.base_dtype,
867                                       name="y")
868        except TypeError:
869          # If the RHS is not a tensor, it might be a tensor aware object
870          # that can implement the operator with knowledge of itself
871          # and the tensor.
872          if hasattr(type(y), "__r%s__" % op_name):
873            return NotImplemented
874          else:
875            raise
876      return func(x, y, name=name)
877
878  def binary_op_wrapper_sparse(sp_x, y):
879    with ops.name_scope(None, op_name, [sp_x, y]) as name:
880      y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y")
881      return sparse_tensor.SparseTensor(sp_x.indices,
882                                        func(
883                                            sp_x.indices,
884                                            sp_x.values,
885                                            sp_x.dense_shape,
886                                            y,
887                                            name=name), sp_x.dense_shape)
888
889  def r_binary_op_wrapper(y, x):
890    with ops.name_scope(None, op_name, [x, y]) as name:
891      x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
892      return func(x, y, name=name)
893
894  # Propagate func.__doc__ to the wrappers
895  try:
896    doc = func.__doc__
897  except AttributeError:
898    doc = None
899  binary_op_wrapper.__doc__ = doc
900  r_binary_op_wrapper.__doc__ = doc
901  binary_op_wrapper_sparse.__doc__ = doc
902
903  if clazz_object is ops.Tensor:
904    clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper)
905    del binary_op_wrapper
906    clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
907    del r_binary_op_wrapper
908  else:
909    clazz_object._override_operator("__%s__" % op_name,
910                                    binary_op_wrapper_sparse)
911    del binary_op_wrapper_sparse
912
913
914# Conversion table for __truediv__.  None entries mean no conversion required.
915_TRUEDIV_TABLE = {
916    dtypes.uint8: dtypes.float32,
917    dtypes.int8: dtypes.float32,
918    dtypes.uint16: dtypes.float32,
919    dtypes.int16: dtypes.float32,
920    dtypes.int32: dtypes.float64,
921    dtypes.int64: dtypes.float64,
922    dtypes.bfloat16: None,
923    dtypes.float16: None,
924    dtypes.float32: None,
925    dtypes.float64: None,
926    dtypes.complex64: None,
927    dtypes.complex128: None,
928}
929
930
931# NOTE: the support of "sparse (true)div dense" is currently not baked in into
932# "tf.(true_)div()".  Until such an API decision is made, the supported usage is
933# to explicitly use the "/" operator to invoke either truediv or div.
934def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None):
935  """Internal helper function for 'sp_t / dense_t'."""
936  with ops.name_scope(name, "truediv",
937                      [sp_indices, sp_values, sp_shape, y]) as name:
938    sp_values = ops.convert_to_tensor(sp_values, name="sp_values")
939    y = ops.convert_to_tensor(y, name="y")
940    x_dtype = sp_values.dtype.base_dtype
941    y_dtype = y.dtype.base_dtype
942    if x_dtype != y_dtype:
943      raise TypeError("x and y must have the same dtype, got %r != %r" %
944                      (x_dtype, y_dtype))
945    try:
946      dtype = _TRUEDIV_TABLE[x_dtype]
947    except KeyError:
948      raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
949    if dtype is not None:
950      sp_values = cast(sp_values, dtype)
951      y = cast(y, dtype)
952    return gen_sparse_ops.sparse_dense_cwise_div(
953        sp_indices, sp_values, sp_shape, y, name=name)
954
955
956def _truediv_python3(x, y, name=None):
957  with ops.name_scope(name, "truediv", [x, y]) as name:
958    x = ops.convert_to_tensor(x, name="x")
959    y = ops.convert_to_tensor(y, name="y")
960    x_dtype = x.dtype.base_dtype
961    y_dtype = y.dtype.base_dtype
962    if x_dtype != y_dtype:
963      raise TypeError("x and y must have the same dtype, got %r != %r" %
964                      (x_dtype, y_dtype))
965    try:
966      dtype = _TRUEDIV_TABLE[x_dtype]
967    except KeyError:
968      raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
969    if dtype is not None:
970      x = cast(x, dtype)
971      y = cast(y, dtype)
972    return gen_math_ops.real_div(x, y, name=name)
973
974
975def _div_python2(x, y, name=None):
976  """Divide two values using Python 2 semantics. Used for Tensor.__div__.
977
978  Args:
979    x: `Tensor` numerator of real numeric type.
980    y: `Tensor` denominator of real numeric type.
981    name: A name for the operation (optional).
982  Returns:
983    `x / y` returns the quotient of x and y.
984  """
985
986  with ops.name_scope(name, "div", [x, y]) as name:
987    x = ops.convert_to_tensor(x, name="x")
988    y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
989    x_dtype = x.dtype.base_dtype
990    y_dtype = y.dtype.base_dtype
991    if x_dtype != y_dtype:
992      raise TypeError("x and y must have the same dtype, got %r != %r" %
993                      (x_dtype, y_dtype))
994    if x_dtype.is_floating or x_dtype.is_complex:
995      return gen_math_ops.real_div(x, y, name=name)
996    else:
997      return gen_math_ops.floor_div(x, y, name=name)
998
999
1000@tf_export("math.truediv", "truediv")
1001@dispatch.add_dispatch_support
1002def truediv(x, y, name=None):
1003  """Divides x / y elementwise (using Python 3 division operator semantics).
1004
1005  NOTE: Prefer using the Tensor operator or tf.divide which obey Python
1006  division operator semantics.
1007
1008  This function forces Python 3 division operator semantics where all integer
1009  arguments are cast to floating types first.   This op is generated by normal
1010  `x / y` division in Python 3 and in Python 2.7 with
1011  `from __future__ import division`.  If you want integer division that rounds
1012  down, use `x // y` or `tf.math.floordiv`.
1013
1014  `x` and `y` must have the same numeric type.  If the inputs are floating
1015  point, the output will have the same type.  If the inputs are integral, the
1016  inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
1017  and `int64` (matching the behavior of Numpy).
1018
1019  Args:
1020    x: `Tensor` numerator of numeric type.
1021    y: `Tensor` denominator of numeric type.
1022    name: A name for the operation (optional).
1023
1024  Returns:
1025    `x / y` evaluated in floating point.
1026
1027  Raises:
1028    TypeError: If `x` and `y` have different dtypes.
1029  """
1030  return _truediv_python3(x, y, name)
1031
1032
1033@deprecation.deprecated(
1034    date=None,
1035    instructions="Deprecated in favor of operator or tf.math.divide.")
1036@tf_export(v1=["div"])
1037def div(x, y, name=None):
1038  """Divides x / y elementwise (using Python 2 division operator semantics).
1039
1040  NOTE: Prefer using the Tensor division operator or tf.divide which obey Python
1041  division operator semantics.
1042
1043  This function divides `x` and `y`, forcing Python 2.7 semantics. That is,
1044  if one of `x` or `y` is a float, then the result will be a float.
1045  Otherwise, the output will be an integer type. Flooring semantics are used
1046  for integer division.
1047
1048  Args:
1049    x: `Tensor` numerator of real numeric type.
1050    y: `Tensor` denominator of real numeric type.
1051    name: A name for the operation (optional).
1052  Returns:
1053    `x / y` returns the quotient of x and y.
1054  """
1055  return _div_python2(x, y, name)
1056
1057
1058@tf_export("math.divide_no_nan", v1=["math.divide_no_nan", "div_no_nan"])
1059@deprecation.deprecated_endpoints("div_no_nan")
1060@dispatch.add_dispatch_support
1061def div_no_nan(x, y, name=None):
1062  """Computes an unsafe divide which returns 0 if the y is zero.
1063
1064  Args:
1065    x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
1066    y: A `Tensor` whose dtype is compatible with `x`.
1067    name: A name for the operation (optional).
1068  Returns:
1069    The element-wise value of the x divided by y.
1070  """
1071
1072  with ops.name_scope(name, "div_no_nan", [x, y]) as name:
1073    x = ops.convert_to_tensor(x, name="x")
1074    y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
1075    x_dtype = x.dtype.base_dtype
1076    y_dtype = y.dtype.base_dtype
1077    if x_dtype != y_dtype:
1078      raise TypeError("x and y must have the same dtype, got %r != %r" %
1079                      (x_dtype, y_dtype))
1080    return gen_math_ops.div_no_nan(x, y, name=name)
1081
1082
1083@tf_export("math.multiply_no_nan")
1084@dispatch.add_dispatch_support
1085def multiply_no_nan(x, y, name=None):
1086  """Computes the product of x and y and returns 0 if the y is zero, even if x is NaN or infinite.
1087
1088  Args:
1089    x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
1090    y: A `Tensor` whose dtype is compatible with `x`.
1091    name: A name for the operation (optional).
1092
1093  Returns:
1094    The element-wise value of the x times y.
1095  """
1096
1097  with ops.name_scope(name, "multiply_no_nan", [x, y]) as name:
1098    x = ops.convert_to_tensor(x, name="x")
1099    y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
1100    x_dtype = x.dtype.base_dtype
1101    y_dtype = y.dtype.base_dtype
1102    if x_dtype != y_dtype:
1103      raise TypeError(
1104          "x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype))
1105    return gen_math_ops.mul_no_nan(x, y, name=name)
1106
1107
1108# TODO(aselle): This should be removed
1109mod = gen_math_ops.floor_mod
1110
1111
1112# TODO(aselle): Deprecate this once all internal functionality uses
1113# tf.truncatediv
1114@tf_export("math.floordiv", v1=["math.floordiv", "floordiv"])
1115@dispatch.add_dispatch_support
1116@deprecation.deprecated_endpoints("floordiv")
1117def floordiv(x, y, name=None):
1118  """Divides `x / y` elementwise, rounding toward the most negative integer.
1119
1120  The same as `tf.div(x,y)` for integers, but uses `tf.floor(tf.div(x,y))` for
1121  floating point arguments so that the result is always an integer (though
1122  possibly an integer represented as floating point).  This op is generated by
1123  `x // y` floor division in Python 3 and in Python 2.7 with
1124  `from __future__ import division`.
1125
1126  `x` and `y` must have the same type, and the result will have the same type
1127  as well.
1128
1129  Args:
1130    x: `Tensor` numerator of real numeric type.
1131    y: `Tensor` denominator of real numeric type.
1132    name: A name for the operation (optional).
1133
1134  Returns:
1135    `x / y` rounded down.
1136
1137  Raises:
1138    TypeError: If the inputs are complex.
1139  """
1140  with ops.name_scope(name, "floordiv", [x, y]) as name:
1141    return gen_math_ops.floor_div(x, y, name=name)
1142
1143
1144realdiv = gen_math_ops.real_div
1145truncatediv = gen_math_ops.truncate_div
1146# TODO(aselle): Rename this to floordiv when we can.
1147floor_div = gen_math_ops.floor_div
1148truncatemod = gen_math_ops.truncate_mod
1149floormod = gen_math_ops.floor_mod
1150
1151
1152def _mul_dispatch(x, y, name=None):
1153  """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
1154  is_tensor_y = isinstance(y, ops.Tensor)
1155  if is_tensor_y:
1156    return gen_math_ops.mul(x, y, name=name)
1157  else:
1158    assert isinstance(y, sparse_tensor.SparseTensor)  # Case: Dense * Sparse.
1159    new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
1160                                                     y.dense_shape, x, name)
1161    return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
1162
1163
1164# NOTE(aselle): When integer division is added for sparse_dense_cwise,
1165# div, truediv, and floordiv should be delegated appropriately for
1166# Python sematnics, analogous to dense cwise tensor operations.
1167_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div",
1168                              sparse_tensor.SparseTensor)
1169_OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv",
1170                              sparse_tensor.SparseTensor)
1171_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul",
1172                              sparse_tensor.SparseTensor)
1173
1174_OverrideBinaryOperatorHelper(gen_math_ops.add, "add")
1175_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub")
1176_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
1177_OverrideBinaryOperatorHelper(_div_python2, "div")
1178_OverrideBinaryOperatorHelper(_truediv_python3, "truediv")
1179_OverrideBinaryOperatorHelper(floordiv, "floordiv")
1180_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
1181_OverrideBinaryOperatorHelper(pow, "pow")
1182
1183
1184@tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"])
1185@dispatch.add_dispatch_support
1186@deprecation.deprecated_endpoints("logical_xor")
1187def logical_xor(x, y, name="LogicalXor"):
1188  """x ^ y = (x | y) & ~(x & y)."""
1189  # TODO(alemi) Make this a cwise op if people end up relying on it.
1190  return gen_math_ops.logical_and(
1191      gen_math_ops.logical_or(x, y),
1192      gen_math_ops.logical_not(gen_math_ops.logical_and(x, y)),
1193      name=name)
1194
1195
1196_OverrideBinaryOperatorHelper(gen_math_ops.logical_and, "and")
1197_OverrideBinaryOperatorHelper(gen_math_ops.logical_or, "or")
1198_OverrideBinaryOperatorHelper(logical_xor, "xor")
1199
1200ops.Tensor._override_operator("__lt__", gen_math_ops.less)
1201ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
1202ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
1203ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
1204
1205
1206@tf_export("range")
1207def range(start, limit=None, delta=1, dtype=None, name="range"):  # pylint: disable=redefined-builtin
1208  """Creates a sequence of numbers.
1209
1210  Creates a sequence of numbers that begins at `start` and extends by
1211  increments of `delta` up to but not including `limit`.
1212
1213  The dtype of the resulting tensor is inferred from the inputs unless
1214  it is provided explicitly.
1215
1216  Like the Python builtin `range`, `start` defaults to 0, so that
1217  `range(n) = range(0, n)`.
1218
1219  For example:
1220
1221  ```python
1222  start = 3
1223  limit = 18
1224  delta = 3
1225  tf.range(start, limit, delta)  # [3, 6, 9, 12, 15]
1226
1227  start = 3
1228  limit = 1
1229  delta = -0.5
1230  tf.range(start, limit, delta)  # [3, 2.5, 2, 1.5]
1231
1232  limit = 5
1233  tf.range(limit)  # [0, 1, 2, 3, 4]
1234  ```
1235
1236  Args:
1237    start: A 0-D `Tensor` (scalar). Acts as first entry in the range if
1238      `limit` is not None; otherwise, acts as range limit and first entry
1239      defaults to 0.
1240    limit: A 0-D `Tensor` (scalar). Upper limit of sequence,
1241      exclusive. If None, defaults to the value of `start` while the first
1242      entry of the range defaults to 0.
1243    delta: A 0-D `Tensor` (scalar). Number that increments
1244      `start`. Defaults to 1.
1245    dtype: The type of the elements of the resulting tensor.
1246    name: A name for the operation. Defaults to "range".
1247
1248  Returns:
1249    An 1-D `Tensor` of type `dtype`.
1250
1251  @compatibility(numpy)
1252  Equivalent to np.arange
1253  @end_compatibility
1254  """
1255  if limit is None:
1256    start, limit = 0, start
1257
1258  with ops.name_scope(name, "Range", [start, limit, delta]) as name:
1259    start = ops.convert_to_tensor(start, dtype=dtype, name="start")
1260    limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit")
1261    delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta")
1262
1263    # infer dtype if not explicitly provided
1264    if dtype is None:
1265      dtype_hierarchy = [
1266          dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
1267      ]
1268      assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
1269      inferred_dtype = max(
1270          [arg.dtype for arg in [start, limit, delta]],
1271          key=dtype_hierarchy.index)
1272
1273      start = cast(start, inferred_dtype)
1274      limit = cast(limit, inferred_dtype)
1275      delta = cast(delta, inferred_dtype)
1276
1277    return gen_math_ops._range(start, limit, delta, name=name)
1278
1279
1280# Reduction operations
1281def _ReductionDims(x, axis, reduction_indices=None):  # pylint: disable=invalid-name
1282  """Returns range(0, rank(x)) if reduction_indices is None."""
1283  # TODO(aselle): Remove this after deprecation
1284  if reduction_indices is not None:
1285    if axis is not None:
1286      raise ValueError("Can't specify both axis' and 'reduction_indices'.")
1287    axis = reduction_indices
1288  if axis is not None:
1289    return axis
1290  else:
1291    # Fast path: avoid creating Rank and Range ops if ndims is known.
1292    rank = common_shapes.rank(x)
1293    if rank is not None:
1294      return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
1295    if (isinstance(x, sparse_tensor.SparseTensor) and
1296        x.dense_shape.shape.is_fully_defined()):
1297      rank = x.dense_shape.shape.dims[0].value  # sparse.dense_shape is 1-D.
1298      return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
1299
1300    # Otherwise, we rely on Range and Rank to do the right thing at run-time.
1301    return range(0, array_ops.rank(x))
1302
1303
1304def _may_reduce_to_scalar(keepdims, axis, output):
1305  """Set a reduction's output shape to be a scalar if we are certain."""
1306  if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and (
1307      axis is None):
1308    output.set_shape(())
1309  return output
1310
1311
1312@tf_export(v1=["math.reduce_sum", "reduce_sum"])
1313@deprecation.deprecated_args(
1314    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1315def reduce_sum_v1(input_tensor,
1316                  axis=None,
1317                  keepdims=None,
1318                  name=None,
1319                  reduction_indices=None,
1320                  keep_dims=None):
1321  """Computes the sum of elements across dimensions of a tensor.
1322
1323  Reduces `input_tensor` along the dimensions given in `axis`.
1324  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1325  entry in `axis`. If `keepdims` is true, the reduced dimensions
1326  are retained with length 1.
1327
1328  If `axis` is None, all dimensions are reduced, and a
1329  tensor with a single element is returned.
1330
1331  For example:
1332
1333  ```python
1334  x = tf.constant([[1, 1, 1], [1, 1, 1]])
1335  tf.reduce_sum(x)  # 6
1336  tf.reduce_sum(x, 0)  # [2, 2, 2]
1337  tf.reduce_sum(x, 1)  # [3, 3]
1338  tf.reduce_sum(x, 1, keepdims=True)  # [[3], [3]]
1339  tf.reduce_sum(x, [0, 1])  # 6
1340  ```
1341
1342  Args:
1343    input_tensor: The tensor to reduce. Should have numeric type.
1344    axis: The dimensions to reduce. If `None` (the default),
1345      reduces all dimensions. Must be in the range
1346      `[-rank(input_tensor), rank(input_tensor))`.
1347    keepdims: If true, retains reduced dimensions with length 1.
1348    name: A name for the operation (optional).
1349    reduction_indices: The old (deprecated) name for axis.
1350    keep_dims: Deprecated alias for `keepdims`.
1351
1352  Returns:
1353    The reduced tensor, of the same dtype as the input_tensor.
1354
1355  @compatibility(numpy)
1356  Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to
1357  int64 while tensorflow returns the same dtype as the input.
1358  @end_compatibility
1359  """
1360  axis = deprecation.deprecated_argument_lookup(
1361      "axis", axis, "reduction_indices", reduction_indices)
1362  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1363                                                    "keep_dims", keep_dims)
1364  return reduce_sum(input_tensor, axis, keepdims, name)
1365
1366
1367@tf_export("math.reduce_sum", "reduce_sum", v1=[])
1368@dispatch.add_dispatch_support
1369def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
1370  """Computes the sum of elements across dimensions of a tensor.
1371
1372  Reduces `input_tensor` along the dimensions given in `axis`.
1373  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1374  entry in `axis`. If `keepdims` is true, the reduced dimensions
1375  are retained with length 1.
1376
1377  If `axis` is None, all dimensions are reduced, and a
1378  tensor with a single element is returned.
1379
1380  For example:
1381
1382  ```python
1383  x = tf.constant([[1, 1, 1], [1, 1, 1]])
1384  tf.reduce_sum(x)  # 6
1385  tf.reduce_sum(x, 0)  # [2, 2, 2]
1386  tf.reduce_sum(x, 1)  # [3, 3]
1387  tf.reduce_sum(x, 1, keepdims=True)  # [[3], [3]]
1388  tf.reduce_sum(x, [0, 1])  # 6
1389  ```
1390
1391  Args:
1392    input_tensor: The tensor to reduce. Should have numeric type.
1393    axis: The dimensions to reduce. If `None` (the default), reduces all
1394      dimensions. Must be in the range `[-rank(input_tensor),
1395      rank(input_tensor))`.
1396    keepdims: If true, retains reduced dimensions with length 1.
1397    name: A name for the operation (optional).
1398
1399  Returns:
1400    The reduced tensor, of the same dtype as the input_tensor.
1401
1402  @compatibility(numpy)
1403  Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to
1404  int64 while tensorflow returns the same dtype as the input.
1405  @end_compatibility
1406  """
1407  keepdims = False if keepdims is None else keepdims
1408  return _may_reduce_to_scalar(
1409      keepdims, axis,
1410      gen_math_ops._sum(
1411          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
1412          name=name))
1413
1414
1415@tf_export("math.reduce_euclidean_norm")
1416def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
1417  """Computes the Euclidean norm of elements across dimensions of a tensor.
1418
1419  Reduces `input_tensor` along the dimensions given in `axis`.
1420  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1421  entry in `axis`. If `keepdims` is true, the reduced dimensions
1422  are retained with length 1.
1423
1424  If `axis` is None, all dimensions are reduced, and a
1425  tensor with a single element is returned.
1426
1427  For example:
1428
1429  ```python
1430  x = tf.constant([[1, 2, 3], [1, 1, 1]])
1431  tf.reduce_euclidean_norm(x)  # sqrt(17)
1432  tf.reduce_euclidean_norm(x, 0)  # [sqrt(2), sqrt(5), sqrt(10)]
1433  tf.reduce_euclidean_norm(x, 1)  # [sqrt(14), sqrt(3)]
1434  tf.reduce_euclidean_norm(x, 1, keepdims=True)  # [[sqrt(14)], [sqrt(3)]]
1435  tf.reduce_euclidean_norm(x, [0, 1])  # sqrt(17)
1436  ```
1437
1438  Args:
1439    input_tensor: The tensor to reduce. Should have numeric type.
1440    axis: The dimensions to reduce. If `None` (the default), reduces all
1441      dimensions. Must be in the range `[-rank(input_tensor),
1442      rank(input_tensor))`.
1443    keepdims: If true, retains reduced dimensions with length 1.
1444    name: A name for the operation (optional).
1445
1446  Returns:
1447    The reduced tensor, of the same dtype as the input_tensor.
1448  """
1449  return _may_reduce_to_scalar(
1450      keepdims, axis,
1451      gen_math_ops.euclidean_norm(
1452          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
1453          name=name))
1454
1455
1456@tf_export(v1=["math.count_nonzero", "count_nonzero"])
1457@deprecation.deprecated_args(
1458    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1459@deprecation.deprecated_args(
1460    None, "reduction_indices is deprecated, use axis instead", "axis")
1461def count_nonzero(input_tensor=None,
1462                  axis=None,
1463                  keepdims=None,
1464                  dtype=dtypes.int64,
1465                  name=None,
1466                  reduction_indices=None,
1467                  keep_dims=None,
1468                  input=None):  # pylint: disable=redefined-builtin
1469  """Computes number of nonzero elements across dimensions of a tensor.
1470
1471  Reduces `input_tensor` along the dimensions given in `axis`.
1472  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1473  entry in `axis`. If `keepdims` is true, the reduced dimensions
1474  are retained with length 1.
1475
1476  If `axis` has no entries, all dimensions are reduced, and a
1477  tensor with a single element is returned.
1478
1479  **NOTE** Floating point comparison to zero is done by exact floating point
1480  equality check.  Small values are **not** rounded to zero for purposes of
1481  the nonzero check.
1482
1483  For example:
1484
1485  ```python
1486  x = tf.constant([[0, 1, 0], [1, 1, 0]])
1487  tf.count_nonzero(x)  # 3
1488  tf.count_nonzero(x, 0)  # [1, 2, 0]
1489  tf.count_nonzero(x, 1)  # [1, 2]
1490  tf.count_nonzero(x, 1, keepdims=True)  # [[1], [2]]
1491  tf.count_nonzero(x, [0, 1])  # 3
1492  ```
1493
1494  **NOTE** Strings are compared against zero-length empty string `""`. Any
1495  string with a size greater than zero is already considered as nonzero.
1496
1497  For example:
1498  ```python
1499  x = tf.constant(["", "a", "  ", "b", ""])
1500  tf.count_nonzero(x) # 3, with "a", "  ", and "b" as nonzero strings.
1501  ```
1502
1503  Args:
1504    input_tensor: The tensor to reduce. Should be of numeric type, `bool`,
1505      or `string`.
1506    axis: The dimensions to reduce. If `None` (the default),
1507      reduces all dimensions. Must be in the range
1508      `[-rank(input_tensor), rank(input_tensor))`.
1509    keepdims: If true, retains reduced dimensions with length 1.
1510    dtype: The output dtype; defaults to `tf.int64`.
1511    name: A name for the operation (optional).
1512    reduction_indices: The old (deprecated) name for axis.
1513    keep_dims: Deprecated alias for `keepdims`.
1514    input: Overrides input_tensor. For compatibility.
1515
1516  Returns:
1517    The reduced tensor (number of nonzero values).
1518  """
1519  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1520                                                    "keep_dims", keep_dims)
1521  input_tensor = deprecation.deprecated_argument_lookup(
1522      "input", input, "input_tensor", input_tensor)
1523  axis = deprecation.deprecated_argument_lookup(
1524      "axis", axis,
1525      "reduction_indices", reduction_indices
1526      )
1527
1528  return count_nonzero_v2(input_tensor, axis, keepdims, dtype, name)
1529
1530
1531@tf_export("math.count_nonzero", v1=[])
1532def count_nonzero_v2(input,  # pylint: disable=redefined-builtin
1533                     axis=None,
1534                     keepdims=None,
1535                     dtype=dtypes.int64,
1536                     name=None):
1537  """Computes number of nonzero elements across dimensions of a tensor.
1538
1539  Reduces `input` along the dimensions given in `axis`.
1540  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1541  entry in `axis`. If `keepdims` is true, the reduced dimensions
1542  are retained with length 1.
1543
1544  If `axis` has no entries, all dimensions are reduced, and a
1545  tensor with a single element is returned.
1546
1547  **NOTE** Floating point comparison to zero is done by exact floating point
1548  equality check.  Small values are **not** rounded to zero for purposes of
1549  the nonzero check.
1550
1551  For example:
1552
1553  ```python
1554  x = tf.constant([[0, 1, 0], [1, 1, 0]])
1555  tf.count_nonzero(x)  # 3
1556  tf.count_nonzero(x, 0)  # [1, 2, 0]
1557  tf.count_nonzero(x, 1)  # [1, 2]
1558  tf.count_nonzero(x, 1, keepdims=True)  # [[1], [2]]
1559  tf.count_nonzero(x, [0, 1])  # 3
1560  ```
1561
1562  **NOTE** Strings are compared against zero-length empty string `""`. Any
1563  string with a size greater than zero is already considered as nonzero.
1564
1565  For example:
1566  ```python
1567  x = tf.constant(["", "a", "  ", "b", ""])
1568  tf.count_nonzero(x) # 3, with "a", "  ", and "b" as nonzero strings.
1569  ```
1570
1571  Args:
1572    input: The tensor to reduce. Should be of numeric type, `bool`,
1573      or `string`.
1574    axis: The dimensions to reduce. If `None` (the default),
1575      reduces all dimensions. Must be in the range
1576      `[-rank(input), rank(input))`.
1577    keepdims: If true, retains reduced dimensions with length 1.
1578    dtype: The output dtype; defaults to `tf.int64`.
1579    name: A name for the operation (optional).
1580
1581  Returns:
1582    The reduced tensor (number of nonzero values).
1583  """
1584  if keepdims is None:
1585    keepdims = False
1586  with ops.name_scope(name, "count_nonzero", [input]):
1587    input = ops.convert_to_tensor(input, name="input")
1588    # A scalar of 'zero' is enough as `not_equal` will broadcast.
1589    zero = array_ops.zeros([], dtype=input.dtype)
1590    return cast(
1591        reduce_sum(
1592            # int64 reduction happens on GPU
1593            cast(gen_math_ops.not_equal(input, zero), dtypes.int64),
1594            axis=axis,
1595            keepdims=keepdims),
1596        dtype=dtype)
1597
1598
1599@tf_export(v1=["math.reduce_mean", "reduce_mean"])
1600def reduce_mean_v1(input_tensor,
1601                   axis=None,
1602                   keepdims=None,
1603                   name=None,
1604                   reduction_indices=None,
1605                   keep_dims=None):
1606  """Computes the mean of elements across dimensions of a tensor.
1607
1608  Reduces `input_tensor` along the dimensions given in `axis`.
1609  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1610  entry in `axis`. If `keepdims` is true, the reduced dimensions
1611  are retained with length 1.
1612
1613  If `axis` is None, all dimensions are reduced, and a
1614  tensor with a single element is returned.
1615
1616  For example:
1617
1618  ```python
1619  x = tf.constant([[1., 1.], [2., 2.]])
1620  tf.reduce_mean(x)  # 1.5
1621  tf.reduce_mean(x, 0)  # [1.5, 1.5]
1622  tf.reduce_mean(x, 1)  # [1.,  2.]
1623  ```
1624
1625  Args:
1626    input_tensor: The tensor to reduce. Should have numeric type.
1627    axis: The dimensions to reduce. If `None` (the default),
1628      reduces all dimensions. Must be in the range
1629      `[-rank(input_tensor), rank(input_tensor))`.
1630    keepdims: If true, retains reduced dimensions with length 1.
1631    name: A name for the operation (optional).
1632    reduction_indices: The old (deprecated) name for axis.
1633    keep_dims: Deprecated alias for `keepdims`.
1634
1635  Returns:
1636    The reduced tensor.
1637
1638  @compatibility(numpy)
1639  Equivalent to np.mean
1640
1641  Please note that `np.mean` has a `dtype` parameter that could be used to
1642  specify the output type. By default this is `dtype=float64`. On the other
1643  hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
1644  for example:
1645
1646  ```python
1647  x = tf.constant([1, 0, 1, 0])
1648  tf.reduce_mean(x)  # 0
1649  y = tf.constant([1., 0., 1., 0.])
1650  tf.reduce_mean(y)  # 0.5
1651  ```
1652
1653  @end_compatibility
1654  """
1655  axis = deprecation.deprecated_argument_lookup(
1656      "axis", axis, "reduction_indices", reduction_indices)
1657  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1658                                                    "keep_dims", keep_dims)
1659  return reduce_mean(input_tensor, axis, keepdims, name)
1660
1661
1662@tf_export("math.reduce_mean", "reduce_mean", v1=[])
1663@dispatch.add_dispatch_support
1664def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
1665  """Computes the mean of elements across dimensions of a tensor.
1666
1667  Reduces `input_tensor` along the dimensions given in `axis`.
1668  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1669  entry in `axis`. If `keepdims` is true, the reduced dimensions
1670  are retained with length 1.
1671
1672  If `axis` is None, all dimensions are reduced, and a
1673  tensor with a single element is returned.
1674
1675  For example:
1676
1677  ```python
1678  x = tf.constant([[1., 1.], [2., 2.]])
1679  tf.reduce_mean(x)  # 1.5
1680  tf.reduce_mean(x, 0)  # [1.5, 1.5]
1681  tf.reduce_mean(x, 1)  # [1.,  2.]
1682  ```
1683
1684  Args:
1685    input_tensor: The tensor to reduce. Should have numeric type.
1686    axis: The dimensions to reduce. If `None` (the default), reduces all
1687      dimensions. Must be in the range `[-rank(input_tensor),
1688      rank(input_tensor))`.
1689    keepdims: If true, retains reduced dimensions with length 1.
1690    name: A name for the operation (optional).
1691
1692  Returns:
1693    The reduced tensor.
1694
1695  @compatibility(numpy)
1696  Equivalent to np.mean
1697
1698  Please note that `np.mean` has a `dtype` parameter that could be used to
1699  specify the output type. By default this is `dtype=float64`. On the other
1700  hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
1701  for example:
1702
1703  ```python
1704  x = tf.constant([1, 0, 1, 0])
1705  tf.reduce_mean(x)  # 0
1706  y = tf.constant([1., 0., 1., 0.])
1707  tf.reduce_mean(y)  # 0.5
1708  ```
1709
1710  @end_compatibility
1711  """
1712  keepdims = False if keepdims is None else keepdims
1713  return _may_reduce_to_scalar(
1714      keepdims, axis,
1715      gen_math_ops.mean(
1716          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
1717          name=name))
1718
1719
1720@tf_export("math.reduce_variance")
1721def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
1722  """Computes the variance of elements across dimensions of a tensor.
1723
1724  Reduces `input_tensor` along the dimensions given in `axis`.
1725  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1726  entry in `axis`. If `keepdims` is true, the reduced dimensions
1727  are retained with length 1.
1728
1729  If `axis` is None, all dimensions are reduced, and a
1730  tensor with a single element is returned.
1731
1732  For example:
1733
1734  ```python
1735  x = tf.constant([[1., 2.], [3., 4.]])
1736  tf.reduce_variance(x)  # 1.25
1737  tf.reduce_variance(x, 0)  # [1., 1.]
1738  tf.reduce_variance(x, 1)  # [0.25,  0.25]
1739  ```
1740
1741  Args:
1742    input_tensor: The tensor to reduce. Should have numeric type.
1743    axis: The dimensions to reduce. If `None` (the default), reduces all
1744      dimensions. Must be in the range `[-rank(input_tensor),
1745      rank(input_tensor))`.
1746    keepdims: If true, retains reduced dimensions with length 1.
1747    name: A name scope for the associated operations (optional).
1748
1749  Returns:
1750    The reduced tensor, of the same dtype as the input_tensor.
1751
1752  @compatibility(numpy)
1753  Equivalent to np.var
1754
1755  Please note that `np.var` has a `dtype` parameter that could be used to
1756  specify the output type. By default this is `dtype=float64`. On the other
1757  hand, `tf.reduce_variance` has an aggressive type inference from
1758  `input_tensor`,
1759  @end_compatibility
1760  """
1761  name = name if name else "reduce_variance"
1762  with ops.name_scope(name):
1763    means = reduce_mean(input_tensor, axis=axis, keepdims=True)
1764    squared_deviations = gen_math_ops.square(input_tensor - means)
1765    return reduce_mean(squared_deviations, axis=axis, keepdims=keepdims)
1766
1767
1768@tf_export("math.reduce_std")
1769def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
1770  """Computes the standard deviation of elements across dimensions of a tensor.
1771
1772  Reduces `input_tensor` along the dimensions given in `axis`.
1773  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1774  entry in `axis`. If `keepdims` is true, the reduced dimensions
1775  are retained with length 1.
1776
1777  If `axis` is None, all dimensions are reduced, and a
1778  tensor with a single element is returned.
1779
1780  For example:
1781
1782  ```python
1783  x = tf.constant([[1., 2.], [3., 4.]])
1784  tf.reduce_std(x)  # 1.1180339887498949
1785  tf.reduce_std(x, 0)  # [1., 1.]
1786  tf.reduce_std(x, 1)  # [0.5,  0.5]
1787  ```
1788
1789  Args:
1790    input_tensor: The tensor to reduce. Should have numeric type.
1791    axis: The dimensions to reduce. If `None` (the default), reduces all
1792      dimensions. Must be in the range `[-rank(input_tensor),
1793      rank(input_tensor))`.
1794    keepdims: If true, retains reduced dimensions with length 1.
1795    name: A name scope for the associated operations (optional).
1796
1797  Returns:
1798    The reduced tensor, of the same dtype as the input_tensor.
1799
1800  @compatibility(numpy)
1801  Equivalent to np.std
1802
1803  Please note that `np.std` has a `dtype` parameter that could be used to
1804  specify the output type. By default this is `dtype=float64`. On the other
1805  hand, `tf.reduce_std` has an aggressive type inference from `input_tensor`,
1806  @end_compatibility
1807  """
1808  name = name if name else "reduce_std"
1809  with ops.name_scope(name):
1810    variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims)
1811    return gen_math_ops.sqrt(variance)
1812
1813
1814@tf_export("math.reduce_prod", "reduce_prod", v1=[])
1815@dispatch.add_dispatch_support
1816def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
1817  """Computes the product of elements across dimensions of a tensor.
1818
1819  Reduces `input_tensor` along the dimensions given in `axis`.
1820  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1821  entry in `axis`. If `keepdims` is true, the reduced dimensions
1822  are retained with length 1.
1823
1824  If `axis` is None, all dimensions are reduced, and a
1825  tensor with a single element is returned.
1826
1827  Args:
1828    input_tensor: The tensor to reduce. Should have numeric type.
1829    axis: The dimensions to reduce. If `None` (the default),
1830      reduces all dimensions. Must be in the range
1831      `[-rank(input_tensor), rank(input_tensor))`.
1832    keepdims: If true, retains reduced dimensions with length 1.
1833    name: A name for the operation (optional).
1834
1835  Returns:
1836    The reduced tensor.
1837
1838  @compatibility(numpy)
1839  Equivalent to np.prod
1840  @end_compatibility
1841  """
1842  keepdims = False if keepdims is None else keepdims
1843  return _may_reduce_to_scalar(
1844      keepdims, axis,
1845      gen_math_ops.prod(
1846          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
1847          name=name))
1848
1849
1850@tf_export(v1=["math.reduce_prod", "reduce_prod"])
1851@deprecation.deprecated_args(
1852    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1853def reduce_prod_v1(input_tensor,
1854                   axis=None,
1855                   keepdims=None,
1856                   name=None,
1857                   reduction_indices=None,
1858                   keep_dims=None):
1859  """Computes the product of elements across dimensions of a tensor.
1860
1861  Reduces `input_tensor` along the dimensions given in `axis`.
1862  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1863  entry in `axis`. If `keepdims` is true, the reduced dimensions
1864  are retained with length 1.
1865
1866  If `axis` is None, all dimensions are reduced, and a
1867  tensor with a single element is returned.
1868
1869  Args:
1870    input_tensor: The tensor to reduce. Should have numeric type.
1871    axis: The dimensions to reduce. If `None` (the default), reduces all
1872      dimensions. Must be in the range `[-rank(input_tensor),
1873      rank(input_tensor))`.
1874    keepdims: If true, retains reduced dimensions with length 1.
1875    name: A name for the operation (optional).
1876    reduction_indices: The old (deprecated) name for axis.
1877    keep_dims: Deprecated alias for `keepdims`.
1878
1879  Returns:
1880    The reduced tensor.
1881
1882  @compatibility(numpy)
1883  Equivalent to np.prod
1884  @end_compatibility
1885  """
1886  axis = deprecation.deprecated_argument_lookup(
1887      "axis", axis, "reduction_indices", reduction_indices)
1888  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1889                                                    "keep_dims", keep_dims)
1890  return reduce_prod(input_tensor, axis, keepdims, name)
1891
1892
1893@tf_export(v1=["math.reduce_min", "reduce_min"])
1894@deprecation.deprecated_args(
1895    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1896def reduce_min_v1(input_tensor,
1897                  axis=None,
1898                  keepdims=None,
1899                  name=None,
1900                  reduction_indices=None,
1901                  keep_dims=None):
1902  """Computes the minimum of elements across dimensions of a tensor.
1903
1904  Reduces `input_tensor` along the dimensions given in `axis`.
1905  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1906  entry in `axis`. If `keepdims` is true, the reduced dimensions
1907  are retained with length 1.
1908
1909  If `axis` is None, all dimensions are reduced, and a
1910  tensor with a single element is returned.
1911
1912  Args:
1913    input_tensor: The tensor to reduce. Should have real numeric type.
1914    axis: The dimensions to reduce. If `None` (the default), reduces all
1915      dimensions. Must be in the range `[-rank(input_tensor),
1916      rank(input_tensor))`.
1917    keepdims: If true, retains reduced dimensions with length 1.
1918    name: A name for the operation (optional).
1919    reduction_indices: The old (deprecated) name for axis.
1920    keep_dims: Deprecated alias for `keepdims`.
1921
1922  Returns:
1923    The reduced tensor.
1924
1925  @compatibility(numpy)
1926  Equivalent to np.min
1927  @end_compatibility
1928  """
1929  axis = deprecation.deprecated_argument_lookup(
1930      "axis", axis, "reduction_indices", reduction_indices)
1931  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1932                                                    "keep_dims", keep_dims)
1933  return reduce_min(input_tensor, axis, keepdims, name)
1934
1935
1936@tf_export("math.reduce_min", "reduce_min", v1=[])
1937@dispatch.add_dispatch_support
1938def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
1939  """Computes the minimum of elements across dimensions of a tensor.
1940
1941  Reduces `input_tensor` along the dimensions given in `axis`.
1942  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1943  entry in `axis`. If `keepdims` is true, the reduced dimensions
1944  are retained with length 1.
1945
1946  If `axis` is None, all dimensions are reduced, and a
1947  tensor with a single element is returned.
1948
1949  Args:
1950    input_tensor: The tensor to reduce. Should have real numeric type.
1951    axis: The dimensions to reduce. If `None` (the default), reduces all
1952      dimensions. Must be in the range `[-rank(input_tensor),
1953      rank(input_tensor))`.
1954    keepdims: If true, retains reduced dimensions with length 1.
1955    name: A name for the operation (optional).
1956
1957  Returns:
1958    The reduced tensor.
1959
1960  @compatibility(numpy)
1961  Equivalent to np.min
1962  @end_compatibility
1963  """
1964  keepdims = False if keepdims is None else keepdims
1965  return _may_reduce_to_scalar(
1966      keepdims, axis,
1967      gen_math_ops._min(
1968          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
1969          name=name))
1970
1971
1972@tf_export(v1=["math.reduce_max", "reduce_max"])
1973@deprecation.deprecated_args(
1974    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1975def reduce_max_v1(input_tensor,
1976                  axis=None,
1977                  keepdims=None,
1978                  name=None,
1979                  reduction_indices=None,
1980                  keep_dims=None):
1981  """Computes the maximum of elements across dimensions of a tensor.
1982
1983  Reduces `input_tensor` along the dimensions given in `axis`.
1984  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1985  entry in `axis`. If `keepdims` is true, the reduced dimensions
1986  are retained with length 1.
1987
1988  If `axis` is None, all dimensions are reduced, and a
1989  tensor with a single element is returned.
1990
1991  Args:
1992    input_tensor: The tensor to reduce. Should have real numeric type.
1993    axis: The dimensions to reduce. If `None` (the default),
1994      reduces all dimensions. Must be in the range
1995      `[-rank(input_tensor), rank(input_tensor))`.
1996    keepdims: If true, retains reduced dimensions with length 1.
1997    name: A name for the operation (optional).
1998    reduction_indices: The old (deprecated) name for axis.
1999    keep_dims: Deprecated alias for `keepdims`.
2000
2001  Returns:
2002    The reduced tensor.
2003
2004  @compatibility(numpy)
2005  Equivalent to np.max
2006  @end_compatibility
2007  """
2008  axis = deprecation.deprecated_argument_lookup(
2009      "axis", axis, "reduction_indices", reduction_indices)
2010  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2011                                                    "keep_dims", keep_dims)
2012  return reduce_max(input_tensor, axis, keepdims, name)
2013
2014
2015@tf_export("math.reduce_max", "reduce_max", v1=[])
2016@dispatch.add_dispatch_support
2017def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
2018  """Computes the maximum of elements across dimensions of a tensor.
2019
2020  Reduces `input_tensor` along the dimensions given in `axis`.
2021  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2022  entry in `axis`. If `keepdims` is true, the reduced dimensions
2023  are retained with length 1.
2024
2025  If `axis` is None, all dimensions are reduced, and a
2026  tensor with a single element is returned.
2027
2028  Args:
2029    input_tensor: The tensor to reduce. Should have real numeric type.
2030    axis: The dimensions to reduce. If `None` (the default), reduces all
2031      dimensions. Must be in the range `[-rank(input_tensor),
2032      rank(input_tensor))`.
2033    keepdims: If true, retains reduced dimensions with length 1.
2034    name: A name for the operation (optional).
2035
2036  Returns:
2037    The reduced tensor.
2038
2039  @compatibility(numpy)
2040  Equivalent to np.max
2041  @end_compatibility
2042  """
2043  keepdims = False if keepdims is None else keepdims
2044  return _may_reduce_to_scalar(
2045      keepdims, axis,
2046      gen_math_ops._max(
2047          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2048          name=name))
2049
2050
2051@tf_export(v1=["math.reduce_all", "reduce_all"])
2052@deprecation.deprecated_args(
2053    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
2054def reduce_all_v1(input_tensor,
2055                  axis=None,
2056                  keepdims=None,
2057                  name=None,
2058                  reduction_indices=None,
2059                  keep_dims=None):
2060  """Computes the "logical and" of elements across dimensions of a tensor.
2061
2062  Reduces `input_tensor` along the dimensions given in `axis`.
2063  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2064  entry in `axis`. If `keepdims` is true, the reduced dimensions
2065  are retained with length 1.
2066
2067  If `axis` is None, all dimensions are reduced, and a
2068  tensor with a single element is returned.
2069
2070  For example:
2071
2072  ```python
2073  x = tf.constant([[True,  True], [False, False]])
2074  tf.reduce_all(x)  # False
2075  tf.reduce_all(x, 0)  # [False, False]
2076  tf.reduce_all(x, 1)  # [True, False]
2077  ```
2078
2079  Args:
2080    input_tensor: The boolean tensor to reduce.
2081    axis: The dimensions to reduce. If `None` (the default), reduces all
2082      dimensions. Must be in the range `[-rank(input_tensor),
2083      rank(input_tensor))`.
2084    keepdims: If true, retains reduced dimensions with length 1.
2085    name: A name for the operation (optional).
2086    reduction_indices: The old (deprecated) name for axis.
2087    keep_dims: Deprecated alias for `keepdims`.
2088
2089  Returns:
2090    The reduced tensor.
2091
2092  @compatibility(numpy)
2093  Equivalent to np.all
2094  @end_compatibility
2095  """
2096  axis = deprecation.deprecated_argument_lookup(
2097      "axis", axis, "reduction_indices", reduction_indices)
2098  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2099                                                    "keep_dims", keep_dims)
2100  return reduce_all(input_tensor, axis, keepdims, name)
2101
2102
2103@tf_export("reduce_all", "math.reduce_all", v1=[])
2104@dispatch.add_dispatch_support
2105def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
2106  """Computes the "logical and" of elements across dimensions of a tensor.
2107
2108  Reduces `input_tensor` along the dimensions given in `axis`.
2109  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2110  entry in `axis`. If `keepdims` is true, the reduced dimensions
2111  are retained with length 1.
2112
2113  If `axis` is None, all dimensions are reduced, and a
2114  tensor with a single element is returned.
2115
2116  For example:
2117
2118  ```python
2119  x = tf.constant([[True,  True], [False, False]])
2120  tf.reduce_all(x)  # False
2121  tf.reduce_all(x, 0)  # [False, False]
2122  tf.reduce_all(x, 1)  # [True, False]
2123  ```
2124
2125  Args:
2126    input_tensor: The boolean tensor to reduce.
2127    axis: The dimensions to reduce. If `None` (the default), reduces all
2128      dimensions. Must be in the range `[-rank(input_tensor),
2129      rank(input_tensor))`.
2130    keepdims: If true, retains reduced dimensions with length 1.
2131    name: A name for the operation (optional).
2132
2133  Returns:
2134    The reduced tensor.
2135
2136  @compatibility(numpy)
2137  Equivalent to np.all
2138  @end_compatibility
2139  """
2140  keepdims = False if keepdims is None else keepdims
2141  return _may_reduce_to_scalar(
2142      keepdims, axis,
2143      gen_math_ops._all(
2144          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2145          name=name))
2146
2147
2148@tf_export(v1=["math.reduce_any", "reduce_any"])
2149@deprecation.deprecated_args(
2150    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
2151def reduce_any_v1(input_tensor,
2152                  axis=None,
2153                  keepdims=None,
2154                  name=None,
2155                  reduction_indices=None,
2156                  keep_dims=None):
2157  """Computes the "logical or" of elements across dimensions of a tensor.
2158
2159  Reduces `input_tensor` along the dimensions given in `axis`.
2160  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2161  entry in `axis`. If `keepdims` is true, the reduced dimensions
2162  are retained with length 1.
2163
2164  If `axis` is None, all dimensions are reduced, and a
2165  tensor with a single element is returned.
2166
2167  For example:
2168
2169  ```python
2170  x = tf.constant([[True,  True], [False, False]])
2171  tf.reduce_any(x)  # True
2172  tf.reduce_any(x, 0)  # [True, True]
2173  tf.reduce_any(x, 1)  # [True, False]
2174  ```
2175
2176  Args:
2177    input_tensor: The boolean tensor to reduce.
2178    axis: The dimensions to reduce. If `None` (the default), reduces all
2179      dimensions. Must be in the range `[-rank(input_tensor),
2180      rank(input_tensor))`.
2181    keepdims: If true, retains reduced dimensions with length 1.
2182    name: A name for the operation (optional).
2183    reduction_indices: The old (deprecated) name for axis.
2184    keep_dims: Deprecated alias for `keepdims`.
2185
2186  Returns:
2187    The reduced tensor.
2188
2189  @compatibility(numpy)
2190  Equivalent to np.any
2191  @end_compatibility
2192  """
2193  axis = deprecation.deprecated_argument_lookup(
2194      "axis", axis, "reduction_indices", reduction_indices)
2195  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2196                                                    "keep_dims", keep_dims)
2197  return reduce_any(input_tensor, axis, keepdims, name)
2198
2199
2200@tf_export("math.reduce_any", "reduce_any", v1=[])
2201@dispatch.add_dispatch_support
2202def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
2203  """Computes the "logical or" of elements across dimensions of a tensor.
2204
2205  Reduces `input_tensor` along the dimensions given in `axis`.
2206  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2207  entry in `axis`. If `keepdims` is true, the reduced dimensions
2208  are retained with length 1.
2209
2210  If `axis` is None, all dimensions are reduced, and a
2211  tensor with a single element is returned.
2212
2213  For example:
2214
2215  ```python
2216  x = tf.constant([[True,  True], [False, False]])
2217  tf.reduce_any(x)  # True
2218  tf.reduce_any(x, 0)  # [True, True]
2219  tf.reduce_any(x, 1)  # [True, False]
2220  ```
2221
2222  Args:
2223    input_tensor: The boolean tensor to reduce.
2224    axis: The dimensions to reduce. If `None` (the default), reduces all
2225      dimensions. Must be in the range `[-rank(input_tensor),
2226      rank(input_tensor))`.
2227    keepdims: If true, retains reduced dimensions with length 1.
2228    name: A name for the operation (optional).
2229
2230  Returns:
2231    The reduced tensor.
2232
2233  @compatibility(numpy)
2234  Equivalent to np.any
2235  @end_compatibility
2236  """
2237  keepdims = False if keepdims is None else keepdims
2238  return _may_reduce_to_scalar(
2239      keepdims, axis,
2240      gen_math_ops._any(
2241          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2242          name=name))
2243
2244
2245@tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"])
2246@deprecation.deprecated_args(
2247    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
2248def reduce_logsumexp_v1(input_tensor,
2249                        axis=None,
2250                        keepdims=None,
2251                        name=None,
2252                        reduction_indices=None,
2253                        keep_dims=None):
2254  """Computes log(sum(exp(elements across dimensions of a tensor))).
2255
2256  Reduces `input_tensor` along the dimensions given in `axis`.
2257  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2258  entry in `axis`. If `keepdims` is true, the reduced dimensions
2259  are retained with length 1.
2260
2261  If `axis` has no entries, all dimensions are reduced, and a
2262  tensor with a single element is returned.
2263
2264  This function is more numerically stable than log(sum(exp(input))). It avoids
2265  overflows caused by taking the exp of large inputs and underflows caused by
2266  taking the log of small inputs.
2267
2268  For example:
2269
2270  ```python
2271  x = tf.constant([[0., 0., 0.], [0., 0., 0.]])
2272  tf.reduce_logsumexp(x)  # log(6)
2273  tf.reduce_logsumexp(x, 0)  # [log(2), log(2), log(2)]
2274  tf.reduce_logsumexp(x, 1)  # [log(3), log(3)]
2275  tf.reduce_logsumexp(x, 1, keepdims=True)  # [[log(3)], [log(3)]]
2276  tf.reduce_logsumexp(x, [0, 1])  # log(6)
2277  ```
2278
2279  Args:
2280    input_tensor: The tensor to reduce. Should have numeric type.
2281    axis: The dimensions to reduce. If `None` (the default), reduces all
2282      dimensions. Must be in the range `[-rank(input_tensor),
2283      rank(input_tensor))`.
2284    keepdims: If true, retains reduced dimensions with length 1.
2285    name: A name for the operation (optional).
2286    reduction_indices: The old (deprecated) name for axis.
2287    keep_dims: Deprecated alias for `keepdims`.
2288
2289  Returns:
2290    The reduced tensor.
2291  """
2292  axis = deprecation.deprecated_argument_lookup(
2293      "axis", axis, "reduction_indices", reduction_indices)
2294  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2295                                                    "keep_dims", keep_dims)
2296  return reduce_logsumexp(input_tensor, axis, keepdims, name)
2297
2298
2299@tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[])
2300def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
2301  """Computes log(sum(exp(elements across dimensions of a tensor))).
2302
2303  Reduces `input_tensor` along the dimensions given in `axis`.
2304  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2305  entry in `axis`. If `keepdims` is true, the reduced dimensions
2306  are retained with length 1.
2307
2308  If `axis` has no entries, all dimensions are reduced, and a
2309  tensor with a single element is returned.
2310
2311  This function is more numerically stable than log(sum(exp(input))). It avoids
2312  overflows caused by taking the exp of large inputs and underflows caused by
2313  taking the log of small inputs.
2314
2315  For example:
2316
2317  ```python
2318  x = tf.constant([[0., 0., 0.], [0., 0., 0.]])
2319  tf.reduce_logsumexp(x)  # log(6)
2320  tf.reduce_logsumexp(x, 0)  # [log(2), log(2), log(2)]
2321  tf.reduce_logsumexp(x, 1)  # [log(3), log(3)]
2322  tf.reduce_logsumexp(x, 1, keepdims=True)  # [[log(3)], [log(3)]]
2323  tf.reduce_logsumexp(x, [0, 1])  # log(6)
2324  ```
2325
2326  Args:
2327    input_tensor: The tensor to reduce. Should have numeric type.
2328    axis: The dimensions to reduce. If `None` (the default), reduces all
2329      dimensions. Must be in the range `[-rank(input_tensor),
2330      rank(input_tensor))`.
2331    keepdims: If true, retains reduced dimensions with length 1.
2332    name: A name for the operation (optional).
2333
2334  Returns:
2335    The reduced tensor.
2336  """
2337  keepdims = False if keepdims is None else keepdims
2338  input_tensor = ops.convert_to_tensor(input_tensor)
2339  with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
2340    raw_max = reduce_max(
2341        input_tensor,
2342        axis=axis,
2343        keepdims=True)
2344    my_max = array_ops.stop_gradient(
2345        array_ops.where(
2346            gen_math_ops.is_finite(raw_max), raw_max,
2347            array_ops.zeros_like(raw_max)))
2348    result = gen_math_ops.log(
2349        reduce_sum(
2350            gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
2351            axis,
2352            keepdims=keepdims))
2353    if not keepdims:
2354      my_max = array_ops.reshape(my_max, array_ops.shape(result))
2355    result = gen_math_ops.add(result, my_max)
2356    return _may_reduce_to_scalar(keepdims, axis, result)
2357
2358
2359@tf_export("linalg.trace", v1=["linalg.trace", "trace"])
2360@deprecation.deprecated_endpoints("trace")
2361def trace(x, name=None):
2362  """Compute the trace of a tensor `x`.
2363
2364  `trace(x)` returns the sum along the main diagonal of each inner-most matrix
2365  in x. If x is of rank `k` with shape `[I, J, K, ..., L, M, N]`, then output
2366  is a tensor of rank `k-2` with dimensions `[I, J, K, ..., L]` where
2367
2368  `output[i, j, k, ..., l] = trace(x[i, j, i, ..., l, :, :])`
2369
2370  For example:
2371
2372  ```python
2373  x = tf.constant([[1, 2], [3, 4]])
2374  tf.linalg.trace(x)  # 5
2375
2376  x = tf.constant([[1, 2, 3],
2377                   [4, 5, 6],
2378                   [7, 8, 9]])
2379  tf.linalg.trace(x)  # 15
2380
2381  x = tf.constant([[[1, 2, 3],
2382                    [4, 5, 6],
2383                    [7, 8, 9]],
2384                   [[-1, -2, -3],
2385                    [-4, -5, -6],
2386                    [-7, -8, -9]]])
2387  tf.linalg.trace(x)  # [15, -15]
2388  ```
2389
2390  Args:
2391    x: tensor.
2392    name: A name for the operation (optional).
2393
2394  Returns:
2395    The trace of input tensor.
2396  """
2397  with ops.name_scope(name, "Trace", [x]) as name:
2398    x = ops.convert_to_tensor(x, name="x")
2399    return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
2400
2401
2402@tf_export("linalg.matmul", "matmul")
2403def matmul(a,
2404           b,
2405           transpose_a=False,
2406           transpose_b=False,
2407           adjoint_a=False,
2408           adjoint_b=False,
2409           a_is_sparse=False,
2410           b_is_sparse=False,
2411           name=None):
2412  """Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
2413
2414  The inputs must, following any transpositions, be tensors of rank >= 2
2415  where the inner 2 dimensions specify valid matrix multiplication arguments,
2416  and any further outer dimensions match.
2417
2418  Both matrices must be of the same type. The supported types are:
2419  `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
2420
2421  Either matrix can be transposed or adjointed (conjugated and transposed) on
2422  the fly by setting one of the corresponding flag to `True`. These are `False`
2423  by default.
2424
2425  If one or both of the matrices contain a lot of zeros, a more efficient
2426  multiplication algorithm can be used by setting the corresponding
2427  `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
2428  This optimization is only available for plain matrices (rank-2 tensors) with
2429  datatypes `bfloat16` or `float32`.
2430
2431  For example:
2432
2433  ```python
2434  # 2-D tensor `a`
2435  # [[1, 2, 3],
2436  #  [4, 5, 6]]
2437  a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
2438
2439  # 2-D tensor `b`
2440  # [[ 7,  8],
2441  #  [ 9, 10],
2442  #  [11, 12]]
2443  b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])
2444
2445  # `a` * `b`
2446  # [[ 58,  64],
2447  #  [139, 154]]
2448  c = tf.matmul(a, b)
2449
2450
2451  # 3-D tensor `a`
2452  # [[[ 1,  2,  3],
2453  #   [ 4,  5,  6]],
2454  #  [[ 7,  8,  9],
2455  #   [10, 11, 12]]]
2456  a = tf.constant(np.arange(1, 13, dtype=np.int32),
2457                  shape=[2, 2, 3])
2458
2459  # 3-D tensor `b`
2460  # [[[13, 14],
2461  #   [15, 16],
2462  #   [17, 18]],
2463  #  [[19, 20],
2464  #   [21, 22],
2465  #   [23, 24]]]
2466  b = tf.constant(np.arange(13, 25, dtype=np.int32),
2467                  shape=[2, 3, 2])
2468
2469  # `a` * `b`
2470  # [[[ 94, 100],
2471  #   [229, 244]],
2472  #  [[508, 532],
2473  #   [697, 730]]]
2474  c = tf.matmul(a, b)
2475
2476  # Since python >= 3.5 the @ operator is supported (see PEP 465).
2477  # In TensorFlow, it simply calls the `tf.matmul()` function, so the
2478  # following lines are equivalent:
2479  d = a @ b @ [[10.], [11.]]
2480  d = tf.matmul(tf.matmul(a, b), [[10.], [11.]])
2481  ```
2482
2483  Args:
2484    a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
2485      `complex128` and rank > 1.
2486    b: `Tensor` with same type and rank as `a`.
2487    transpose_a: If `True`, `a` is transposed before multiplication.
2488    transpose_b: If `True`, `b` is transposed before multiplication.
2489    adjoint_a: If `True`, `a` is conjugated and transposed before
2490      multiplication.
2491    adjoint_b: If `True`, `b` is conjugated and transposed before
2492      multiplication.
2493    a_is_sparse: If `True`, `a` is treated as a sparse matrix.
2494    b_is_sparse: If `True`, `b` is treated as a sparse matrix.
2495    name: Name for the operation (optional).
2496
2497  Returns:
2498    A `Tensor` of the same type as `a` and `b` where each inner-most matrix is
2499    the product of the corresponding matrices in `a` and `b`, e.g. if all
2500    transpose or adjoint attributes are `False`:
2501
2502    `output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
2503    for all indices i, j.
2504
2505    Note: This is matrix product, not element-wise product.
2506
2507
2508  Raises:
2509    ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
2510      are both set to True.
2511  """
2512  with ops.name_scope(name, "MatMul", [a, b]) as name:
2513    if transpose_a and adjoint_a:
2514      raise ValueError("Only one of transpose_a and adjoint_a can be True.")
2515    if transpose_b and adjoint_b:
2516      raise ValueError("Only one of transpose_b and adjoint_b can be True.")
2517
2518    if context.executing_eagerly():
2519      if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
2520        a = ops.convert_to_tensor(a, name="a")
2521      if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
2522        b = ops.convert_to_tensor(b, name="b")
2523    else:
2524      a = ops.convert_to_tensor(a, name="a")
2525      b = ops.convert_to_tensor(b, name="b")
2526
2527    # TODO(apassos) remove _shape_tuple here when it is not needed.
2528    a_shape = a._shape_tuple()  # pylint: disable=protected-access
2529    b_shape = b._shape_tuple()  # pylint: disable=protected-access
2530    if (not a_is_sparse and
2531        not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and
2532                              (b_shape is None or len(b_shape) > 2)):
2533      # BatchMatmul does not support transpose, so we conjugate the matrix and
2534      # use adjoint instead. Conj() is a noop for real matrices.
2535      if transpose_a:
2536        a = conj(a)
2537        adjoint_a = True
2538      if transpose_b:
2539        b = conj(b)
2540        adjoint_b = True
2541      return gen_math_ops.batch_mat_mul(
2542          a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
2543
2544    # Neither matmul nor sparse_matmul support adjoint, so we conjugate
2545    # the matrix and use transpose instead. Conj() is a noop for real
2546    # matrices.
2547    if adjoint_a:
2548      a = conj(a)
2549      transpose_a = True
2550    if adjoint_b:
2551      b = conj(b)
2552      transpose_b = True
2553
2554    use_sparse_matmul = False
2555    if a_is_sparse or b_is_sparse:
2556      sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
2557      use_sparse_matmul = (
2558          a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
2559    if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and
2560        a.dtype != b.dtype):
2561      # matmul currently doesn't handle mixed-precision inputs.
2562      use_sparse_matmul = True
2563    if use_sparse_matmul:
2564      ret = sparse_matmul(
2565          a,
2566          b,
2567          transpose_a=transpose_a,
2568          transpose_b=transpose_b,
2569          a_is_sparse=a_is_sparse,
2570          b_is_sparse=b_is_sparse,
2571          name=name)
2572      # sparse_matmul always returns float32, even with
2573      # bfloat16 inputs. This prevents us from configuring bfloat16 training.
2574      # casting to bfloat16 also matches non-sparse matmul behavior better.
2575      if a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16:
2576        ret = cast(ret, dtypes.bfloat16)
2577      return ret
2578    else:
2579      return gen_math_ops.mat_mul(
2580          a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
2581
2582
2583@tf_export("linalg.matvec")
2584def matvec(a,
2585           b,
2586           transpose_a=False,
2587           adjoint_a=False,
2588           a_is_sparse=False,
2589           b_is_sparse=False,
2590           name=None):
2591  """Multiplies matrix `a` by vector `b`, producing `a` * `b`.
2592
2593  The matrix `a` must, following any transpositions, be a tensor of rank >= 2,
2594  and we must have `shape(b) = shape(a)[:-2] + [shape(a)[-1]]`.
2595
2596  Both `a` and `b` must be of the same type. The supported types are:
2597  `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
2598
2599  Matrix `a` can be transposed or adjointed (conjugated and transposed) on
2600  the fly by setting one of the corresponding flag to `True`. These are `False`
2601  by default.
2602
2603  If one or both of the inputs contain a lot of zeros, a more efficient
2604  multiplication algorithm can be used by setting the corresponding
2605  `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
2606  This optimization is only available for plain matrices/vectors (rank-2/1
2607  tensors) with datatypes `bfloat16` or `float32`.
2608
2609  For example:
2610
2611  ```python
2612  # 2-D tensor `a`
2613  # [[1, 2, 3],
2614  #  [4, 5, 6]]
2615  a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
2616
2617  # 1-D tensor `b`
2618  # [7, 9, 11]
2619  b = tf.constant([7, 9, 11], shape=[3])
2620
2621  # `a` * `b`
2622  # [ 58,  64]
2623  c = tf.matvec(a, b)
2624
2625
2626  # 3-D tensor `a`
2627  # [[[ 1,  2,  3],
2628  #   [ 4,  5,  6]],
2629  #  [[ 7,  8,  9],
2630  #   [10, 11, 12]]]
2631  a = tf.constant(np.arange(1, 13, dtype=np.int32),
2632                  shape=[2, 2, 3])
2633
2634  # 2-D tensor `b`
2635  # [[13, 14, 15],
2636  #  [16, 17, 18]]
2637  b = tf.constant(np.arange(13, 19, dtype=np.int32),
2638                  shape=[2, 3])
2639
2640  # `a` * `b`
2641  # [[ 86, 212],
2642  #  [410, 563]]
2643  c = tf.matvec(a, b)
2644  ```
2645
2646  Args:
2647    a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
2648      `complex128` and rank > 1.
2649    b: `Tensor` with same type and rank = `rank(a) - 1`.
2650    transpose_a: If `True`, `a` is transposed before multiplication.
2651    adjoint_a: If `True`, `a` is conjugated and transposed before
2652      multiplication.
2653    a_is_sparse: If `True`, `a` is treated as a sparse matrix.
2654    b_is_sparse: If `True`, `b` is treated as a sparse matrix.
2655    name: Name for the operation (optional).
2656
2657  Returns:
2658    A `Tensor` of the same type as `a` and `b` where each inner-most vector is
2659    the product of the corresponding matrices in `a` and vectors in `b`, e.g. if
2660    all transpose or adjoint attributes are `False`:
2661
2662    `output`[..., i] = sum_k (`a`[..., i, k] * `b`[..., k]), for all indices i.
2663
2664    Note: This is matrix-vector product, not element-wise product.
2665
2666
2667  Raises:
2668    ValueError: If transpose_a and adjoint_a are both set to True.
2669  """
2670  with ops.name_scope(name, "MatVec", [a, b]) as name:
2671    output = matmul(
2672        a,
2673        array_ops.expand_dims(b, axis=-1),
2674        transpose_a=transpose_a,
2675        adjoint_a=adjoint_a,
2676        a_is_sparse=a_is_sparse,
2677        b_is_sparse=b_is_sparse)
2678    return array_ops.squeeze(output, axis=-1)
2679
2680
2681_OverrideBinaryOperatorHelper(matmul, "matmul")
2682
2683sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")(
2684    gen_math_ops.sparse_mat_mul)
2685tf_export(v1=["sparse_matmul"])(sparse_matmul)
2686
2687
2688@ops.RegisterStatistics("MatMul", "flops")
2689def _calc_mat_mul_flops(graph, node):
2690  """Calculates the compute resources needed for MatMul."""
2691  transpose_a = node.attr["transpose_a"].b
2692  a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
2693  a_shape.assert_is_fully_defined()
2694  if transpose_a:
2695    k = int(a_shape[0])
2696  else:
2697    k = int(a_shape[1])
2698  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
2699  output_shape.assert_is_fully_defined()
2700  output_count = np.prod(output_shape.as_list())
2701  return ops.OpStats("flops", (k * output_count * 2))
2702
2703
2704def _as_indexed_slices(x, optimize=True):
2705  """Convert 'x' to IndexedSlices.
2706
2707  Convert a dense Tensor to a block-sparse IndexedSlices.
2708
2709  Args:
2710    x: Either a Tensor object, or an IndexedSlices object.
2711    optimize: if true, attempt to optimize the conversion of 'x'.
2712
2713  Returns:
2714    An IndexedSlices object.
2715
2716  Raises:
2717    TypeError: If 'x' is not a Tensor or an IndexedSlices object.
2718  """
2719  # TODO(touts): op_scope
2720  if not isinstance(x, (ops.Tensor, ops.IndexedSlices)):
2721    raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
2722  if isinstance(x, ops.IndexedSlices):
2723    return x
2724  x_shape = array_ops.shape_internal(x, optimize=optimize)
2725  return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
2726
2727
2728def _as_indexed_slices_list(inputs, optimize=True):
2729  """Convert all elements of 'inputs' to IndexedSlices.
2730
2731  Additionally, homogenize the types of all the indices to
2732  either int32 or int64.
2733
2734  Args:
2735    inputs: List containing either Tensor or IndexedSlices objects.
2736    optimize: if true, attempt to optimize the conversion of each input.
2737
2738  Returns:
2739    A list of IndexedSlices objects.
2740
2741  Raises:
2742    TypeError: If 'inputs' is not a list or a tuple.
2743  """
2744  if not isinstance(inputs, (list, tuple)):
2745    raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
2746  outputs = [_as_indexed_slices(i, optimize=optimize) for i in inputs]
2747  with_int32_index = [
2748      o.indices for o in outputs if o.indices.dtype == dtypes.int32
2749  ]
2750  if not with_int32_index or len(with_int32_index) == len(outputs):
2751    return outputs
2752  casted_outputs = []
2753  for o in outputs:
2754    if o.indices.dtype == dtypes.int32:
2755      casted_outputs.append(
2756          ops.IndexedSlices(o.values, cast(o.indices, dtypes.int64),
2757                            o.dense_shape))
2758    else:
2759      casted_outputs.append(o)
2760  return casted_outputs
2761
2762
2763@tf_export("math.add_n", "add_n")
2764@dispatch.add_dispatch_support
2765def add_n(inputs, name=None):
2766  """Adds all input tensors element-wise.
2767
2768  Converts `IndexedSlices` objects into dense tensors prior to adding.
2769
2770  Args:
2771    inputs: A list of `Tensor` or `IndexedSlices` objects, each with same shape
2772      and type.
2773    name: A name for the operation (optional).
2774
2775  Returns:
2776    A `Tensor` of same shape and type as the elements of `inputs`.
2777
2778  Raises:
2779    ValueError: If `inputs` don't all have same shape and dtype or the shape
2780    cannot be inferred.
2781  """
2782  if not inputs or not isinstance(inputs, (list, tuple)):
2783    raise ValueError("inputs must be a list of at least one "
2784                     "Tensor/IndexedSlices with the same dtype and shape")
2785  inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
2786  if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
2787    raise ValueError("inputs must be a list of at least one "
2788                     "Tensor/IndexedSlices with the same dtype and shape")
2789
2790  if len(inputs) == 1:
2791    if isinstance(inputs[0], ops.IndexedSlices):
2792      values = ops.convert_to_tensor(inputs[0])
2793    else:
2794      values = inputs[0]
2795    if name:
2796      return array_ops.identity(values, name=name)
2797    return values
2798  return gen_math_ops.add_n(inputs, name=name)
2799
2800
2801@tf_export("math.accumulate_n", v1=["math.accumulate_n", "accumulate_n"])
2802@deprecation.deprecated_endpoints("accumulate_n")
2803def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
2804  """Returns the element-wise sum of a list of tensors.
2805
2806  Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
2807  otherwise, these are inferred.
2808
2809  `tf.math.accumulate_n` performs the same operation as `tf.add_n`, but does not
2810  wait for all of its inputs to be ready before beginning to sum. This can
2811  save memory if inputs are ready at different times, since minimum temporary
2812  storage is proportional to the output size rather than the inputs size.
2813
2814  `accumulate_n` is differentiable (but wasn't previous to TensorFlow 1.7).
2815
2816  For example:
2817
2818  ```python
2819  a = tf.constant([[1, 2], [3, 4]])
2820  b = tf.constant([[5, 0], [0, 6]])
2821  tf.math.accumulate_n([a, b, a])  # [[7, 4], [6, 14]]
2822
2823  # Explicitly pass shape and type
2824  tf.math.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
2825                                                                 # [[7,  4],
2826                                                                 #  [6, 14]]
2827  ```
2828
2829  Args:
2830    inputs: A list of `Tensor` objects, each with same shape and type.
2831    shape: Shape of elements of `inputs`.
2832    tensor_dtype: The type of `inputs`.
2833    name: A name for the operation (optional).
2834
2835  Returns:
2836    A `Tensor` of same shape and type as the elements of `inputs`.
2837
2838  Raises:
2839    ValueError: If `inputs` don't all have same shape and dtype or the shape
2840    cannot be inferred.
2841  """
2842
2843  def _input_error():
2844    return ValueError("inputs must be a list of at least one Tensor with the "
2845                      "same dtype and shape")
2846
2847  if not inputs or not isinstance(inputs, (list, tuple)):
2848    raise _input_error()
2849  inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
2850  if not all(isinstance(x, ops.Tensor) for x in inputs):
2851    raise _input_error()
2852  if not all(x.dtype == inputs[0].dtype for x in inputs):
2853    raise _input_error()
2854  if shape is not None:
2855    shape = tensor_shape.as_shape(shape)
2856  else:
2857    shape = tensor_shape.unknown_shape()
2858  for input_tensor in inputs:
2859    if isinstance(input_tensor, ops.Tensor):
2860      shape = shape.merge_with(input_tensor.get_shape())
2861
2862  # tensor_dtype is for safety only; operator's output type computed in C++
2863  if tensor_dtype is not None and tensor_dtype != inputs[0].dtype:
2864    raise TypeError("tensor_dtype is {}, but input is of type {}".format(
2865        tensor_dtype, inputs[0].dtype))
2866
2867  if len(inputs) == 1 and name is None:
2868    return inputs[0]
2869  elif len(inputs) == 1 and name is not None:
2870    return array_ops.identity(inputs[0], name=name)
2871  elif context.executing_eagerly():
2872    # TemporaryVariable not currently supported in eager mode; fall back
2873    # onto AddN for now.
2874    # TODO(frreiss) remove this once the lifetime of eager variables gets
2875    # addressed
2876    return add_n(inputs, name=name)
2877  else:
2878    return gen_math_ops.accumulate_nv2(inputs, name=name, shape=shape)  # pylint: disable=protected-access
2879
2880
2881@ops.RegisterGradient("AccumulateNV2")
2882def _accumulate_n_grad(op, grad):
2883  """Same as gradient for AddN. Copies the gradient to all inputs."""
2884  # Not broadcasting.
2885  return [grad] * len(op.inputs)
2886
2887
2888@tf_export("math.sigmoid", "nn.sigmoid", "sigmoid")
2889def sigmoid(x, name=None):
2890  """Computes sigmoid of `x` element-wise.
2891
2892  Specifically, `y = 1 / (1 + exp(-x))`.
2893
2894  Args:
2895    x: A Tensor with type `float16`, `float32`, `float64`, `complex64`,
2896      or `complex128`.
2897    name: A name for the operation (optional).
2898
2899  Returns:
2900    A Tensor with the same type as `x`.
2901
2902  @compatibility(scipy)
2903  Equivalent to scipy.special.expit
2904  @end_compatibility
2905  """
2906  with ops.name_scope(name, "Sigmoid", [x]) as name:
2907    x = ops.convert_to_tensor(x, name="x")
2908    return gen_math_ops.sigmoid(x, name=name)
2909
2910
2911@tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"])
2912@dispatch.add_dispatch_support
2913@deprecation.deprecated_endpoints("log_sigmoid")
2914def log_sigmoid(x, name=None):
2915  """Computes log sigmoid of `x` element-wise.
2916
2917  Specifically, `y = log(1 / (1 + exp(-x)))`.  For numerical stability,
2918  we use `y = -tf.nn.softplus(-x)`.
2919
2920  Args:
2921    x: A Tensor with type `float32` or `float64`.
2922    name: A name for the operation (optional).
2923
2924  Returns:
2925    A Tensor with the same type as `x`.
2926  """
2927  with ops.name_scope(name, "LogSigmoid", [x]) as name:
2928    x = ops.convert_to_tensor(x, name="x")
2929    return gen_math_ops.neg(gen_nn_ops.softplus(-x), name=name)
2930
2931
2932@tf_export("math.bincount", v1=[])
2933def bincount(arr,
2934             weights=None,
2935             minlength=None,
2936             maxlength=None,
2937             dtype=dtypes.int32,
2938             name=None):
2939  """Counts the number of occurrences of each value in an integer array.
2940
2941  If `minlength` and `maxlength` are not given, returns a vector with length
2942  `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
2943  If `weights` are non-None, then index `i` of the output stores the sum of the
2944  value in `weights` at each index where the corresponding value in `arr` is
2945  `i`.
2946
2947  Args:
2948    arr: An int32 tensor of non-negative values.
2949    weights: If non-None, must be the same shape as arr. For each value in
2950      `arr`, the bin will be incremented by the corresponding weight instead of
2951      1.
2952    minlength: If given, ensures the output has length at least `minlength`,
2953      padding with zeros at the end if necessary.
2954    maxlength: If given, skips values in `arr` that are equal or greater than
2955      `maxlength`, ensuring that the output has length at most `maxlength`.
2956    dtype: If `weights` is None, determines the type of the output bins.
2957    name: A name scope for the associated operations (optional).
2958
2959  Returns:
2960    A vector with the same dtype as `weights` or the given `dtype`. The bin
2961    values.
2962  """
2963  name = "bincount" if name is None else name
2964  with ops.name_scope(name):
2965    arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
2966    array_is_nonempty = reduce_prod(array_ops.shape(arr)) > 0
2967    output_size = cast(array_is_nonempty, dtypes.int32) * (reduce_max(arr) + 1)
2968    if minlength is not None:
2969      minlength = ops.convert_to_tensor(
2970          minlength, name="minlength", dtype=dtypes.int32)
2971      output_size = gen_math_ops.maximum(minlength, output_size)
2972    if maxlength is not None:
2973      maxlength = ops.convert_to_tensor(
2974          maxlength, name="maxlength", dtype=dtypes.int32)
2975      output_size = gen_math_ops.minimum(maxlength, output_size)
2976    if weights is not None:
2977      weights = ops.convert_to_tensor(weights, name="weights")
2978      return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
2979    weights = constant_op.constant([], dtype)
2980    return gen_math_ops.bincount(arr, output_size, weights)
2981
2982
2983@tf_export(v1=["math.bincount", "bincount"])
2984@deprecation.deprecated_endpoints("bincount")
2985def bincount_v1(arr,
2986                weights=None,
2987                minlength=None,
2988                maxlength=None,
2989                dtype=dtypes.int32):
2990  """Counts the number of occurrences of each value in an integer array.
2991
2992  If `minlength` and `maxlength` are not given, returns a vector with length
2993  `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
2994  If `weights` are non-None, then index `i` of the output stores the sum of the
2995  value in `weights` at each index where the corresponding value in `arr` is
2996  `i`.
2997
2998  Args:
2999    arr: An int32 tensor of non-negative values.
3000    weights: If non-None, must be the same shape as arr. For each value in
3001      `arr`, the bin will be incremented by the corresponding weight instead of
3002      1.
3003    minlength: If given, ensures the output has length at least `minlength`,
3004      padding with zeros at the end if necessary.
3005    maxlength: If given, skips values in `arr` that are equal or greater than
3006      `maxlength`, ensuring that the output has length at most `maxlength`.
3007    dtype: If `weights` is None, determines the type of the output bins.
3008
3009  Returns:
3010    A vector with the same dtype as `weights` or the given `dtype`. The bin
3011    values.
3012  """
3013  return bincount(arr, weights, minlength, maxlength, dtype)
3014
3015
3016@tf_export("math.cumsum", "cumsum")
3017def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
3018  """Compute the cumulative sum of the tensor `x` along `axis`.
3019
3020  By default, this op performs an inclusive cumsum, which means that the first
3021  element of the input is identical to the first element of the output:
3022
3023  ```python
3024  tf.cumsum([a, b, c])  # [a, a + b, a + b + c]
3025  ```
3026
3027  By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed
3028  instead:
3029
3030  ```python
3031  tf.cumsum([a, b, c], exclusive=True)  # [0, a, a + b]
3032  ```
3033
3034  By setting the `reverse` kwarg to `True`, the cumsum is performed in the
3035  opposite direction:
3036
3037  ```python
3038  tf.cumsum([a, b, c], reverse=True)  # [a + b + c, b + c, c]
3039  ```
3040
3041  This is more efficient than using separate `tf.reverse` ops.
3042
3043  The `reverse` and `exclusive` kwargs can also be combined:
3044
3045  ```python
3046  tf.cumsum([a, b, c], exclusive=True, reverse=True)  # [b + c, c, 0]
3047  ```
3048
3049  Args:
3050    x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
3051       `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
3052       `complex128`, `qint8`, `quint8`, `qint32`, `half`.
3053    axis: A `Tensor` of type `int32` (default: 0). Must be in the range
3054      `[-rank(x), rank(x))`.
3055    exclusive: If `True`, perform exclusive cumsum.
3056    reverse: A `bool` (default: False).
3057    name: A name for the operation (optional).
3058
3059  Returns:
3060    A `Tensor`. Has the same type as `x`.
3061  """
3062  with ops.name_scope(name, "Cumsum", [x]) as name:
3063    x = ops.convert_to_tensor(x, name="x")
3064    return gen_math_ops.cumsum(
3065        x, axis, exclusive=exclusive, reverse=reverse, name=name)
3066
3067
3068@tf_export("math.cumprod", v1=["math.cumprod", "cumprod"])
3069@deprecation.deprecated_endpoints("cumprod")
3070def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
3071  """Compute the cumulative product of the tensor `x` along `axis`.
3072
3073  By default, this op performs an inclusive cumprod, which means that the
3074  first element of the input is identical to the first element of the output:
3075
3076  ```python
3077  tf.math.cumprod([a, b, c])  # [a, a * b, a * b * c]
3078  ```
3079
3080  By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
3081  performed
3082  instead:
3083
3084  ```python
3085  tf.math.cumprod([a, b, c], exclusive=True)  # [1, a, a * b]
3086  ```
3087
3088  By setting the `reverse` kwarg to `True`, the cumprod is performed in the
3089  opposite direction:
3090
3091  ```python
3092  tf.math.cumprod([a, b, c], reverse=True)  # [a * b * c, b * c, c]
3093  ```
3094
3095  This is more efficient than using separate `tf.reverse` ops.
3096  The `reverse` and `exclusive` kwargs can also be combined:
3097
3098  ```python
3099  tf.math.cumprod([a, b, c], exclusive=True, reverse=True)  # [b * c, c, 1]
3100  ```
3101
3102  Args:
3103    x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
3104       `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
3105       `complex128`, `qint8`, `quint8`, `qint32`, `half`.
3106    axis: A `Tensor` of type `int32` (default: 0). Must be in the range
3107      `[-rank(x), rank(x))`.
3108    exclusive: If `True`, perform exclusive cumprod.
3109    reverse: A `bool` (default: False).
3110    name: A name for the operation (optional).
3111
3112  Returns:
3113    A `Tensor`. Has the same type as `x`.
3114  """
3115  with ops.name_scope(name, "Cumprod", [x]) as name:
3116    x = ops.convert_to_tensor(x, name="x")
3117    return gen_math_ops.cumprod(
3118        x, axis, exclusive=exclusive, reverse=reverse, name=name)
3119
3120
3121@tf_export("math.conj", v1=["math.conj", "conj"])
3122@dispatch.add_dispatch_support
3123@deprecation.deprecated_endpoints("conj")
3124def conj(x, name=None):
3125  r"""Returns the complex conjugate of a complex number.
3126
3127  Given a tensor `input` of complex numbers, this operation returns a tensor of
3128  complex numbers that are the complex conjugate of each element in `input`. The
3129  complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the
3130  real part and *b* is the imaginary part.
3131
3132  The complex conjugate returned by this operation is of the form \\(a - bj\\).
3133
3134  For example:
3135
3136      # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
3137      tf.math.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
3138
3139  If `x` is real, it is returned unchanged.
3140
3141  Args:
3142    x: `Tensor` to conjugate.  Must have numeric or variant type.
3143    name: A name for the operation (optional).
3144
3145  Returns:
3146    A `Tensor` that is the conjugate of `x` (with the same type).
3147
3148  Raises:
3149    TypeError: If `x` is not a numeric tensor.
3150  """
3151  if isinstance(x, ops.Tensor):
3152    dt = x.dtype
3153    if dt.is_floating or dt.is_integer:
3154      return x
3155  with ops.name_scope(name, "Conj", [x]) as name:
3156    x = ops.convert_to_tensor(x, name="x")
3157    if x.dtype.is_complex or x.dtype == dtypes.variant:
3158      return gen_math_ops.conj(x, name=name)
3159    elif x.dtype.is_floating or x.dtype.is_integer:
3160      return x
3161    else:
3162      raise TypeError(
3163          "Expected numeric or variant tensor, got dtype %r" % x.dtype)
3164
3165
3166def _BroadcastShape(op):
3167  """Common shape function for binary operators that broadcast their inputs."""
3168  return [
3169      common_shapes.broadcast_shape(op.inputs[0].get_shape(),
3170                                    op.inputs[1].get_shape())
3171  ]
3172
3173
3174def reduced_shape(input_shape, axes):
3175  """Helper function for reduction ops.
3176
3177  Args:
3178    input_shape: 1-D Tensor, the shape of the Tensor being reduced.
3179    axes: 1-D Tensor, the reduction axes.
3180  Returns:
3181    A 1-D Tensor, the output shape as if keepdims were set to True.
3182  """
3183  # Example:
3184  # cast needed for SparseTensor reductions
3185  if context.executing_eagerly():
3186    input_shape = input_shape.numpy()
3187    axes = axes.numpy()
3188    input_shape[axes] = 1
3189    return input_shape
3190
3191  input_shape = cast(input_shape, dtypes.int32)  # [2, 3, 5, 7]
3192  axes = cast(axes, dtypes.int32)  # [1, 2]
3193
3194  input_rank = array_ops.size(input_shape)  # 4
3195  axes = (axes + input_rank) % input_rank
3196  axes_shape = array_ops.shape(axes)  # [2]
3197  return gen_data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
3198      [
3199          range(input_rank),  # [0, 1, 2, 3]
3200          axes
3201      ],  # [1, 2]
3202      [
3203          input_shape,  # [2, 3, 5, 7]
3204          array_ops.fill(axes_shape, 1)
3205      ])  # [1, 1]
3206
3207
3208def _unsorted_segment_N(data, segment_ids, num_segments):
3209  """ Helper function for unsorted_segment_mean/_sqrtN. Computes the number
3210      of segment entries with 0-entries set to 1 to allow division by N.
3211  """
3212  # bincount doesn't support negative indices so we use unsorted_segment_sum
3213  segment_ids_shape = array_ops.shape_internal(segment_ids)
3214  ones_tensor = array_ops.ones(segment_ids_shape, dtype=data.dtype)
3215  N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments)
3216  # add dimensions for all non-reduced axes
3217  ndims_output = data.shape.ndims - segment_ids.shape.ndims
3218  broadcast_shape = [num_segments] + [1] * ndims_output
3219  N = array_ops.reshape(N, broadcast_shape)
3220  return gen_math_ops.maximum(N, 1)
3221
3222
3223@tf_export(
3224    "math.unsorted_segment_mean",
3225    v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
3226@deprecation.deprecated_endpoints("unsorted_segment_mean")
3227@dispatch.add_dispatch_support
3228def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
3229  r"""Computes the mean along segments of a tensor.
3230
3231  Read [the section on
3232  segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
3233  for an explanation of segments.
3234
3235  This operator is similar to the unsorted segment sum operator found
3236  [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
3237  Instead of computing the sum over segments, it computes the mean of all
3238  entries belonging to a segment such that:
3239
3240  \\(output_i = 1/N_i \sum_{j...} data[j...]\\) where the sum is over tuples
3241  `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the number of
3242  occurrences of id \\i\\.
3243
3244  If there is no entry for a given segment ID `i`, it outputs 0.
3245
3246  If the given segment ID `i` is negative, the value is dropped and will not
3247  be added to the sum of the segment.
3248
3249  Args:
3250    data: A `Tensor` with floating point or complex dtype.
3251    segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
3252    num_segments: An integer scalar `Tensor`.  The number of distinct
3253      segment IDs.
3254    name: A name for the operation (optional).
3255
3256  Returns:
3257    A `Tensor`.  Has same shape as data, except for the first `segment_ids.rank`
3258    dimensions, which are replaced with a single dimension which has size
3259   `num_segments`.
3260  """
3261  with ops.name_scope(name, "UnsortedSegmentMean"):
3262    data = ops.convert_to_tensor(data)
3263    segment_ids = ops.convert_to_tensor(segment_ids)
3264    N = _unsorted_segment_N(data, segment_ids, num_segments)
3265    summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
3266    return summed / N
3267
3268
3269@tf_export(
3270    "math.unsorted_segment_sqrt_n",
3271    v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
3272@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
3273@dispatch.add_dispatch_support
3274def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
3275  r"""Computes the sum along segments of a tensor divided by the sqrt(N).
3276
3277  Read [the section on
3278  segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
3279  for an explanation of segments.
3280
3281  This operator is similar to the unsorted segment sum operator found
3282  [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
3283  Additionally to computing the sum over segments, it divides the results by
3284  sqrt(N).
3285
3286  \\(output_i = 1/sqrt(N_i) \sum_{j...} data[j...]\\) where the sum is over
3287  tuples `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the
3288  number of occurrences of id \\i\\.
3289
3290  If there is no entry for a given segment ID `i`, it outputs 0.
3291
3292  Note that this op only supports floating point and complex dtypes,
3293  due to tf.sqrt only supporting these types.
3294
3295  If the given segment ID `i` is negative, the value is dropped and will not
3296  be added to the sum of the segment.
3297
3298  Args:
3299    data: A `Tensor` with floating point or complex dtype.
3300    segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
3301    num_segments: An integer scalar `Tensor`.  The number of distinct
3302      segment IDs.
3303    name: A name for the operation (optional).
3304
3305  Returns:
3306    A `Tensor`.  Has same shape as data, except for the first `segment_ids.rank`
3307    dimensions, which are replaced with a single dimension which has size
3308   `num_segments`.
3309  """
3310  with ops.name_scope(name, "UnsortedSegmentSqrtN"):
3311    data = ops.convert_to_tensor(data)
3312    segment_ids = ops.convert_to_tensor(segment_ids)
3313    N = _unsorted_segment_N(data, segment_ids, num_segments)
3314    summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
3315    return summed / gen_math_ops.sqrt(N)
3316
3317
3318@tf_export(v1=["sparse.segment_sum", "sparse_segment_sum"])
3319@deprecation.deprecated_endpoints("sparse_segment_sum")
3320def sparse_segment_sum(data, indices, segment_ids, name=None,
3321                       num_segments=None):
3322  r"""Computes the sum along sparse segments of a tensor.
3323
3324  Read [the section on
3325  segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
3326  for an explanation of segments.
3327
3328  Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
3329  dimension, selecting a subset of dimension 0, specified by `indices`.
3330  `segment_ids` is allowed to have missing ids, in which case the output will
3331  be zeros at those indices. In those cases `num_segments` is used to determine
3332  the size of the output.
3333
3334  For example:
3335
3336  ```python
3337  c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
3338
3339  # Select two rows, one segment.
3340  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
3341  # => [[0 0 0 0]]
3342
3343  # Select two rows, two segment.
3344  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
3345  # => [[ 1  2  3  4]
3346  #     [-1 -2 -3 -4]]
3347
3348  # With missing segment ids.
3349  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
3350                        num_segments=4)
3351  # => [[ 1  2  3  4]
3352  #     [ 0  0  0  0]
3353  #     [-1 -2 -3 -4]
3354  #     [ 0  0  0  0]]
3355
3356  # Select all rows, two segments.
3357  tf.sparse.segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
3358  # => [[0 0 0 0]
3359  #     [5 6 7 8]]
3360
3361  # Which is equivalent to:
3362  tf.segment_sum(c, tf.constant([0, 0, 1]))
3363  ```
3364
3365  Args:
3366    data: A `Tensor` with data that will be assembled in the output.
3367    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
3368      `segment_ids`.
3369    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
3370      Values should be sorted and can be repeated.
3371    name: A name for the operation (optional).
3372    num_segments: An optional int32 scalar. Indicates the size of the output
3373      `Tensor`.
3374
3375  Returns:
3376    A `tensor` of the shape as data, except for dimension 0 which
3377    has size `k`, the number of segments specified via `num_segments` or
3378    inferred for the last element in `segments_ids`.
3379  """
3380  if num_segments is not None:
3381    return gen_math_ops.sparse_segment_sum_with_num_segments(
3382        data=data,
3383        indices=indices,
3384        segment_ids=segment_ids,
3385        num_segments=num_segments,
3386        name=name)
3387  else:
3388    return gen_math_ops.sparse_segment_sum(
3389        data=data, indices=indices, segment_ids=segment_ids, name=name)
3390
3391
3392@tf_export("sparse.segment_sum", v1=[])
3393def sparse_segment_sum_v2(data,
3394                          indices,
3395                          segment_ids,
3396                          num_segments=None,
3397                          name=None):
3398  return sparse_segment_mean(
3399      data, indices, segment_ids, name=name, num_segments=num_segments)
3400
3401
3402@tf_export(v1=["sparse.segment_mean", "sparse_segment_mean"])
3403@deprecation.deprecated_endpoints("sparse_segment_mean")
3404def sparse_segment_mean(data,
3405                        indices,
3406                        segment_ids,
3407                        name=None,
3408                        num_segments=None):
3409  r"""Computes the mean along sparse segments of a tensor.
3410
3411  Read [the section on
3412  segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
3413  for an explanation of segments.
3414
3415  Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
3416  dimension, selecting a subset of dimension 0, specified by `indices`.
3417  `segment_ids` is allowed to have missing ids, in which case the output will
3418  be zeros at those indices. In those cases `num_segments` is used to determine
3419  the size of the output.
3420
3421  Args:
3422    data: A `Tensor` with data that will be assembled in the output.
3423    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
3424      `segment_ids`.
3425    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
3426      Values should be sorted and can be repeated.
3427    name: A name for the operation (optional).
3428    num_segments: An optional int32 scalar. Indicates the size of the output
3429      `Tensor`.
3430
3431  Returns:
3432    A `tensor` of the shape as data, except for dimension 0 which
3433    has size `k`, the number of segments specified via `num_segments` or
3434    inferred for the last element in `segments_ids`.
3435  """
3436  if num_segments is not None:
3437    return gen_math_ops.sparse_segment_mean_with_num_segments(
3438        data=data,
3439        indices=indices,
3440        segment_ids=segment_ids,
3441        num_segments=num_segments,
3442        name=name)
3443  else:
3444    return gen_math_ops.sparse_segment_mean(
3445        data=data, indices=indices, segment_ids=segment_ids, name=name)
3446
3447
3448@tf_export("sparse.segment_mean", v1=[])
3449def sparse_segment_mean_v2(data,
3450                           indices,
3451                           segment_ids,
3452                           num_segments=None,
3453                           name=None):
3454  r"""Computes the mean along sparse segments of a tensor.
3455
3456  Read [the section on
3457  segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
3458  for an explanation of segments.
3459
3460  Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
3461  dimension, selecting a subset of dimension 0, specified by `indices`.
3462  `segment_ids` is allowed to have missing ids, in which case the output will
3463  be zeros at those indices. In those cases `num_segments` is used to determine
3464  the size of the output.
3465
3466  Args:
3467    data: A `Tensor` with data that will be assembled in the output.
3468    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
3469      `segment_ids`.
3470    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
3471      should be sorted and can be repeated.
3472    num_segments: An optional int32 scalar. Indicates the size of the output
3473      `Tensor`.
3474    name: A name for the operation (optional).
3475
3476  Returns:
3477    A `tensor` of the shape as data, except for dimension 0 which
3478    has size `k`, the number of segments specified via `num_segments` or
3479    inferred for the last element in `segments_ids`.
3480  """
3481  return sparse_segment_mean(
3482      data, indices, segment_ids, name=name, num_segments=num_segments)
3483
3484
3485@tf_export(v1=["sparse.segment_sqrt_n", "sparse_segment_sqrt_n"])
3486@deprecation.deprecated_endpoints("sparse_segment_sqrt_n")
3487def sparse_segment_sqrt_n(data,
3488                          indices,
3489                          segment_ids,
3490                          name=None,
3491                          num_segments=None):
3492  r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
3493
3494  `N` is the size of the segment being reduced.
3495
3496  Args:
3497    data: A `Tensor` with data that will be assembled in the output.
3498    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
3499      `segment_ids`.
3500    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
3501      Values should be sorted and can be repeated.
3502    name: A name for the operation (optional).
3503    num_segments: An optional int32 scalar. Indicates the size of the output
3504      `Tensor`.
3505
3506  Returns:
3507    A `tensor` of the shape as data, except for dimension 0 which
3508    has size `k`, the number of segments specified via `num_segments` or
3509    inferred for the last element in `segments_ids`.
3510  """
3511  if num_segments is not None:
3512    return gen_math_ops.sparse_segment_sqrt_n_with_num_segments(
3513        data=data,
3514        indices=indices,
3515        segment_ids=segment_ids,
3516        num_segments=num_segments,
3517        name=name)
3518  else:
3519    return gen_math_ops.sparse_segment_sqrt_n(
3520        data=data, indices=indices, segment_ids=segment_ids, name=name)
3521
3522
3523@tf_export("sparse.segment_sqrt_n", v1=[])
3524def sparse_segment_sqrt_n_v2(data,
3525                             indices,
3526                             segment_ids,
3527                             num_segments=None,
3528                             name=None):
3529  r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
3530
3531  `N` is the size of the segment being reduced.
3532
3533  Args:
3534    data: A `Tensor` with data that will be assembled in the output.
3535    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
3536      `segment_ids`.
3537    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
3538      should be sorted and can be repeated.
3539    num_segments: An optional int32 scalar. Indicates the size of the output
3540      `Tensor`.
3541    name: A name for the operation (optional).
3542
3543  Returns:
3544    A `tensor` of the shape as data, except for dimension 0 which
3545    has size `k`, the number of segments specified via `num_segments` or
3546    inferred for the last element in `segments_ids`.
3547  """
3548  return sparse_segment_sqrt_n(
3549      data, indices, segment_ids, name=name, num_segments=num_segments)
3550
3551
3552@tf_export("tensordot", "linalg.tensordot")
3553def tensordot(a, b, axes, name=None):
3554  r"""Tensor contraction of a and b along specified axes.
3555
3556  Tensordot (also known as tensor contraction) sums the product of elements
3557  from `a` and `b` over the indices specified by `a_axes` and `b_axes`.
3558  The lists `a_axes` and `b_axes` specify those pairs of axes along which to
3559  contract the tensors. The axis `a_axes[i]` of `a` must have the same dimension
3560  as axis `b_axes[i]` of `b` for all `i` in `range(0, len(a_axes))`. The lists
3561  `a_axes` and `b_axes` must have identical length and consist of unique
3562  integers that specify valid axes for each of the tensors.
3563
3564  This operation corresponds to `numpy.tensordot(a, b, axes)`.
3565
3566  Example 1: When `a` and `b` are matrices (order 2), the case `axes = 1`
3567  is equivalent to matrix multiplication.
3568
3569  Example 2: When `a` and `b` are matrices (order 2), the case
3570  `axes = [[1], [0]]` is equivalent to matrix multiplication.
3571
3572  Example 3: Suppose that \\(a_{ijk}\\) and \\(b_{lmn}\\) represent two
3573  tensors of order 3. Then, `contract(a, b, [[0], [2]])` is the order 4 tensor
3574  \\(c_{jklm}\\) whose entry
3575  corresponding to the indices \\((j,k,l,m)\\) is given by:
3576
3577  \\( c_{jklm} = \sum_i a_{ijk} b_{lmi} \\).
3578
3579  In general, `order(c) = order(a) + order(b) - 2*len(axes[0])`.
3580
3581  Args:
3582    a: `Tensor` of type `float32` or `float64`.
3583    b: `Tensor` with the same type as `a`.
3584    axes: Either a scalar `N`, or a list or an `int32` `Tensor` of shape [2, k].
3585      If axes is a scalar, sum over the last N axes of a and the first N axes of
3586      b in order. If axes is a list or `Tensor` the first and second row contain
3587      the set of unique integers specifying axes along which the contraction is
3588      computed, for `a` and `b`, respectively. The number of axes for `a` and
3589      `b` must be equal.
3590    name: A name for the operation (optional).
3591
3592  Returns:
3593    A `Tensor` with the same type as `a`.
3594
3595  Raises:
3596    ValueError: If the shapes of `a`, `b`, and `axes` are incompatible.
3597    IndexError: If the values in axes exceed the rank of the corresponding
3598      tensor.
3599  """
3600
3601  def _tensordot_reshape(a, axes, flipped=False):
3602    """Helper method to perform transpose and reshape for contraction op.
3603
3604    This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul`
3605    using `array_ops.transpose` and `array_ops.reshape`. The method takes a
3606    tensor and performs the correct transpose and reshape operation for a given
3607    set of indices. It returns the reshaped tensor as well as a list of indices
3608    necessary to reshape the tensor again after matrix multiplication.
3609
3610    Args:
3611      a: `Tensor`.
3612      axes: List or `int32` `Tensor` of unique indices specifying valid axes of
3613       `a`.
3614      flipped: An optional `bool`. Defaults to `False`. If `True`, the method
3615        assumes that `a` is the second argument in the contraction operation.
3616
3617    Returns:
3618      A tuple `(reshaped_a, free_dims, free_dims_static)` where `reshaped_a` is
3619      the tensor `a` reshaped to allow contraction via `matmul`, `free_dims` is
3620      either a list of integers or an `int32` `Tensor`, depending on whether
3621      the shape of a is fully specified, and free_dims_static is either a list
3622      of integers and None values, or None, representing the inferred
3623      static shape of the free dimensions
3624    """
3625    if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)):
3626      shape_a = a.get_shape().as_list()
3627      axes = [i if i >= 0 else i + len(shape_a) for i in axes]
3628      free = [i for i in xrange(len(shape_a)) if i not in axes]
3629      free_dims = [shape_a[i] for i in free]
3630      prod_free = int(np.prod([shape_a[i] for i in free]))
3631      prod_axes = int(np.prod([shape_a[i] for i in axes]))
3632      perm = list(axes) + free if flipped else free + list(axes)
3633      new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes]
3634      reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
3635      return reshaped_a, free_dims, free_dims
3636    else:
3637      if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)):
3638        shape_a = a.get_shape().as_list()
3639        axes = [i if i >= 0 else i + len(shape_a) for i in axes]
3640        free = [i for i in xrange(len(shape_a)) if i not in axes]
3641        axes_dims = [shape_a[i] for i in axes]
3642        free_dims = [shape_a[i] for i in free]
3643        free_dims_static = free_dims
3644        axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
3645        free = ops.convert_to_tensor(free, dtype=dtypes.int32, name="free")
3646        shape_a = array_ops.shape(a)
3647      else:
3648        free_dims_static = None
3649        shape_a = array_ops.shape(a)
3650        rank_a = array_ops.rank(a)
3651        axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
3652        axes = array_ops.where(axes >= 0, axes, axes + rank_a)
3653        free, _ = array_ops.setdiff1d(range(rank_a), axes)
3654      free_dims = array_ops.gather(shape_a, free)
3655      axes_dims = array_ops.gather(shape_a, axes)
3656      prod_free_dims = reduce_prod(free_dims)
3657      prod_axes_dims = reduce_prod(axes_dims)
3658      if flipped:
3659        perm = array_ops.concat([axes, free], 0)
3660        new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
3661      else:
3662        perm = array_ops.concat([free, axes], 0)
3663        new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
3664      reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
3665      return reshaped_a, free_dims, free_dims_static
3666
3667  def _tensordot_axes(a, axes):
3668    """Generates two sets of contraction axes for the two tensor arguments."""
3669    a_shape = a.get_shape()
3670    if isinstance(axes, compat.integral_types):
3671      if axes < 0:
3672        raise ValueError("'axes' must be at least 0.")
3673      if a_shape.ndims is not None:
3674        if axes > a_shape.ndims:
3675          raise ValueError("'axes' must not be larger than the number of "
3676                           "dimensions of tensor %s." % a)
3677        return (list(xrange(a_shape.ndims - axes, a_shape.ndims)),
3678                list(xrange(axes)))
3679      else:
3680        rank = array_ops.rank(a)
3681        return (range(rank - axes, rank, dtype=dtypes.int32),
3682                range(axes, dtype=dtypes.int32))
3683    elif isinstance(axes, (list, tuple)):
3684      if len(axes) != 2:
3685        raise ValueError("'axes' must be an integer or have length 2.")
3686      a_axes = axes[0]
3687      b_axes = axes[1]
3688      if isinstance(a_axes, compat.integral_types) and \
3689          isinstance(b_axes, compat.integral_types):
3690        a_axes = [a_axes]
3691        b_axes = [b_axes]
3692      if len(a_axes) != len(b_axes):
3693        raise ValueError(
3694            "Different number of contraction axes 'a' and 'b', %s != %s." %
3695            (len(a_axes), len(b_axes)))
3696      return a_axes, b_axes
3697    else:
3698      axes = ops.convert_to_tensor(axes, name="axes", dtype=dtypes.int32)
3699      return axes[0], axes[1]
3700
3701  with ops.name_scope(name, "Tensordot", [a, b, axes]) as name:
3702    a = ops.convert_to_tensor(a, name="a")
3703    b = ops.convert_to_tensor(b, name="b")
3704    a_axes, b_axes = _tensordot_axes(a, axes)
3705    a_reshape, a_free_dims, a_free_dims_static = _tensordot_reshape(a, a_axes)
3706    b_reshape, b_free_dims, b_free_dims_static = _tensordot_reshape(
3707        b, b_axes, True)
3708    ab_matmul = matmul(a_reshape, b_reshape)
3709    if isinstance(a_free_dims, list) and isinstance(b_free_dims, list):
3710      return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name)
3711    else:
3712      a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32)
3713      b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32)
3714      product = array_ops.reshape(
3715          ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)
3716      if a_free_dims_static is not None and b_free_dims_static is not None:
3717        product.set_shape(a_free_dims_static + b_free_dims_static)
3718      return product
3719
3720
3721@tf_export("math.polyval")
3722def polyval(coeffs, x, name=None):
3723  r"""Computes the elementwise value of a polynomial.
3724
3725  If `x` is a tensor and `coeffs` is a list n + 1 tensors, this function returns
3726  the value of the n-th order polynomial
3727
3728     p(x) = coeffs[n-1] + coeffs[n-2] * x + ...  + coeffs[0] * x**(n-1)
3729
3730  evaluated using Horner's method, i.e.
3731
3732     p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] +
3733            x * coeffs[0]))
3734
3735  Args:
3736    coeffs: A list of `Tensor` representing the coefficients of the polynomial.
3737    x: A `Tensor` representing the variable of the polynomial.
3738    name: A name for the operation (optional).
3739
3740  Returns:
3741    A `tensor` of the shape as the expression p(x) with usual broadcasting rules
3742    for element-wise addition and multiplication applied.
3743
3744  @compatibility(numpy)
3745  Equivalent to numpy.polyval.
3746  @end_compatibility
3747  """
3748
3749  with ops.name_scope(name, "polyval", nest.flatten(coeffs) + [x]) as name:
3750    x = ops.convert_to_tensor(x, name="x")
3751    if len(coeffs) < 1:
3752      return array_ops.zeros_like(x, name=name)
3753    coeffs = [
3754        ops.convert_to_tensor(coeff, name=("coeff_%d" % index))
3755        for index, coeff in enumerate(coeffs)
3756    ]
3757    p = coeffs[0]
3758    for c in coeffs[1:]:
3759      p = c + p * x
3760    return p
3761