• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# Tests for this file live in python/kernel_tests/array_ops_test.py
16"""Support for manipulating tensors."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numbers
23import numpy as np
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import common_shapes
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_util
35# 'Constant' gets imported in the module 'array_ops'.
36from tensorflow.python.framework.constant_op import constant
37from tensorflow.python.ops import gen_array_ops
38from tensorflow.python.ops import gen_math_ops
39# go/tf-wildcard-import
40# pylint: disable=wildcard-import
41from tensorflow.python.ops.gen_array_ops import *
42from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse  # pylint: disable=unused-import
43from tensorflow.python.types import core
44from tensorflow.python.util import deprecation
45from tensorflow.python.util import dispatch
46from tensorflow.python.util import nest
47from tensorflow.python.util import tf_decorator
48from tensorflow.python.util.tf_export import tf_export
49# pylint: enable=wildcard-import
50
51# Used for slicing to specify a new 1 size dimension
52newaxis = None
53tf_export("newaxis").export_constant(__name__, "newaxis")
54
55# We override the 'slice' for the "slice" op, so we keep Python's
56# existing 'slice' for later use in this module.
57_BaseSlice = slice
58
59
60@tf_export("reshape", v1=["reshape", "manip.reshape"])
61@dispatch.add_dispatch_support
62def reshape(tensor, shape, name=None):  # pylint: disable=redefined-outer-name
63  r"""Reshapes a tensor.
64
65  Given `tensor`, this operation returns a new `tf.Tensor` that has the same
66  values as `tensor` in the same order, except with a new shape given by
67  `shape`.
68
69  >>> t1 = [[1, 2, 3],
70  ...       [4, 5, 6]]
71  >>> print(tf.shape(t1).numpy())
72  [2 3]
73  >>> t2 = tf.reshape(t1, [6])
74  >>> t2
75  <tf.Tensor: shape=(6,), dtype=int32,
76    numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
77  >>> tf.reshape(t2, [3, 2])
78  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
79    array([[1, 2],
80           [3, 4],
81           [5, 6]], dtype=int32)>
82
83  The `tf.reshape` does not change the order of or the total number of elements
84  in the tensor, and so it can reuse the underlying data buffer. This makes it
85  a fast operation independent of how big of a tensor it is operating on.
86
87  >>> tf.reshape([1, 2, 3], [2, 2])
88  Traceback (most recent call last):
89  ...
90  InvalidArgumentError: Input to reshape is a tensor with 3 values, but the
91  requested shape has 4
92
93  To instead reorder the data to rearrange the dimensions of a tensor, see
94  `tf.transpose`.
95
96  >>> t = [[1, 2, 3],
97  ...      [4, 5, 6]]
98  >>> tf.reshape(t, [3, 2]).numpy()
99  array([[1, 2],
100         [3, 4],
101         [5, 6]], dtype=int32)
102  >>> tf.transpose(t, perm=[1, 0]).numpy()
103  array([[1, 4],
104         [2, 5],
105         [3, 6]], dtype=int32)
106
107  If one component of `shape` is the special value -1, the size of that
108  dimension is computed so that the total size remains constant.  In particular,
109  a `shape` of `[-1]` flattens into 1-D.  At most one component of `shape` can
110  be -1.
111
112  >>> t = [[1, 2, 3],
113  ...      [4, 5, 6]]
114  >>> tf.reshape(t, [-1])
115  <tf.Tensor: shape=(6,), dtype=int32,
116    numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
117  >>> tf.reshape(t, [3, -1])
118  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
119    array([[1, 2],
120           [3, 4],
121           [5, 6]], dtype=int32)>
122  >>> tf.reshape(t, [-1, 2])
123  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
124    array([[1, 2],
125           [3, 4],
126           [5, 6]], dtype=int32)>
127
128  `tf.reshape(t, [])` reshapes a tensor `t` with one element to a scalar.
129
130  >>> tf.reshape([7], []).numpy()
131  7
132
133  More examples:
134
135  >>> t = [1, 2, 3, 4, 5, 6, 7, 8, 9]
136  >>> print(tf.shape(t).numpy())
137  [9]
138  >>> tf.reshape(t, [3, 3])
139  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
140    array([[1, 2, 3],
141           [4, 5, 6],
142           [7, 8, 9]], dtype=int32)>
143
144  >>> t = [[[1, 1], [2, 2]],
145  ...      [[3, 3], [4, 4]]]
146  >>> print(tf.shape(t).numpy())
147  [2 2 2]
148  >>> tf.reshape(t, [2, 4])
149  <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
150    array([[1, 1, 2, 2],
151           [3, 3, 4, 4]], dtype=int32)>
152
153  >>> t = [[[1, 1, 1],
154  ...       [2, 2, 2]],
155  ...      [[3, 3, 3],
156  ...       [4, 4, 4]],
157  ...      [[5, 5, 5],
158  ...       [6, 6, 6]]]
159  >>> print(tf.shape(t).numpy())
160  [3 2 3]
161  >>> # Pass '[-1]' to flatten 't'.
162  >>> tf.reshape(t, [-1])
163  <tf.Tensor: shape=(18,), dtype=int32,
164    numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
165    dtype=int32)>
166  >>> # -- Using -1 to infer the shape --
167  >>> # Here -1 is inferred to be 9:
168  >>> tf.reshape(t, [2, -1])
169  <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
170    array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
171           [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
172  >>> # -1 is inferred to be 2:
173  >>> tf.reshape(t, [-1, 9])
174  <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
175    array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
176           [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
177  >>> # -1 is inferred to be 3:
178  >>> tf.reshape(t, [ 2, -1, 3])
179  <tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
180    array([[[1, 1, 1],
181            [2, 2, 2],
182            [3, 3, 3]],
183           [[4, 4, 4],
184            [5, 5, 5],
185            [6, 6, 6]]], dtype=int32)>
186
187  Args:
188    tensor: A `Tensor`.
189    shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
190      Defines the shape of the output tensor.
191    name: Optional string. A name for the operation.
192
193  Returns:
194    A `Tensor`. Has the same type as `tensor`.
195  """
196  result = gen_array_ops.reshape(tensor, shape, name)
197  tensor_util.maybe_set_static_shape(result, shape)
198  return result
199
200
201@tf_export("fill")
202@dispatch.add_dispatch_support
203def fill(dims, value, name=None):
204  r"""Creates a tensor filled with a scalar value.
205
206  See also `tf.ones`, `tf.zeros`, `tf.one_hot`, `tf.eye`.
207
208  This operation creates a tensor of shape `dims` and fills it with `value`.
209
210  For example:
211
212  >>> tf.fill([2, 3], 9)
213  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
214  array([[9, 9, 9],
215         [9, 9, 9]], dtype=int32)>
216
217  `tf.fill` evaluates at graph runtime and supports dynamic shapes based on
218  other runtime `tf.Tensors`, unlike `tf.constant(value, shape=dims)`, which
219  embeds the value as a `Const` node.
220
221  Args:
222    dims: A 1-D sequence of non-negative numbers. Represents the shape of the
223      output `tf.Tensor`. Entries should be of type: `int32`, `int64`.
224    value: A value to fill the returned `tf.Tensor`.
225    name: Optional string. The name of the output `tf.Tensor`.
226
227  Returns:
228    A `tf.Tensor` with shape `dims` and the same dtype as `value`.
229
230  Raises:
231    InvalidArgumentError: `dims` contains negative entries.
232    NotFoundError: `dims` contains non-integer entries.
233
234  @compatibility(numpy)
235  Similar to `np.full`. In `numpy`, more parameters are supported. Passing a
236  number argument as the shape (`np.full(5, value)`) is valid in `numpy` for
237  specifying a 1-D shaped result, while TensorFlow does not support this syntax.
238  @end_compatibility
239  """
240  result = gen_array_ops.fill(dims, value, name=name)
241  tensor_util.maybe_set_static_shape(result, dims)
242  return result
243
244
245@tf_export("identity")
246@dispatch.add_dispatch_support
247def identity(input, name=None):  # pylint: disable=redefined-builtin
248  r"""Return a Tensor with the same shape and contents as input.
249
250  The return value is not the same Tensor as the original, but contains the same
251  values.  This operation is fast when used on the same device.
252
253  For example:
254
255  >>> a = tf.constant([0.78])
256  >>> a_identity = tf.identity(a)
257  >>> a.numpy()
258  array([0.78], dtype=float32)
259  >>> a_identity.numpy()
260  array([0.78], dtype=float32)
261
262  Calling `tf.identity` on a variable will make a Tensor that represents the
263  value of that variable at the time it is called. This is equivalent to calling
264  `<variable>.read_value()`.
265
266  >>> a = tf.Variable(5)
267  >>> a_identity = tf.identity(a)
268  >>> a.assign_add(1)
269  <tf.Variable ... shape=() dtype=int32, numpy=6>
270  >>> a.numpy()
271  6
272  >>> a_identity.numpy()
273  5
274
275  Args:
276    input: A `Tensor`, a `Variable`, a `CompositeTensor` or anything that can be
277    converted to a tensor using `tf.convert_to_tensor`.
278    name: A name for the operation (optional).
279
280  Returns:
281    A `Tensor` or CompositeTensor. Has the same type and contents as `input`.
282  """
283  if isinstance(input, composite_tensor.CompositeTensor):
284    return nest.map_structure(identity, input, expand_composites=True)
285  if context.executing_eagerly() and not hasattr(input, "graph"):
286    # Make sure we get an input with handle data attached from resource
287    # variables. Variables have correct handle data when graph building.
288    input = ops.convert_to_tensor(input)
289  ret = gen_array_ops.identity(input, name=name)
290  # Propagate handle data for happier shape inference for resource variables.
291  if hasattr(input, "_handle_data"):
292    ret._handle_data = input._handle_data  # pylint: disable=protected-access
293  return ret
294
295
296# pylint: disable=redefined-builtin,protected-access
297@tf_export(v1=["expand_dims"])
298@dispatch.add_dispatch_support
299@deprecation.deprecated_args(None, "Use the `axis` argument instead", "dim")
300def expand_dims(input, axis=None, name=None, dim=None):
301  """Returns a tensor with a length 1 axis inserted at index `axis`.
302
303  Given a tensor `input`, this operation inserts a dimension of length 1 at the
304  dimension index `axis` of `input`'s shape. The dimension index follows Python
305  indexing rules: It's zero-based, a negative index it is counted backward
306  from the end.
307
308  This operation is useful to:
309
310  * Add an outer "batch" dimension to a single element.
311  * Align axes for broadcasting.
312  * To add an inner vector length axis to a tensor of scalars.
313
314  For example:
315
316  If you have a single image of shape `[height, width, channels]`:
317
318  >>> image = tf.zeros([10,10,3])
319
320  You can add an outer `batch` axis by passing `axis=0`:
321
322  >>> tf.expand_dims(image, axis=0).shape.as_list()
323  [1, 10, 10, 3]
324
325  The new axis location matches Python `list.insert(axis, 1)`:
326
327  >>> tf.expand_dims(image, axis=1).shape.as_list()
328  [10, 1, 10, 3]
329
330  Following standard Python indexing rules, a negative `axis` counts from the
331  end so `axis=-1` adds an inner most dimension:
332
333  >>> tf.expand_dims(image, -1).shape.as_list()
334  [10, 10, 3, 1]
335
336  This operation requires that `axis` is a valid index for `input.shape`,
337  following Python indexing rules:
338
339  ```
340  -1-tf.rank(input) <= axis <= tf.rank(input)
341  ```
342
343  This operation is related to:
344
345  * `tf.squeeze`, which removes dimensions of size 1.
346  * `tf.reshape`, which provides more flexible reshaping capability.
347  * `tf.sparse.expand_dims`, which provides this functionality for
348    `tf.SparseTensor`
349
350  Args:
351    input: A `Tensor`.
352    axis: 0-D (scalar). Specifies the dimension index at which to expand the
353      shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`.
354    name: The name of the output `Tensor` (optional).
355    dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
356
357  Returns:
358    A `Tensor` with the same data as `input`, but its shape has an additional
359    dimension of size 1 added.
360
361  Raises:
362    ValueError: if either both or neither of `dim` and `axis` are specified.
363  """
364  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
365  if axis is None:
366    raise ValueError("Must specify an axis argument to tf.expand_dims()")
367  return expand_dims_v2(input, axis, name)
368
369
370@tf_export("expand_dims", v1=[])
371@dispatch.add_dispatch_support
372def expand_dims_v2(input, axis, name=None):
373  """Returns a tensor with a length 1 axis inserted at index `axis`.
374
375  Given a tensor `input`, this operation inserts a dimension of length 1 at the
376  dimension index `axis` of `input`'s shape. The dimension index follows Python
377  indexing rules: It's zero-based, a negative index it is counted backward
378  from the end.
379
380  This operation is useful to:
381
382  * Add an outer "batch" dimension to a single element.
383  * Align axes for broadcasting.
384  * To add an inner vector length axis to a tensor of scalars.
385
386  For example:
387
388  If you have a single image of shape `[height, width, channels]`:
389
390  >>> image = tf.zeros([10,10,3])
391
392  You can add an outer `batch` axis by passing `axis=0`:
393
394  >>> tf.expand_dims(image, axis=0).shape.as_list()
395  [1, 10, 10, 3]
396
397  The new axis location matches Python `list.insert(axis, 1)`:
398
399  >>> tf.expand_dims(image, axis=1).shape.as_list()
400  [10, 1, 10, 3]
401
402  Following standard Python indexing rules, a negative `axis` counts from the
403  end so `axis=-1` adds an inner most dimension:
404
405  >>> tf.expand_dims(image, -1).shape.as_list()
406  [10, 10, 3, 1]
407
408  This operation requires that `axis` is a valid index for `input.shape`,
409  following Python indexing rules:
410
411  ```
412  -1-tf.rank(input) <= axis <= tf.rank(input)
413  ```
414
415  This operation is related to:
416
417  * `tf.squeeze`, which removes dimensions of size 1.
418  * `tf.reshape`, which provides more flexible reshaping capability.
419  * `tf.sparse.expand_dims`, which provides this functionality for
420    `tf.SparseTensor`
421
422  Args:
423    input: A `Tensor`.
424    axis: Integer specifying the dimension index at which to expand the
425      shape of `input`. Given an input of D dimensions, `axis` must be in range
426      `[-(D+1), D]` (inclusive).
427    name: Optional string. The name of the output `Tensor`.
428
429  Returns:
430    A tensor with the same data as `input`, with an additional dimension
431    inserted at the index specified by `axis`.
432
433  Raises:
434    TypeError: If `axis` is not specified.
435    InvalidArgumentError: If `axis` is out of range `[-(D+1), D]`.
436  """
437  return gen_array_ops.expand_dims(input, axis, name)
438
439
440# pylint: enable=redefined-builtin,protected-access
441
442
443# Aliases for some automatically-generated names.
444# pylint: disable=protected-access
445@deprecation.deprecated("2016-11-30",
446                        "This op will be removed after the deprecation date. "
447                        "Please switch to tf.setdiff1d().")
448def listdiff(x, y, out_idx=None, name=None):
449  return gen_array_ops.list_diff(x, y, out_idx, name)
450
451
452listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__
453
454# pylint: enable=protected-access
455
456
457# pylint: disable=undefined-variable
458@deprecation.deprecated("2018-11-30",
459                        "This op will be removed after the deprecation date. "
460                        "Please switch to tf.sets.difference().")
461@tf_export(v1=["setdiff1d"])
462@dispatch.add_dispatch_support
463def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
464  """Computes the difference between two lists of numbers or strings.
465
466  Given a list x and a list y, this operation returns a list out that
467  represents all values that are in x but not in y. The returned list
468  out is sorted in the same order that the numbers appear in x
469  (duplicates are preserved). This operation also returns a list idx
470  that represents the position of each out element in x.
471
472  In other words:
473
474  ```python
475  out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]
476  ```
477
478  Example usage:
479
480  >>> x = [1, 2, 3, 4, 5, 6]
481  >>> y = [1, 3, 5]
482  >>> setdiff1d(x,y)
483  ListDiff(out=<tf.Tensor: id=2, shape=(3,), dtype=int32,
484  numpy=array([2, 4, 6], dtype=int32)>, idx=<tf.Tensor: id=3,
485  shape=(3,), dtype=int32, numpy=array([1, 3, 5], dtype=int32)>)
486
487  Args:
488    x: A Tensor. 1-D. Values to keep.
489    y: A Tensor. Must have the same type as x. 1-D. Values to remove.
490    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
491      tf.int32.
492    name: A name for the operation (optional).
493
494  Returns:
495    A tuple of Tensor objects (out, idx).
496    out: A Tensor. Has the same type as x.
497    idx: A Tensor of type out_idx.
498  """
499  return gen_array_ops.list_diff(x, y, index_dtype, name)
500
501
502setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__
503
504
505@tf_export("broadcast_dynamic_shape")
506@dispatch.add_dispatch_support
507def broadcast_dynamic_shape(shape_x, shape_y):
508  """Computes the shape of a broadcast given symbolic shapes.
509
510  When `shape_x` and `shape_y` are Tensors representing shapes (i.e. the result
511  of calling tf.shape on another Tensor) this computes a Tensor which is the
512  shape of the result of a broadcasting op applied in tensors of shapes
513  `shape_x` and `shape_y`.
514
515  This is useful when validating the result of a broadcasting operation when the
516  tensors do not have statically known shapes.
517
518  Example:
519
520  >>> shape_x = (1, 2, 3)
521  >>> shape_y = (5, 1, 3)
522  >>> tf.broadcast_dynamic_shape(shape_x, shape_y)
523  <tf.Tensor: shape=(3,), dtype=int32, numpy=array([5, 2, 3], ...>
524
525  Args:
526    shape_x: A rank 1 integer `Tensor`, representing the shape of x.
527    shape_y: A rank 1 integer `Tensor`, representing the shape of y.
528
529  Returns:
530    A rank 1 integer `Tensor` representing the broadcasted shape.
531
532  Raises:
533    InvalidArgumentError: If the two shapes are incompatible for
534    broadcasting.
535  """
536  return gen_array_ops.broadcast_args(shape_x, shape_y)
537
538
539@tf_export("broadcast_static_shape")
540@dispatch.add_dispatch_support
541def broadcast_static_shape(shape_x, shape_y):
542  """Computes the shape of a broadcast given known shapes.
543
544  When `shape_x` and `shape_y` are fully known `TensorShape`s this computes a
545  `TensorShape` which is the shape of the result of a broadcasting op applied in
546  tensors of shapes `shape_x` and `shape_y`.
547
548  For example, if shape_x is `TensorShape([1, 2, 3])` and shape_y is
549  `TensorShape([5, 1, 3])`, the result is a TensorShape whose value is
550  `TensorShape([5, 2, 3])`.
551
552  This is useful when validating the result of a broadcasting operation when the
553  tensors have statically known shapes.
554
555  Example:
556
557  >>> shape_x = tf.TensorShape([1, 2, 3])
558  >>> shape_y = tf.TensorShape([5, 1 ,3])
559  >>> tf.broadcast_static_shape(shape_x, shape_y)
560  TensorShape([5, 2, 3])
561
562  Args:
563    shape_x: A `TensorShape`
564    shape_y: A `TensorShape`
565
566  Returns:
567    A `TensorShape` representing the broadcasted shape.
568
569  Raises:
570    ValueError: If the two shapes can not be broadcasted.
571  """
572  return common_shapes.broadcast_shape(shape_x, shape_y)
573
574
575@tf_export("shape", v1=[])
576@dispatch.add_dispatch_support
577def shape_v2(input, out_type=dtypes.int32, name=None):
578  # pylint: disable=redefined-builtin
579  """Returns a tensor containing the shape of the input tensor.
580
581  See also `tf.size`, `tf.rank`.
582
583  `tf.shape` returns a 1-D integer tensor representing the shape of `input`.
584  For a scalar input, the tensor returned has a shape of (0,) and its value is
585  the empty vector (i.e. []).
586
587  For example:
588
589  >>> tf.shape(1.)
590  <tf.Tensor: shape=(0,), dtype=int32, numpy=array([], dtype=int32)>
591
592  >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
593  >>> tf.shape(t)
594  <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 2, 3], dtype=int32)>
595
596  Note: When using symbolic tensors, such as when using the Keras API,
597  tf.shape() will return the shape of the symbolic tensor.
598
599  >>> a = tf.keras.layers.Input((None, 10))
600  >>> tf.shape(a)
601  <... shape=(3,) dtype=int32...>
602
603  In these cases, using `tf.Tensor.shape` will return more informative results.
604
605  >>> a.shape
606  TensorShape([None, None, 10])
607
608  (The first `None` represents the as yet unknown batch size.)
609
610  `tf.shape` and `Tensor.shape` should be identical in eager mode.  Within
611  `tf.function` or within a `compat.v1` context, not all dimensions may be
612  known until execution time. Hence when defining custom layers and models
613  for graph mode, prefer the dynamic `tf.shape(x)` over the static `x.shape`.
614
615  Args:
616    input: A `Tensor` or `SparseTensor`.
617    out_type: (Optional) The specified output type of the operation (`int32` or
618      `int64`). Defaults to `tf.int32`.
619    name: A name for the operation (optional).
620
621  Returns:
622    A `Tensor` of type `out_type`.
623  """
624  return shape(input, name, out_type)
625
626
627@tf_export(v1=["shape"])
628@dispatch.add_dispatch_support
629def shape(input, name=None, out_type=dtypes.int32):
630  # pylint: disable=redefined-builtin
631  """Returns the shape of a tensor.
632
633  This operation returns a 1-D integer tensor representing the shape of `input`.
634
635  For example:
636
637  ```python
638  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
639  tf.shape(t)  # [2, 2, 3]
640  ```
641
642  Args:
643    input: A `Tensor` or `SparseTensor`.
644    name: A name for the operation (optional).
645    out_type: (Optional) The specified output type of the operation (`int32`
646    or `int64`). Defaults to `tf.int32`.
647
648  Returns:
649    A `Tensor` of type `out_type`.
650  """
651  return shape_internal(input, name, optimize=True, out_type=out_type)
652
653
654def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
655  # pylint: disable=redefined-builtin
656  """Returns the shape of a tensor.
657
658  Args:
659    input: A `Tensor` or `SparseTensor`.
660    name: A name for the operation (optional).
661    optimize: if true, encode the shape as a constant when possible.
662    out_type: (Optional) The specified output type of the operation (`int32` or
663      `int64`). Defaults to tf.int32.
664
665  Returns:
666    A `Tensor` of type `out_type`.
667
668  """
669  with ops.name_scope(name, "Shape", [input]) as name:
670    if isinstance(
671        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
672      return gen_math_ops.cast(input.dense_shape, out_type)
673    else:
674      if not context.executing_eagerly():
675        input = ops.convert_to_tensor(input)
676        input_shape = input.get_shape()
677        if optimize and input_shape.is_fully_defined():
678          return constant(input_shape.as_list(), out_type, name=name)
679      return gen_array_ops.shape(input, name=name, out_type=out_type)
680
681
682@tf_export("shape_n")
683@dispatch.add_dispatch_support
684def shape_n(input, out_type=dtypes.int32, name=None):
685  # pylint: disable=redefined-builtin
686  """Returns shape of tensors.
687
688  Args:
689    input: A list of at least 1 `Tensor` object with the same type.
690    out_type: The specified output type of the operation (`int32` or `int64`).
691      Defaults to `tf.int32`(optional).
692    name: A name for the operation (optional).
693
694  Returns:
695    A list with the same length as `input` of `Tensor` objects with
696      type `out_type`.
697  """
698
699  return gen_array_ops.shape_n(input, out_type=out_type, name=name)
700
701
702@tf_export("size", v1=[])
703@dispatch.add_dispatch_support
704def size_v2(input, out_type=dtypes.int32, name=None):
705  # pylint: disable=redefined-builtin
706  """Returns the size of a tensor.
707
708  See also `tf.shape`.
709
710  Returns a 0-D `Tensor` representing the number of elements in `input`
711  of type `out_type`. Defaults to tf.int32.
712
713  For example:
714
715  >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
716  >>> tf.size(t)
717  <tf.Tensor: shape=(), dtype=int32, numpy=12>
718
719  Args:
720    input: A `Tensor` or `SparseTensor`.
721    name: A name for the operation (optional).
722    out_type: (Optional) The specified non-quantized numeric output type of the
723      operation. Defaults to `tf.int32`.
724
725  Returns:
726    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
727
728  @compatibility(numpy)
729  Equivalent to np.size()
730  @end_compatibility
731  """
732
733  return size(input, name, out_type)
734
735
736@tf_export(v1=["size"])
737@dispatch.add_dispatch_support
738def size(input, name=None, out_type=dtypes.int32):
739  # pylint: disable=redefined-builtin
740  """Returns the size of a tensor.
741
742  Returns a 0-D `Tensor` representing the number of elements in `input`
743  of type `out_type`. Defaults to tf.int32.
744
745  For example:
746
747  ```python
748  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
749  tf.size(t)  # 12
750  ```
751
752  Args:
753    input: A `Tensor` or `SparseTensor`.
754    name: A name for the operation (optional).
755    out_type: (Optional) The specified non-quantized numeric output type of the
756      operation. Defaults to `tf.int32`.
757
758  Returns:
759    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
760
761  @compatibility(numpy)
762  Equivalent to np.size()
763  @end_compatibility
764  """
765  return size_internal(input, name, optimize=True, out_type=out_type)
766
767
768def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
769  # pylint: disable=redefined-builtin,protected-access
770  """Returns the size of a tensor.
771
772  Args:
773    input: A `Tensor` or `SparseTensor`.
774    name: A name for the operation (optional).
775    optimize: if true, encode the size as a constant when possible.
776    out_type: (Optional) The specified non-quantized numeric output type of the
777      operation. Defaults to `tf.int32`.
778
779  Returns:
780    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
781  """
782  if (context.executing_eagerly() and not hasattr(input, "graph") and
783      not isinstance(
784          input,
785          (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))):
786    input = ops.convert_to_tensor(input)
787    np_out_type = out_type.as_numpy_dtype
788    num_elements = np.prod(input._shape_tuple(), dtype=np_out_type)  # pylint: disable=protected-access
789    return ops.convert_to_tensor(num_elements, dtype=out_type)
790  with ops.name_scope(name, "Size", [input]) as name:
791    if isinstance(
792        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
793      return gen_math_ops.prod(
794          gen_math_ops.cast(input.dense_shape, out_type), 0, name=name)
795    else:
796      input = ops.convert_to_tensor(input)
797      input_shape = input.get_shape()
798      if optimize:
799        if input_shape.is_fully_defined():
800          return constant(input_shape.num_elements(), out_type, name=name)
801        if input_shape.dims and any(dim == 0 for dim in input_shape.dims):
802          return constant(0, out_type, name=name)
803      return gen_array_ops.size(input, name=name, out_type=out_type)
804
805
806@tf_export("rank")
807@dispatch.add_dispatch_support
808def rank(input, name=None):
809  # pylint: disable=redefined-builtin
810  """Returns the rank of a tensor.
811
812  See also `tf.shape`.
813
814  Returns a 0-D `int32` `Tensor` representing the rank of `input`.
815
816  For example:
817
818  ```python
819  # shape of tensor 't' is [2, 2, 3]
820  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
821  tf.rank(t)  # 3
822  ```
823
824  **Note**: The rank of a tensor is not the same as the rank of a matrix. The
825  rank of a tensor is the number of indices required to uniquely select each
826  element of the tensor. Rank is also known as "order", "degree", or "ndims."
827
828  Args:
829    input: A `Tensor` or `SparseTensor`.
830    name: A name for the operation (optional).
831
832  Returns:
833    A `Tensor` of type `int32`.
834
835  @compatibility(numpy)
836  Equivalent to np.ndim
837  @end_compatibility
838  """
839  return rank_internal(input, name, optimize=True)
840
841
842def rank_internal(input, name=None, optimize=True):
843  # pylint: disable=redefined-builtin
844  """Returns the rank of a tensor.
845
846  Args:
847    input: A `Tensor` or `SparseTensor`.
848    name: A name for the operation (optional).
849    optimize: if true, encode the rank as a constant when possible.
850
851  Returns:
852    A `Tensor` of type `int32`.
853  """
854  with ops.name_scope(name, "Rank", [input]) as name:
855    if isinstance(
856        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
857      return gen_array_ops.size(input.dense_shape, name=name)
858    else:
859      input = ops.convert_to_tensor(input)
860      input_shape = input.get_shape()
861      if optimize and input_shape.ndims is not None:
862        return constant(input_shape.ndims, dtypes.int32, name=name)
863      return gen_array_ops.rank(input, name=name)
864
865
866_SLICE_TYPE_ERROR = (
867    "Only integers, slices (`:`), ellipsis (`...`), "
868    "tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid "
869    "indices")
870
871_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64,
872                           dtypes.int64_ref)
873
874
875def _check_index(idx):
876  """Check if a given value is a valid index into a tensor."""
877  if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
878    return
879
880  # Optimistic check. Assumptions:
881  # * any object with a dtype is supported
882  # * any object with a dtype has a sizeable shape attribute.
883  dtype = getattr(idx, "dtype", None)
884  if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or
885      idx.shape and len(idx.shape) == 1):
886    # TODO(slebedev): IndexError seems more appropriate here, but it
887    # will break `_slice_helper` contract.
888    raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))
889
890
891def _is_undefined_dimension(d):
892  return isinstance(d, tensor_shape.Dimension) and d.value is None
893
894
895@tf_export("__operators__.getitem", v1=[])
896@dispatch.add_dispatch_support
897def _slice_helper(tensor, slice_spec, var=None):
898  """Overload for Tensor.__getitem__.
899
900  This operation extracts the specified region from the tensor.
901  The notation is similar to NumPy with the restriction that
902  currently only support basic indexing. That means that
903  using a non-scalar tensor as input is not currently allowed.
904
905  Some useful examples:
906
907  ```python
908  # Strip leading and trailing 2 elements
909  foo = tf.constant([1,2,3,4,5,6])
910  print(foo[2:-2].eval())  # => [3,4]
911
912  # Skip every other row and reverse the order of the columns
913  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
914  print(foo[::2,::-1].eval())  # => [[3,2,1], [9,8,7]]
915
916  # Use scalar tensors as indices on both dimensions
917  print(foo[tf.constant(0), tf.constant(2)].eval())  # => 3
918
919  # Insert another dimension
920  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
921  print(foo[tf.newaxis, :, :].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
922  print(foo[:, tf.newaxis, :].eval()) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
923  print(foo[:, :, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]],
924  [[7],[8],[9]]]
925
926  # Ellipses (3 equivalent operations)
927  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
928  print(foo[tf.newaxis, :, :].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
929  print(foo[tf.newaxis, ...].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
930  print(foo[tf.newaxis].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
931
932  # Masks
933  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
934  print(foo[foo > 2].eval())  # => [3, 4, 5, 6, 7, 8, 9]
935  ```
936
937  Notes:
938    - `tf.newaxis` is `None` as in NumPy.
939    - An implicit ellipsis is placed at the end of the `slice_spec`
940    - NumPy advanced indexing is currently not supported.
941
942  Purpose in the API:
943
944    This method is exposed in TensorFlow's API so that library developers
945    can register dispatching for `Tensor.__getitem__` to allow it to handle
946    custom composite tensors & other custom objects.
947
948    The API symbol is not intended to be called by users directly and does
949    appear in TensorFlow's generated documentation.
950
951  Args:
952    tensor: An ops.Tensor object.
953    slice_spec: The arguments to Tensor.__getitem__.
954    var: In the case of variable slice assignment, the Variable object to slice
955      (i.e. tensor is the read-only view of this variable).
956
957  Returns:
958    The appropriate slice of "tensor", based on "slice_spec".
959
960  Raises:
961    ValueError: If a slice range is negative size.
962    TypeError: If the slice indices aren't int, slice, ellipsis,
963      tf.newaxis or scalar int32/int64 tensors.
964  """
965  tensor = ops.convert_to_tensor(tensor)
966  # TODO(wangpeng): Consider supporting var
967  if var is None and ops._numpy_style_slicing:  # pylint: disable=protected-access
968    return tensor._numpy_style_getitem(slice_spec)  # pylint: disable=protected-access
969
970  if isinstance(slice_spec, bool) or \
971  (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
972  (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
973    return boolean_mask(tensor=tensor, mask=slice_spec)
974
975  if not isinstance(slice_spec, (list, tuple)):
976    slice_spec = [slice_spec]
977
978  begin, end, strides = [], [], []
979  index = 0
980
981  new_axis_mask, shrink_axis_mask = 0, 0
982  begin_mask, end_mask = 0, 0
983  ellipsis_mask = 0
984  for s in slice_spec:
985    if isinstance(s, _BaseSlice):
986      if s.start is not None and not _is_undefined_dimension(s.start):
987        _check_index(s.start)
988        begin.append(s.start)
989      else:
990        begin.append(0)
991        begin_mask |= (1 << index)
992      if s.stop is not None and not _is_undefined_dimension(s.stop):
993        _check_index(s.stop)
994        end.append(s.stop)
995      else:
996        end.append(0)
997        end_mask |= (1 << index)
998      if s.step is not None and not _is_undefined_dimension(s.step):
999        _check_index(s.step)
1000        strides.append(s.step)
1001      else:
1002        strides.append(1)
1003    elif s is Ellipsis:
1004      begin.append(0)
1005      end.append(0)
1006      strides.append(1)
1007      ellipsis_mask |= (1 << index)
1008    elif s is newaxis:
1009      begin.append(0)
1010      end.append(0)
1011      strides.append(1)
1012      new_axis_mask |= (1 << index)
1013    else:
1014      _check_index(s)
1015      begin.append(s)
1016      end.append(s + 1)
1017      strides.append(1)
1018      shrink_axis_mask |= (1 << index)
1019    index += 1
1020
1021  # stack possibly involves no tensors, so we must use op_scope correct graph.
1022  with ops.name_scope(
1023      None,
1024      "strided_slice", [tensor] + begin + end + strides,
1025      skip_on_eager=False) as name:
1026    if begin:
1027      packed_begin, packed_end, packed_strides = (stack(begin), stack(end),
1028                                                  stack(strides))
1029      if (packed_begin.dtype == dtypes.int64 or
1030          packed_end.dtype == dtypes.int64 or
1031          packed_strides.dtype == dtypes.int64):
1032        if packed_begin.dtype != dtypes.int64:
1033          packed_begin = gen_math_ops.cast(packed_begin, dtypes.int64)
1034        if packed_end.dtype != dtypes.int64:
1035          packed_end = gen_math_ops.cast(packed_end, dtypes.int64)
1036        if packed_strides.dtype != dtypes.int64:
1037          packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64)
1038    else:
1039      var_empty = constant([], dtype=dtypes.int32)
1040      packed_begin = packed_end = packed_strides = var_empty
1041    return strided_slice(
1042        tensor,
1043        packed_begin,
1044        packed_end,
1045        packed_strides,
1046        begin_mask=begin_mask,
1047        end_mask=end_mask,
1048        shrink_axis_mask=shrink_axis_mask,
1049        new_axis_mask=new_axis_mask,
1050        ellipsis_mask=ellipsis_mask,
1051        var=var,
1052        name=name)
1053
1054
1055# pylint: disable=undefined-variable,protected-access,redefined-outer-name
1056@tf_export("slice")
1057@dispatch.add_dispatch_support
1058def slice(input_, begin, size, name=None):
1059  # pylint: disable=redefined-builtin
1060  """Extracts a slice from a tensor.
1061
1062  See also `tf.strided_slice`.
1063
1064  This operation extracts a slice of size `size` from a tensor `input_` starting
1065  at the location specified by `begin`. The slice `size` is represented as a
1066  tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
1067  of `input_` that you want to slice. The starting location (`begin`) for the
1068  slice is represented as an offset in each dimension of `input_`. In other
1069  words, `begin[i]` is the offset into the i'th dimension of `input_` that you
1070  want to slice from.
1071
1072  Note that `tf.Tensor.__getitem__` is typically a more pythonic way to
1073  perform slices, as it allows you to write `foo[3:7, :-2]` instead of
1074  `tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`.
1075
1076  `begin` is zero-based; `size` is one-based. If `size[i]` is -1,
1077  all remaining elements in dimension i are included in the
1078  slice. In other words, this is equivalent to setting:
1079
1080  `size[i] = input_.dim_size(i) - begin[i]`
1081
1082  This operation requires that:
1083
1084  `0 <= begin[i] <= begin[i] + size[i] <= Di  for i in [0, n]`
1085
1086  For example:
1087
1088  ```python
1089  t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1090                   [[3, 3, 3], [4, 4, 4]],
1091                   [[5, 5, 5], [6, 6, 6]]])
1092  tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]]
1093  tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
1094                                     #   [4, 4, 4]]]
1095  tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
1096                                     #  [[5, 5, 5]]]
1097  ```
1098
1099  Args:
1100    input_: A `Tensor`.
1101    begin: An `int32` or `int64` `Tensor`.
1102    size: An `int32` or `int64` `Tensor`.
1103    name: A name for the operation (optional).
1104
1105  Returns:
1106    A `Tensor` the same type as `input_`.
1107  """
1108  return gen_array_ops._slice(input_, begin, size, name=name)
1109
1110
1111# pylint: disable=invalid-name
1112@tf_export("strided_slice")
1113@dispatch.add_dispatch_support
1114def strided_slice(input_,
1115                  begin,
1116                  end,
1117                  strides=None,
1118                  begin_mask=0,
1119                  end_mask=0,
1120                  ellipsis_mask=0,
1121                  new_axis_mask=0,
1122                  shrink_axis_mask=0,
1123                  var=None,
1124                  name=None):
1125  """Extracts a strided slice of a tensor (generalized Python array indexing).
1126
1127  See also `tf.slice`.
1128
1129  **Instead of calling this op directly most users will want to use the
1130  NumPy-style slicing syntax (e.g. `tensor[..., 3:4:-1, tf.newaxis, 3]`), which
1131  is supported via `tf.Tensor.__getitem__` and `tf.Variable.__getitem__`.**
1132  The interface of this op is a low-level encoding of the slicing syntax.
1133
1134  Roughly speaking, this op extracts a slice of size `(end-begin)/stride`
1135  from the given `input_` tensor. Starting at the location specified by `begin`
1136  the slice continues by adding `stride` to the index until all dimensions are
1137  not less than `end`.
1138  Note that a stride can be negative, which causes a reverse slice.
1139
1140  Given a Python slice `input[spec0, spec1, ..., specn]`,
1141  this function will be called as follows.
1142
1143  `begin`, `end`, and `strides` will be vectors of length n.
1144  n in general is not equal to the rank of the `input_` tensor.
1145
1146  In each mask field (`begin_mask`, `end_mask`, `ellipsis_mask`,
1147  `new_axis_mask`, `shrink_axis_mask`) the ith bit will correspond to
1148  the ith spec.
1149
1150  If the ith bit of `begin_mask` is set, `begin[i]` is ignored and
1151  the fullest possible range in that dimension is used instead.
1152  `end_mask` works analogously, except with the end range.
1153
1154  `foo[5:,:,:3]` on a 7x8x9 tensor is equivalent to `foo[5:7,0:8,0:3]`.
1155  `foo[::-1]` reverses a tensor with shape 8.
1156
1157  If the ith bit of `ellipsis_mask` is set, as many unspecified dimensions
1158  as needed will be inserted between other dimensions. Only one
1159  non-zero bit is allowed in `ellipsis_mask`.
1160
1161  For example `foo[3:5,...,4:5]` on a shape 10x3x3x10 tensor is
1162  equivalent to `foo[3:5,:,:,4:5]` and
1163  `foo[3:5,...]` is equivalent to `foo[3:5,:,:,:]`.
1164
1165  If the ith bit of `new_axis_mask` is set, then `begin`,
1166  `end`, and `stride` are ignored and a new length 1 dimension is
1167  added at this point in the output tensor.
1168
1169  For example,
1170  `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor.
1171
1172  If the ith bit of `shrink_axis_mask` is set, it implies that the ith
1173  specification shrinks the dimensionality by 1, taking on the value at index
1174  `begin[i]`. `end[i]` and `strides[i]` are ignored in this case. For example in
1175  Python one might do `foo[:, 3, :]` which would result in `shrink_axis_mask`
1176  equal to 2.
1177
1178
1179  NOTE: `begin` and `end` are zero-indexed.
1180  `strides` entries must be non-zero.
1181
1182
1183  ```python
1184  t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1185                   [[3, 3, 3], [4, 4, 4]],
1186                   [[5, 5, 5], [6, 6, 6]]])
1187  tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1])  # [[[3, 3, 3]]]
1188  tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1])  # [[[3, 3, 3],
1189                                                        #   [4, 4, 4]]]
1190  tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1])  # [[[4, 4, 4],
1191                                                           #   [3, 3, 3]]]
1192  ```
1193
1194  Args:
1195    input_: A `Tensor`.
1196    begin: An `int32` or `int64` `Tensor`.
1197    end: An `int32` or `int64` `Tensor`.
1198    strides: An `int32` or `int64` `Tensor`.
1199    begin_mask: An `int32` mask.
1200    end_mask: An `int32` mask.
1201    ellipsis_mask: An `int32` mask.
1202    new_axis_mask: An `int32` mask.
1203    shrink_axis_mask: An `int32` mask.
1204    var: The variable corresponding to `input_` or None
1205    name: A name for the operation (optional).
1206
1207  Returns:
1208    A `Tensor` the same type as `input`.
1209  """
1210
1211  if strides is None:
1212    strides = ones_like(begin)
1213
1214  op = gen_array_ops.strided_slice(
1215      input=input_,
1216      begin=begin,
1217      end=end,
1218      strides=strides,
1219      name=name,
1220      begin_mask=begin_mask,
1221      end_mask=end_mask,
1222      ellipsis_mask=ellipsis_mask,
1223      new_axis_mask=new_axis_mask,
1224      shrink_axis_mask=shrink_axis_mask)
1225
1226  parent_name = name
1227
1228  if var is not None:
1229    def assign(val, name=None):
1230      """Closure that holds all the arguments to create an assignment."""
1231
1232      if name is None:
1233        name = parent_name + "_assign"
1234
1235      return var._strided_slice_assign(
1236          begin=begin,
1237          end=end,
1238          strides=strides,
1239          value=val,
1240          name=name,
1241          begin_mask=begin_mask,
1242          end_mask=end_mask,
1243          ellipsis_mask=ellipsis_mask,
1244          new_axis_mask=new_axis_mask,
1245          shrink_axis_mask=shrink_axis_mask)
1246
1247    op.assign = assign
1248
1249  return op
1250
1251
1252def _SliceHelperVar(var, slice_spec):
1253  """Creates a slice helper object given a variable.
1254
1255  This allows creating a sub-tensor from part of the current contents
1256  of a variable. See `tf.Tensor.__getitem__` for detailed examples
1257  of slicing.
1258
1259  This function in addition also allows assignment to a sliced range.
1260  This is similar to `__setitem__` functionality in Python. However,
1261  the syntax is different so that the user can capture the assignment
1262  operation for grouping or passing to `sess.run()`.
1263  For example,
1264
1265  ```python
1266  import tensorflow as tf
1267  A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
1268  with tf.compat.v1.Session() as sess:
1269    sess.run(tf.compat.v1.global_variables_initializer())
1270    print(sess.run(A[:2, :2]))  # => [[1,2], [4,5]]
1271
1272    op = A[:2,:2].assign(22. * tf.ones((2, 2)))
1273    print(sess.run(op))  # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
1274  ```
1275
1276  Note that assignments currently do not support NumPy broadcasting
1277  semantics.
1278
1279  Args:
1280    var: An `ops.Variable` object.
1281    slice_spec: The arguments to `Tensor.__getitem__`.
1282
1283  Returns:
1284    The appropriate slice of "tensor", based on "slice_spec".
1285    As an operator. The operator also has a `assign()` method
1286    that can be used to generate an assignment operator.
1287
1288  Raises:
1289    ValueError: If a slice range is negative size.
1290    TypeError: TypeError: If the slice indices aren't int, slice,
1291      ellipsis, tf.newaxis or int32/int64 tensors.
1292
1293  """
1294
1295  return _slice_helper(var.value(), slice_spec, var)
1296
1297
1298ops.Tensor._override_operator("__getitem__", _slice_helper)
1299
1300
1301@tf_export("parallel_stack")
1302@dispatch.add_dispatch_support
1303def parallel_stack(values, name="parallel_stack"):
1304  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel.
1305
1306  Requires that the shape of inputs be known at graph construction time.
1307
1308  Packs the list of tensors in `values` into a tensor with rank one higher than
1309  each tensor in `values`, by packing them along the first dimension.
1310  Given a list of length `N` of tensors of shape `(A, B, C)`; the `output`
1311  tensor will have the shape `(N, A, B, C)`.
1312
1313  For example:
1314
1315  ```python
1316  x = tf.constant([1, 4])
1317  y = tf.constant([2, 5])
1318  z = tf.constant([3, 6])
1319  tf.parallel_stack([x, y, z])  # [[1, 4], [2, 5], [3, 6]]
1320  ```
1321
1322  The difference between `stack` and `parallel_stack` is that `stack` requires
1323  all the inputs be computed before the operation will begin but doesn't require
1324  that the input shapes be known during graph construction.
1325
1326  `parallel_stack` will copy pieces of the input into the output as they become
1327  available, in some situations this can provide a performance benefit.
1328
1329  Unlike `stack`, `parallel_stack` does NOT support backpropagation.
1330
1331  This is the opposite of unstack.  The numpy equivalent is
1332
1333      tf.parallel_stack([x, y, z]) = np.asarray([x, y, z])
1334
1335  @compatibility(eager)
1336  parallel_stack is not compatible with eager execution.
1337  @end_compatibility
1338
1339  Args:
1340    values: A list of `Tensor` objects with the same shape and type.
1341    name: A name for this operation (optional).
1342
1343  Returns:
1344    output: A stacked `Tensor` with the same type as `values`.
1345
1346  Raises:
1347    RuntimeError: if executed in eager mode.
1348  """
1349  if context.executing_eagerly():
1350    raise RuntimeError("tf.parallel_stack() is not compatible with "
1351                       "eager execution.")
1352  with ops.name_scope(name):
1353    value_t = ops.convert_to_tensor(values[0])
1354    value_shape = ops.convert_to_tensor(value_t).get_shape()
1355
1356    output_shape = tensor_shape.TensorShape([len(values)])
1357    output_shape = output_shape.concatenate(value_shape)
1358    # expand_dims converts concat to stack.
1359    return gen_array_ops.parallel_concat(
1360        [expand_dims(value, 0) for value in values], shape=output_shape)
1361
1362
1363@tf_export("stack")
1364@dispatch.add_dispatch_support
1365def stack(values, axis=0, name="stack"):
1366  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
1367
1368  See also `tf.concat`, `tf.tile`, `tf.repeat`.
1369
1370  Packs the list of tensors in `values` into a tensor with rank one higher than
1371  each tensor in `values`, by packing them along the `axis` dimension.
1372  Given a list of length `N` of tensors of shape `(A, B, C)`;
1373
1374  if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
1375  if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
1376  Etc.
1377
1378  For example:
1379
1380  >>> x = tf.constant([1, 4])
1381  >>> y = tf.constant([2, 5])
1382  >>> z = tf.constant([3, 6])
1383  >>> tf.stack([x, y, z])
1384  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
1385  array([[1, 4],
1386         [2, 5],
1387         [3, 6]], dtype=int32)>
1388  >>> tf.stack([x, y, z], axis=1)
1389  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
1390  array([[1, 2, 3],
1391         [4, 5, 6]], dtype=int32)>
1392
1393  This is the opposite of unstack.  The numpy equivalent is `np.stack`
1394
1395  >>> np.array_equal(np.stack([x, y, z]), tf.stack([x, y, z]))
1396  True
1397
1398  Args:
1399    values: A list of `Tensor` objects with the same shape and type.
1400    axis: An `int`. The axis to stack along. Defaults to the first dimension.
1401      Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
1402    name: A name for this operation (optional).
1403
1404  Returns:
1405    output: A stacked `Tensor` with the same type as `values`.
1406
1407  Raises:
1408    ValueError: If `axis` is out of the range [-(R+1), R+1).
1409  """
1410  if axis == 0:
1411    try:
1412      # If the input is a constant list, it can be converted to a constant op
1413      return ops.convert_to_tensor(values, name=name)
1414    except (TypeError, ValueError):
1415      pass  # Input list contains non-constant tensors
1416
1417  value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple()  # pylint: disable=protected-access
1418  if value_shape is not None:
1419    expanded_num_dims = len(value_shape) + 1
1420    if axis < -expanded_num_dims or axis >= expanded_num_dims:
1421      raise ValueError("axis = %d not in [%d, %d)" %
1422                       (axis, -expanded_num_dims, expanded_num_dims))
1423
1424  return gen_array_ops.pack(values, axis=axis, name=name)
1425
1426
1427# pylint: disable=invalid-name
1428def _autopacking_helper(list_or_tuple, dtype, name):
1429  """Converts the given list or tuple to a tensor by packing.
1430
1431  Args:
1432    list_or_tuple: A (possibly nested) list or tuple containing a tensor.
1433    dtype: The element type of the returned tensor.
1434    name: A name for the returned tensor.
1435
1436  Returns:
1437    A `tf.Tensor` with value equivalent to `list_or_tuple`.
1438  """
1439  if context.executing_eagerly():
1440    # NOTE: Fast path when all the items are tensors, this doesn't do any type
1441    # checking.
1442    if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
1443      return gen_array_ops.pack(list_or_tuple, name=name)
1444  must_pack = False
1445  converted_elems = []
1446  with ops.name_scope(name) as scope:
1447    for i, elem in enumerate(list_or_tuple):
1448      if isinstance(elem, core.Tensor):
1449        if dtype is not None and elem.dtype.base_dtype != dtype:
1450          raise TypeError("Cannot convert a list containing a tensor of dtype "
1451                          "%s to %s (Tensor is: %r)" %
1452                          (elem.dtype, dtype, elem))
1453        converted_elems.append(elem)
1454        must_pack = True
1455      elif isinstance(elem, (list, tuple)):
1456        converted_elem = _autopacking_helper(elem, dtype, str(i))
1457        if isinstance(converted_elem, core.Tensor):
1458          must_pack = True
1459        converted_elems.append(converted_elem)
1460      else:
1461        converted_elems.append(elem)
1462    if must_pack:
1463      elems_as_tensors = []
1464      for i, elem in enumerate(converted_elems):
1465        if isinstance(elem, core.Tensor):
1466          elems_as_tensors.append(elem)
1467        else:
1468          # NOTE(mrry): This is inefficient, but it enables us to
1469          # handle the case where the list arguments are other
1470          # convertible-to-tensor types, such as numpy arrays.
1471          elems_as_tensors.append(
1472              constant_op.constant(elem, dtype=dtype, name=str(i)))
1473      return gen_array_ops.pack(elems_as_tensors, name=scope)
1474    else:
1475      return converted_elems
1476
1477
1478def _get_dtype_from_nested_lists(list_or_tuple):
1479  """Returns the dtype of any tensor-like object in `list_or_tuple`, if found.
1480
1481  Args:
1482    list_or_tuple: A list or tuple representing an object that can be converted
1483      to a `tf.Tensor`.
1484
1485  Returns:
1486    The dtype of any tensor-like object in `list_or_tuple`, or `None` if no
1487    such object exists.
1488  """
1489  for elem in list_or_tuple:
1490    if isinstance(elem, core.Tensor):
1491      return elem.dtype.base_dtype
1492    elif isinstance(elem, (list, tuple)):
1493      maybe_dtype = _get_dtype_from_nested_lists(elem)
1494      if maybe_dtype is not None:
1495        return maybe_dtype
1496  return None
1497
1498
1499def _cast_nested_seqs_to_dtype(dtype):
1500
1501  def _maybe_cast(elem):
1502    if isinstance(elem, core.Tensor):
1503      if dtype != elem.dtype.base_dtype:
1504        elem = gen_math_ops.cast(elem, dtype)
1505    return elem
1506
1507  return _maybe_cast
1508
1509
1510_NON_AUTOPACKABLE_TYPES = set(np.core.numerictypes.ScalarType)
1511_NON_AUTOPACKABLE_TYPES.add(np.ndarray)
1512
1513
1514def _should_not_autopack(v):
1515  # The condition we really want is
1516  #    any(isinstance(elem, core.Tensor))
1517  # but it is >5x slower due to abc.ABCMeta.__instancecheck__.
1518  # pylint: disable=unidiomatic-typecheck
1519  # TODO(slebedev): add nest.all?
1520  return all(type(elem) in _NON_AUTOPACKABLE_TYPES for elem in nest.flatten(v))
1521  # pylint: enable=unidiomatic-typecheck
1522
1523
1524def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
1525  """Tensor conversion function that automatically packs arguments."""
1526  if as_ref or _should_not_autopack(v):
1527    return NotImplemented
1528  inferred_dtype = _get_dtype_from_nested_lists(v)
1529  if inferred_dtype is None:
1530    # We did not find any tensor-like objects in the nested lists, so defer to
1531    # other conversion functions.
1532    return NotImplemented
1533  if dtype is None:
1534    dtype = inferred_dtype
1535  elif dtype != inferred_dtype:
1536    v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
1537  return _autopacking_helper(v, dtype, name or "packed")
1538
1539
1540# pylint: enable=invalid-name
1541
1542# NOTE: Register this conversion function to run *before* one that
1543# assumes every element is a value.
1544ops.register_tensor_conversion_function((list, tuple),
1545                                        _autopacking_conversion_function, 99)
1546
1547
1548@tf_export("unstack")
1549@dispatch.add_dispatch_support
1550def unstack(value, num=None, axis=0, name="unstack"):
1551  """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
1552
1553  Unpacks tensors from `value` by chipping it along the `axis` dimension.
1554
1555  >>> x = tf.reshape(tf.range(12), (3,4))
1556  >>>
1557  >>> p, q, r = tf.unstack(x)
1558  >>> p.shape.as_list()
1559  [4]
1560
1561  >>> i, j, k, l = tf.unstack(x, axis=1)
1562  >>> i.shape.as_list()
1563  [3]
1564
1565  This is the opposite of stack.
1566
1567  >>> x = tf.stack([i, j, k, l], axis=1)
1568
1569  More generally if you have a tensor of shape `(A, B, C, D)`:
1570
1571  >>> A, B, C, D = [2, 3, 4, 5]
1572  >>> t = tf.random.normal(shape=[A, B, C, D])
1573
1574  The number of tensor returned is equal to the length of the target `axis`:
1575
1576  >>> axis = 2
1577  >>> items = tf.unstack(t, axis=axis)
1578  >>> len(items) == t.shape[axis]
1579  True
1580
1581  The shape of each result tensor is equal to the shape of the input tensor,
1582  with the target `axis` removed.
1583
1584  >>> items[0].shape.as_list()  # [A, B, D]
1585  [2, 3, 5]
1586
1587  The value of each tensor `items[i]` is equal to the slice of `input` across
1588  `axis` at index `i`:
1589
1590  >>> for i in range(len(items)):
1591  ...   slice = t[:,:,i,:]
1592  ...   assert tf.reduce_all(slice == items[i])
1593
1594  #### Python iterable unpacking
1595
1596  With eager execution you _can_ unstack the 0th axis of a tensor using python's
1597  iterable unpacking:
1598
1599  >>> t = tf.constant([1,2,3])
1600  >>> a,b,c = t
1601
1602  `unstack` is still necessary because Iterable unpacking doesn't work in
1603  a `@tf.function`: Symbolic tensors are not iterable.
1604
1605  You need to use `tf.unstack` here:
1606
1607  >>> @tf.function
1608  ... def bad(t):
1609  ...   a,b,c = t
1610  ...   return a
1611  >>>
1612  >>> bad(t)
1613  Traceback (most recent call last):
1614  ...
1615  OperatorNotAllowedInGraphError: ...
1616
1617  >>> @tf.function
1618  ... def good(t):
1619  ...   a,b,c = tf.unstack(t)
1620  ...   return a
1621  >>>
1622  >>> good(t).numpy()
1623  1
1624
1625  #### Unknown shapes
1626
1627  Eager tensors have concrete values, so their shape is always known.
1628  Inside a `tf.function` the symbolic tensors may have unknown shapes.
1629  If the length of `axis` is unknown `tf.unstack` will fail because it cannot
1630  handle an unknown number of tensors:
1631
1632  >>> @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
1633  ... def bad(t):
1634  ...   tensors = tf.unstack(t)
1635  ...   return tensors[0]
1636  >>>
1637  >>> bad(tf.constant([1,2,3]))
1638  Traceback (most recent call last):
1639  ...
1640  ValueError: Cannot infer num from shape (None,)
1641
1642  If you know the `axis` length you can pass it as the `num` argument. But this
1643  must be a constant value.
1644
1645  If you actually need a variable number of tensors in a single `tf.function`
1646  trace, you will need to use exlicit loops and a `tf.TensorArray` instead.
1647
1648  Args:
1649    value: A rank `R > 0` `Tensor` to be unstacked.
1650    num: An `int`. The length of the dimension `axis`. Automatically inferred if
1651      `None` (the default).
1652    axis: An `int`. The axis to unstack along. Defaults to the first dimension.
1653      Negative values wrap around, so the valid range is `[-R, R)`.
1654    name: A name for the operation (optional).
1655
1656  Returns:
1657    The list of `Tensor` objects unstacked from `value`.
1658
1659  Raises:
1660    ValueError: If `axis` is out of the range `[-R, R)`.
1661    ValueError: If `num` is unspecified and cannot be inferred.
1662    InvalidArgumentError: If `num` does not match the shape of `value`.
1663  """
1664  if num is None:
1665    value = ops.convert_to_tensor(value)
1666    value_shape = value.get_shape()
1667    if value_shape.ndims is not None:
1668      if axis < -value_shape.ndims or axis >= value_shape.ndims:
1669        raise ValueError("axis = %d not in [%d, %d)" %
1670                         (axis, -value_shape.ndims, value_shape.ndims))
1671      num = value_shape.dims[axis].value
1672  if num is None:
1673    raise ValueError("Cannot infer num from shape %s" % value_shape)
1674  return gen_array_ops.unpack(value, num=num, axis=axis, name=name)
1675
1676
1677@tf_export("concat")
1678@dispatch.add_dispatch_support
1679def concat(values, axis, name="concat"):
1680  """Concatenates tensors along one dimension.
1681
1682  See also `tf.tile`, `tf.stack`, `tf.repeat`.
1683
1684  Concatenates the list of tensors `values` along dimension `axis`.  If
1685  `values[i].shape = [D0, D1, ... Daxis(i), ...Dn]`, the concatenated
1686  result has shape
1687
1688      [D0, D1, ... Raxis, ...Dn]
1689
1690  where
1691
1692      Raxis = sum(Daxis(i))
1693
1694  That is, the data from the input tensors is joined along the `axis`
1695  dimension.
1696
1697  The number of dimensions of the input tensors must match, and all dimensions
1698  except `axis` must be equal.
1699
1700  For example:
1701
1702  >>> t1 = [[1, 2, 3], [4, 5, 6]]
1703  >>> t2 = [[7, 8, 9], [10, 11, 12]]
1704  >>> tf.concat([t1, t2], 0)
1705  <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
1706  array([[ 1,  2,  3],
1707         [ 4,  5,  6],
1708         [ 7,  8,  9],
1709         [10, 11, 12]], dtype=int32)>
1710
1711  >>> tf.concat([t1, t2], 1)
1712  <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
1713  array([[ 1,  2,  3,  7,  8,  9],
1714         [ 4,  5,  6, 10, 11, 12]], dtype=int32)>
1715
1716  As in Python, the `axis` could also be negative numbers. Negative `axis`
1717  are interpreted as counting from the end of the rank, i.e.,
1718   `axis + rank(values)`-th dimension.
1719
1720  For example:
1721
1722  >>> t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
1723  >>> t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
1724  >>> tf.concat([t1, t2], -1)
1725  <tf.Tensor: shape=(2, 2, 4), dtype=int32, numpy=
1726    array([[[ 1,  2,  7,  4],
1727            [ 2,  3,  8,  4]],
1728           [[ 4,  4,  2, 10],
1729            [ 5,  3, 15, 11]]], dtype=int32)>
1730
1731  Note: If you are concatenating along a new axis consider using stack.
1732  E.g.
1733
1734  ```python
1735  tf.concat([tf.expand_dims(t, axis) for t in tensors], axis)
1736  ```
1737
1738  can be rewritten as
1739
1740  ```python
1741  tf.stack(tensors, axis=axis)
1742  ```
1743
1744  Args:
1745    values: A list of `Tensor` objects or a single `Tensor`.
1746    axis: 0-D `int32` `Tensor`.  Dimension along which to concatenate. Must be
1747      in the range `[-rank(values), rank(values))`. As in Python, indexing for
1748      axis is 0-based. Positive axis in the rage of `[0, rank(values))` refers
1749      to `axis`-th dimension. And negative axis refers to `axis +
1750      rank(values)`-th dimension.
1751    name: A name for the operation (optional).
1752
1753  Returns:
1754    A `Tensor` resulting from concatenation of the input tensors.
1755  """
1756  if not isinstance(values, (list, tuple)):
1757    values = [values]
1758  # TODO(mrry): Change to return values?
1759  if len(values) == 1:  # Degenerate case of one tensor.
1760    # Make a throwaway call to convert_to_tensor to make sure
1761    # that axis is of the correct type, and make sure that
1762    # the returned tensor is a scalar.
1763    # TODO(keveman): Implement a standalone type and shape checker.
1764    with ops.name_scope(name) as scope:
1765      ops.convert_to_tensor(
1766          axis, name="concat_dim",
1767          dtype=dtypes.int32).get_shape().assert_has_rank(0)
1768      return identity(values[0], name=name)
1769  return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
1770
1771
1772@tf_export(v1=["boolean_mask"])
1773@dispatch.add_dispatch_support
1774def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
1775  """Apply boolean mask to tensor.
1776
1777  Numpy equivalent is `tensor[mask]`.
1778
1779  In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1780  the first K dimensions of `tensor`'s shape.  We then have:
1781    `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1782  where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1783  The `axis` could be used with `mask` to indicate the axis to mask from.
1784  In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1785  the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1786
1787  See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1788  ragged tensors, and can be used if you need to preserve the masked dimensions
1789  of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1790
1791  Examples:
1792
1793  ```python
1794  # 1-D example
1795  tensor = [0, 1, 2, 3]
1796  mask = np.array([True, False, True, False])
1797  tf.boolean_mask(tensor, mask)  # [0, 2]
1798
1799  # 2-D example
1800  tensor = [[1, 2], [3, 4], [5, 6]]
1801  mask = np.array([True, False, True])
1802  tf.boolean_mask(tensor, mask)  # [[1, 2], [5, 6]]
1803  ```
1804
1805  Args:
1806    tensor:  N-D Tensor.
1807    mask:  K-D boolean Tensor, K <= N and K must be known statically.
1808    name:  A name for this operation (optional).
1809    axis:  A 0-D int Tensor representing the axis in `tensor` to mask from. By
1810      default, axis is 0 which will mask from the first dimension. Otherwise K +
1811      axis <= N.
1812
1813  Returns:
1814    (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1815    to `True` values in `mask`.
1816
1817  Raises:
1818    ValueError:  If shapes do not conform.
1819  """
1820
1821  def _apply_mask_1d(reshaped_tensor, mask, axis=None):
1822    """Mask tensor along dimension 0 with a 1-D mask."""
1823    indices = squeeze(where_v2(mask), axis=[1])
1824    return gather(reshaped_tensor, indices, axis=axis)
1825
1826  with ops.name_scope(name, values=[tensor, mask]):
1827    tensor = ops.convert_to_tensor(tensor, name="tensor")
1828    mask = ops.convert_to_tensor(mask, name="mask")
1829
1830    shape_mask = mask.get_shape()
1831    ndims_mask = shape_mask.ndims
1832    shape_tensor = tensor.get_shape()
1833    if ndims_mask == 0:
1834      raise ValueError("mask cannot be scalar.")
1835    if ndims_mask is None:
1836      raise ValueError(
1837          "Number of mask dimensions must be specified, even if some dimensions"
1838          " are None.  E.g. shape=[None] is ok, but shape=None is not.")
1839    axis = 0 if axis is None else axis
1840    axis_value = tensor_util.constant_value(axis)
1841    if axis_value is not None:
1842      axis = axis_value
1843      shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask)
1844
1845    leading_size = gen_math_ops.prod(shape(tensor)[axis:axis + ndims_mask], [0])
1846    tensor = reshape(
1847        tensor,
1848        concat([
1849            shape(tensor)[:axis], [leading_size],
1850            shape(tensor)[axis + ndims_mask:]
1851        ], 0))
1852    # TODO(yongtang): tf.reshape in C++ kernel might have set the shape
1853    # correctly, so the following may not be needed? It still might be possible
1854    # that there are some edge case where tensor_util.constant_value resolves
1855    # more cases than ShapeInference of tf.reshape in C++ kernel.
1856    if axis_value is not None:
1857      first_dim = shape_tensor[axis:axis + ndims_mask].num_elements()
1858      tensor.set_shape(
1859          tensor_shape.as_shape(shape_tensor[:axis]).concatenate(
1860              [first_dim]).concatenate(shape_tensor[axis + ndims_mask:]))
1861
1862    mask = reshape(mask, [-1])
1863    return _apply_mask_1d(tensor, mask, axis)
1864
1865
1866@tf_export("boolean_mask", v1=[])
1867@dispatch.add_dispatch_support
1868def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
1869  """Apply boolean mask to tensor.
1870
1871  Numpy equivalent is `tensor[mask]`.
1872
1873  In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1874  the first K dimensions of `tensor`'s shape.  We then have:
1875    `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1876  where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1877  The `axis` could be used with `mask` to indicate the axis to mask from.
1878  In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1879  the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1880
1881  See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1882  ragged tensors, and can be used if you need to preserve the masked dimensions
1883  of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1884
1885  Examples:
1886
1887  >>> tensor = [0, 1, 2, 3]  # 1-D example
1888  >>> mask = np.array([True, False, True, False])
1889  >>> tf.boolean_mask(tensor, mask)
1890  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 2], dtype=int32)>
1891
1892  >>> tensor = [[1, 2], [3, 4], [5, 6]] # 2-D example
1893  >>> mask = np.array([True, False, True])
1894  >>> tf.boolean_mask(tensor, mask)
1895  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
1896  array([[1, 2],
1897         [5, 6]], dtype=int32)>
1898
1899  Args:
1900    tensor:  N-D Tensor.
1901    mask:  K-D boolean Tensor, K <= N and K must be known statically.
1902    axis:  A 0-D int Tensor representing the axis in `tensor` to mask from. By
1903      default, axis is 0 which will mask from the first dimension. Otherwise K +
1904      axis <= N.
1905    name:  A name for this operation (optional).
1906
1907  Returns:
1908    (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1909    to `True` values in `mask`.
1910
1911  Raises:
1912    ValueError:  If shapes do not conform.
1913
1914  Examples:
1915
1916  ```python
1917  # 2-D example
1918  tensor = [[1, 2], [3, 4], [5, 6]]
1919  mask = np.array([True, False, True])
1920  boolean_mask(tensor, mask)  # [[1, 2], [5, 6]]
1921  ```
1922  """
1923  return boolean_mask(tensor, mask, name, axis)
1924
1925
1926@tf_export("sparse.mask", v1=["sparse.mask", "sparse_mask"])
1927@deprecation.deprecated_endpoints("sparse_mask")
1928def sparse_mask(a, mask_indices, name=None):
1929  """Masks elements of `IndexedSlices`.
1930
1931  Given an `IndexedSlices` instance `a`, returns another `IndexedSlices` that
1932  contains a subset of the slices of `a`. Only the slices at indices not
1933  specified in `mask_indices` are returned.
1934
1935  This is useful when you need to extract a subset of slices in an
1936  `IndexedSlices` object.
1937
1938  For example:
1939
1940  ```python
1941  # `a` contains slices at indices [12, 26, 37, 45] from a large tensor
1942  # with shape [1000, 10]
1943  a.indices  # [12, 26, 37, 45]
1944  tf.shape(a.values)  # [4, 10]
1945
1946  # `b` will be the subset of `a` slices at its second and third indices, so
1947  # we want to mask its first and last indices (which are at absolute
1948  # indices 12, 45)
1949  b = tf.sparse.mask(a, [12, 45])
1950
1951  b.indices  # [26, 37]
1952  tf.shape(b.values)  # [2, 10]
1953  ```
1954
1955  Args:
1956    a: An `IndexedSlices` instance.
1957    mask_indices: Indices of elements to mask.
1958    name: A name for the operation (optional).
1959
1960  Returns:
1961    The masked `IndexedSlices` instance.
1962  """
1963  with ops.name_scope(name, "sparse_mask", [a, mask_indices]) as name:
1964    indices = a.indices
1965    out_indices, to_gather = gen_array_ops.list_diff(indices, mask_indices)
1966    out_values = gather(a.values, to_gather, name=name)
1967    return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
1968
1969
1970@tf_export("unique")
1971@dispatch.add_dispatch_support
1972def unique(x, out_idx=dtypes.int32, name=None):
1973  """Finds unique elements in a 1-D tensor.
1974
1975  See also `tf.unique_with_counts`.
1976
1977  This operation returns a tensor `y` containing all of the unique elements
1978  of `x` sorted in the same order that they occur in `x`. This operation
1979  also returns a tensor `idx` the same size as `x` that contains the index
1980  of each value of `x` in the unique output `y`. In other words:
1981
1982
1983    y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
1984
1985  Example usage:
1986
1987  >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
1988  >>> y, idx = unique(x)
1989  >>> y
1990  <tf.Tensor: id=5, shape=(5,), dtype=int32,
1991  numpy=array([1, 2, 4, 7, 8], dtype=int32)>
1992  >>> idx
1993  <tf.Tensor: id=6, shape=(9,), dtype=int32,
1994  numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
1995
1996  Args:
1997    x: A Tensor. 1-D.
1998    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
1999      tf.int32.
2000    name: A name for the operation (optional).
2001
2002  Returns:
2003    A tuple of Tensor objects (y, idx).
2004      y: A Tensor. Has the same type as x.
2005      idx: A Tensor of type out_idx.
2006
2007  """
2008  # TODO(yongtang): switch to v2 once API deprecation
2009  # period (3 weeks) pass.
2010  # TODO(yongtang): The documentation should also
2011  # be updated when switch  to v2.
2012  return gen_array_ops.unique(x, out_idx, name)
2013
2014
2015unique.__doc__ = gen_array_ops.unique.__doc__
2016
2017
2018@tf_export("unique_with_counts")
2019@dispatch.add_dispatch_support
2020def unique_with_counts(x, out_idx=dtypes.int32, name=None):
2021  """Finds unique elements in a 1-D tensor.
2022
2023  See also `tf.unique`.
2024
2025  This operation returns a tensor `y` containing all of the unique elements
2026  of `x` sorted in the same order that they occur in `x`. This operation
2027  also returns a tensor `idx` the same size as `x` that contains the index
2028  of each value of `x` in the unique output `y`. Finally, it returns a
2029  third tensor `count` that contains the count of each element of `y`
2030  in `x`. In other words:
2031
2032    y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
2033
2034  Example usage:
2035
2036  >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
2037  >>> y, idx, count = unique_with_counts(x)
2038  >>> y
2039  <tf.Tensor: id=8, shape=(5,), dtype=int32,
2040  numpy=array([1, 2, 4, 7, 8], dtype=int32)>
2041  >>> idx
2042  <tf.Tensor: id=9, shape=(9,), dtype=int32,
2043  numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
2044  >>> count
2045  <tf.Tensor: id=10, shape=(5,), dtype=int32,
2046  numpy=array([2, 1, 3, 1, 2], dtype=int32)>
2047
2048  Args:
2049    x: A Tensor. 1-D.
2050    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
2051      tf.int32.
2052    name: A name for the operation (optional).
2053
2054  Returns:
2055    A tuple of Tensor objects (y, idx, count).
2056      y: A Tensor. Has the same type as x.
2057      idx: A Tensor of type out_idx.
2058      count: A Tensor of type out_idx.
2059
2060  """
2061  # TODO(yongtang): switch to v2 once API deprecation
2062  # period (3 weeks) pass.
2063  # TODO(yongtang): The documentation should also
2064  # be updated when switch  to v2.
2065  return gen_array_ops.unique_with_counts(x, out_idx, name)
2066
2067
2068unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
2069
2070
2071@tf_export("split")
2072@dispatch.add_dispatch_support
2073def split(value, num_or_size_splits, axis=0, num=None, name="split"):
2074  """Splits a tensor `value` into a list of sub tensors.
2075
2076  See also `tf.unstack`.
2077
2078  If `num_or_size_splits` is an integer,  then `value` is split along the
2079  dimension `axis` into `num_or_size_splits` smaller tensors. This requires that
2080  `value.shape[axis]` is divisible by `num_or_size_splits`.
2081
2082  If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
2083  `len(num_or_size_splits)` elements. The shape of the `i`-th
2084  element has the same size as the `value` except along dimension `axis` where
2085  the size is `num_or_size_splits[i]`.
2086
2087  For example:
2088
2089  >>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
2090  >>>
2091  >>> # Split `x` into 3 tensors along dimension 1
2092  >>> s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
2093  >>> tf.shape(s0).numpy()
2094  array([ 5, 10], dtype=int32)
2095  >>>
2096  >>> # Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
2097  >>> split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
2098  >>> tf.shape(split0).numpy()
2099  array([5, 4], dtype=int32)
2100  >>> tf.shape(split1).numpy()
2101  array([ 5, 15], dtype=int32)
2102  >>> tf.shape(split2).numpy()
2103  array([ 5, 11], dtype=int32)
2104
2105  Args:
2106    value: The `Tensor` to split.
2107    num_or_size_splits: Either an integer indicating the number of splits along
2108      `axis` or a 1-D integer `Tensor` or Python list containing the sizes of
2109      each output tensor along `axis`. If a scalar, then it must evenly divide
2110      `value.shape[axis]`; otherwise the sum of sizes along the split axis
2111      must match that of the `value`.
2112    axis: An integer or scalar `int32` `Tensor`. The dimension along which to
2113      split. Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
2114    num: Optional, used to specify the number of outputs when it cannot be
2115      inferred from the shape of `size_splits`.
2116    name: A name for the operation (optional).
2117
2118  Returns:
2119    if `num_or_size_splits` is a scalar returns a list of `num_or_size_splits`
2120    `Tensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
2121    `num_or_size_splits.get_shape[0]` `Tensor` objects resulting from splitting
2122    `value`.
2123
2124  Raises:
2125    ValueError: If `num` is unspecified and cannot be inferred.
2126  """
2127  if isinstance(num_or_size_splits,
2128                (numbers.Integral, tensor_shape.Dimension)):
2129    return gen_array_ops.split(
2130        axis=axis, num_split=num_or_size_splits, value=value, name=name)
2131
2132  size_splits = ops.convert_to_tensor(num_or_size_splits)
2133
2134  if size_splits._rank() == 0:
2135    raise ValueError(
2136        "Rank-0 tensors are not supported as the num_or_size_splits argument "
2137        "to split. Argument provided: %s" % (num_or_size_splits,))
2138
2139  if num is None:
2140    size_splits_shape = size_splits._shape_tuple()
2141    if size_splits_shape:
2142      num = size_splits_shape[0]
2143    if num is None:
2144      raise ValueError("Cannot infer num from shape %s" % num_or_size_splits)
2145
2146  return gen_array_ops.split_v(
2147      value=value, size_splits=size_splits, axis=axis, num_split=num, name=name)
2148
2149
2150@tf_export("transpose", v1=[])
2151@dispatch.add_dispatch_support
2152def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
2153  """Transposes `a`, where `a` is a Tensor.
2154
2155  Permutes the dimensions according to the value of `perm`.
2156
2157  The returned tensor's dimension `i` will correspond to the input dimension
2158  `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is the rank
2159  of the input tensor. Hence by default, this operation performs a regular
2160  matrix transpose on 2-D input Tensors.
2161
2162  If conjugate is `True` and `a.dtype` is either `complex64` or `complex128`
2163  then the values of `a` are conjugated and transposed.
2164
2165  @compatibility(numpy)
2166  In `numpy` transposes are memory-efficient constant time operations as they
2167  simply return a new view of the same data with adjusted `strides`.
2168
2169  TensorFlow does not support strides, so `transpose` returns a new tensor with
2170  the items permuted.
2171  @end_compatibility
2172
2173  For example:
2174
2175  >>> x = tf.constant([[1, 2, 3], [4, 5, 6]])
2176  >>> tf.transpose(x)
2177  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
2178  array([[1, 4],
2179         [2, 5],
2180         [3, 6]], dtype=int32)>
2181
2182  Equivalently, you could call `tf.transpose(x, perm=[1, 0])`.
2183
2184  If `x` is complex, setting conjugate=True gives the conjugate transpose:
2185
2186  >>> x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2187  ...                  [4 + 4j, 5 + 5j, 6 + 6j]])
2188  >>> tf.transpose(x, conjugate=True)
2189  <tf.Tensor: shape=(3, 2), dtype=complex128, numpy=
2190  array([[1.-1.j, 4.-4.j],
2191         [2.-2.j, 5.-5.j],
2192         [3.-3.j, 6.-6.j]])>
2193
2194  'perm' is more useful for n-dimensional tensors where n > 2:
2195
2196  >>> x = tf.constant([[[ 1,  2,  3],
2197  ...                   [ 4,  5,  6]],
2198  ...                  [[ 7,  8,  9],
2199  ...                   [10, 11, 12]]])
2200
2201  As above, simply calling `tf.transpose` will default to `perm=[2,1,0]`.
2202
2203  To take the transpose of the matrices in dimension-0 (such as when you are
2204  transposing matrices where 0 is the batch dimension), you would set
2205  `perm=[0,2,1]`.
2206
2207  >>> tf.transpose(x, perm=[0, 2, 1])
2208  <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
2209  array([[[ 1,  4],
2210          [ 2,  5],
2211          [ 3,  6]],
2212          [[ 7, 10],
2213          [ 8, 11],
2214          [ 9, 12]]], dtype=int32)>
2215
2216  Note: This has a shorthand `linalg.matrix_transpose`):
2217
2218  Args:
2219    a: A `Tensor`.
2220    perm: A permutation of the dimensions of `a`.  This should be a vector.
2221    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2222      to tf.math.conj(tf.transpose(input)).
2223    name: A name for the operation (optional).
2224
2225  Returns:
2226    A transposed `Tensor`.
2227  """
2228  return transpose(a=a, perm=perm, name=name, conjugate=conjugate)
2229
2230
2231@tf_export(v1=["transpose"])
2232@dispatch.add_dispatch_support
2233def transpose(a, perm=None, name="transpose", conjugate=False):
2234  """Transposes `a`.
2235
2236  Permutes the dimensions according to `perm`.
2237
2238  The returned tensor's dimension i will correspond to the input dimension
2239  `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
2240  the rank of the input tensor. Hence by default, this operation performs a
2241  regular matrix transpose on 2-D input Tensors. If conjugate is True and
2242  `a.dtype` is either `complex64` or `complex128` then the values of `a`
2243  are conjugated and transposed.
2244
2245  @compatibility(numpy)
2246  In `numpy` transposes are memory-efficient constant time operations as they
2247  simply return a new view of the same data with adjusted `strides`.
2248
2249  TensorFlow does not support strides, so `transpose` returns a new tensor with
2250  the items permuted.
2251  @end_compatibility
2252
2253  For example:
2254
2255  ```python
2256  x = tf.constant([[1, 2, 3], [4, 5, 6]])
2257  tf.transpose(x)  # [[1, 4]
2258                   #  [2, 5]
2259                   #  [3, 6]]
2260
2261  # Equivalently
2262  tf.transpose(x, perm=[1, 0])  # [[1, 4]
2263                                #  [2, 5]
2264                                #  [3, 6]]
2265
2266  # If x is complex, setting conjugate=True gives the conjugate transpose
2267  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2268                   [4 + 4j, 5 + 5j, 6 + 6j]])
2269  tf.transpose(x, conjugate=True)  # [[1 - 1j, 4 - 4j],
2270                                   #  [2 - 2j, 5 - 5j],
2271                                   #  [3 - 3j, 6 - 6j]]
2272
2273  # 'perm' is more useful for n-dimensional tensors, for n > 2
2274  x = tf.constant([[[ 1,  2,  3],
2275                    [ 4,  5,  6]],
2276                   [[ 7,  8,  9],
2277                    [10, 11, 12]]])
2278
2279  # Take the transpose of the matrices in dimension-0
2280  # (this common operation has a shorthand `linalg.matrix_transpose`)
2281  tf.transpose(x, perm=[0, 2, 1])  # [[[1,  4],
2282                                   #   [2,  5],
2283                                   #   [3,  6]],
2284                                   #  [[7, 10],
2285                                   #   [8, 11],
2286                                   #   [9, 12]]]
2287  ```
2288
2289  Args:
2290    a: A `Tensor`.
2291    perm: A permutation of the dimensions of `a`.
2292    name: A name for the operation (optional).
2293    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2294      to tf.math.conj(tf.transpose(input)).
2295
2296  Returns:
2297    A transposed `Tensor`.
2298  """
2299  with ops.name_scope(name, "transpose", [a]) as name:
2300    if not tensor_util.is_tf_type(a):
2301      a = ops.convert_to_tensor(a, name="a")
2302
2303    if conjugate and a.dtype.is_complex:
2304      transpose_fn = gen_array_ops.conjugate_transpose
2305    else:
2306      transpose_fn = gen_array_ops.transpose
2307
2308    if perm is not None:
2309      return transpose_fn(a, perm, name=name)
2310
2311    rank = a.shape.rank
2312    if rank is None:
2313      perm = gen_math_ops._range(gen_array_ops.rank(a) - 1, -1, -1)
2314    else:
2315      perm = np.arange(rank - 1, -1, -1, dtype=np.int32)
2316    return transpose_fn(a, perm, name=name)
2317
2318
2319# pylint: disable=invalid-name
2320@tf_export(
2321    "linalg.matrix_transpose",
2322    v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"])
2323@dispatch.add_dispatch_support
2324@deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose")
2325def matrix_transpose(a, name="matrix_transpose", conjugate=False):
2326  """Transposes last two dimensions of tensor `a`.
2327
2328  For example:
2329
2330  ```python
2331  x = tf.constant([[1, 2, 3], [4, 5, 6]])
2332  tf.linalg.matrix_transpose(x)  # [[1, 4],
2333                                 #  [2, 5],
2334                                 #  [3, 6]]
2335
2336  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2337                   [4 + 4j, 5 + 5j, 6 + 6j]])
2338  tf.linalg.matrix_transpose(x, conjugate=True)  # [[1 - 1j, 4 - 4j],
2339                                                 #  [2 - 2j, 5 - 5j],
2340                                                 #  [3 - 3j, 6 - 6j]]
2341
2342  # Matrix with two batch dimensions.
2343  # x.shape is [1, 2, 3, 4]
2344  # tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
2345  ```
2346
2347  Note that `tf.matmul` provides kwargs allowing for transpose of arguments.
2348  This is done with minimal cost, and is preferable to using this function. E.g.
2349
2350  ```python
2351  # Good!  Transpose is taken at minimal additional cost.
2352  tf.matmul(matrix, b, transpose_b=True)
2353
2354  # Inefficient!
2355  tf.matmul(matrix, tf.linalg.matrix_transpose(b))
2356  ```
2357
2358  @compatibility(numpy)
2359  In `numpy` transposes are memory-efficient constant time operations as they
2360  simply return a new view of the same data with adjusted `strides`.
2361
2362  TensorFlow does not support strides, `linalg.matrix_transpose` returns a new
2363  tensor with the items permuted.
2364  @end_compatibility
2365
2366  Args:
2367    a: A `Tensor` with `rank >= 2`.
2368    name: A name for the operation (optional).
2369    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2370      to tf.math.conj(tf.linalg.matrix_transpose(input)).
2371
2372  Returns:
2373    A transposed batch matrix `Tensor`.
2374
2375  Raises:
2376    ValueError:  If `a` is determined statically to have `rank < 2`.
2377  """
2378  with ops.name_scope(name, values=[a]):
2379    a = ops.convert_to_tensor(a, name="a")
2380
2381    # If we know the number of dimensions (statically), we can do two things:
2382    # 1. Check that `a` is a (batch) matrix.
2383    # 2. Use a Python list for perm.  This preserves static shape information
2384    #    and avoids extra computations.
2385    a_shape = a.get_shape()
2386    ndims = a_shape.ndims
2387    if ndims is not None:
2388      if ndims < 2:
2389        raise ValueError(
2390            "Argument 'a' should be a (batch) matrix, with rank >= 2.  Found: "
2391            "%s" % a_shape)
2392      perm = list(range(ndims - 2)) + [ndims - 1] + [ndims - 2]
2393    else:
2394      a_rank = rank(a)
2395      perm = concat(
2396          (gen_math_ops._range(0, a_rank - 2, 1), [a_rank - 1, a_rank - 2]), 0)
2397
2398    return transpose(a, perm=perm, conjugate=conjugate)
2399
2400
2401@tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"])
2402@dispatch.add_dispatch_support
2403@deprecation.deprecated_endpoints("matrix_diag")
2404def matrix_diag(diagonal,
2405                name="diag",
2406                k=0,
2407                num_rows=-1,
2408                num_cols=-1,
2409                padding_value=0,
2410                align="RIGHT_LEFT"):
2411  """Returns a batched diagonal tensor with given batched diagonal values.
2412
2413  Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
2414  diagonals of a matrix, with everything else padded with `padding`. `num_rows`
2415  and `num_cols` specify the dimension of the innermost matrix of the output. If
2416  both are not specified, the op assumes the innermost matrix is square and
2417  infers its size from `k` and the innermost dimension of `diagonal`. If only
2418  one of them is specified, the op assumes the unspecified value is the smallest
2419  possible based on other criteria.
2420
2421  Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor
2422  has rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only
2423  one diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has
2424  rank `r` with shape `[I, J, ..., L, num_rows, num_cols]`.
2425
2426  The second innermost dimension of `diagonal` has double meaning. When `k` is
2427  scalar or `k[0] == k[1]`, `M` is part of the batch size [I, J, ..., M], and
2428  the output tensor is:
2429
2430  ```
2431  output[i, j, ..., l, m, n]
2432    = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper
2433      padding_value                             ; otherwise
2434  ```
2435
2436  Otherwise, `M` is treated as the number of diagonals for the matrix in the
2437  same batch (`M = k[1]-k[0]+1`), and the output tensor is:
2438
2439  ```
2440  output[i, j, ..., l, m, n]
2441    = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2442      padding_value                                     ; otherwise
2443  ```
2444  where `d = n - m`, `diag_index = k[1] - d`, and
2445  `index_in_diag = n - max(d, 0) + offset`.
2446
2447  `offset` is zero except when the alignment of the diagonal is to the right.
2448  ```
2449  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2450                                             and `d >= 0`) or
2451                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2452                                             and `d <= 0`)
2453           0                          ; otherwise
2454  ```
2455  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2456
2457  For example:
2458
2459  ```
2460  # The main diagonal.
2461  diagonal = np.array([[1, 2, 3, 4],            # Input shape: (2, 4)
2462                       [5, 6, 7, 8]])
2463  tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0],  # Output shape: (2, 4, 4)
2464                                 [0, 2, 0, 0],
2465                                 [0, 0, 3, 0],
2466                                 [0, 0, 0, 4]],
2467                                [[5, 0, 0, 0],
2468                                 [0, 6, 0, 0],
2469                                 [0, 0, 7, 0],
2470                                 [0, 0, 0, 8]]]
2471
2472  # A superdiagonal (per batch).
2473  diagonal = np.array([[1, 2, 3],  # Input shape: (2, 3)
2474                       [4, 5, 6]])
2475  tf.matrix_diag(diagonal, k = 1)
2476    ==> [[[0, 1, 0, 0],  # Output shape: (2, 4, 4)
2477          [0, 0, 2, 0],
2478          [0, 0, 0, 3],
2479          [0, 0, 0, 0]],
2480         [[0, 4, 0, 0],
2481          [0, 0, 5, 0],
2482          [0, 0, 0, 6],
2483          [0, 0, 0, 0]]]
2484
2485  # A tridiagonal band (per batch).
2486  diagonals = np.array([[[8, 9, 0],  # Input shape: (2, 2, 3)
2487                         [1, 2, 3],
2488                         [0, 4, 5]],
2489                        [[2, 3, 0],
2490                         [6, 7, 9],
2491                         [0, 9, 1]]])
2492  tf.matrix_diag(diagonals, k = (-1, 1))
2493    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
2494          [4, 2, 9],
2495          [0, 5, 3]],
2496         [[6, 2, 0],
2497          [9, 7, 3],
2498          [0, 1, 9]]]
2499
2500  # RIGHT_LEFT alignment.
2501  diagonals = np.array([[[0, 8, 9],  # Input shape: (2, 2, 3)
2502                         [1, 2, 3],
2503                         [4, 5, 0]],
2504                        [[0, 2, 3],
2505                         [6, 7, 9],
2506                         [9, 1, 0]]])
2507  tf.matrix_diag(diagonals, k = (-1, 1), align="RIGHT_LEFT")
2508    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
2509          [4, 2, 9],
2510          [0, 5, 3]],
2511         [[6, 2, 0],
2512          [9, 7, 3],
2513          [0, 1, 9]]]
2514
2515  # Rectangular matrix.
2516  diagonal = np.array([1, 2])  # Input shape: (2)
2517  tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4)
2518    ==> [[0, 0, 0, 0],  # Output shape: (3, 4)
2519         [1, 0, 0, 0],
2520         [0, 2, 0, 0]]
2521
2522  # Rectangular matrix with inferred num_cols and padding_value = 9.
2523  tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
2524    ==> [[9, 9],  # Output shape: (3, 2)
2525         [1, 9],
2526         [9, 2]]
2527  ```
2528
2529  Args:
2530    diagonal: A `Tensor` with `rank k >= 1`.
2531    name: A name for the operation (optional).
2532    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2533      main diagonal, and negative value means subdiagonals. `k` can be a single
2534      integer (for a single diagonal) or a pair of integers specifying the low
2535      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2536    num_rows: The number of rows of the output matrix. If it is not provided,
2537      the op assumes the output matrix is a square matrix and infers the matrix
2538      size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2539    num_cols: The number of columns of the output matrix. If it is not provided,
2540      the op assumes the output matrix is a square matrix and infers the matrix
2541      size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2542    padding_value: The value to fill the area outside the specified diagonal
2543      band with. Default is 0.
2544    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2545      `align` is a string specifying how superdiagonals and subdiagonals should
2546      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2547      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2548      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2549      the left (right-pads the row). It is the packing format LAPACK uses.
2550      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2551
2552  Returns:
2553    A Tensor. Has the same type as `diagonal`.
2554  """
2555  # Special case to sidestep the tf.constant conversion error:
2556  # TypeError: Expected bool, got 0 of type 'int' instead.
2557  if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
2558    padding_value = bool(padding_value)
2559
2560  return gen_array_ops.matrix_diag_v3(
2561      diagonal=diagonal,
2562      k=k,
2563      num_rows=num_rows,
2564      num_cols=num_cols,
2565      padding_value=padding_value,
2566      align=align,
2567      name=name)
2568
2569
2570@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
2571@dispatch.add_dispatch_support
2572@deprecation.deprecated_endpoints("matrix_diag_part")
2573def matrix_diag_part(
2574    input,  # pylint:disable=redefined-builtin
2575    name="diag_part",
2576    k=0,
2577    padding_value=0,
2578    align="RIGHT_LEFT"):
2579  """Returns the batched diagonal part of a batched tensor.
2580
2581  Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
2582  `input`.
2583
2584  Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`.
2585  Let `max_diag_len` be the maximum length among all diagonals to be extracted,
2586  `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2587  Let `num_diags` be the number of diagonals to extract,
2588  `num_diags = k[1] - k[0] + 1`.
2589
2590  If `num_diags == 1`, the output tensor is of rank `r - 1` with shape
2591  `[I, J, ..., L, max_diag_len]` and values:
2592
2593  ```
2594  diagonal[i, j, ..., l, n]
2595    = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2596      padding_value                 ; otherwise.
2597  ```
2598  where `y = max(-k[1], 0)`, `x = max(k[1], 0)`.
2599
2600  Otherwise, the output tensor has rank `r` with dimensions
2601  `[I, J, ..., L, num_diags, max_diag_len]` with values:
2602
2603  ```
2604  diagonal[i, j, ..., l, m, n]
2605    = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2606      padding_value                 ; otherwise.
2607  ```
2608  where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
2609
2610  `offset` is zero except when the alignment of the diagonal is to the right.
2611  ```
2612  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2613                                             and `d >= 0`) or
2614                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2615                                             and `d <= 0`)
2616           0                          ; otherwise
2617  ```
2618  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2619
2620  The input must be at least a matrix.
2621
2622  For example:
2623
2624  ```
2625  input = np.array([[[1, 2, 3, 4],  # Input shape: (2, 3, 4)
2626                     [5, 6, 7, 8],
2627                     [9, 8, 7, 6]],
2628                    [[5, 4, 3, 2],
2629                     [1, 2, 3, 4],
2630                     [5, 6, 7, 8]]])
2631
2632  # A main diagonal from each batch.
2633  tf.linalg.diag_part(input) ==> [[1, 6, 7],  # Output shape: (2, 3)
2634                                  [5, 2, 7]]
2635
2636  # A superdiagonal from each batch.
2637  tf.linalg.diag_part(input, k = 1)
2638    ==> [[2, 7, 6],  # Output shape: (2, 3)
2639         [4, 3, 8]]
2640
2641  # A band from each batch.
2642  tf.linalg.diag_part(input, k = (-1, 2))
2643    ==> [[[3, 8, 0],  # Output shape: (2, 4, 3)
2644          [2, 7, 6],
2645          [1, 6, 7],
2646          [0, 5, 8]],
2647         [[3, 4, 0],
2648          [4, 3, 8],
2649          [5, 2, 7],
2650          [0, 1, 6]]]
2651
2652  # RIGHT_LEFT alignment.
2653  tf.linalg.diag_part(input, k = (-1, 2), align="RIGHT_LEFT")
2654    ==> [[[0, 3, 8],  # Output shape: (2, 4, 3)
2655          [2, 7, 6],
2656          [1, 6, 7],
2657          [5, 8, 0]],
2658         [[0, 3, 4],
2659          [4, 3, 8],
2660          [5, 2, 7],
2661          [1, 6, 0]]]
2662
2663  # max_diag_len can be shorter than the main diagonal.
2664  tf.linalg.diag_part(input, k = (-2, -1))
2665    ==> [[[5, 8],
2666          [0, 9]],
2667         [[1, 6],
2668          [0, 5]]]
2669
2670  # padding_value = 9
2671  tf.linalg.diag_part(input, k = (1, 3), padding_value = 9)
2672    ==> [[[4, 9, 9],  # Output shape: (2, 3, 3)
2673          [3, 8, 9],
2674          [2, 7, 6]],
2675         [[2, 9, 9],
2676          [3, 4, 9],
2677          [4, 3, 8]]]
2678
2679  ```
2680
2681  Args:
2682    input: A `Tensor` with `rank k >= 2`.
2683    name: A name for the operation (optional).
2684    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2685      main diagonal, and negative value means subdiagonals. `k` can be a single
2686      integer (for a single diagonal) or a pair of integers specifying the low
2687      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2688    padding_value: The value to fill the area outside the specified diagonal
2689      band with. Default is 0.
2690    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2691      `align` is a string specifying how superdiagonals and subdiagonals should
2692      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2693      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2694      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2695      the left (right-pads the row). It is the packing format LAPACK uses.
2696      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2697
2698  Returns:
2699    A Tensor containing diagonals of `input`. Has the same type as `input`.
2700
2701  Raises:
2702    InvalidArgumentError: When `k` is out of bound or when `k[0]>k[1:]`.
2703  """
2704  # Special case to sidestep the tf.constant conversion error:
2705  # TypeError: Expected bool, got 0 of type 'int' instead.
2706  if hasattr(input, "dtype") and input.dtype == "bool":
2707    padding_value = bool(padding_value)
2708
2709  return gen_array_ops.matrix_diag_part_v3(
2710      input=input, k=k, padding_value=padding_value, align=align, name=name)
2711
2712
2713@tf_export(
2714    "linalg.tensor_diag_part", v1=["linalg.tensor_diag_part", "diag_part"])
2715@dispatch.add_dispatch_support
2716@deprecation.deprecated_endpoints("diag_part")
2717def tensor_diag_part(
2718    input,  # pylint:disable=redefined-builtin
2719    name=None):
2720  """Returns the diagonal part of the tensor.
2721
2722  This operation returns a tensor with the `diagonal` part
2723  of the `input`. The `diagonal` part is computed as follows:
2724
2725  Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a
2726  tensor of rank `k` with dimensions `[D1,..., Dk]` where:
2727
2728  `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`.
2729
2730  For a rank 2 tensor, `linalg.diag_part` and `linalg.tensor_diag_part`
2731  produce the same result. For rank 3 and higher, linalg.diag_part extracts
2732  the diagonal of each inner-most matrix in the tensor. An example where
2733  they differ is given below.
2734
2735  >>> x = [[[[1111,1112],[1121,1122]],
2736  ...       [[1211,1212],[1221,1222]]],
2737  ...      [[[2111, 2112], [2121, 2122]],
2738  ...       [[2211, 2212], [2221, 2222]]]
2739  ...      ]
2740  >>> tf.linalg.tensor_diag_part(x)
2741  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2742  array([[1111, 1212],
2743         [2121, 2222]], dtype=int32)>
2744  >>> tf.linalg.diag_part(x).shape
2745  TensorShape([2, 2, 2])
2746
2747  Args:
2748    input: A `Tensor` with rank `2k`.
2749    name: A name for the operation (optional).
2750
2751  Returns:
2752    A Tensor containing diagonals of `input`. Has the same type as `input`, and
2753    rank `k`.
2754  """
2755  return gen_array_ops.diag_part(input=input, name=name)
2756
2757
2758@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
2759@dispatch.add_dispatch_support
2760@deprecation.deprecated_endpoints("matrix_set_diag")
2761def matrix_set_diag(
2762    input,  # pylint:disable=redefined-builtin
2763    diagonal,
2764    name="set_diag",
2765    k=0,
2766    align="RIGHT_LEFT"):
2767  """Returns a batched matrix tensor with new batched diagonal values.
2768
2769  Given `input` and `diagonal`, this operation returns a tensor with the
2770  same shape and values as `input`, except for the specified diagonals of the
2771  innermost matrices. These will be overwritten by the values in `diagonal`.
2772
2773  `input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
2774  `k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
2775  Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
2776  `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
2777  `max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
2778  `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2779
2780  The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
2781  If `k` is scalar or `k[0] == k[1]`:
2782
2783  ```
2784  output[i, j, ..., l, m, n]
2785    = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
2786      input[i, j, ..., l, m, n]              ; otherwise
2787  ```
2788
2789  Otherwise,
2790
2791  ```
2792  output[i, j, ..., l, m, n]
2793    = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2794      input[i, j, ..., l, m, n]                         ; otherwise
2795  ```
2796  where `d = n - m`, `diag_index = k[1] - d`, and
2797  `index_in_diag = n - max(d, 0) + offset`.
2798
2799  `offset` is zero except when the alignment of the diagonal is to the right.
2800  ```
2801  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2802                                             and `d >= 0`) or
2803                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2804                                             and `d <= 0`)
2805           0                          ; otherwise
2806  ```
2807  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2808
2809  For example:
2810
2811  ```
2812  # The main diagonal.
2813  input = np.array([[[7, 7, 7, 7],              # Input shape: (2, 3, 4)
2814                     [7, 7, 7, 7],
2815                     [7, 7, 7, 7]],
2816                    [[7, 7, 7, 7],
2817                     [7, 7, 7, 7],
2818                     [7, 7, 7, 7]]])
2819  diagonal = np.array([[1, 2, 3],               # Diagonal shape: (2, 3)
2820                       [4, 5, 6]])
2821  tf.matrix_set_diag(input, diagonal)
2822    ==> [[[1, 7, 7, 7],  # Output shape: (2, 3, 4)
2823          [7, 2, 7, 7],
2824          [7, 7, 3, 7]],
2825         [[4, 7, 7, 7],
2826          [7, 5, 7, 7],
2827          [7, 7, 6, 7]]]
2828
2829  # A superdiagonal (per batch).
2830  tf.matrix_set_diag(input, diagonal, k = 1)
2831    ==> [[[7, 1, 7, 7],  # Output shape: (2, 3, 4)
2832          [7, 7, 2, 7],
2833          [7, 7, 7, 3]],
2834         [[7, 4, 7, 7],
2835          [7, 7, 5, 7],
2836          [7, 7, 7, 6]]]
2837
2838  # A band of diagonals.
2839  diagonals = np.array([[[9, 1, 0],  # Diagonal shape: (2, 4, 3)
2840                         [6, 5, 8],
2841                         [1, 2, 3],
2842                         [0, 4, 5]],
2843                        [[1, 2, 0],
2844                         [5, 6, 4],
2845                         [6, 1, 2],
2846                         [0, 3, 4]]])
2847  tf.matrix_set_diag(input, diagonals, k = (-1, 2))
2848    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
2849          [4, 2, 5, 1],
2850          [7, 5, 3, 8]],
2851         [[6, 5, 1, 7],
2852          [3, 1, 6, 2],
2853          [7, 4, 2, 4]]]
2854
2855  # RIGHT_LEFT alignment.
2856  diagonals = np.array([[[0, 9, 1],  # Diagonal shape: (2, 4, 3)
2857                         [6, 5, 8],
2858                         [1, 2, 3],
2859                         [4, 5, 0]],
2860                        [[0, 1, 2],
2861                         [5, 6, 4],
2862                         [6, 1, 2],
2863                         [3, 4, 0]]])
2864  tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="RIGHT_LEFT")
2865    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
2866          [4, 2, 5, 1],
2867          [7, 5, 3, 8]],
2868         [[6, 5, 1, 7],
2869          [3, 1, 6, 2],
2870          [7, 4, 2, 4]]]
2871
2872  ```
2873
2874  Args:
2875    input: A `Tensor` with rank `k + 1`, where `k >= 1`.
2876    diagonal:  A `Tensor` with rank `k`, when `d_lower == d_upper`, or `k + 1`,
2877      otherwise. `k >= 1`.
2878    name: A name for the operation (optional).
2879    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2880      main diagonal, and negative value means subdiagonals. `k` can be a single
2881      integer (for a single diagonal) or a pair of integers specifying the low
2882      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2883    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2884      `align` is a string specifying how superdiagonals and subdiagonals should
2885      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2886      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2887      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2888      the left (right-pads the row). It is the packing format LAPACK uses.
2889      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2890  """
2891  return gen_array_ops.matrix_set_diag_v3(
2892      input=input, diagonal=diagonal, k=k, align=align, name=name)
2893
2894
2895# pylint: enable=invalid-name
2896
2897
2898def _constant_if_small(value, shape, dtype, name):
2899  try:
2900    if np.prod(shape) < 1000:
2901      return constant(value, shape=shape, dtype=dtype, name=name)
2902  except (NotImplementedError, TypeError):
2903    # Happens when shape is a Tensor, list with Tensor elements, etc.
2904    pass
2905  return None
2906
2907
2908def _tag_zeros_tensor(fun):
2909  """ Tags the result of function by setting _is_zeros_tensor attribute.
2910
2911  This is useful to compute Hessians of fused ops such as cross_entropy.
2912  """
2913
2914  def wrapped(*args, **kwargs):
2915    tensor = fun(*args, **kwargs)
2916    tensor._is_zeros_tensor = True
2917    return tensor
2918
2919  return tf_decorator.make_decorator(fun, wrapped)
2920
2921
2922@tf_export("zeros")
2923@dispatch.add_dispatch_support
2924@_tag_zeros_tensor
2925def zeros(shape, dtype=dtypes.float32, name=None):
2926  """Creates a tensor with all elements set to zero.
2927
2928  See also `tf.zeros_like`, `tf.ones`, `tf.fill`, `tf.eye`.
2929
2930  This operation returns a tensor of type `dtype` with shape `shape` and
2931  all elements set to zero.
2932
2933  >>> tf.zeros([3, 4], tf.int32)
2934  <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
2935  array([[0, 0, 0, 0],
2936         [0, 0, 0, 0],
2937         [0, 0, 0, 0]], dtype=int32)>
2938
2939  Args:
2940    shape: A `list` of integers, a `tuple` of integers, or
2941      a 1-D `Tensor` of type `int32`.
2942    dtype: The DType of an element in the resulting `Tensor`.
2943    name: Optional string. A name for the operation.
2944
2945  Returns:
2946    A `Tensor` with all elements set to zero.
2947  """
2948  dtype = dtypes.as_dtype(dtype).base_dtype
2949  with ops.name_scope(name, "zeros", [shape]) as name:
2950    if dtype == dtypes.bool:
2951      zero = False
2952    elif dtype == dtypes.string:
2953      zero = ""
2954    elif dtype.is_quantized:
2955      zero = np.zeros([]).astype(dtype.as_numpy_dtype)
2956    else:
2957      zero = 0
2958
2959    if not isinstance(shape, ops.Tensor):
2960      try:
2961        if not context.executing_eagerly():
2962          # Create a constant if it won't be very big. Otherwise create a fill
2963          # op to prevent serialized GraphDefs from becoming too large.
2964          output = _constant_if_small(zero, shape, dtype, name)
2965          if output is not None:
2966            return output
2967
2968        # Go through tensor shapes to get int64-if-needed semantics
2969        shape = constant_op._tensor_shape_tensor_conversion_function(
2970            tensor_shape.TensorShape(shape))
2971      except (TypeError, ValueError, errors.UnimplementedError):
2972        # Happens when shape is a list with tensor elements
2973        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
2974    if not shape._shape_tuple():
2975      shape = reshape(shape, [-1])  # Ensure it's a vector
2976    output = fill(shape, constant(zero, dtype=dtype), name=name)
2977  assert output.dtype.base_dtype == dtype
2978  return output
2979
2980
2981@tf_export(v1=["zeros_like"])
2982@dispatch.add_dispatch_support
2983def zeros_like(tensor, dtype=None, name=None, optimize=True):
2984  """Creates a tensor with all elements set to zero.
2985
2986  See also `tf.zeros`.
2987
2988  Given a single tensor (`tensor`), this operation returns a tensor of the
2989  same type and shape as `tensor` with all elements set to zero. Optionally,
2990  you can use `dtype` to specify a new type for the returned tensor.
2991
2992  Examples:
2993
2994    >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
2995    >>> tf.zeros_like(tensor)
2996    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
2997    array([[0, 0, 0],
2998           [0, 0, 0]], dtype=int32)>
2999
3000    >>> tf.zeros_like(tensor, dtype=tf.float32)
3001    <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
3002    array([[0., 0., 0.],
3003           [0., 0., 0.]], dtype=float32)>
3004
3005  Args:
3006    tensor: A `Tensor`.
3007    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3008      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3009      `complex64`, `complex128`, `bool` or `string`. (optional)
3010    name: A name for the operation (optional).
3011    optimize: if `True`, attempt to statically determine the shape of `tensor`
3012      and encode it as a constant. (optional, defaults to `True`)
3013
3014  Returns:
3015    A `Tensor` with all elements set to zero.
3016  """
3017  return zeros_like_impl(tensor, dtype, name, optimize)
3018
3019
3020@tf_export("zeros_like", v1=[])
3021@dispatch.add_dispatch_support
3022def zeros_like_v2(
3023    input,  # pylint: disable=redefined-builtin
3024    dtype=None,
3025    name=None):
3026  """Creates a tensor with all elements set to zero.
3027
3028  See also `tf.zeros`.
3029
3030  Given a single tensor or array-like object (`input`), this operation returns
3031  a tensor of the same type and shape as `input` with all elements set to zero.
3032  Optionally, you can use `dtype` to specify a new type for the returned tensor.
3033
3034  Examples:
3035
3036    >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3037    >>> tf.zeros_like(tensor)
3038    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3039    array([[0, 0, 0],
3040           [0, 0, 0]], dtype=int32)>
3041
3042    >>> tf.zeros_like(tensor, dtype=tf.float32)
3043    <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
3044    array([[0., 0., 0.],
3045           [0., 0., 0.]], dtype=float32)>
3046
3047    >>> tf.zeros_like([[1, 2, 3], [4, 5, 6]])
3048    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3049    array([[0, 0, 0],
3050           [0, 0, 0]], dtype=int32)>
3051
3052  Args:
3053    input: A `Tensor` or array-like object.
3054    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3055      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3056      `complex64`, `complex128`, `bool` or `string` (optional).
3057    name: A name for the operation (optional).
3058
3059  Returns:
3060    A `Tensor` with all elements set to zero.
3061  """
3062  return zeros_like_impl(input, dtype, name, optimize=True)
3063
3064
3065@_tag_zeros_tensor
3066def zeros_like_impl(tensor, dtype, name, optimize=True):
3067  """Internal implementation for the v1/v2 zeros_like API calls."""
3068  with ops.name_scope(name, "zeros_like", [tensor]) as name:
3069    if not tensor_util.is_tf_type(tensor):
3070      tensor = ops.convert_to_tensor(tensor, name="tensor")
3071    tensor_shape = tensor.shape
3072    tensor_dtype = tensor.dtype
3073
3074    if context.executing_eagerly():
3075      if dtype is not None and dtype != tensor_dtype:
3076        return zeros(
3077            shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
3078      return gen_array_ops.zeros_like(tensor, name=name)
3079
3080    # For now, variant types must be created via zeros_like; as we need to
3081    # pass the input variant object to the proper zeros callback.
3082
3083    if (optimize and tensor_shape.is_fully_defined() and
3084        tensor_dtype != dtypes.variant):
3085      # We can produce a zeros tensor independent of the value of 'tensor',
3086      # since the shape is known statically.
3087      return zeros(tensor_shape, dtype=dtype or tensor_dtype, name=name)
3088
3089    if dtype is not None and dtype != tensor_dtype and dtype != dtypes.variant:
3090      return zeros(
3091          shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
3092    else:
3093      return gen_array_ops.zeros_like(tensor, name=name)
3094
3095
3096@tf_export(v1=["ones_like"])
3097@dispatch.add_dispatch_support
3098def ones_like(tensor, dtype=None, name=None, optimize=True):
3099  """Creates a tensor with all elements set to 1.
3100
3101  See also `tf.ones`.
3102
3103  Given a single tensor (`tensor`), this operation returns a tensor of the same
3104  type and shape as `tensor` with all elements set to 1. Optionally, you can
3105  specify a new type (`dtype`) for the returned tensor.
3106
3107  For example:
3108
3109  ```python
3110  tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3111  tf.ones_like(tensor)  # [[1, 1, 1], [1, 1, 1]]
3112  ```
3113
3114  Args:
3115    tensor: A `Tensor`.
3116    dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
3117      `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, `complex64`,
3118      `complex128` or `bool`.
3119    name: A name for the operation (optional).
3120    optimize: if true, attempt to statically determine the shape of 'tensor' and
3121      encode it as a constant.
3122
3123  Returns:
3124    A `Tensor` with all elements set to 1.
3125  """
3126  return ones_like_impl(tensor, dtype, name, optimize)
3127
3128
3129@tf_export("ones_like", v1=[])
3130@dispatch.add_dispatch_support
3131def ones_like_v2(
3132    input,  # pylint: disable=redefined-builtin
3133    dtype=None,
3134    name=None):
3135  """Creates a tensor of all ones that has the same shape as the input.
3136
3137  See also `tf.ones`.
3138
3139  Given a single tensor (`tensor`), this operation returns a tensor of the
3140  same type and shape as `tensor` with all elements set to 1. Optionally,
3141  you can use `dtype` to specify a new type for the returned tensor.
3142
3143  For example:
3144
3145  >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3146  >>> tf.ones_like(tensor)
3147  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3148    array([[1, 1, 1],
3149           [1, 1, 1]], dtype=int32)>
3150
3151  Args:
3152    input: A `Tensor`.
3153    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3154      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3155      `complex64`, `complex128`, `bool` or `string`.
3156    name: A name for the operation (optional).
3157
3158  Returns:
3159    A `Tensor` with all elements set to one.
3160  """
3161  return ones_like_impl(input, dtype, name, optimize=True)
3162
3163
3164def ones_like_impl(tensor, dtype, name, optimize=True):
3165  """Internal implementation for the v1/v2 ones_like API calls."""
3166  with ops.name_scope(name, "ones_like", [tensor]) as name:
3167    tensor = ops.convert_to_tensor(tensor, name="tensor")
3168    ones_shape = shape_internal(tensor, optimize=optimize)
3169    if dtype is None:
3170      dtype = tensor.dtype
3171    ret = ones(ones_shape, dtype=dtype, name=name)
3172    if not context.executing_eagerly():
3173      ret.set_shape(tensor.get_shape())
3174    return ret
3175
3176
3177@tf_export("ones")
3178@dispatch.add_dispatch_support
3179def ones(shape, dtype=dtypes.float32, name=None):
3180  """Creates a tensor with all elements set to one (1).
3181
3182  See also `tf.ones_like`, `tf.zeros`, `tf.fill`, `tf.eye`.
3183
3184  This operation returns a tensor of type `dtype` with shape `shape` and
3185  all elements set to one.
3186
3187  >>> tf.ones([3, 4], tf.int32)
3188  <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3189  array([[1, 1, 1, 1],
3190         [1, 1, 1, 1],
3191         [1, 1, 1, 1]], dtype=int32)>
3192
3193  Args:
3194    shape: A `list` of integers, a `tuple` of integers, or
3195      a 1-D `Tensor` of type `int32`.
3196    dtype: Optional DType of an element in the resulting `Tensor`. Default is
3197      `tf.float32`.
3198    name: Optional string. A name for the operation.
3199
3200  Returns:
3201    A `Tensor` with all elements set to one (1).
3202  """
3203  dtype = dtypes.as_dtype(dtype).base_dtype
3204  with ops.name_scope(name, "ones", [shape]) as name:
3205    if dtype == dtypes.bool:
3206      one = True
3207    elif dtype.is_quantized:
3208      one = np.ones([]).astype(dtype.as_numpy_dtype)
3209    else:
3210      one = 1
3211    if not isinstance(shape, ops.Tensor):
3212      try:
3213        if not context.executing_eagerly():
3214          # Create a constant if it won't be very big. Otherwise create a fill
3215          # op to prevent serialized GraphDefs from becoming too large.
3216          output = _constant_if_small(one, shape, dtype, name)
3217          if output is not None:
3218            return output
3219
3220        # Go through tensor shapes to get int64-if-needed semantics
3221        shape = constant_op._tensor_shape_tensor_conversion_function(
3222            tensor_shape.TensorShape(shape))
3223      except (TypeError, ValueError):
3224        # Happens when shape is a list with tensor elements
3225        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
3226    if not shape._shape_tuple():
3227      shape = reshape(shape, [-1])  # Ensure it's a vector
3228    output = fill(shape, constant(one, dtype=dtype), name=name)
3229  assert output.dtype.base_dtype == dtype
3230  return output
3231
3232
3233@tf_export(v1=["placeholder"])
3234def placeholder(dtype, shape=None, name=None):
3235  """Inserts a placeholder for a tensor that will be always fed.
3236
3237  **Important**: This tensor will produce an error if evaluated. Its value must
3238  be fed using the `feed_dict` optional argument to `Session.run()`,
3239  `Tensor.eval()`, or `Operation.run()`.
3240
3241  For example:
3242
3243  ```python
3244  x = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024))
3245  y = tf.matmul(x, x)
3246
3247  with tf.compat.v1.Session() as sess:
3248    print(sess.run(y))  # ERROR: will fail because x was not fed.
3249
3250    rand_array = np.random.rand(1024, 1024)
3251    print(sess.run(y, feed_dict={x: rand_array}))  # Will succeed.
3252  ```
3253
3254  Args:
3255    dtype: The type of elements in the tensor to be fed.
3256    shape: The shape of the tensor to be fed (optional). If the shape is not
3257      specified, you can feed a tensor of any shape.
3258    name: A name for the operation (optional).
3259
3260  Returns:
3261    A `Tensor` that may be used as a handle for feeding a value, but not
3262    evaluated directly.
3263
3264  Raises:
3265    RuntimeError: if eager execution is enabled
3266
3267  @compatibility(TF2)
3268  This API is not compatible with eager execution and `tf.function`. To migrate
3269  to TF2, rewrite the code to be compatible with eager execution. Check the
3270  [migration
3271  guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
3272  on replacing `Session.run` calls. In TF2, you can just pass tensors directly
3273  into ops and layers. If you want to explicitly set up your inputs, also see
3274  [Keras functional API](https://www.tensorflow.org/guide/keras/functional) on
3275  how to use `tf.keras.Input` to replace `tf.compat.v1.placeholder`.
3276  `tf.function` arguments also do the job of `tf.compat.v1.placeholder`.
3277  For more details please read [Better
3278  performance with tf.function](https://www.tensorflow.org/guide/function).
3279  @end_compatibility
3280  """
3281  if context.executing_eagerly():
3282    raise RuntimeError("tf.placeholder() is not compatible with "
3283                       "eager execution.")
3284
3285  return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
3286
3287
3288@tf_export(v1=["placeholder_with_default"])
3289def placeholder_with_default(input, shape, name=None):  # pylint: disable=redefined-builtin
3290  """A placeholder op that passes through `input` when its output is not fed.
3291
3292  @compatibility(TF2)
3293  This API is strongly discouraged for use with eager execution and
3294  `tf.function`. The primary use of this API is for testing computation wrapped
3295  within a `tf.function` where the input tensors might not have statically known
3296  fully-defined shapes. The same can be achieved by creating a
3297  [concrete function](
3298  https://www.tensorflow.org/guide/function#obtaining_concrete_functions)
3299  from the `tf.function` with a `tf.TensorSpec` input which has partially
3300  defined shapes. For example, the code
3301
3302  >>> @tf.function
3303  ... def f():
3304  ...   x = tf.compat.v1.placeholder_with_default(
3305  ...       tf.constant([[1., 2., 3.], [4., 5., 6.]]), [None, 3])
3306  ...   y = tf.constant([[1.],[2.], [3.]])
3307  ...   z = tf.matmul(x, y)
3308  ...   assert z.shape[0] == None
3309  ...   assert z.shape[1] == 1
3310
3311  >>> f()
3312
3313  can easily be replaced by
3314
3315  >>> @tf.function
3316  ... def f(x):
3317  ...   y = tf.constant([[1.],[2.], [3.]])
3318  ...   z = tf.matmul(x, y)
3319  ...   assert z.shape[0] == None
3320  ...   assert z.shape[1] == 1
3321
3322  >>> g = f.get_concrete_function(tf.TensorSpec([None, 3]))
3323
3324  You can learn more about `tf.function` at [Better
3325  performance with tf.function](https://www.tensorflow.org/guide/function).
3326  @end_compatibility
3327
3328  Args:
3329    input: A `Tensor`. The default value to produce when output is not fed.
3330    shape: A `tf.TensorShape` or list of `int`s. The (possibly partial) shape of
3331      the tensor.
3332    name: A name for the operation (optional).
3333
3334  Returns:
3335    A `Tensor`. Has the same type as `input`.
3336  """
3337  return gen_array_ops.placeholder_with_default(input, shape, name)
3338
3339
3340@tf_export(v1=["sparse.placeholder", "sparse_placeholder"])
3341@deprecation.deprecated_endpoints("sparse_placeholder")
3342def sparse_placeholder(dtype, shape=None, name=None):
3343  """Inserts a placeholder for a sparse tensor that will be always fed.
3344
3345  **Important**: This sparse tensor will produce an error if evaluated.
3346  Its value must be fed using the `feed_dict` optional argument to
3347  `Session.run()`, `Tensor.eval()`, or `Operation.run()`.
3348
3349  For example:
3350
3351  ```python
3352  x = tf.compat.v1.sparse.placeholder(tf.float32)
3353  y = tf.sparse.reduce_sum(x)
3354
3355  with tf.compat.v1.Session() as sess:
3356    print(sess.run(y))  # ERROR: will fail because x was not fed.
3357
3358    indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)
3359    values = np.array([1.0, 2.0], dtype=np.float32)
3360    shape = np.array([7, 9, 2], dtype=np.int64)
3361    print(sess.run(y, feed_dict={
3362      x: tf.compat.v1.SparseTensorValue(indices, values, shape)}))  # Will
3363      succeed.
3364    print(sess.run(y, feed_dict={
3365      x: (indices, values, shape)}))  # Will succeed.
3366
3367    sp = tf.sparse.SparseTensor(indices=indices, values=values,
3368                                dense_shape=shape)
3369    sp_value = sp.eval(session=sess)
3370    print(sess.run(y, feed_dict={x: sp_value}))  # Will succeed.
3371  ```
3372
3373  @compatibility{eager} Placeholders are not compatible with eager execution.
3374
3375  Args:
3376    dtype: The type of `values` elements in the tensor to be fed.
3377    shape: The shape of the tensor to be fed (optional). If the shape is not
3378      specified, you can feed a sparse tensor of any shape.
3379    name: A name for prefixing the operations (optional).
3380
3381  Returns:
3382    A `SparseTensor` that may be used as a handle for feeding a value, but not
3383    evaluated directly.
3384
3385  Raises:
3386    RuntimeError: if eager execution is enabled
3387  """
3388  if context.executing_eagerly():
3389    raise RuntimeError("`sparse_placeholder` is not compatible with "
3390                       "eager execution.")
3391
3392  shape_name = (name + "/shape") if name is not None else None
3393  default_shape_name = (name + "/shape_default") if name is not None else None
3394  if shape is None:
3395    rank = None
3396    dense_shape = placeholder(dtypes.int64, shape=[rank], name=shape_name)
3397    dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
3398  else:
3399    if isinstance(shape, ops.Tensor):
3400      rank = shape.get_shape()[0]
3401      dense_shape_default = tensor_util.constant_value_as_shape(shape)
3402    else:
3403      rank = len(shape)
3404      # determine the shape, to override the `.shape` property of the
3405      # `SparseTensor`
3406      dense_shape_default = tensor_shape.TensorShape(
3407          tuple(None if dim == -1 else dim for dim in shape))
3408      shape = tuple(tensor_shape.dimension_value(dim) for dim in shape)
3409      shape = tuple(-1 if dim is None else dim for dim in shape)
3410      shape = ops.convert_to_tensor(
3411          shape, dtype=dtypes.int64, name=default_shape_name)
3412
3413    # `dense_shape` needs to be feedable (for users that treat this as an
3414    # actual placeholder). `constant_value_as_shape` sets constants to
3415    # not-feedable. placeholder_with_default works, but blocks `SparseTensor`
3416    # from reading the default value back out.
3417    dense_shape = placeholder_with_default(
3418        shape, shape=shape.shape, name=shape_name)
3419
3420  result = sparse_tensor.SparseTensor(
3421      values=placeholder(
3422          dtype,
3423          shape=[None],
3424          name=(name + "/values") if name is not None else None),
3425      indices=placeholder(
3426          dtypes.int64,
3427          shape=[None, rank],
3428          name=(name + "/indices") if name is not None else None),
3429      dense_shape=dense_shape)
3430
3431  # Now the SparseTensor.shape is a list of `None`s, since it couldn't read the
3432  # default shape out of the placeholder. Override that
3433  # shape to be the value determined here, so partial shapes can be
3434  # propagated.
3435  result._dense_shape_default = dense_shape_default
3436  return result
3437
3438# pylint: enable=redefined-outer-name
3439
3440
3441@tf_export("pad", v1=[])
3442@dispatch.add_dispatch_support
3443def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
3444  """Pads a tensor.
3445
3446  This operation pads a `tensor` according to the `paddings` you specify.
3447  `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3448  `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3449  many values to add before the contents of `tensor` in that dimension, and
3450  `paddings[D, 1]` indicates how many values to add after the contents of
3451  `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3452  and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3453  `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3454  no greater than `tensor.dim_size(D)`.
3455
3456  The padded size of each dimension D of the output is:
3457
3458  `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3459
3460  For example:
3461
3462  ```python
3463  t = tf.constant([[1, 2, 3], [4, 5, 6]])
3464  paddings = tf.constant([[1, 1,], [2, 2]])
3465  # 'constant_values' is 0.
3466  # rank of 't' is 2.
3467  tf.pad(t, paddings, "CONSTANT")  # [[0, 0, 0, 0, 0, 0, 0],
3468                                   #  [0, 0, 1, 2, 3, 0, 0],
3469                                   #  [0, 0, 4, 5, 6, 0, 0],
3470                                   #  [0, 0, 0, 0, 0, 0, 0]]
3471
3472  tf.pad(t, paddings, "REFLECT")  # [[6, 5, 4, 5, 6, 5, 4],
3473                                  #  [3, 2, 1, 2, 3, 2, 1],
3474                                  #  [6, 5, 4, 5, 6, 5, 4],
3475                                  #  [3, 2, 1, 2, 3, 2, 1]]
3476
3477  tf.pad(t, paddings, "SYMMETRIC")  # [[2, 1, 1, 2, 3, 3, 2],
3478                                    #  [2, 1, 1, 2, 3, 3, 2],
3479                                    #  [5, 4, 4, 5, 6, 6, 5],
3480                                    #  [5, 4, 4, 5, 6, 6, 5]]
3481  ```
3482
3483  Args:
3484    tensor: A `Tensor`.
3485    paddings: A `Tensor` of type `int32`.
3486    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3487    constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3488      same type as `tensor`.
3489    name: A name for the operation (optional).
3490
3491  Returns:
3492    A `Tensor`. Has the same type as `tensor`.
3493
3494  Raises:
3495    ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3496  """
3497  return pad(tensor, paddings, mode, name, constant_values)
3498
3499
3500@tf_export(v1=["pad"])
3501@dispatch.add_dispatch_support
3502def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0):  # pylint: disable=invalid-name
3503  """Pads a tensor.
3504
3505  This operation pads a `tensor` according to the `paddings` you specify.
3506  `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3507  `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3508  many values to add before the contents of `tensor` in that dimension, and
3509  `paddings[D, 1]` indicates how many values to add after the contents of
3510  `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3511  and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3512  `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3513  no greater than `tensor.dim_size(D)`.
3514
3515  The padded size of each dimension D of the output is:
3516
3517  `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3518
3519  For example:
3520
3521  ```python
3522  t = tf.constant([[1, 2, 3], [4, 5, 6]])
3523  paddings = tf.constant([[1, 1,], [2, 2]])
3524  # 'constant_values' is 0.
3525  # rank of 't' is 2.
3526  tf.pad(t, paddings, "CONSTANT")  # [[0, 0, 0, 0, 0, 0, 0],
3527                                   #  [0, 0, 1, 2, 3, 0, 0],
3528                                   #  [0, 0, 4, 5, 6, 0, 0],
3529                                   #  [0, 0, 0, 0, 0, 0, 0]]
3530
3531  tf.pad(t, paddings, "REFLECT")  # [[6, 5, 4, 5, 6, 5, 4],
3532                                  #  [3, 2, 1, 2, 3, 2, 1],
3533                                  #  [6, 5, 4, 5, 6, 5, 4],
3534                                  #  [3, 2, 1, 2, 3, 2, 1]]
3535
3536  tf.pad(t, paddings, "SYMMETRIC")  # [[2, 1, 1, 2, 3, 3, 2],
3537                                    #  [2, 1, 1, 2, 3, 3, 2],
3538                                    #  [5, 4, 4, 5, 6, 6, 5],
3539                                    #  [5, 4, 4, 5, 6, 6, 5]]
3540  ```
3541
3542  Args:
3543    tensor: A `Tensor`.
3544    paddings: A `Tensor` of type `int32`.
3545    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3546    name: A name for the operation (optional).
3547    constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3548      same type as `tensor`.
3549
3550  Returns:
3551    A `Tensor`. Has the same type as `tensor`.
3552
3553  Raises:
3554    ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3555  """
3556
3557  # Convert lower/mixed case to upper for NumPy compatibility
3558  # NumPy uses all lower-case modes.
3559  mode = mode.upper()
3560  if mode == "CONSTANT":
3561    # TODO(rjryan): Once the forward compatibility period (3 weeks) have passed
3562    # remove the "Pad" fallback here.
3563    if not tensor_util.is_tf_type(constant_values) and constant_values == 0:
3564      result = gen_array_ops.pad(tensor, paddings, name=name)
3565    else:
3566      result = gen_array_ops.pad_v2(
3567          tensor, paddings, constant_values, name=name)
3568  elif mode == "REFLECT":
3569    result = gen_array_ops.mirror_pad(
3570        tensor, paddings, mode="REFLECT", name=name)
3571  elif mode == "SYMMETRIC":
3572    result = gen_array_ops.mirror_pad(
3573        tensor, paddings, mode="SYMMETRIC", name=name)
3574  else:
3575    raise ValueError("Unknown padding mode: %s" % mode)
3576
3577  # Restore shape information where possible.
3578  if not context.executing_eagerly():
3579    paddings_constant = _get_paddings_constant(paddings)
3580    input_shape = (
3581        tensor_shape.TensorShape(tensor.shape)
3582        if isinstance(tensor, ops.Tensor) else result.op.inputs[0].shape)
3583    if (input_shape.ndims is not None and
3584        not result.shape.is_fully_defined() and paddings_constant is not None):
3585      new_shape = []
3586      for padding, dim in zip(paddings_constant, input_shape.as_list()):
3587        if padding is None or dim is None or any((x is None for x in padding)):
3588          new_shape.append(None)
3589        else:
3590          new_shape.append(sum(padding) + dim)
3591      result.set_shape(new_shape)
3592
3593  return result
3594
3595
3596def _get_paddings_constant(paddings):
3597  """Helper to get the constant values of the paddings arg to pad().
3598
3599  Used under V1 graph mode to facilitate computation of the shape of the output
3600  tensor of `pad()`.
3601
3602  Args:
3603    paddings: The same paddings arg as passed to pad(). Can be a Tensor, or
3604      a nested list or tuple of Tensor and/or numbers.
3605
3606  Returns:
3607    A nested list or numbers or `None`, in which `None` indicates unknown
3608    padding size.
3609  """
3610  if isinstance(paddings, ops.Tensor):
3611    return tensor_util.constant_value(paddings, partial=True)
3612  elif isinstance(paddings, (list, tuple)):
3613    return [_get_paddings_constant(x) for x in paddings]
3614  else:
3615    return paddings
3616
3617
3618@tf_export("meshgrid")
3619@dispatch.add_dispatch_support
3620def meshgrid(*args, **kwargs):
3621  """Broadcasts parameters for evaluation on an N-D grid.
3622
3623  Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
3624  of N-D coordinate arrays for evaluating expressions on an N-D grid.
3625
3626  Notes:
3627
3628  `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
3629  When the `indexing` argument is set to 'xy' (the default), the broadcasting
3630  instructions for the first two dimensions are swapped.
3631
3632  Examples:
3633
3634  Calling `X, Y = meshgrid(x, y)` with the tensors
3635
3636  ```python
3637  x = [1, 2, 3]
3638  y = [4, 5, 6]
3639  X, Y = tf.meshgrid(x, y)
3640  # X = [[1, 2, 3],
3641  #      [1, 2, 3],
3642  #      [1, 2, 3]]
3643  # Y = [[4, 4, 4],
3644  #      [5, 5, 5],
3645  #      [6, 6, 6]]
3646  ```
3647
3648  Args:
3649    *args: `Tensor`s with rank 1.
3650    **kwargs:
3651      - indexing: Either 'xy' or 'ij' (optional, default: 'xy').
3652      - name: A name for the operation (optional).
3653
3654  Returns:
3655    outputs: A list of N `Tensor`s with rank N.
3656
3657  Raises:
3658    TypeError: When no keyword arguments (kwargs) are passed.
3659    ValueError: When indexing keyword argument is not one of `xy` or `ij`.
3660  """
3661
3662  indexing = kwargs.pop("indexing", "xy")
3663  name = kwargs.pop("name", "meshgrid")
3664  if kwargs:
3665    key = list(kwargs.keys())[0]
3666    raise TypeError("'{}' is an invalid keyword argument "
3667                    "for this function".format(key))
3668
3669  if indexing not in ("xy", "ij"):
3670    raise ValueError("indexing parameter must be either 'xy' or 'ij'")
3671
3672  with ops.name_scope(name, "meshgrid", args) as name:
3673    ndim = len(args)
3674    s0 = (1,) * ndim
3675
3676    if not ndim:
3677      return []
3678
3679    # Prepare reshape by inserting dimensions with size 1 where needed
3680    output = []
3681    for i, x in enumerate(args):
3682      output.append(reshape(stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
3683    # Create parameters for broadcasting each tensor to the full size
3684    shapes = [size(x) for x in args]
3685
3686    output_dtype = ops.convert_to_tensor(args[0]).dtype.base_dtype
3687
3688    if indexing == "xy" and ndim > 1:
3689      output[0] = reshape(output[0], (1, -1) + (1,) * (ndim - 2))
3690      output[1] = reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
3691      shapes[0], shapes[1] = shapes[1], shapes[0]
3692
3693    # TODO(nolivia): improve performance with a broadcast
3694    mult_fact = ones(shapes, output_dtype)
3695    return [x * mult_fact for x in output]
3696
3697
3698NEW_AXIS = -1
3699SHRINK_AXIS = -2
3700
3701
3702# PEP-8 naming
3703# pylint: disable=invalid-name,redefined-outer-name
3704def _compute_size_of_strided_dim(shrink, spec, size):
3705  """Computes the size of a single strided slice dimension."""
3706
3707  unknown = None  # Document what None means here.
3708  use_full_range = None  # Document other use of None.
3709  # if this is a shrink axis (i.e. a non-range index)
3710  # it either will produce an error or return 1
3711  if shrink:
3712    return 1
3713  if size is unknown or size.value is unknown:
3714    return unknown
3715  size = size.value
3716  stride = spec.step
3717  if stride is not unknown:
3718    if stride == 0:
3719      return unknown
3720    stride = spec.step
3721    valid_range = [0, size] if stride > 0 else [-1, size - 1]
3722
3723    # PEP-8 naming
3724    # pylint: disable=invalid-name
3725    def canonical(x, c):
3726      if x is use_full_range:
3727        return valid_range[c] if stride > 0 else valid_range[(c + 1) & 1]
3728      else:
3729        x_fwd = size + x if x < 0 else x  # make negative indices positive
3730        return max(valid_range[0], min(valid_range[1], x_fwd))
3731
3732    begin = canonical(spec.start, 0)
3733    end = canonical(spec.stop, 1)
3734    interval_length = end - begin
3735    if interval_length == 0 or ((interval_length < 0) != (stride < 0)):
3736      return 0
3737    else:
3738      remainder = 1 if interval_length % stride != 0 else 0
3739      return interval_length // stride + remainder
3740  else:
3741    return unknown  # unknown because stride is unknown
3742
3743
3744def _TileGradShape(op):
3745  """Shape function for the TileGrad op."""
3746  multiples_shape = op.inputs[1].get_shape().with_rank(1)
3747  input_shape = op.inputs[0].get_shape().with_rank(multiples_shape[0])
3748  # NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
3749  # it is a vector of non-negative integers, and (ii) doing so allows
3750  # us to handle partially-known multiples.
3751  multiples = tensor_util.constant_value_as_shape(op.inputs[1]).with_rank(
3752      input_shape.ndims)
3753  if multiples.ndims is None:
3754    return [tensor_shape.unknown_shape()]
3755  else:
3756    output_dims = []
3757    for dim, multiple in zip(input_shape.dims, multiples.dims):
3758      output_dims.append(dim // multiple)
3759    return [tensor_shape.TensorShape(output_dims)]
3760
3761
3762@tf_export("edit_distance")
3763@dispatch.add_dispatch_support
3764def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
3765  """Computes the Levenshtein distance between sequences.
3766
3767  This operation takes variable-length sequences (`hypothesis` and `truth`),
3768  each provided as a `SparseTensor`, and computes the Levenshtein distance.
3769  You can normalize the edit distance by length of `truth` by setting
3770  `normalize` to true.
3771
3772  For example:
3773
3774  Given the following input,
3775  * `hypothesis` is a `tf.SparseTensor` of shape `[2, 1, 1]`
3776  * `truth` is a `tf.SparseTensor` of shape `[2, 2, 2]`
3777
3778  >>> hypothesis = tf.SparseTensor(
3779  ...   [[0, 0, 0],
3780  ...    [1, 0, 0]],
3781  ...   ["a", "b"],
3782  ...   (2, 1, 1))
3783  >>> truth = tf.SparseTensor(
3784  ...   [[0, 1, 0],
3785  ...    [1, 0, 0],
3786  ...    [1, 0, 1],
3787  ...    [1, 1, 0]],
3788  ...    ["a", "b", "c", "a"],
3789  ...    (2, 2, 2))
3790  >>> tf.edit_distance(hypothesis, truth, normalize=True)
3791  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3792  array([[inf, 1. ],
3793         [0.5, 1. ]], dtype=float32)>
3794
3795  The operation returns a dense Tensor of shape `[2, 2]` with
3796  edit distances normalized by `truth` lengths.
3797
3798  **Note**: It is possible to calculate edit distance between two
3799  sparse tensors with variable-length values. However, attempting to create
3800  them while eager execution is enabled will result in a `ValueError`.
3801
3802  For the following  inputs,
3803
3804  ```python
3805  # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values:
3806  #   (0,0) = ["a"]
3807  #   (1,0) = ["b"]
3808  hypothesis = tf.sparse.SparseTensor(
3809      [[0, 0, 0],
3810       [1, 0, 0]],
3811      ["a", "b"],
3812      (2, 1, 1))
3813
3814  # 'truth' is a tensor of shape `[2, 2]` with variable-length values:
3815  #   (0,0) = []
3816  #   (0,1) = ["a"]
3817  #   (1,0) = ["b", "c"]
3818  #   (1,1) = ["a"]
3819  truth = tf.sparse.SparseTensor(
3820      [[0, 1, 0],
3821       [1, 0, 0],
3822       [1, 0, 1],
3823       [1, 1, 0]],
3824      ["a", "b", "c", "a"],
3825      (2, 2, 2))
3826
3827  normalize = True
3828
3829  # The output would be a dense Tensor of shape `(2,)`, with edit distances
3830  normalized by 'truth' lengths.
3831  # output => array([0., 0.5], dtype=float32)
3832  ```
3833
3834  Args:
3835    hypothesis: A `SparseTensor` containing hypothesis sequences.
3836    truth: A `SparseTensor` containing truth sequences.
3837    normalize: A `bool`. If `True`, normalizes the Levenshtein distance by
3838      length of `truth.`
3839    name: A name for the operation (optional).
3840
3841  Returns:
3842    A dense `Tensor` with rank `R - 1`, where R is the rank of the
3843    `SparseTensor` inputs `hypothesis` and `truth`.
3844
3845  Raises:
3846    TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
3847  """
3848  if not isinstance(
3849      hypothesis,
3850      (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3851    raise TypeError("Hypothesis must be a SparseTensor.")
3852  if not isinstance(
3853      truth, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3854    raise TypeError("Truth must be a SparseTensor.")
3855
3856  return gen_array_ops.edit_distance(
3857      hypothesis.indices,
3858      hypothesis.values,
3859      hypothesis.dense_shape,
3860      truth.indices,
3861      truth.values,
3862      truth.dense_shape,
3863      normalize=normalize,
3864      name=name)
3865
3866
3867@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
3868def _FakeQuantWithMinMaxArgsGradient(op, grad):
3869  """Gradient for FakeQuantWithMinMaxArgs op."""
3870  return fake_quant_with_min_max_args_gradient(
3871      grad,
3872      op.inputs[0],
3873      min=op.get_attr("min"),
3874      max=op.get_attr("max"),
3875      num_bits=op.get_attr("num_bits"),
3876      narrow_range=op.get_attr("narrow_range"))
3877
3878
3879@ops.RegisterGradient("FakeQuantWithMinMaxVars")
3880def _FakeQuantWithMinMaxVarsGradient(op, grad):
3881  """Gradient for FakeQuantWithMinMaxVars op."""
3882  return fake_quant_with_min_max_vars_gradient(
3883      grad,
3884      op.inputs[0],
3885      op.inputs[1],
3886      op.inputs[2],
3887      num_bits=op.get_attr("num_bits"),
3888      narrow_range=op.get_attr("narrow_range"))
3889
3890
3891@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
3892def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
3893  """Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
3894  return fake_quant_with_min_max_vars_per_channel_gradient(
3895      grad,
3896      op.inputs[0],
3897      op.inputs[1],
3898      op.inputs[2],
3899      num_bits=op.get_attr("num_bits"),
3900      narrow_range=op.get_attr("narrow_range"))
3901
3902
3903@ops.RegisterGradient("QuantizeAndDequantizeV4")
3904def _QuantizeAndDequantizeV4Grad(op, grad):
3905  """Gradient for QuantizeAndDequantizeV4 op."""
3906  return quantize_and_dequantize_v4_grad(
3907      grad,
3908      op.inputs[0],
3909      op.inputs[1],
3910      op.inputs[2],
3911      axis=op.get_attr("axis"))
3912
3913
3914@ops.RegisterGradient("QuantizeAndDequantizeV4Grad")
3915def _QuantizeAndDequantizeV4GradGrad(op, grad):
3916  """Gradient for QuantizeAndDequantizeV4Grad op."""
3917  return _QuantizeAndDequantizeV4Grad(op, grad)
3918
3919
3920@tf_export("required_space_to_batch_paddings")
3921def required_space_to_batch_paddings(input_shape,
3922                                     block_shape,
3923                                     base_paddings=None,
3924                                     name=None):
3925  """Calculate padding required to make block_shape divide input_shape.
3926
3927  This function can be used to calculate a suitable paddings argument for use
3928  with space_to_batch_nd and batch_to_space_nd.
3929
3930  Args:
3931    input_shape: int32 Tensor of shape [N].
3932    block_shape: int32 Tensor of shape [N].
3933    base_paddings: Optional int32 Tensor of shape [N, 2].  Specifies the minimum
3934      amount of padding to use.  All elements must be >= 0.  If not specified,
3935      defaults to 0.
3936    name: string.  Optional name prefix.
3937
3938  Returns:
3939    (paddings, crops), where:
3940
3941    `paddings` and `crops` are int32 Tensors of rank 2 and shape [N, 2]
3942    satisfying:
3943
3944        paddings[i, 0] = base_paddings[i, 0].
3945        0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
3946        (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0
3947
3948        crops[i, 0] = 0
3949        crops[i, 1] = paddings[i, 1] - base_paddings[i, 1]
3950
3951  Raises: ValueError if called with incompatible shapes.
3952  """
3953  with ops.name_scope(name, "required_space_to_batch_paddings",
3954                      [input_shape, block_shape]):
3955    input_shape = ops.convert_to_tensor(
3956        input_shape, dtype=dtypes.int32, name="input_shape")
3957    block_shape = ops.convert_to_tensor(
3958        block_shape, dtype=dtypes.int32, name="block_shape")
3959
3960    block_shape.get_shape().assert_is_fully_defined()
3961    block_shape.get_shape().assert_has_rank(1)
3962    num_block_dims = block_shape.get_shape().dims[0].value
3963    if num_block_dims == 0:
3964      return zeros([0, 2], dtypes.int32), zeros([0, 2], dtypes.int32)
3965
3966    input_shape.get_shape().assert_is_compatible_with([num_block_dims])
3967
3968    if base_paddings is not None:
3969      base_paddings = ops.convert_to_tensor(
3970          base_paddings, dtype=dtypes.int32, name="base_paddings")
3971      base_paddings.get_shape().assert_is_compatible_with([num_block_dims, 2])
3972    else:
3973      base_paddings = zeros([num_block_dims, 2], dtypes.int32)
3974
3975    const_block_shape = tensor_util.constant_value(block_shape)
3976    const_input_shape = tensor_util.constant_value(input_shape)
3977    const_base_paddings = tensor_util.constant_value(base_paddings)
3978    if (const_block_shape is not None and const_input_shape is not None and
3979        const_base_paddings is not None):
3980      block_shape = const_block_shape
3981      input_shape = const_input_shape
3982      base_paddings = const_base_paddings
3983
3984    # Use same expression for both constant and non-constant case.
3985    pad_start = base_paddings[:, 0]
3986    orig_pad_end = base_paddings[:, 1]
3987    full_input_shape = input_shape + pad_start + orig_pad_end
3988    pad_end_extra = (block_shape - full_input_shape % block_shape) % block_shape
3989    pad_end = orig_pad_end + pad_end_extra
3990
3991    result_paddings = stack(
3992        [[pad_start[i], pad_end[i]] for i in range(num_block_dims)],
3993        name="paddings")
3994    result_crops = stack([[0, pad_end_extra[i]] for i in range(num_block_dims)],
3995                         name="crops")
3996    return result_paddings, result_crops
3997
3998
3999@tf_export(v1=["nn.space_to_batch", "space_to_batch"])
4000@dispatch.add_dispatch_support
4001@deprecation.deprecated_endpoints("space_to_batch")
4002def space_to_batch(  # pylint: disable=missing-docstring
4003    input,  # pylint: disable=redefined-builtin
4004    paddings,
4005    block_size=None,
4006    name=None,
4007    block_shape=None):  # pylint: disable=redefined-builtin
4008  block_size = deprecation.deprecated_argument_lookup("block_shape",
4009                                                      block_shape, "block_size",
4010                                                      block_size)
4011  result = space_to_batch_nd(
4012      input,
4013      paddings=paddings,
4014      block_shape=np.array([block_size, block_size], dtype=np.int64),
4015      name=name)
4016  result.set_shape(result.get_shape().with_rank(4))
4017  return result
4018
4019
4020space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
4021
4022
4023@tf_export("space_to_batch", "nn.space_to_batch", v1=[])
4024@dispatch.add_dispatch_support
4025def space_to_batch_v2(input, block_shape, paddings, name=None):  # pylint: disable=redefined-builtin
4026  return space_to_batch_nd(input, block_shape, paddings, name)
4027
4028
4029space_to_batch_v2.__doc__ = gen_array_ops.space_to_batch_nd.__doc__
4030
4031
4032@tf_export(v1=["nn.space_to_depth", "space_to_depth"])
4033@dispatch.add_dispatch_support
4034@deprecation.deprecated_endpoints("space_to_depth")
4035def space_to_depth(input, block_size, name=None, data_format="NHWC"):  # pylint: disable=redefined-builtin
4036  return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
4037
4038
4039space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
4040
4041
4042@tf_export("nn.space_to_depth", v1=[])
4043@dispatch.add_dispatch_support
4044def space_to_depth_v2(input, block_size, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4045  return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
4046
4047
4048space_to_depth_v2.__doc__ = gen_array_ops.space_to_depth.__doc__
4049
4050
4051@tf_export(v1=["nn.depth_to_space", "depth_to_space"])
4052@dispatch.add_dispatch_support
4053@deprecation.deprecated_endpoints("depth_to_space")
4054def depth_to_space(input, block_size, name=None, data_format="NHWC"):  # pylint: disable=redefined-builtin
4055  return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
4056
4057
4058depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
4059
4060
4061@tf_export("nn.depth_to_space", v1=[])
4062@dispatch.add_dispatch_support
4063def depth_to_space_v2(input, block_size, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4064  return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
4065
4066
4067depth_to_space_v2.__doc__ = gen_array_ops.depth_to_space.__doc__
4068
4069
4070@tf_export(v1=["batch_to_space"])
4071@dispatch.add_dispatch_support
4072def batch_to_space(input, crops, block_size, name=None, block_shape=None):  # pylint: disable=redefined-builtin,missing-docstring
4073  block_size = deprecation.deprecated_argument_lookup("block_shape",
4074                                                      block_shape, "block_size",
4075                                                      block_size)
4076  result = batch_to_space_nd(
4077      input,
4078      crops=crops,
4079      block_shape=np.array([block_size, block_size], dtype=np.int64),
4080      name=name)
4081  result.set_shape(result.get_shape().with_rank(4))
4082  return result
4083
4084
4085batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__
4086
4087
4088@tf_export("batch_to_space", v1=[])
4089@dispatch.add_dispatch_support
4090def batch_to_space_v2(input, block_shape, crops, name=None):  # pylint: disable=redefined-builtin
4091  """BatchToSpace for N-D tensors of type T.
4092
4093  This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
4094  shape `block_shape + [batch]`, interleaves these blocks back into the grid
4095  defined by the spatial dimensions `[1, ..., M]`, to obtain a result with the
4096  same rank as the input.  The spatial dimensions of this intermediate result
4097  are then optionally cropped according to `crops` to produce the output.  This
4098  is the reverse of SpaceToBatch (see `tf.space_to_batch`).
4099
4100  Args:
4101    input: A N-D `Tensor` with shape `input_shape = [batch] + spatial_shape +
4102      remaining_shape`, where `spatial_shape` has M dimensions.
4103    block_shape: A 1-D `Tensor` with shape [M]. Must be one of the following
4104      types: `int32`, `int64`. All values must be >= 1. For backwards
4105      compatibility with TF 1.0, this parameter may be an int, in which case it
4106      is converted to
4107      `numpy.array([block_shape, block_shape],
4108      dtype=numpy.int64)`.
4109    crops: A  2-D `Tensor` with shape `[M, 2]`. Must be one of the
4110      following types: `int32`, `int64`. All values must be >= 0.
4111      `crops[i] = [crop_start, crop_end]` specifies the amount to crop from
4112      input dimension `i + 1`, which corresponds to spatial dimension `i`.
4113      It is required that
4114      `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`.
4115      This operation is equivalent to the following steps:
4116      1. Reshape `input` to `reshaped` of shape: [block_shape[0], ...,
4117        block_shape[M-1], batch / prod(block_shape), input_shape[1], ...,
4118        input_shape[N-1]]
4119      2. Permute dimensions of `reshaped` to produce `permuted` of shape
4120         [batch / prod(block_shape),  input_shape[1], block_shape[0], ...,
4121         input_shape[M], block_shape[M-1], input_shape[M+1],
4122        ..., input_shape[N-1]]
4123      3. Reshape `permuted` to produce `reshaped_permuted` of shape
4124         [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
4125         input_shape[M] * block_shape[M-1], input_shape[M+1], ...,
4126         input_shape[N-1]]
4127      4. Crop the start and end of dimensions `[1, ..., M]` of
4128         `reshaped_permuted` according to `crops` to produce the output
4129         of shape:
4130         [batch / prod(block_shape),  input_shape[1] *
4131           block_shape[0] - crops[0,0] - crops[0,1], ..., input_shape[M] *
4132           block_shape[M-1] - crops[M-1,0] - crops[M-1,1],  input_shape[M+1],
4133           ..., input_shape[N-1]]
4134    name: A name for the operation (optional).
4135
4136  Examples:
4137
4138  1. For the following input of shape `[4, 1, 1, 1]`,
4139     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4140
4141     ```python
4142     [[[[1]]],
4143      [[[2]]],
4144      [[[3]]],
4145      [[[4]]]]
4146     ```
4147
4148    The output tensor has shape `[1, 2, 2, 1]` and value:
4149
4150     ```
4151     x = [[[[1], [2]],
4152         [[3], [4]]]]
4153     ```
4154
4155  2. For the following input of shape `[4, 1, 1, 3]`,
4156     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4157
4158     ```python
4159     [[[1,  2,   3]],
4160      [[4,  5,   6]],
4161      [[7,  8,   9]],
4162      [[10, 11, 12]]]
4163     ```
4164
4165    The output tensor has shape `[1, 2, 2, 3]` and value:
4166
4167    ```python
4168     x = [[[[1, 2, 3], [4,  5,  6 ]],
4169           [[7, 8, 9], [10, 11, 12]]]]
4170     ```
4171
4172  3. For the following
4173     input of shape `[4, 2, 2, 1]`,
4174     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4175
4176     ```python
4177     x = [[[[1], [3]], [[ 9], [11]]],
4178          [[[2], [4]], [[10], [12]]],
4179          [[[5], [7]], [[13], [15]]],
4180          [[[6], [8]], [[14], [16]]]]
4181     ```
4182
4183    The output tensor has shape `[1, 4, 4, 1]` and value:
4184
4185    ```python
4186     x = [[[1],  [2],  [ 3], [ 4]],
4187          [[5],  [6],  [ 7], [ 8]],
4188          [[9],  [10], [11], [12]],
4189          [[13], [14], [15], [16]]]
4190     ```
4191
4192  4. For the following input of shape
4193      `[8, 1, 3, 1]`,
4194      `block_shape = [2, 2]`, and `crops = [[0, 0], [2, 0]]`:
4195
4196      ```python
4197      x = [[[[0], [ 1], [ 3]]],
4198           [[[0], [ 9], [11]]],
4199           [[[0], [ 2], [ 4]]],
4200           [[[0], [10], [12]]],
4201           [[[0], [ 5], [ 7]]],
4202           [[[0], [13], [15]]],
4203           [[[0], [ 6], [ 8]]],
4204           [[[0], [14], [16]]]]
4205      ```
4206
4207      The output tensor has shape `[2, 2, 4, 1]` and value:
4208
4209      ```python
4210      x = [[[[ 1], [ 2], [ 3], [ 4]],
4211            [[ 5], [ 6], [ 7], [ 8]]],
4212           [[[ 9], [10], [11], [12]],
4213            [[13], [14], [15], [16]]]]
4214      ```
4215
4216  Returns:
4217    A `Tensor`. Has the same type as `input`.
4218  """
4219  if isinstance(block_shape, int):
4220    block_shape = np.array([block_shape, block_shape], dtype=np.int64)
4221
4222  return batch_to_space_nd(
4223      input=input, block_shape=block_shape, crops=crops, name=name)
4224
4225
4226@tf_export("one_hot")
4227@dispatch.add_dispatch_support
4228def one_hot(indices,
4229            depth,
4230            on_value=None,
4231            off_value=None,
4232            axis=None,
4233            dtype=None,
4234            name=None):
4235  """Returns a one-hot tensor.
4236
4237  See also `tf.fill`, `tf.eye`.
4238
4239  The locations represented by indices in `indices` take value `on_value`,
4240  while all other locations take value `off_value`.
4241
4242  `on_value` and `off_value` must have matching data types. If `dtype` is also
4243  provided, they must be the same data type as specified by `dtype`.
4244
4245  If `on_value` is not provided, it will default to the value `1` with type
4246  `dtype`
4247
4248  If `off_value` is not provided, it will default to the value `0` with type
4249  `dtype`
4250
4251  If the input `indices` is rank `N`, the output will have rank `N+1`. The
4252  new axis is created at dimension `axis` (default: the new axis is appended
4253  at the end).
4254
4255  If `indices` is a scalar the output shape will be a vector of length `depth`
4256
4257  If `indices` is a vector of length `features`, the output shape will be:
4258
4259  ```
4260    features x depth if axis == -1
4261    depth x features if axis == 0
4262  ```
4263
4264  If `indices` is a matrix (batch) with shape `[batch, features]`, the output
4265  shape will be:
4266
4267  ```
4268    batch x features x depth if axis == -1
4269    batch x depth x features if axis == 1
4270    depth x batch x features if axis == 0
4271  ```
4272
4273  If `indices` is a RaggedTensor, the 'axis' argument must be positive and refer
4274  to a non-ragged axis. The output will be equivalent to applying 'one_hot' on
4275  the values of the RaggedTensor, and creating a new RaggedTensor from the
4276  result.
4277
4278  If `dtype` is not provided, it will attempt to assume the data type of
4279  `on_value` or `off_value`, if one or both are passed in. If none of
4280  `on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
4281  value `tf.float32`.
4282
4283  Note: If a non-numeric data type output is desired (`tf.string`, `tf.bool`,
4284  etc.), both `on_value` and `off_value` _must_ be provided to `one_hot`.
4285
4286  For example:
4287
4288  ```python
4289  indices = [0, 1, 2]
4290  depth = 3
4291  tf.one_hot(indices, depth)  # output: [3 x 3]
4292  # [[1., 0., 0.],
4293  #  [0., 1., 0.],
4294  #  [0., 0., 1.]]
4295
4296  indices = [0, 2, -1, 1]
4297  depth = 3
4298  tf.one_hot(indices, depth,
4299             on_value=5.0, off_value=0.0,
4300             axis=-1)  # output: [4 x 3]
4301  # [[5.0, 0.0, 0.0],  # one_hot(0)
4302  #  [0.0, 0.0, 5.0],  # one_hot(2)
4303  #  [0.0, 0.0, 0.0],  # one_hot(-1)
4304  #  [0.0, 5.0, 0.0]]  # one_hot(1)
4305
4306  indices = [[0, 2], [1, -1]]
4307  depth = 3
4308  tf.one_hot(indices, depth,
4309             on_value=1.0, off_value=0.0,
4310             axis=-1)  # output: [2 x 2 x 3]
4311  # [[[1.0, 0.0, 0.0],   # one_hot(0)
4312  #   [0.0, 0.0, 1.0]],  # one_hot(2)
4313  #  [[0.0, 1.0, 0.0],   # one_hot(1)
4314  #   [0.0, 0.0, 0.0]]]  # one_hot(-1)
4315
4316  indices = tf.ragged.constant([[0, 1], [2]])
4317  depth = 3
4318  tf.one_hot(indices, depth)  # output: [2 x None x 3]
4319  # [[[1., 0., 0.],
4320  #   [0., 1., 0.]],
4321  #  [[0., 0., 1.]]]
4322  ```
4323
4324  Args:
4325    indices: A `Tensor` of indices.
4326    depth: A scalar defining the depth of the one hot dimension.
4327    on_value: A scalar defining the value to fill in output when `indices[j]
4328      = i`. (default: 1)
4329    off_value: A scalar defining the value to fill in output when `indices[j]
4330      != i`. (default: 0)
4331    axis: The axis to fill (default: -1, a new inner-most axis).
4332    dtype: The data type of the output tensor.
4333    name: A name for the operation (optional).
4334
4335  Returns:
4336    output: The one-hot tensor.
4337
4338  Raises:
4339    TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
4340    TypeError: If dtype of `on_value` and `off_value` don't match one another
4341  """
4342  with ops.name_scope(
4343      name, "one_hot",
4344      [indices, depth, on_value, off_value, axis, dtype]) as name:
4345    on_exists = on_value is not None
4346    off_exists = off_value is not None
4347
4348    if on_exists:
4349      on_value = ops.convert_to_tensor(on_value, dtype_hint=dtype)
4350    if off_exists:
4351      off_value = ops.convert_to_tensor(off_value, dtype_hint=dtype)
4352
4353    on_dtype = on_value.dtype.base_dtype if on_exists else None
4354    off_dtype = off_value.dtype.base_dtype if off_exists else None
4355
4356    if on_exists or off_exists:
4357      if dtype is not None:
4358        # Ensure provided on_value and/or off_value match dtype
4359        if on_exists and on_dtype != dtype:
4360          raise TypeError("dtype {0} of on_value does not match "
4361                          "dtype parameter {1}".format(on_dtype, dtype))
4362        if off_exists and off_dtype != dtype:
4363          raise TypeError("dtype {0} of off_value does not match "
4364                          "dtype parameter {1}".format(off_dtype, dtype))
4365      else:
4366        # dtype not provided: automatically assign it
4367        dtype = on_dtype if on_exists else off_dtype
4368    elif dtype is None:
4369      # None of on_value, off_value, or dtype provided. Default dtype to float32
4370      dtype = dtypes.float32
4371
4372    if not on_exists:
4373      # on_value not provided: assign to value 1 of type dtype
4374      on_value = ops.convert_to_tensor(1, dtype, name="on_value")
4375      on_dtype = dtype
4376    if not off_exists:
4377      # off_value not provided: assign to value 0 of type dtype
4378      off_value = ops.convert_to_tensor(0, dtype, name="off_value")
4379      off_dtype = dtype
4380
4381    if on_dtype != off_dtype:
4382      raise TypeError("dtype {0} of on_value does not match "
4383                      "dtype {1} of off_value".format(on_dtype, off_dtype))
4384
4385    return gen_array_ops.one_hot(indices, depth, on_value, off_value, axis,
4386                                 name)
4387
4388
4389def _all_dimensions(x):
4390  """Returns a 1D-tensor listing all dimensions in x."""
4391  # Fast path: avoid creating Rank and Range ops if ndims is known.
4392  if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
4393    return constant_op.constant(
4394        np.arange(x.get_shape().ndims), dtype=dtypes.int32)
4395  if (isinstance(x, sparse_tensor.SparseTensor) and
4396      x.dense_shape.get_shape().is_fully_defined()):
4397    r = x.dense_shape.get_shape().dims[0].value  # sparse.dense_shape is 1-D.
4398    return constant_op.constant(np.arange(r), dtype=dtypes.int32)
4399
4400  # Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
4401  return gen_math_ops._range(0, rank(x), 1)
4402
4403
4404@tf_export("sequence_mask")
4405@dispatch.add_dispatch_support
4406def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
4407  """Returns a mask tensor representing the first N positions of each cell.
4408
4409  If `lengths` has shape `[d_1, d_2, ..., d_n]` the resulting tensor `mask` has
4410  dtype `dtype` and shape `[d_1, d_2, ..., d_n, maxlen]`, with
4411
4412  ```
4413  mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
4414  ```
4415
4416  Examples:
4417
4418  ```python
4419  tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
4420                                  #  [True, True, True, False, False],
4421                                  #  [True, True, False, False, False]]
4422
4423  tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
4424                                    #   [True, True, True]],
4425                                    #  [[True, True, False],
4426                                    #   [False, False, False]]]
4427  ```
4428
4429  Args:
4430    lengths: integer tensor, all its values <= maxlen.
4431    maxlen: scalar integer tensor, size of last dimension of returned tensor.
4432      Default is the maximum value in `lengths`.
4433    dtype: output type of the resulting tensor.
4434    name: name of the op.
4435
4436  Returns:
4437    A mask tensor of shape `lengths.shape + (maxlen,)`, cast to specified dtype.
4438  Raises:
4439    ValueError: if `maxlen` is not a scalar.
4440  """
4441  with ops.name_scope(name, "SequenceMask", [lengths, maxlen]):
4442    lengths = ops.convert_to_tensor(lengths)
4443
4444    if maxlen is None:
4445      maxlen = gen_math_ops._max(lengths, _all_dimensions(lengths))
4446      maxlen = gen_math_ops.maximum(constant(0, maxlen.dtype), maxlen)
4447    else:
4448      maxlen = ops.convert_to_tensor(maxlen)
4449    if maxlen.get_shape().ndims is not None and maxlen.get_shape().ndims != 0:
4450      raise ValueError("maxlen must be scalar for sequence_mask")
4451
4452    # The basic idea is to compare a range row vector of size maxlen:
4453    # [0, 1, 2, 3, 4]
4454    # to length as a matrix with 1 column: [[1], [3], [2]].
4455    # Because of broadcasting on both arguments this comparison results
4456    # in a matrix of size (len(lengths), maxlen)
4457    row_vector = gen_math_ops._range(
4458        constant(0, maxlen.dtype), maxlen, constant(1, maxlen.dtype))
4459    # Since maxlen >= max(lengths), it is safe to use maxlen as a cast
4460    # authoritative type. Whenever maxlen fits into tf.int32, so do the lengths.
4461    matrix = gen_math_ops.cast(expand_dims(lengths, -1), maxlen.dtype)
4462    result = row_vector < matrix
4463    if dtype is None or result.dtype.is_compatible_with(dtype):
4464      return result
4465    else:
4466      return gen_math_ops.cast(result, dtype)
4467
4468
4469@tf_export(v1=["squeeze"])
4470@dispatch.add_dispatch_support
4471@deprecation.deprecated_args(None, "Use the `axis` argument instead",
4472                             "squeeze_dims")
4473def squeeze(input, axis=None, name=None, squeeze_dims=None):
4474  # pylint: disable=redefined-builtin
4475  """Removes dimensions of size 1 from the shape of a tensor.
4476
4477  Given a tensor `input`, this operation returns a tensor of the same type with
4478  all dimensions of size 1 removed. If you don't want to remove all size 1
4479  dimensions, you can remove specific size 1 dimensions by specifying
4480  `axis`.
4481
4482  For example:
4483
4484  >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4485  >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4486  >>> print(tf.shape(tf.squeeze(t)).numpy())
4487  [2 3]
4488
4489  Or, to remove specific size 1 dimensions:
4490
4491  >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4492  >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4493  >>> print(tf.shape(tf.squeeze(t, [2, 4])).numpy())
4494  [1 2 3 1]
4495
4496  Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4497  time, where `N` is the number of elements in the squeezed dimensions.
4498
4499  Args:
4500    input: A `Tensor`. The `input` to squeeze.
4501    axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4502      squeezes the dimensions listed. The dimension index starts at 0. It is an
4503      error to squeeze a dimension that is not 1. Must be in the range
4504      `[-rank(input), rank(input))`. Must be specified if `input` is a
4505      `RaggedTensor`.
4506    name: A name for the operation (optional).
4507    squeeze_dims: Deprecated keyword argument that is now axis.
4508
4509  Returns:
4510    A `Tensor`. Has the same type as `input`.
4511    Contains the same data as `input`, but has one or more dimensions of
4512    size 1 removed.
4513
4514  Raises:
4515    ValueError: When both `squeeze_dims` and `axis` are specified.
4516  """
4517  axis = deprecation.deprecated_argument_lookup("axis", axis, "squeeze_dims",
4518                                                squeeze_dims)
4519  if np.isscalar(axis):
4520    axis = [axis]
4521  return gen_array_ops.squeeze(input, axis, name)
4522
4523
4524@tf_export("squeeze", v1=[])
4525@dispatch.add_dispatch_support
4526def squeeze_v2(input, axis=None, name=None):
4527  """Removes dimensions of size 1 from the shape of a tensor.
4528
4529  Given a tensor `input`, this operation returns a tensor of the same type with
4530  all dimensions of size 1 removed. If you don't want to remove all size 1
4531  dimensions, you can remove specific size 1 dimensions by specifying
4532  `axis`.
4533
4534  For example:
4535
4536  ```python
4537  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4538  tf.shape(tf.squeeze(t))  # [2, 3]
4539  ```
4540
4541  Or, to remove specific size 1 dimensions:
4542
4543  ```python
4544  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4545  tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]
4546  ```
4547
4548  Unlike the older op `tf.compat.v1.squeeze`, this op does not accept a
4549  deprecated `squeeze_dims` argument.
4550
4551  Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4552  time, where `N` is the number of elements in the squeezed dimensions.
4553
4554  Args:
4555    input: A `Tensor`. The `input` to squeeze.
4556    axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4557      squeezes the dimensions listed. The dimension index starts at 0. It is an
4558      error to squeeze a dimension that is not 1. Must be in the range
4559      `[-rank(input), rank(input))`. Must be specified if `input` is a
4560      `RaggedTensor`.
4561    name: A name for the operation (optional).
4562
4563  Returns:
4564    A `Tensor`. Has the same type as `input`.
4565    Contains the same data as `input`, but has one or more dimensions of
4566    size 1 removed.
4567
4568  Raises:
4569    ValueError: The input cannot be converted to a tensor, or the specified
4570      axis cannot be squeezed.
4571  """
4572  # pylint: disable=redefined-builtin
4573  return squeeze(input, axis, name)
4574
4575
4576@tf_export(v1=["where"])
4577@dispatch.add_dispatch_support
4578def where(condition, x=None, y=None, name=None):
4579  """Return the elements, either from `x` or `y`, depending on the `condition`.
4580
4581  If both `x` and `y` are None, then this operation returns the coordinates of
4582  true elements of `condition`.  The coordinates are returned in a 2-D tensor
4583  where the first dimension (rows) represents the number of true elements, and
4584  the second dimension (columns) represents the coordinates of the true
4585  elements. Keep in mind, the shape of the output tensor can vary depending on
4586  how many true values there are in input. Indices are output in row-major
4587  order.
4588
4589  If both non-None, `x` and `y` must have the same shape.
4590  The `condition` tensor must be a scalar if `x` and `y` are scalar.
4591  If `x` and `y` are tensors of higher rank, then `condition` must be either a
4592  vector with size matching the first dimension of `x`, or must have the same
4593  shape as `x`.
4594
4595  The `condition` tensor acts as a mask that chooses, based on the value at each
4596  element, whether the corresponding element / row in the output should be taken
4597  from `x` (if true) or `y` (if false).
4598
4599  If `condition` is a vector and `x` and `y` are higher rank matrices, then it
4600  chooses which row (outer dimension) to copy from `x` and `y`. If `condition`
4601  has the same shape as `x` and `y`, then it chooses which element to copy from
4602  `x` and `y`.
4603
4604  Args:
4605    condition: A `Tensor` of type `bool`
4606    x: A Tensor which may have the same shape as `condition`. If `condition` is
4607      rank 1, `x` may have higher rank, but its first dimension must match the
4608      size of `condition`.
4609    y: A `tensor` with the same shape and type as `x`.
4610    name: A name of the operation (optional)
4611
4612  Returns:
4613    A `Tensor` with the same type and shape as `x`, `y` if they are non-None.
4614    Otherwise, a `Tensor` with shape `(num_true, rank(condition))`.
4615
4616  Raises:
4617    ValueError: When exactly one of `x` or `y` is non-None.
4618
4619  @compatibility(TF2)
4620
4621  This API is compatible with eager execution and `tf.function`. However, this
4622  is still a legacy API endpoint originally designed for TF1. To migrate to
4623  fully-native TF2, please replace its usage with `tf.where` instead, which is
4624  directly backwards compatible with `tf.compat.v1.where`.
4625
4626  However,`tf.compat.v1.where` is more restrictive than `tf.where`, requiring
4627  `x` and `y` to have the same shape, and returning a `Tensor` with the same
4628  type and shape as `x`, `y` (if they are both non-None).
4629
4630  `tf.where` will accept `x`, `y` that are not the same shape as long as they
4631  are broadcastable with one another and with `condition`, and will return a
4632  `Tensor` with shape broadcast from `condition`, `x`, and `y`.
4633
4634  For example, the following works with `tf.where` but not `tf.compat.v1.where`:
4635
4636  >>> tf.where([True, False, False, True], [1,2,3,4], [100])
4637  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
4638  dtype=int32)>
4639
4640  >>> tf.where(True, [1,2,3,4], 100)
4641  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
4642  dtype=int32)>
4643
4644  @end_compatibility
4645  """
4646  if x is None and y is None:
4647    with ops.name_scope(name, "Where", [condition]) as name:
4648      condition = ops.convert_to_tensor(
4649          condition, preferred_dtype=dtypes.bool, name="condition")
4650      return gen_array_ops.where(condition=condition, name=name)
4651  elif x is not None and y is not None:
4652    return gen_math_ops.select(condition=condition, x=x, y=y, name=name)
4653  else:
4654    raise ValueError("x and y must both be non-None or both be None.")
4655
4656
4657@tf_export("where", v1=["where_v2"])
4658@dispatch.add_dispatch_support
4659def where_v2(condition, x=None, y=None, name=None):
4660  """Return the elements where `condition` is `True` (multiplexing `x` and `y`).
4661
4662  This operator has two modes: in one mode both `x` and `y` are provided, in
4663  another mode neither are provided. `condition` is always expected to be a
4664  `tf.Tensor` of type `bool`.
4665
4666  #### Retrieving indices of `True` elements
4667
4668  If `x` and `y` are not provided (both are None):
4669
4670  `tf.where` will return the indices of `condition` that are `True`, in
4671  the form of a 2-D tensor with shape (n, d).
4672  (Where n is the number of matching indices in `condition`,
4673  and d is the number of dimensions in `condition`).
4674
4675  Indices are output in row-major order.
4676
4677  >>> tf.where([True, False, False, True])
4678  <tf.Tensor: shape=(2, 1), dtype=int64, numpy=
4679  array([[0],
4680         [3]])>
4681
4682  >>> tf.where([[True, False], [False, True]])
4683  <tf.Tensor: shape=(2, 2), dtype=int64, numpy=
4684  array([[0, 0],
4685         [1, 1]])>
4686
4687  >>> tf.where([[[True, False], [False, True], [True, True]]])
4688  <tf.Tensor: shape=(4, 3), dtype=int64, numpy=
4689  array([[0, 0, 0],
4690         [0, 1, 1],
4691         [0, 2, 0],
4692         [0, 2, 1]])>
4693
4694  #### Multiplexing between `x` and `y`
4695
4696  If `x` and `y` are provided (both have non-None values):
4697
4698  `tf.where` will choose an output shape from the shapes of `condition`, `x`,
4699  and `y` that all three shapes are
4700  [broadcastable](https://docs.scipy.org/doc/numpy/reference/ufuncs.html) to.
4701
4702  The `condition` tensor acts as a mask that chooses whether the corresponding
4703  element / row in the output should be taken from `x`
4704  (if the element in `condition` is True) or `y` (if it is false).
4705
4706  >>> tf.where([True, False, False, True], [1,2,3,4], [100,200,300,400])
4707  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 200, 300,   4],
4708  dtype=int32)>
4709  >>> tf.where([True, False, False, True], [1,2,3,4], [100])
4710  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
4711  dtype=int32)>
4712  >>> tf.where([True, False, False, True], [1,2,3,4], 100)
4713  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
4714  dtype=int32)>
4715  >>> tf.where([True, False, False, True], 1, 100)
4716  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   1],
4717  dtype=int32)>
4718
4719  >>> tf.where(True, [1,2,3,4], 100)
4720  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
4721  dtype=int32)>
4722  >>> tf.where(False, [1,2,3,4], 100)
4723  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([100, 100, 100, 100],
4724  dtype=int32)>
4725
4726  Note that if the gradient of either branch of the tf.where generates
4727  a NaN, then the gradient of the entire tf.where will be NaN. This is because
4728  the gradient calculation for tf.where combines the two branches, for
4729  performance reasons.
4730
4731  A workaround is to use an inner tf.where to ensure the function has
4732  no asymptote, and to avoid computing a value whose gradient is NaN by
4733  replacing dangerous inputs with safe inputs.
4734
4735  Instead of this,
4736
4737  >>> x = tf.constant(0., dtype=tf.float32)
4738  >>> with tf.GradientTape() as tape:
4739  ...   tape.watch(x)
4740  ...   y = tf.where(x < 1., 0., 1. / x)
4741  >>> print(tape.gradient(y, x))
4742  tf.Tensor(nan, shape=(), dtype=float32)
4743
4744  Although, the `1. / x` values are never used, its gradient is a NaN when x =
4745  0. Instead, we should guard that with another `tf.where`
4746
4747  >>> x = tf.constant(0., dtype=tf.float32)
4748  >>> with tf.GradientTape() as tape:
4749  ...   tape.watch(x)
4750  ...   safe_x = tf.where(tf.equal(x, 0.), 1., x)
4751  ...   y = tf.where(x < 1., 0., 1. / safe_x)
4752  >>> print(tape.gradient(y, x))
4753  tf.Tensor(0.0, shape=(), dtype=float32)
4754
4755  Args:
4756    condition: A `tf.Tensor` of type `bool`
4757    x: If provided, a Tensor which is of the same type as `y`, and has a shape
4758      broadcastable with `condition` and `y`.
4759    y: If provided, a Tensor which is of the same type as `x`, and has a shape
4760      broadcastable with `condition` and `x`.
4761    name: A name of the operation (optional).
4762
4763  Returns:
4764    If `x` and `y` are provided:
4765      A `Tensor` with the same type as `x` and `y`, and shape that
4766      is broadcast from `condition`, `x`, and `y`.
4767    Otherwise, a `Tensor` with shape `(num_true, dim_size(condition))`.
4768
4769  Raises:
4770    ValueError: When exactly one of `x` or `y` is non-None, or the shapes
4771      are not all broadcastable.
4772  """
4773  if x is None and y is None:
4774    with ops.name_scope(name, "Where", [condition]) as name:
4775      condition = ops.convert_to_tensor(
4776          condition, preferred_dtype=dtypes.bool, name="condition")
4777      return gen_array_ops.where(condition=condition, name=name)
4778  elif x is not None and y is not None:
4779    return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
4780  else:
4781    raise ValueError("x and y must both be non-None or both be None.")
4782
4783
4784# pylint: disable=redefined-builtin
4785@tf_export(v1=["reverse_sequence"])
4786@deprecation.deprecated_args(None,
4787                             "seq_dim is deprecated, use seq_axis instead",
4788                             "seq_dim")
4789@deprecation.deprecated_args(None,
4790                             "batch_dim is deprecated, use batch_axis instead",
4791                             "batch_dim")
4792def reverse_sequence(input,
4793                     seq_lengths,
4794                     seq_axis=None,
4795                     batch_axis=None,
4796                     name=None,
4797                     seq_dim=None,
4798                     batch_dim=None):
4799  """Reverses variable length slices.
4800
4801  This op first slices `input` along the dimension `batch_axis`, and for
4802  each slice `i`, reverses the first `seq_lengths[i]` elements along the
4803  dimension `seq_axis`.
4804
4805  The elements of `seq_lengths` must obey `seq_lengths[i] <=
4806  input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
4807  `input.dims[batch_axis]`.
4808
4809  The output slice `i` along dimension `batch_axis` is then given by
4810  input slice `i`, with the first `seq_lengths[i]` slices along
4811  dimension `seq_axis` reversed.
4812
4813  Example usage:
4814
4815  >>> seq_lengths = [7, 2, 3, 5]
4816  >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
4817  ...          [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
4818  >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
4819  >>> output
4820  <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
4821  array([[0, 0, 5, 4, 3, 2, 1, 0],
4822         [2, 1, 0, 0, 0, 0, 0, 0],
4823         [3, 2, 1, 4, 0, 0, 0, 0],
4824         [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
4825
4826  Args:
4827    input: A `Tensor`. The input to reverse.
4828    seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
4829      `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
4830      input.dims(seq_axis)`
4831    seq_axis: An `int`. The dimension which is partially reversed.
4832    batch_axis: An optional `int`. Defaults to `0`. The dimension along which
4833      reversal is performed.
4834    name: A name for the operation (optional).
4835
4836  Returns:
4837    A Tensor. Has the same type as input.
4838  """
4839  seq_axis = deprecation.deprecated_argument_lookup("seq_axis", seq_axis,
4840                                                    "seq_dim", seq_dim)
4841  batch_axis = deprecation.deprecated_argument_lookup("batch_axis", batch_axis,
4842                                                      "batch_dim", batch_dim)
4843  return gen_array_ops.reverse_sequence(
4844      input=input,
4845      seq_lengths=seq_lengths,
4846      seq_dim=seq_axis,
4847      batch_dim=batch_axis,
4848      name=name)
4849
4850
4851@tf_export("reverse_sequence", v1=[])
4852@dispatch.add_dispatch_support
4853def reverse_sequence_v2(input,
4854                        seq_lengths,
4855                        seq_axis=None,
4856                        batch_axis=None,
4857                        name=None):
4858  """Reverses variable length slices.
4859
4860  This op first slices `input` along the dimension `batch_axis`, and for
4861  each slice `i`, reverses the first `seq_lengths[i]` elements along the
4862  dimension `seq_axis`.
4863
4864  The elements of `seq_lengths` must obey `seq_lengths[i] <=
4865  input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
4866  `input.dims[batch_axis]`.
4867
4868  The output slice `i` along dimension `batch_axis` is then given by
4869  input slice `i`, with the first `seq_lengths[i]` slices along
4870  dimension `seq_axis` reversed.
4871
4872  Example usage:
4873
4874  >>> seq_lengths = [7, 2, 3, 5]
4875  >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
4876  ...          [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
4877  >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
4878  >>> output
4879  <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
4880  array([[0, 0, 5, 4, 3, 2, 1, 0],
4881         [2, 1, 0, 0, 0, 0, 0, 0],
4882         [3, 2, 1, 4, 0, 0, 0, 0],
4883         [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
4884
4885  Args:
4886    input: A `Tensor`. The input to reverse.
4887    seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
4888      `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
4889      input.dims(seq_axis)`
4890    seq_axis: An `int`. The dimension which is partially reversed.
4891    batch_axis: An optional `int`. Defaults to `0`. The dimension along which
4892      reversal is performed.
4893    name: A name for the operation (optional).
4894
4895  Returns:
4896    A Tensor. Has the same type as input.
4897  """
4898  return gen_array_ops.reverse_sequence(
4899      input=input,
4900      seq_lengths=seq_lengths,
4901      seq_dim=seq_axis,
4902      batch_dim=batch_axis,
4903      name=name)
4904
4905# pylint: enable=redefined-builtin
4906
4907
4908@tf_export(v1=["gather"])
4909@deprecation.deprecated_args(None,
4910                             ("The `validate_indices` argument has no effect. "
4911                              "Indices are always validated on CPU and never "
4912                              "validated on GPU."),
4913                             ("validate_indices", None))
4914@dispatch.add_dispatch_support
4915def gather(params,
4916           indices,
4917           validate_indices=None,
4918           name=None,
4919           axis=None,
4920           batch_dims=0):  # pylint: disable=g-doc-args
4921  r"""Gather slices from params axis `axis` according to indices.
4922
4923  Gather slices from `params` axis `axis` according to `indices`.  `indices`
4924  must be an integer tensor of any dimension (often 1-D).
4925
4926  `Tensor.__getitem__` works for scalars, `tf.newaxis`, and
4927  [python slices](https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing)
4928
4929  `tf.gather` extends indexing to handle tensors of indices.
4930
4931  In the simplest case it's identical to scalar indexing:
4932
4933  >>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
4934  >>> params[3].numpy()
4935  b'p3'
4936  >>> tf.gather(params, 3).numpy()
4937  b'p3'
4938
4939  The most common case is to pass a single axis tensor of indices (this
4940  can't be expressed as a python slice because the indices are not sequential):
4941
4942  >>> indices = [2, 0, 2, 5]
4943  >>> tf.gather(params, indices).numpy()
4944  array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
4945
4946  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
4947  <img style="width:100%" src="https://www.tensorflow.org/images/Gather.png"
4948  alt>
4949  </div>
4950
4951  The indices can have any shape. When the `params` has 1 axis, the
4952  output shape is equal to the input shape:
4953
4954  >>> tf.gather(params, [[2, 0], [2, 5]]).numpy()
4955  array([[b'p2', b'p0'],
4956         [b'p2', b'p5']], dtype=object)
4957
4958  The `params` may also have any shape. `gather` can select slices
4959  across any axis depending on the `axis` argument (which defaults to 0).
4960  Below it is used to gather first rows, then columns from a matrix:
4961
4962  >>> params = tf.constant([[0, 1.0, 2.0],
4963  ...                       [10.0, 11.0, 12.0],
4964  ...                       [20.0, 21.0, 22.0],
4965  ...                       [30.0, 31.0, 32.0]])
4966  >>> tf.gather(params, indices=[3,1]).numpy()
4967  array([[30., 31., 32.],
4968         [10., 11., 12.]], dtype=float32)
4969  >>> tf.gather(params, indices=[2,1], axis=1).numpy()
4970  array([[ 2.,  1.],
4971         [12., 11.],
4972         [22., 21.],
4973         [32., 31.]], dtype=float32)
4974
4975  More generally: The output shape has the same shape as the input, with the
4976  indexed-axis replaced by the shape of the indices.
4977
4978  >>> def result_shape(p_shape, i_shape, axis=0):
4979  ...   return p_shape[:axis] + i_shape + p_shape[axis+1:]
4980  >>>
4981  >>> result_shape([1, 2, 3], [], axis=1)
4982  [1, 3]
4983  >>> result_shape([1, 2, 3], [7], axis=1)
4984  [1, 7, 3]
4985  >>> result_shape([1, 2, 3], [7, 5], axis=1)
4986  [1, 7, 5, 3]
4987
4988  Here are some examples:
4989
4990  >>> params.shape.as_list()
4991  [4, 3]
4992  >>> indices = tf.constant([[0, 2]])
4993  >>> tf.gather(params, indices=indices, axis=0).shape.as_list()
4994  [1, 2, 3]
4995  >>> tf.gather(params, indices=indices, axis=1).shape.as_list()
4996  [4, 1, 2]
4997
4998  >>> params = tf.random.normal(shape=(5, 6, 7, 8))
4999  >>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
5000  >>> result = tf.gather(params, indices, axis=2)
5001  >>> result.shape.as_list()
5002  [5, 6, 10, 11, 8]
5003
5004  This is because each index takes a slice from `params`, and
5005  places it at the corresponding location in the output. For the above example
5006
5007  >>> # For any location in indices
5008  >>> a, b = 0, 1
5009  >>> tf.reduce_all(
5010  ...     # the corresponding slice of the result
5011  ...     result[:, :, a, b, :] ==
5012  ...     # is equal to the slice of `params` along `axis` at the index.
5013  ...     params[:, :, indices[a, b], :]
5014  ... ).numpy()
5015  True
5016
5017  ### Batching:
5018
5019  The `batch_dims` argument lets you gather different items from each element
5020  of a batch.
5021
5022  Using `batch_dims=1` is equivalent to having an outer loop over the first
5023  axis of `params` and `indices`:
5024
5025  >>> params = tf.constant([
5026  ...     [0, 0, 1, 0, 2],
5027  ...     [3, 0, 0, 0, 4],
5028  ...     [0, 5, 0, 6, 0]])
5029  >>> indices = tf.constant([
5030  ...     [2, 4],
5031  ...     [0, 4],
5032  ...     [1, 3]])
5033
5034  >>> tf.gather(params, indices, axis=1, batch_dims=1).numpy()
5035  array([[1, 2],
5036         [3, 4],
5037         [5, 6]], dtype=int32)
5038
5039  This is is equivalent to:
5040
5041  >>> def manually_batched_gather(params, indices, axis):
5042  ...   batch_dims=1
5043  ...   result = []
5044  ...   for p,i in zip(params, indices):
5045  ...     r = tf.gather(p, i, axis=axis-batch_dims)
5046  ...     result.append(r)
5047  ...   return tf.stack(result)
5048  >>> manually_batched_gather(params, indices, axis=1).numpy()
5049  array([[1, 2],
5050         [3, 4],
5051         [5, 6]], dtype=int32)
5052
5053  Higher values of `batch_dims` are equivalent to multiple nested loops over
5054  the outer axes of `params` and `indices`. So the overall shape function is
5055
5056  >>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):
5057  ...   return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]
5058  >>>
5059  >>> batched_result_shape(
5060  ...     p_shape=params.shape.as_list(),
5061  ...     i_shape=indices.shape.as_list(),
5062  ...     axis=1,
5063  ...     batch_dims=1)
5064  [3, 2]
5065
5066  >>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()
5067  [3, 2]
5068
5069  This comes up naturally if you need to use the indices of an operation like
5070  `tf.argsort`, or `tf.math.top_k` where the last dimension of the indices
5071  indexes into the last dimension of input, at the corresponding location.
5072  In this case you can use `tf.gather(values, indices, batch_dims=-1)`.
5073
5074  See also:
5075
5076  * `tf.Tensor.__getitem__`: The direct tensor index operation (`t[]`), handles
5077    scalars and python-slices `tensor[..., 7, 1:-1]`
5078  * `tf.scatter`: A collection of operations similar to `__setitem__`
5079    (`t[i] = x`)
5080  * `tf.gather_nd`: An operation similar to `tf.gather` but gathers across
5081    multiple axis at once (it can gather elements of a matrix instead of rows
5082    or columns)
5083  * `tf.boolean_mask`, `tf.where`: Binary indexing.
5084  * `tf.slice` and `tf.strided_slice`: For lower level access to the
5085    implementation of `__getitem__`'s python-slice handling (`t[1:-1:2]`)
5086
5087  Args:
5088    params: The `Tensor` from which to gather values. Must be at least rank
5089      `axis + 1`.
5090    indices: The index `Tensor`.  Must be one of the following types: `int32`,
5091      `int64`. The values must be in range `[0, params.shape[axis])`.
5092    validate_indices: Deprecated, does nothing. Indices are always validated on
5093      CPU, never validated on GPU.
5094
5095      Caution: On CPU, if an out of bound index is found, an error is raised.
5096      On GPU, if an out of bound index is found, a 0 is stored in the
5097      corresponding output value.
5098    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5099      `axis` in `params` to gather `indices` from. Must be greater than or equal
5100      to `batch_dims`.  Defaults to the first non-batch dimension. Supports
5101      negative indexes.
5102    batch_dims: An `integer`.  The number of batch dimensions.  Must be less
5103      than or equal to `rank(indices)`.
5104    name: A name for the operation (optional).
5105
5106  Returns:
5107    A `Tensor`. Has the same type as `params`.
5108  """
5109  del validate_indices
5110
5111  if axis is None:
5112    axis = batch_dims
5113  if tensor_util.constant_value(axis) != 0:
5114    return gen_array_ops.gather_v2(
5115        params, indices, axis, batch_dims=batch_dims, name=name)
5116  try:
5117    # TODO(apassos) find a less bad way of detecting resource variables
5118    # without introducing a circular dependency.
5119    return params.sparse_read(indices, name=name)
5120  except AttributeError:
5121    return gen_array_ops.gather_v2(params, indices, axis, name=name)
5122
5123
5124@tf_export("gather", v1=[])
5125@dispatch.add_dispatch_support
5126def gather_v2(params,
5127              indices,
5128              validate_indices=None,
5129              axis=None,
5130              batch_dims=0,
5131              name=None):
5132  return gather(
5133      params,
5134      indices,
5135      validate_indices=validate_indices,
5136      name=name,
5137      axis=axis,
5138      batch_dims=batch_dims)
5139
5140
5141gather_v2.__doc__ = gather.__doc__
5142
5143
5144@tf_export(v1=["batch_gather"])
5145@dispatch.add_dispatch_support
5146@deprecation.deprecated(
5147    "2017-10-25", "`tf.batch_gather` is deprecated, please use `tf.gather` "
5148    "with `batch_dims=-1` instead.")  # pylint: disable=missing-docstring
5149def batch_gather(params, indices, name=None):
5150  """Gather slices from params according to indices with leading batch dims."""
5151  with ops.name_scope(name, "BatchGather", [params, indices]):
5152    indices = ops.convert_to_tensor(indices, name="indices")
5153    params = ops.convert_to_tensor(params, name="params")
5154    if indices.shape.ndims is None:
5155      raise ValueError(
5156          "batch_gather does not allow indices with unknown shape.")
5157    return _batch_gather(params, indices, batch_dims=indices.shape.ndims - 1)
5158
5159
5160def _batch_gather(params, indices, batch_dims, axis=None):
5161  r"""Gather slices from params according to indices with leading batch dims.
5162
5163  This operation assumes that the leading `batch_dims` dimensions of `indices`
5164  and `params` are batch dimensions; and performs a `tf.gather` operation within
5165  each batch. (If `batch_dims` is not specified, then it defaults to
5166  `rank(indices)-1`.)  In the case in which `batch_dims==0`, this operation
5167  is equivalent to `tf.gather`.
5168
5169  Args:
5170    params: A Tensor. The tensor from which to gather values.
5171    indices: A Tensor. Must be one of the following types: int32, int64. Index
5172      tensor. Must be in range `[0, params.shape[batch_dims]]`.
5173    batch_dims: An integer or none.  The number of batch dimensions.  Must be
5174      less than `rank(indices)`.  Defaults to `rank(indices) - 1` if None.
5175    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5176      `axis` in `params` to gather `indices` from. Must be greater than or equal
5177      to `batch_dims`.  Defaults to the first non-batch dimension. Supports
5178      negative indexes.
5179
5180  Returns:
5181    A Tensor. Has the same type as `params`.
5182
5183  Raises:
5184    ValueError: if `indices` has an unknown shape.
5185  """
5186  if batch_dims is not None and not isinstance(batch_dims, int):
5187    raise TypeError("batch_dims must be an int; got %r" % (batch_dims,))
5188  indices = ops.convert_to_tensor(indices, name="indices")
5189  params = ops.convert_to_tensor(params, name="params")
5190
5191  indices_ndims = indices.shape.ndims
5192  if indices_ndims is None:
5193    raise ValueError("tf.gather does not allow indices with unknown "
5194                     "rank when batch_dims is specified.")
5195  if batch_dims is None:
5196    batch_dims = indices_ndims - 1
5197  if batch_dims < 0:
5198    batch_dims += indices_ndims
5199  if batch_dims < 0 or batch_dims >= indices_ndims:
5200    raise ValueError("batch_dims = %d must be less than rank(indices) = %d" %
5201                     (batch_dims, indices_ndims))
5202  if params.shape.ndims is not None and batch_dims >= params.shape.ndims:
5203    raise ValueError("batch_dims = %d must be less than rank(params) = %d" %
5204                     (batch_dims, params.shape.ndims))
5205
5206  # Handle axis by transposing the axis dimension to be the first non-batch
5207  # dimension, recursively calling batch_gather with axis=0, and then
5208  # transposing the result to put the pre-axis dimensions before the indices
5209  # dimensions.
5210  if axis is not None and axis != batch_dims:
5211    # Adjust axis to be positive.
5212    if not isinstance(axis, int):
5213      axis = tf.where(axis < 0, axis + array_ops.rank(params), axis)
5214    elif axis < 0 and params.shape.ndims is None:
5215      axis = axis + array_ops.rank(params)
5216    else:
5217      if (axis < -params.shape.ndims) or (axis >= params.shape.ndims):
5218        raise ValueError("axis (%d) out of range [%d, %d)" %
5219                         (axis, -params.shape.ndims, params.shape.ndims))
5220      if axis < 0:
5221        axis += params.shape.ndims
5222      if axis < batch_dims:
5223        raise ValueError("batch_dims = %d must be less than or equal to "
5224                         "axis = %d" % (batch_dims, axis))
5225
5226    # Move params[axis] up to params[batch_dims].
5227    perm = [
5228        list(range(batch_dims)), [axis],
5229        gen_math_ops._range(batch_dims, axis, 1),
5230        gen_math_ops._range(axis + 1, rank(params), 1)
5231    ]
5232    params = transpose(params, concat(perm, axis=0))
5233
5234    result = _batch_gather(params, indices, batch_dims=batch_dims)
5235
5236    # Move the result dimensions corresponding to params[batch_dims:axis]
5237    # to just before the dimensions corresponding to indices[batch_dims:].
5238    params_start = indices_ndims + axis - batch_dims
5239    perm = [
5240        list(range(batch_dims)),
5241        gen_math_ops._range(indices_ndims, params_start, 1),
5242        list(range(batch_dims, indices_ndims)),
5243        gen_math_ops._range(params_start, rank(result), 1)
5244    ]
5245    return transpose(result, perm=concat(perm, axis=0))
5246
5247  indices_shape = shape(indices)
5248  params_shape = shape(params)
5249  batch_indices = indices
5250  indices_dtype = indices.dtype.base_dtype
5251  accum_dim_value = ones((), dtype=indices_dtype)
5252  # Use correct type for offset index computation
5253  casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
5254  for dim in range(batch_dims, 0, -1):
5255    dim_value = casted_params_shape[dim - 1]
5256    accum_dim_value *= casted_params_shape[dim]
5257    start = zeros((), dtype=indices_dtype)
5258    step = ones((), dtype=indices_dtype)
5259    dim_indices = gen_math_ops._range(start, dim_value, step)
5260    dim_indices *= accum_dim_value
5261    dim_shape = stack(
5262        [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0)
5263    batch_indices += reshape(dim_indices, dim_shape)
5264
5265  flat_indices = reshape(batch_indices, [-1])
5266  outer_shape = params_shape[batch_dims + 1:]
5267  flat_inner_shape = gen_math_ops.prod(params_shape[:batch_dims + 1], [0],
5268                                       False)
5269
5270  flat_params = reshape(params, concat([[flat_inner_shape], outer_shape],
5271                                       axis=0))
5272  flat_result = gather(flat_params, flat_indices)
5273  result = reshape(flat_result, concat([indices_shape, outer_shape], axis=0))
5274  final_shape = indices.get_shape()[:batch_dims].merge_with(
5275      params.get_shape()[:batch_dims])
5276  final_shape = final_shape.concatenate(indices.get_shape().dims[batch_dims:])
5277  final_shape = final_shape.concatenate(params.get_shape()[batch_dims + 1:])
5278  result.set_shape(final_shape)
5279  return result
5280
5281
5282@tf_export(v1=["gather_nd", "manip.gather_nd"])
5283@dispatch.add_dispatch_support
5284@deprecated_endpoints("manip.gather_nd")
5285def gather_nd(params, indices, name=None, batch_dims=0):
5286  r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
5287
5288  `indices` is an K-dimensional integer tensor, best thought of as a
5289  (K-1)-dimensional tensor of indices into `params`, where each element defines
5290  a slice of `params`:
5291
5292      output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
5293
5294  Whereas in `tf.gather` `indices` defines slices into the first
5295  dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
5296  first `N` dimensions of `params`, where `N = indices.shape[-1]`.
5297
5298  The last dimension of `indices` can be at most the rank of
5299  `params`:
5300
5301      indices.shape[-1] <= params.rank
5302
5303  The last dimension of `indices` corresponds to elements
5304  (if `indices.shape[-1] == params.rank`) or slices
5305  (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]`
5306  of `params`.  The output tensor has shape
5307
5308      indices.shape[:-1] + params.shape[indices.shape[-1]:]
5309
5310  Additionally both 'params' and 'indices' can have M leading batch
5311  dimensions that exactly match. In this case 'batch_dims' must be M.
5312
5313  Note that on CPU, if an out of bound index is found, an error is returned.
5314  On GPU, if an out of bound index is found, a 0 is stored in the
5315  corresponding output value.
5316
5317  Some examples below.
5318
5319  Simple indexing into a matrix:
5320
5321  ```python
5322      indices = [[0, 0], [1, 1]]
5323      params = [['a', 'b'], ['c', 'd']]
5324      output = ['a', 'd']
5325  ```
5326
5327  Slice indexing into a matrix:
5328
5329  ```python
5330      indices = [[1], [0]]
5331      params = [['a', 'b'], ['c', 'd']]
5332      output = [['c', 'd'], ['a', 'b']]
5333  ```
5334
5335  Indexing into a 3-tensor:
5336
5337  ```python
5338      indices = [[1]]
5339      params = [[['a0', 'b0'], ['c0', 'd0']],
5340                [['a1', 'b1'], ['c1', 'd1']]]
5341      output = [[['a1', 'b1'], ['c1', 'd1']]]
5342
5343
5344      indices = [[0, 1], [1, 0]]
5345      params = [[['a0', 'b0'], ['c0', 'd0']],
5346                [['a1', 'b1'], ['c1', 'd1']]]
5347      output = [['c0', 'd0'], ['a1', 'b1']]
5348
5349
5350      indices = [[0, 0, 1], [1, 0, 1]]
5351      params = [[['a0', 'b0'], ['c0', 'd0']],
5352                [['a1', 'b1'], ['c1', 'd1']]]
5353      output = ['b0', 'b1']
5354  ```
5355
5356  The examples below are for the case when only indices have leading extra
5357  dimensions. If both 'params' and 'indices' have leading batch dimensions, use
5358  the 'batch_dims' parameter to run gather_nd in batch mode.
5359
5360  Batched indexing into a matrix:
5361
5362  ```python
5363      indices = [[[0, 0]], [[0, 1]]]
5364      params = [['a', 'b'], ['c', 'd']]
5365      output = [['a'], ['b']]
5366  ```
5367
5368  Batched slice indexing into a matrix:
5369
5370  ```python
5371      indices = [[[1]], [[0]]]
5372      params = [['a', 'b'], ['c', 'd']]
5373      output = [[['c', 'd']], [['a', 'b']]]
5374  ```
5375
5376  Batched indexing into a 3-tensor:
5377
5378  ```python
5379      indices = [[[1]], [[0]]]
5380      params = [[['a0', 'b0'], ['c0', 'd0']],
5381                [['a1', 'b1'], ['c1', 'd1']]]
5382      output = [[[['a1', 'b1'], ['c1', 'd1']]],
5383                [[['a0', 'b0'], ['c0', 'd0']]]]
5384
5385      indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
5386      params = [[['a0', 'b0'], ['c0', 'd0']],
5387                [['a1', 'b1'], ['c1', 'd1']]]
5388      output = [[['c0', 'd0'], ['a1', 'b1']],
5389                [['a0', 'b0'], ['c1', 'd1']]]
5390
5391
5392      indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
5393      params = [[['a0', 'b0'], ['c0', 'd0']],
5394                [['a1', 'b1'], ['c1', 'd1']]]
5395      output = [['b0', 'b1'], ['d0', 'c1']]
5396  ```
5397
5398  Examples with batched 'params' and 'indices':
5399
5400  ```python
5401      batch_dims = 1
5402      indices = [[1], [0]]
5403      params = [[['a0', 'b0'], ['c0', 'd0']],
5404                [['a1', 'b1'], ['c1', 'd1']]]
5405      output = [['c0', 'd0'], ['a1', 'b1']]
5406
5407      batch_dims = 1
5408      indices = [[[1]], [[0]]]
5409      params = [[['a0', 'b0'], ['c0', 'd0']],
5410                [['a1', 'b1'], ['c1', 'd1']]]
5411      output = [[['c0', 'd0']], [['a1', 'b1']]]
5412
5413      batch_dims = 1
5414      indices = [[[1, 0]], [[0, 1]]]
5415      params = [[['a0', 'b0'], ['c0', 'd0']],
5416                [['a1', 'b1'], ['c1', 'd1']]]
5417      output = [['c0'], ['b1']]
5418  ```
5419
5420  See also `tf.gather`.
5421
5422  Args:
5423    params: A `Tensor`. The tensor from which to gather values.
5424    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
5425      Index tensor.
5426    name: A name for the operation (optional).
5427    batch_dims: An integer or a scalar 'Tensor'. The number of batch dimensions.
5428
5429  Returns:
5430    A `Tensor`. Has the same type as `params`.
5431  """
5432  batch_dims_ = tensor_util.constant_value(batch_dims)
5433  if batch_dims_ is not None:
5434    batch_dims = int(batch_dims_)
5435  if batch_dims == 0:
5436    try:
5437      # TODO(apassos) find a less bad way of detecting resource variables
5438      # without introducing a circular dependency.
5439      return params.gather_nd(indices, name=name)
5440    except AttributeError:
5441      return gen_array_ops.gather_nd(params, indices, name=name)
5442  else:
5443    return batch_gather_nd(params, indices, batch_dims=batch_dims, name=name)
5444
5445
5446@tf_export("gather_nd", v1=[])
5447@dispatch.add_dispatch_support
5448def gather_nd_v2(params, indices, batch_dims=0, name=None):
5449  return gather_nd(params, indices, name=name, batch_dims=batch_dims)
5450
5451
5452gather_nd_v2.__doc__ = gather_nd.__doc__
5453
5454
5455def batch_gather_nd(params, indices, batch_dims, name=None):
5456  """gather_nd implementation with batch support."""
5457  with ops.name_scope(name, "BatchGatherND", [params, indices]):
5458    indices = ops.convert_to_tensor(indices, name="indices")
5459    params = ops.convert_to_tensor(params, name="params")
5460
5461    if not isinstance(batch_dims, int):
5462      raise TypeError("batch_dims must be an int; got %r" % (batch_dims,))
5463    if batch_dims < 0:
5464      raise ValueError("tf.gather_nd does not allow negative batch_dims.")
5465    params_ndims = params.shape.ndims
5466    indices_ndims = indices.shape.ndims
5467    if indices_ndims is not None and batch_dims >= indices_ndims:
5468      raise ValueError("batch_dims = %d must be less than rank(indices) = %d" %
5469                       (batch_dims, indices_ndims))
5470    if params_ndims is not None and batch_dims >= params_ndims:
5471      raise ValueError("batch_dims = %d must be less than rank(params) = %d" %
5472                       (batch_dims, params_ndims))
5473
5474    expand = batch_dims == 0
5475    if expand:
5476      # Normally gather_nd will be called when batch_dims == 0.
5477      # But if this function is called with batch_dims = 0, e.g. for testing
5478      # purposes, this adds a dummy batch dimension to make batch_dims = 1.
5479      params = expand_dims(params, axis=0)
5480      indices = expand_dims(indices, axis=0)
5481      batch_dims = 1
5482
5483    params_shape = shape(params)
5484    indices_shape = shape(indices)
5485    batch_shape = params_shape[:batch_dims]
5486    batch_size = gen_math_ops.prod(batch_shape, [0])
5487    index_internal_ndims = rank(indices) - batch_dims - 1
5488    indices_internal_shape = indices_shape[batch_dims:-1]
5489
5490    # Assuming a 'params' with shape [b1, ..., bM, g1, ..., gN] and an 'indices'
5491    # with shape [b1, ..., bM, i1, ..., iK, C], where C <= N, we need to modify
5492    # 'indices' s.t. it has shape [i1, ..., iK, D], where D <= M + N and slices
5493    # to the entire 'params' tensor.
5494    # Assuming we have a batch of shape [B1, B2], we use meshgrid to create a
5495    # grid of size B1 x B2.
5496    batch_dim_list = unstack(batch_shape, axis=0)
5497    dim_ranges = [
5498        gen_math_ops.cast(gen_math_ops._range(0, x, 1), indices.dtype)
5499        for x in batch_dim_list
5500    ]
5501    mesh_list = meshgrid(*dim_ranges, indexing="ij") if dim_ranges else []
5502    # Then we flatten and stack the tensors to form a (B1.B2) by 2 matrix.
5503    flat_list = [reshape(x, shape=(-1,)) for x in mesh_list]
5504    index_grid = transpose(stack(flat_list, axis=0))
5505    # We need to concatenate these batch coordinates with the internal indices.
5506    # concat -> index_grid [B1.B2, 2] with indices [i1, ..., iK, C]
5507    # So we reshape them both to [(B1.B2), i1, ..., iK, *]
5508    index_grid_shape = shape(index_grid)
5509    index_grid = reshape(
5510        index_grid,
5511        concat([
5512            index_grid_shape[:1],
5513            ones(index_internal_ndims, dtype=dtypes.int32), index_grid_shape[1:]
5514        ],
5515               axis=0))
5516    tile_shape = concat(((1,), indices_internal_shape, (1,)), axis=0)
5517    index_grid = tile(index_grid, multiples=tile_shape)
5518    # index_grid now has shape [(B1.B2), i1, ..., iK, 2]
5519    flat_shape = concat(([batch_size], indices_shape[batch_dims:]), axis=0)
5520    flat_indices = reshape(indices, shape=flat_shape)
5521    # flat_indices now has shape [(B1.B2), i1, ..., iK, C]
5522    indices = concat((index_grid, flat_indices), axis=-1)
5523    # indices has shape [(B1.B2), i1, ..., iK, 2+C]
5524    out = gen_array_ops.gather_nd(params, indices)
5525    # out has shape [(B1.B2), i1, ..., iK, N-C]. Now we reshape batch to
5526    # its original form.
5527    out_shape = shape(out)
5528    out = reshape(out, shape=concat((batch_shape, out_shape[1:]), axis=0))
5529    if expand:
5530      out = squeeze(out, axis=0)
5531  return out
5532
5533
5534@deprecation.deprecated_endpoints("tensor_scatter_update")
5535@tf_export(
5536    "tensor_scatter_nd_update",
5537    v1=["tensor_scatter_nd_update", "tensor_scatter_update"])
5538@dispatch.add_dispatch_support
5539def tensor_scatter_nd_update(tensor, indices, updates, name=None):
5540  """"Scatter `updates` into an existing tensor according to `indices`.
5541
5542  This operation creates a new tensor by applying sparse `updates` to the
5543  input `tensor`. This is similar to an index assignment.
5544
5545  ```
5546  # Not implemented: tensors cannot be updated inplace.
5547  tensor[indices] = updates
5548  ```
5549
5550  If an out of bound index is found on CPU, an error is returned.
5551
5552  > **WARNING**: There are some GPU specific semantics for this operation.
5553  >
5554  > - If an out of bound index is found, the index is ignored.
5555  > - The order in which updates are applied is nondeterministic, so the output
5556  >   will be nondeterministic if `indices` contains duplicates.
5557
5558  This operation is very similar to `tf.scatter_nd`, except that the updates are
5559  scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
5560  for the existing tensor cannot be re-used, a copy is made and updated.
5561
5562  In general:
5563
5564  * `indices` is an integer tensor - the indices to update in `tensor`.
5565  * `indices` has **at least two** axes, the last axis is the depth of the
5566    index vectors.
5567  * For each index vector in `indices` there is a corresponding entry in
5568    `updates`.
5569  * If the length of the index vectors matches the rank of the `tensor`, then
5570    the index vectors each point to scalars in `tensor` and each update is a
5571    scalar.
5572  * If the length of the index vectors is less than the rank of `tensor`, then
5573    the index vectors each point to slices of `tensor` and shape of the updates
5574    must match that slice.
5575
5576  Overall this leads to the following shape constraints:
5577
5578  ```
5579  assert tf.rank(indices) >= 2
5580  index_depth = indices.shape[-1]
5581  batch_shape = indices.shape[:-1]
5582  assert index_depth <= tf.rank(tensor)
5583  outer_shape = tensor.shape[:index_depth]
5584  inner_shape = tensor.shape[index_depth:]
5585  assert updates.shape == batch_shape + inner_shape
5586  ```
5587
5588  Typical usage is often much simpler than this general form, and it
5589  can be better understood starting with simple examples:
5590
5591  ### Scalar updates
5592
5593  The simplest usage inserts scalar elements into a tensor by index.
5594  In this case, the `index_depth` must equal the rank of the
5595  input `tensor`, slice each column of `indices` is an index into an axis of the
5596  input `tensor`.
5597
5598  In this simplest case the shape constraints are:
5599
5600  ```
5601  num_updates, index_depth = indices.shape.as_list()
5602  assert updates.shape == [num_updates]
5603  assert index_depth == tf.rank(tensor)`
5604  ```
5605
5606  For example, to insert 4 scattered elements in a rank-1 tensor with
5607  8 elements.
5608
5609  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
5610  <img style="width:100%"
5611    src="https://www.tensorflow.org/images/ScatterNd1.png">
5612  </div>
5613
5614  This scatter operation would look like this:
5615
5616  >>> tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
5617  >>> indices = [[1], [3], [4], [7]]       # num_updates == 4, index_depth == 1
5618  >>> updates = [9, 10, 11, 12]            # num_updates == 4
5619  >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5620  tf.Tensor([ 0 9  0 10  11  0  0 12], shape=(8,), dtype=int32)
5621
5622  The length (first axis) of `updates` must equal the length of the `indices`:
5623  `num_updates`. This is the number of updates being inserted. Each scalar
5624  update is inserted into `tensor` at the indexed location.
5625
5626  For a higher rank input `tensor` scalar updates can be inserted by using an
5627  `index_depth` that matches `tf.rank(tensor)`:
5628
5629  >>> tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
5630  >>> indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
5631  >>> updates = [5, 10]                    # num_updates == 2
5632  >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5633  tf.Tensor(
5634      [[ 1  5]
5635       [ 1  1]
5636       [10  1]], shape=(3, 2), dtype=int32)
5637
5638  ### Slice updates
5639
5640  When the input `tensor` has more than one axis scatter can be used to update
5641  entire slices.
5642
5643  In this case it's helpful to think of the input `tensor` as being a two level
5644  array-of-arrays. The shape of this two level array is split into the
5645  `outer_shape` and the `inner_shape`.
5646
5647  `indices` indexes into the outer level of the input tensor (`outer_shape`).
5648  and replaces the sub-array at that location with the corresponding item from
5649  the `updates` list. The shape of each update is `inner_shape`.
5650
5651  When updating a list of slices the shape constraints are:
5652
5653  ```
5654  num_updates, index_depth = indices.shape.as_list()
5655  inner_shape = tensor.shape[:index_depth]
5656  outer_shape = tensor.shape[index_depth:]
5657  assert updates.shape == [num_updates, inner_shape]
5658  ```
5659
5660  For example, to update rows of a `(6, 3)` `tensor`:
5661
5662  >>> tensor = tf.zeros([6, 3], dtype=tf.int32)
5663
5664  Use an index depth of one.
5665
5666  >>> indices = tf.constant([[2], [4]])     # num_updates == 2, index_depth == 1
5667  >>> num_updates, index_depth = indices.shape.as_list()
5668
5669  The `outer_shape` is `6`, the inner shape is `3`:
5670
5671  >>> outer_shape = tensor.shape[:index_depth]
5672  >>> inner_shape = tensor.shape[index_depth:]
5673
5674  2 rows are being indexed so 2 `updates` must be supplied.
5675  Each update must be shaped to match the `inner_shape`.
5676
5677  >>> # num_updates == 2, inner_shape==3
5678  >>> updates = tf.constant([[1, 2, 3],
5679  ...                        [4, 5, 6]])
5680
5681  Altogether this gives:
5682
5683  >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
5684  array([[0, 0, 0],
5685         [0, 0, 0],
5686         [1, 2, 3],
5687         [0, 0, 0],
5688         [4, 5, 6],
5689         [0, 0, 0]], dtype=int32)
5690
5691  #### More slice update examples
5692
5693  A tensor representing a batch of uniformly sized video clips naturally has 5
5694  axes: `[batch_size, time, width, height, channels]`.
5695
5696  For example:
5697
5698  >>> batch_size, time, width, height, channels = 13,11,7,5,3
5699  >>> video_batch = tf.zeros([batch_size, time, width, height, channels])
5700
5701  To replace a selection of video clips:
5702    * Use an `index_depth` of 1 (indexing the `outer_shape`: `[batch_size]`)
5703    * Provide updates each with a shape matching the `inner_shape`:
5704      `[time, width, height, channels]`.
5705
5706  To replace the first two clips with ones:
5707
5708  >>> indices = [[0],[1]]
5709  >>> new_clips = tf.ones([2, time, width, height, channels])
5710  >>> tf.tensor_scatter_nd_update(video_batch, indices, new_clips)
5711
5712  To replace a selection of frames in the videos:
5713
5714  * `indices` must have an `index_depth` of 2 for the `outer_shape`:
5715    `[batch_size, time]`.
5716  * `updates` must be shaped like a list of images.  Each update must have a
5717    shape, matching the `inner_shape`: `[width, height, channels]`.
5718
5719  To replace the first frame of the first three video clips:
5720
5721  >>> indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2
5722  >>> new_images = tf.ones([
5723  ...   # num_updates=3, inner_shape=(width, height, channels)
5724  ...   3, width, height, channels])
5725  >>> tf.tensor_scatter_nd_update(video_batch, indices, new_images)
5726
5727  ### Folded indices
5728
5729  In simple cases it's convenient to think of `indices` and `updates` as
5730  lists, but this is not a strict requirement. Instead of a flat `num_updates`,
5731  the `indices` and `updates` can be folded into a `batch_shape`. This
5732  `batch_shape` is all axes of the `indices`, except for the innermost
5733  `index_depth` axis.
5734
5735  ```
5736  index_depth = indices.shape[-1]
5737  batch_shape = indices.shape[:-1]
5738  ```
5739
5740  Note: The one exception is that the `batch_shape` cannot be `[]`. You can't
5741  update a single index by passing indices with shape `[index_depth]`.
5742
5743  `updates` must have a matching `batch_shape` (the axes before `inner_shape`).
5744
5745  ```
5746  assert updates.shape == batch_shape + inner_shape
5747  ```
5748
5749  Note: The result is equivalent to flattening the `batch_shape` axes of
5750  `indices` and `updates`. This generalization just avoids the need
5751  for reshapes when it is more natural to construct "folded" indices and
5752  updates.
5753
5754  With this generalization the full shape constraints are:
5755
5756  ```
5757  assert tf.rank(indices) >= 2
5758  index_depth = indices.shape[-1]
5759  batch_shape = indices.shape[:-1]
5760  assert index_depth <= tf.rank(tensor)
5761  outer_shape = tensor.shape[:index_depth]
5762  inner_shape = tensor.shape[index_depth:]
5763  assert updates.shape == batch_shape + inner_shape
5764  ```
5765
5766  For example, to draw an `X` on a `(5,5)` matrix start with these indices:
5767
5768  >>> tensor = tf.zeros([5,5])
5769  >>> indices = tf.constant([
5770  ...  [[0,0],
5771  ...   [1,1],
5772  ...   [2,2],
5773  ...   [3,3],
5774  ...   [4,4]],
5775  ...  [[0,4],
5776  ...   [1,3],
5777  ...   [2,2],
5778  ...   [3,1],
5779  ...   [4,0]],
5780  ... ])
5781  >>> indices.shape.as_list()  # batch_shape == [2, 5], index_depth == 2
5782  [2, 5, 2]
5783
5784  Here the `indices` do not have a shape of `[num_updates, index_depth]`, but a
5785  shape of `batch_shape+[index_depth]`.
5786
5787  Since the `index_depth` is equal to the rank of `tensor`:
5788
5789  * `outer_shape` is `(5,5)`
5790  * `inner_shape` is `()` - each update is scalar
5791  * `updates.shape` is `batch_shape + inner_shape == (5,2) + ()`
5792
5793  >>> updates = [
5794  ...   [1,1,1,1,1],
5795  ...   [1,1,1,1,1],
5796  ... ]
5797
5798  Putting this together gives:
5799
5800  >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
5801  array([[1., 0., 0., 0., 1.],
5802         [0., 1., 0., 1., 0.],
5803         [0., 0., 1., 0., 0.],
5804         [0., 1., 0., 1., 0.],
5805         [1., 0., 0., 0., 1.]], dtype=float32)
5806
5807  Args:
5808    tensor: Tensor to copy/update.
5809    indices: Indices to update.
5810    updates: Updates to apply at the indices.
5811    name: Optional name for the operation.
5812
5813  Returns:
5814    A new tensor with the given shape and updates applied according to the
5815    indices.
5816  """
5817  return gen_array_ops.tensor_scatter_update(
5818      tensor=tensor, indices=indices, updates=updates, name=name)
5819
5820
5821# Define quantize_v2 here in order to make name the second-to-last attribute,
5822# because round_mode was added later.
5823# (And also now because of 'axis' processing).
5824@tf_export(v1=["quantize_v2"])
5825@dispatch.add_dispatch_support
5826@deprecation.deprecated(
5827    "2017-10-25",
5828    "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
5829    "instead.")  # pylint: disable=missing-docstring
5830def quantize_v2(
5831    input,  # pylint: disable=redefined-builtin
5832    min_range,
5833    max_range,
5834    T,
5835    mode="MIN_COMBINED",
5836    name=None,
5837    round_mode="HALF_AWAY_FROM_ZERO",
5838    narrow_range=False,
5839    axis=None,
5840    ensure_minimum_range=0.01):
5841  if axis is None:
5842    axis = -1
5843  elif axis < 0:
5844    if input.shape.ndims is None:
5845      raise ValueError("input should have known rank to use negative axis.")
5846    axis %= input.shape.ndims
5847
5848  if ensure_minimum_range != 0.01:
5849    return gen_array_ops.quantize_v2(
5850        input,
5851        min_range,
5852        max_range,
5853        T=T,
5854        mode=mode,
5855        name=name,
5856        round_mode=round_mode,
5857        narrow_range=narrow_range,
5858        axis=axis,
5859        ensure_minimum_range=ensure_minimum_range)
5860  return gen_array_ops.quantize_v2(
5861      input,
5862      min_range,
5863      max_range,
5864      T=T,
5865      mode=mode,
5866      name=name,
5867      round_mode=round_mode,
5868      narrow_range=narrow_range,
5869      axis=axis)
5870
5871
5872quantize_v2.__doc__ = """Please use `tf.quantization.quantize` instead."""
5873
5874
5875# We want to expose tf.quantization.quantize instead of
5876# tf.quantization.quantize; we can deprecate tf.quantization.quantize in next
5877# version of TensorFlow.
5878@tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"])
5879@dispatch.add_dispatch_support
5880@deprecation.deprecated_endpoints("quantize")
5881def quantize(
5882    input,  # pylint: disable=redefined-builtin
5883    min_range,
5884    max_range,
5885    T,
5886    mode="MIN_COMBINED",
5887    round_mode="HALF_AWAY_FROM_ZERO",
5888    name=None,
5889    narrow_range=False,
5890    axis=None,
5891    ensure_minimum_range=0.01):
5892  """Quantize the input tensor."""
5893  if ensure_minimum_range != 0.01:
5894    return quantize_v2(
5895        input,
5896        min_range,
5897        max_range,
5898        T,
5899        mode=mode,
5900        round_mode=round_mode,
5901        name=name,
5902        narrow_range=narrow_range,
5903        axis=axis,
5904        ensure_minimum_range=ensure_minimum_range)
5905  return quantize_v2(
5906      input,
5907      min_range,
5908      max_range,
5909      T,
5910      mode=mode,
5911      round_mode=round_mode,
5912      name=name,
5913      narrow_range=narrow_range,
5914      axis=axis)
5915
5916
5917@tf_export("quantization.dequantize", v1=["quantization.dequantize",
5918                                          "dequantize"])
5919@dispatch.add_dispatch_support
5920@deprecation.deprecated_endpoints("dequantize")
5921def dequantize(  # pylint: disable=missing-docstring
5922    input,  # pylint: disable=redefined-builtin
5923    min_range,
5924    max_range,
5925    mode="MIN_COMBINED",
5926    name=None,
5927    axis=None,
5928    narrow_range=False,
5929    dtype=dtypes.float32):
5930  if axis is None:
5931    axis = -1
5932  elif axis < 0:
5933    if input.shape.ndims is None:
5934      raise ValueError("input should have known rank to use negative axis.")
5935    axis %= input.shape.ndims
5936
5937  if axis >= 0 or narrow_range:
5938    return gen_array_ops.dequantize(
5939        input,
5940        min_range,
5941        max_range,
5942        mode=mode,
5943        name=name,
5944        narrow_range=narrow_range,
5945        axis=axis,
5946        dtype=dtype)
5947  return gen_array_ops.dequantize(
5948      input, min_range, max_range, mode=mode, name=name, dtype=dtype)
5949
5950
5951dequantize.__doc__ = gen_array_ops.dequantize.__doc__
5952
5953
5954@tf_export("quantization.quantize_and_dequantize")
5955@dispatch.add_dispatch_support
5956@deprecation.deprecated(None,
5957                        "This Op has been deprecated, use" +
5958                        "`quantize_and_dequantize_v2` instead. To " +
5959                        "To simulate the V1 the behavior of " +
5960                        "tf.quantization.quantize_and_dequantize(...) use " +
5961                        "tf.grad_pass_through(" +
5962                        "tf.quantization.quantize_and_dequantize_v2)(...).")
5963def quantize_and_dequantize(
5964    input,  # pylint: disable=redefined-builtin
5965    input_min,
5966    input_max,
5967    signed_input=True,
5968    num_bits=8,
5969    range_given=False,
5970    round_mode="HALF_TO_EVEN",
5971    name=None,
5972    narrow_range=False,
5973    axis=None):
5974  """Quantizes then dequantizes a tensor.
5975
5976  Args:
5977    input: A `Tensor` to quantize and dequantize.
5978    input_min: If range_given=True, the minimum input value, that needs to be
5979      represented in the quantized representation. If axis is specified, this
5980      should be a vector of minimum values for each slice along axis.
5981    input_max: If range_given=True, the maximum input value that needs to be
5982      represented in the quantized representation. If axis is specified, this
5983      should be a vector of maximum values for each slice along axis.
5984    signed_input: True if the quantization is signed or unsigned.
5985    num_bits: The bitwidth of the quantization.
5986    range_given: If true use `input_min` and `input_max` for the range of the
5987      input, otherwise determine min and max from the input `Tensor`.
5988    round_mode: Rounding mode when rounding from float values to quantized ones.
5989      one of ['HALF_TO_EVEN', 'HALF_UP']
5990    name: Optional name for the operation.
5991    narrow_range: If true, then the absolute value of the quantized minimum
5992      value is the same as the quantized maximum value, instead of 1 greater.
5993      i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
5994    axis: Integer. If specified, refers to a dimension of the input tensor, such
5995      that quantization will be per slice along that dimension.
5996
5997  Returns:
5998    A `Tensor`. Each element is the result of quantizing and dequantizing the
5999    corresponding element of `input`.
6000  """
6001  if axis is None:
6002    axis = -1
6003  elif axis < 0:
6004    if input.shape.ndims is None:
6005      raise ValueError("input should have known rank to use negative axis.")
6006    axis %= input.shape.ndims
6007
6008  return gen_array_ops.quantize_and_dequantize_v2(
6009      input,
6010      input_min=input_min,
6011      input_max=input_max,
6012      signed_input=signed_input,
6013      num_bits=num_bits,
6014      range_given=range_given,
6015      round_mode=round_mode,
6016      narrow_range=narrow_range,
6017      axis=axis,
6018      name=name)
6019
6020
6021@tf_export("quantization.quantize_and_dequantize_v2")
6022@dispatch.add_dispatch_support
6023def quantize_and_dequantize_v2(
6024    input,  # pylint: disable=redefined-builtin
6025    input_min,
6026    input_max,
6027    signed_input=True,
6028    num_bits=8,
6029    range_given=False,
6030    round_mode="HALF_TO_EVEN",
6031    name=None,
6032    narrow_range=False,
6033    axis=None):
6034  """Quantizes then dequantizes a tensor.
6035
6036  Updates the gradient definition for quantization that is outside the range to
6037  be 0.To simulate the V1 the behavior of
6038  tf.quantization.quantize_and_dequantize(...) use
6039  tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
6040
6041  Example usage:
6042
6043  ```python
6044  def getQuantizeOp(input):
6045      input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
6046      net = tf.quantization.quantize_and_dequantize(input,
6047                                                    input_min=min_threshold,
6048                                                    input_max=max_threshold,
6049                                                    range_given=True)
6050
6051  To simulate v1 behavior:
6052
6053  def testDecomposeQuantizeDequantize(self):
6054      def f(input_tensor):
6055        return tf.quantization.quantize_and_dequantize_v2(input_tensor,
6056                                                          input_min = 5.0,
6057                                                          input_max= -10.0,
6058                                                          range_given=True)
6059      input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
6060      net = tf.grad_pass_through(f)(input_tensor)
6061  ```
6062
6063  Args:
6064    input: A `Tensor` to quantize and dequantize.
6065    input_min: If range_given=True, the minimum input value, that needs to be
6066      represented in the quantized representation. If axis is specified, this
6067      should be a vector of minimum values for each slice along axis.
6068    input_max: If range_given=True, the maximum input value that needs to be
6069      represented in the quantized representation. If axis is specified, this
6070      should be a vector of maximum values for each slice along axis.
6071    signed_input: True if the quantization is signed or unsigned.
6072    num_bits: The bitwidth of the quantization.
6073    range_given: If true use `input_min` and `input_max` for the range of the
6074      input, otherwise determine min and max from the input `Tensor`.
6075    round_mode: Rounding mode when rounding from float values to quantized ones.
6076      one of ['HALF_TO_EVEN', 'HALF_UP']
6077    name: Optional name for the operation.
6078    narrow_range: If true, then the absolute value of the quantized minimum
6079      value is the same as the quantized maximum value, instead of 1 greater.
6080      i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
6081    axis: Integer. If specified, refers to a dimension of the input tensor, such
6082      that quantization will be per slice along that dimension.
6083
6084  Returns:
6085    A `Tensor`. Each element is the result of quantizing and dequantizing the
6086    corresponding element of `input`.
6087  """
6088  if axis is None:
6089    axis = -1
6090  elif axis < 0:
6091    if input.shape.ndims is None:
6092      raise ValueError("input should have known rank to use negative axis.")
6093    axis %= input.shape.ndims
6094
6095  return gen_array_ops.quantize_and_dequantize_v4(
6096      input,
6097      input_min=input_min,
6098      input_max=input_max,
6099      signed_input=signed_input,
6100      num_bits=num_bits,
6101      range_given=range_given,
6102      round_mode=round_mode,
6103      narrow_range=narrow_range,
6104      axis=axis,
6105      name=name)
6106
6107
6108@tf_export("searchsorted")
6109@dispatch.add_dispatch_support
6110def searchsorted(sorted_sequence,
6111                 values,
6112                 side="left",
6113                 out_type=dtypes.int32,
6114                 name=None):
6115  """Searches for where a value would go in a sorted sequence.
6116
6117  This is not a method for checking containment (like python `in`).
6118
6119  The typical use case for this operation is "binning", "bucketing", or
6120  "discretizing". The `values` are assigned to bucket-indices based on the
6121  **edges** listed in `sorted_sequence`. This operation
6122  returns the bucket-index for each value.
6123
6124  >>> edges = [-1, 3.3, 9.1, 10.0]
6125  >>> values = [0.0, 4.1, 12.0]
6126  >>> tf.searchsorted(edges, values).numpy()
6127  array([1, 2, 4], dtype=int32)
6128
6129  The `side` argument controls which index is returned if a value lands exactly
6130  on an edge:
6131
6132  >>> seq = [0, 3, 9, 10, 10]
6133  >>> values = [0, 4, 10]
6134  >>> tf.searchsorted(seq, values).numpy()
6135  array([0, 2, 3], dtype=int32)
6136  >>> tf.searchsorted(seq, values, side="right").numpy()
6137  array([1, 2, 5], dtype=int32)
6138
6139  The `axis` is not settable for this operation. It always operates on the
6140  innermost dimension (`axis=-1`). The operation will accept any number of
6141  outer dimensions. Here it is applied to the rows of a matrix:
6142
6143  >>> sorted_sequence = [[0., 3., 8., 9., 10.],
6144  ...                    [1., 2., 3., 4., 5.]]
6145  >>> values = [[9.8, 2.1, 4.3],
6146  ...           [0.1, 6.6, 4.5, ]]
6147  >>> tf.searchsorted(sorted_sequence, values).numpy()
6148  array([[4, 1, 2],
6149         [0, 5, 4]], dtype=int32)
6150
6151  Note: This operation assumes that `sorted_sequence` **is sorted** along the
6152  innermost axis, maybe using `tf.sort(..., axis=-1)`. **If the sequence is not
6153  sorted no error is raised** and the content of the returned tensor is not well
6154  defined.
6155
6156  Args:
6157    sorted_sequence: N-D `Tensor` containing a sorted sequence.
6158    values: N-D `Tensor` containing the search values.
6159    side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
6160      upper_bound.
6161    out_type: The output type (`int32` or `int64`).  Default is `tf.int32`.
6162    name: Optional name for the operation.
6163
6164  Returns:
6165    An N-D `Tensor` the size of `values` containing the result of applying
6166    either lower_bound or upper_bound (depending on side) to each value.  The
6167    result is not a global index to the entire `Tensor`, but the index in the
6168    last dimension.
6169
6170  Raises:
6171    ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
6172                If the total size of `values` exceeds `2^31 - 1` elements.
6173                If the first `N-1` dimensions of the two tensors don't match.
6174  """
6175  sequence_size = shape_internal(sorted_sequence)[-1]
6176  values_size = shape_internal(values)[-1]
6177  sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
6178  values_2d = reshape(values, [-1, values_size])
6179  if side == "right":
6180    output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
6181                                       name)
6182  elif side == "left":
6183    output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
6184                                       name)
6185  else:
6186    raise ValueError("side must be either 'right' or 'left'.  Saw: %s." % side)
6187  return reshape(output, shape_internal(values))
6188
6189
6190quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
6191
6192
6193@tf_export("image.extract_patches")
6194@dispatch.add_dispatch_support
6195def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
6196  r"""Extract `patches` from `images`.
6197
6198  This op collects patches from the input image, as if applying a
6199  convolution. All extracted patches are stacked in the depth (last) dimension
6200  of the output.
6201
6202  Specifically, the op extracts patches of shape `sizes` which are `strides`
6203  apart in the input image. The output is subsampled using the `rates` argument,
6204  in the same manner as "atrous" or "dilated" convolutions.
6205
6206  The result is a 4D tensor which is indexed by batch, row, and column.
6207  `output[i, x, y]` contains a flattened patch of size `sizes[1], sizes[2]`
6208  which is taken from the input starting at
6209  `images[i, x*strides[1], y*strides[2]]`.
6210
6211  Each output patch can be reshaped to `sizes[1], sizes[2], depth`, where
6212  `depth` is `images.shape[3]`.
6213
6214  The output elements are taken from the input at intervals given by the `rate`
6215  argument, as in dilated convolutions.
6216
6217  The `padding` argument has no effect on the size of each patch, it determines
6218  how many patches are extracted. If `VALID`, only patches which are fully
6219  contained in the input image are included. If `SAME`, all patches whose
6220  starting point is inside the input are included, and areas outside the input
6221  default to zero.
6222
6223  Example:
6224
6225  ```
6226    n = 10
6227    # images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100
6228    images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]
6229
6230    # We generate two outputs as follows:
6231    # 1. 3x3 patches with stride length 5
6232    # 2. Same as above, but the rate is increased to 2
6233    tf.image.extract_patches(images=images,
6234                             sizes=[1, 3, 3, 1],
6235                             strides=[1, 5, 5, 1],
6236                             rates=[1, 1, 1, 1],
6237                             padding='VALID')
6238
6239    # Yields:
6240    [[[[ 1  2  3 11 12 13 21 22 23]
6241       [ 6  7  8 16 17 18 26 27 28]]
6242      [[51 52 53 61 62 63 71 72 73]
6243       [56 57 58 66 67 68 76 77 78]]]]
6244  ```
6245
6246  If we mark the pixels in the input image which are taken for the output with
6247  `*`, we see the pattern:
6248
6249  ```
6250     *  *  *  4  5  *  *  *  9 10
6251     *  *  * 14 15  *  *  * 19 20
6252     *  *  * 24 25  *  *  * 29 30
6253    31 32 33 34 35 36 37 38 39 40
6254    41 42 43 44 45 46 47 48 49 50
6255     *  *  * 54 55  *  *  * 59 60
6256     *  *  * 64 65  *  *  * 69 70
6257     *  *  * 74 75  *  *  * 79 80
6258    81 82 83 84 85 86 87 88 89 90
6259    91 92 93 94 95 96 97 98 99 100
6260  ```
6261
6262  ```
6263    tf.image.extract_patches(images=images,
6264                             sizes=[1, 3, 3, 1],
6265                             strides=[1, 5, 5, 1],
6266                             rates=[1, 2, 2, 1],
6267                             padding='VALID')
6268
6269    # Yields:
6270    [[[[  1   3   5  21  23  25  41  43  45]
6271       [  6   8  10  26  28  30  46  48  50]]
6272
6273      [[ 51  53  55  71  73  75  91  93  95]
6274       [ 56  58  60  76  78  80  96  98 100]]]]
6275  ```
6276
6277  We can again draw the effect, this time using the symbols `*`, `x`, `+` and
6278  `o` to distinguish the patches:
6279
6280  ```
6281     *  2  *  4  *  x  7  x  9  x
6282    11 12 13 14 15 16 17 18 19 20
6283     * 22  * 24  *  x 27  x 29  x
6284    31 32 33 34 35 36 37 38 39 40
6285     * 42  * 44  *  x 47  x 49  x
6286     + 52  + 54  +  o 57  o 59  o
6287    61 62 63 64 65 66 67 68 69 70
6288     + 72  + 74  +  o 77  o 79  o
6289    81 82 83 84 85 86 87 88 89 90
6290     + 92  + 94  +  o 97  o 99  o
6291  ```
6292
6293  Args:
6294    images: A 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`.
6295    sizes: The size of the extracted patches. Must be
6296      `[1, size_rows, size_cols, 1]`.
6297    strides: A 1-D Tensor of length 4. How far the centers of two consecutive
6298      patches are in the images. Must be: `[1, stride_rows, stride_cols, 1]`.
6299    rates: A 1-D Tensor of length 4. Must be: `[1, rate_rows, rate_cols, 1]`.
6300      This is the input stride, specifying how far two consecutive patch samples
6301      are in the input. Equivalent to extracting patches with `patch_sizes_eff =
6302      patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by subsampling
6303      them spatially by a factor of `rates`. This is equivalent to `rate` in
6304      dilated (a.k.a. Atrous) convolutions.
6305    padding: The type of padding algorithm to use.
6306    name: A name for the operation (optional).
6307
6308  Returns:
6309    A 4-D Tensor of the same type as the input.
6310  """
6311  return gen_array_ops.extract_image_patches(images, sizes, strides, rates,
6312                                             padding, name)
6313
6314
6315@tf_export(v1=["image.extract_image_patches", "extract_image_patches"])
6316@dispatch.add_dispatch_support
6317@deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead",
6318                             "ksizes")
6319def extract_image_patches(  # pylint: disable=missing-docstring
6320    images,
6321    ksizes=None,
6322    strides=None,
6323    rates=None,
6324    padding=None,
6325    name=None,
6326    sizes=None):
6327  """Extract patches from images and put them in the "depth" output dimension.
6328
6329  Args:
6330    `images`: A `Tensor`. Must be one of the following types: `float32`,
6331      `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`,
6332      `uint16`, `half`, `uint32`, `uint64`. 4-D Tensor with shape
6333    `[batch, in_rows, in_cols, depth]`. `ksizes`: A list of `ints` that has
6334      length `>= 4`. The size of the sliding window for each
6335    dimension of `images`. `strides`: A list of `ints` that has length `>= 4`.
6336      1-D of length 4. How far the centers of two consecutive
6337    patches are in the images. Must be:
6338    `[1, stride_rows, stride_cols, 1]`. `rates`: A list of `ints`
6339    that has length `>= 4`. 1-D of length 4. Must be: `[1, rate_rows, rate_cols,
6340      1]`. This is the input stride, specifying how far two consecutive patch
6341      samples are in the input. Equivalent to extracting patches with
6342      `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`,
6343      followed by subsampling them spatially by a factor of `rates`. This is
6344      equivalent to `rate` in dilated (a.k.a. Atrous) convolutions.
6345    `padding`: A `string` from: "SAME", "VALID". The type of padding algorithm
6346      to use.
6347    We specify the size-related attributes as:  ``` ksizes = [1, ksize_rows,
6348      ksize_cols, 1] strides = [1, strides_rows, strides_cols, 1] rates = [1,
6349      rates_rows, rates_cols, 1]
6350    name: A name for the operation (optional). ```
6351
6352  Returns:
6353    A Tensor. Has the same type as images.
6354  """
6355  ksizes = deprecation.deprecated_argument_lookup("sizes", sizes, "ksizes",
6356                                                  ksizes)
6357  return gen_array_ops.extract_image_patches(images, ksizes, strides, rates,
6358                                             padding, name)
6359
6360
6361extract_image_patches.__doc__ = gen_array_ops.extract_image_patches.__doc__
6362
6363
6364@tf_export("fingerprint")
6365@dispatch.add_dispatch_support
6366def fingerprint(data, method="farmhash64", name=None):
6367  r"""Generates fingerprint values.
6368
6369  Generates fingerprint values of `data`.
6370
6371  Fingerprint op considers the first dimension of `data` as the batch dimension,
6372  and `output[i]` contains the fingerprint value generated from contents in
6373  `data[i, ...]` for all `i`.
6374
6375  Fingerprint op writes fingerprint values as byte arrays. For example, the
6376  default method `farmhash64` generates a 64-bit fingerprint value at a time.
6377  This 8-byte value is written out as an `tf.uint8` array of size 8, in
6378  little-endian order.
6379
6380  For example, suppose that `data` has data type `tf.int32` and shape (2, 3, 4),
6381  and that the fingerprint method is `farmhash64`. In this case, the output
6382  shape is (2, 8), where 2 is the batch dimension size of `data`, and 8 is the
6383  size of each fingerprint value in bytes. `output[0, :]` is generated from
6384  12 integers in `data[0, :, :]` and similarly `output[1, :]` is generated from
6385  other 12 integers in `data[1, :, :]`.
6386
6387  Note that this op fingerprints the raw underlying buffer, and it does not
6388  fingerprint Tensor's metadata such as data type and/or shape. For example, the
6389  fingerprint values are invariant under reshapes and bitcasts as long as the
6390  batch dimension remain the same:
6391
6392  ```python
6393  tf.fingerprint(data) == tf.fingerprint(tf.reshape(data, ...))
6394  tf.fingerprint(data) == tf.fingerprint(tf.bitcast(data, ...))
6395  ```
6396
6397  For string data, one should expect `tf.fingerprint(data) !=
6398  tf.fingerprint(tf.string.reduce_join(data))` in general.
6399
6400  Args:
6401    data: A `Tensor`. Must have rank 1 or higher.
6402    method: A `Tensor` of type `tf.string`. Fingerprint method used by this op.
6403      Currently available method is `farmhash64`.
6404    name: A name for the operation (optional).
6405
6406  Returns:
6407    A two-dimensional `Tensor` of type `tf.uint8`. The first dimension equals to
6408    `data`'s first dimension, and the second dimension size depends on the
6409    fingerprint algorithm.
6410  """
6411  return gen_array_ops.fingerprint(data, method, name)
6412
6413
6414def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
6415  """Converts the given value to an integer Tensor."""
6416  tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
6417  if tensor.dtype.is_integer:
6418    tensor = gen_math_ops.cast(tensor, dtype)
6419  else:
6420    raise TypeError("%s must be an integer tensor; dtype=%s" %
6421                    (name, tensor.dtype))
6422  return tensor
6423
6424
6425def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"):
6426  """Validate an `axis` parameter, and normalize it to be positive.
6427
6428  If `ndims` is known (i.e., not `None`), then check that `axis` is in the
6429  range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
6430  `axis + ndims` (otherwise).
6431  If `ndims` is not known, and `axis` is positive, then return it as-is.
6432  If `ndims` is not known, and `axis` is negative, then report an error.
6433
6434  Args:
6435    axis: An integer constant
6436    ndims: An integer constant, or `None`
6437    axis_name: The name of `axis` (for error messages).
6438    ndims_name: The name of `ndims` (for error messages).
6439
6440  Returns:
6441    The normalized `axis` value.
6442
6443  Raises:
6444    ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
6445      `ndims is None`.
6446  """
6447  if not isinstance(axis, int):
6448    raise TypeError("%s must be an int; got %s" %
6449                    (axis_name, type(axis).__name__))
6450  if ndims is not None:
6451    if 0 <= axis < ndims:
6452      return axis
6453    elif -ndims <= axis < 0:
6454      return axis + ndims
6455    else:
6456      raise ValueError("%s=%s out of bounds: expected %s<=%s<%s" %
6457                       (axis_name, axis, -ndims, axis_name, ndims))
6458  elif axis < 0:
6459    raise ValueError("%s may only be negative if %s is statically known." %
6460                     (axis_name, ndims_name))
6461  return axis
6462
6463
6464# This op is intended to exactly match the semantics of numpy.repeat, with
6465# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
6466# when axis is not specified.  Rather than implement that special behavior, we
6467# simply make `axis` be a required argument.
6468#
6469# External (OSS) `tf.repeat` feature request:
6470# https://github.com/tensorflow/tensorflow/issues/8246
6471def repeat_with_axis(data, repeats, axis, name=None):
6472  """Repeats elements of `data`.
6473
6474  Args:
6475    data: An `N`-dimensional tensor.
6476    repeats: A 1-D integer tensor specifying how many times each element in
6477      `axis` should be repeated.  `len(repeats)` must equal `data.shape[axis]`.
6478      Supports broadcasting from a scalar value.
6479    axis: `int`.  The axis along which to repeat values.  Must be less than
6480      `max(N, 1)`.
6481    name: A name for the operation.
6482
6483  Returns:
6484    A tensor with `max(N, 1)` dimensions.  Has the same shape as `data`,
6485    except that dimension `axis` has size `sum(repeats)`.
6486
6487  Example usage:
6488
6489  >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6490  <tf.Tensor: shape=(5,), dtype=string,
6491  numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6492  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6493  <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6494  array([[1, 2],
6495         [1, 2],
6496         [3, 4],
6497         [3, 4],
6498         [3, 4]], dtype=int32)>
6499  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6500  <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6501  array([[1, 1, 2, 2, 2],
6502         [3, 3, 4, 4, 4]], dtype=int32)>
6503
6504  """
6505  if not isinstance(axis, int):
6506    raise TypeError("axis must be an int; got %s" % type(axis).__name__)
6507
6508  with ops.name_scope(name, "Repeat", [data, repeats]):
6509    data = ops.convert_to_tensor(data, name="data")
6510    repeats = convert_to_int_tensor(repeats, name="repeats")
6511    repeats.shape.with_rank_at_most(1)
6512
6513    # If `data` is a scalar, then upgrade it to a vector.
6514    data = _with_nonzero_rank(data)
6515    data_shape = shape(data)
6516
6517    # If `axis` is negative, then convert it to a positive value.
6518    axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
6519
6520    # If we know that `repeats` is a scalar, then we can just tile & reshape.
6521    if repeats.shape.num_elements() == 1:
6522      repeats = reshape(repeats, [])
6523      expanded = expand_dims(data, axis + 1)
6524      tiled = tile_one_dimension(expanded, axis + 1, repeats)
6525      result_shape = concat([
6526          data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:]
6527      ],
6528                            axis=0)
6529      return reshape(tiled, result_shape)
6530
6531
6532    # Check data Tensor shapes.
6533    if repeats.shape.ndims == 1:
6534      data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
6535
6536    repeats = broadcast_to(repeats, [data_shape[axis]])
6537    repeats_original = repeats
6538
6539    # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
6540    if repeats.shape.ndims != axis + 1:
6541      repeats_shape = shape(repeats)
6542      repeats_ndims = rank(repeats)
6543      broadcast_shape = concat(
6544          [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
6545      repeats = broadcast_to(repeats, broadcast_shape)
6546      repeats.set_shape([None] * (axis + 1))
6547
6548    # Create a "sequence mask" based on `repeats`, where slices across `axis`
6549    # contain one `True` value for each repetition.  E.g., if
6550    # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
6551    max_repeat = gen_math_ops.maximum(
6552        0, gen_math_ops._max(repeats, _all_dimensions(repeats)))
6553    mask = sequence_mask(repeats, max_repeat)
6554
6555    # Add a new dimension around each value that needs to be repeated, and
6556    # then tile that new dimension to match the maximum number of repetitions.
6557    expanded = expand_dims(data, axis + 1)
6558    tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
6559
6560    # Use `boolean_mask` to discard the extra repeated values.  This also
6561    # flattens all dimensions up through `axis`.
6562    masked = boolean_mask(tiled, mask)
6563
6564    # Reshape the output tensor to add the outer dimensions back.
6565    if axis == 0:
6566      result = masked
6567    else:
6568      repeated_dim_size = gen_math_ops._sum(
6569          repeats_original,
6570          axis=gen_math_ops._range(0, rank(repeats_original), 1))
6571      result_shape = concat(
6572          [data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]],
6573          axis=0)
6574      result = reshape(masked, result_shape)
6575
6576    # Preserve shape information.
6577    if data.shape.ndims is not None:
6578      new_axis_size = 0 if repeats.shape[0] == 0 else None
6579      result.set_shape(data.shape[:axis].concatenate(
6580          [new_axis_size]).concatenate(data.shape[axis + 1:]))
6581
6582    return result
6583
6584
6585def tile_one_dimension(data, axis, multiple):
6586  """Tiles a single dimension of a tensor."""
6587  # Assumes axis is a nonnegative int.
6588  if data.shape.ndims is not None:
6589    multiples = [1] * data.shape.ndims
6590    multiples[axis] = multiple
6591  else:
6592    ones_value = ones(rank(data), dtypes.int32)
6593    multiples = concat([ones_value[:axis], [multiple], ones_value[axis + 1:]],
6594                       axis=0)
6595  return tile(data, multiples)
6596
6597
6598def _with_nonzero_rank(data):
6599  """If `data` is scalar, then add a dimension; otherwise return as-is."""
6600  if data.shape.ndims is not None:
6601    if data.shape.ndims == 0:
6602      return stack([data])
6603    else:
6604      return data
6605  else:
6606    data_shape = shape(data)
6607    data_ndims = rank(data)
6608    return reshape(data, concat([[1], data_shape], axis=0)[-data_ndims:])
6609
6610
6611@tf_export("repeat")
6612@dispatch.add_dispatch_support
6613def repeat(input, repeats, axis=None, name=None):  # pylint: disable=redefined-builtin
6614  """Repeat elements of `input`.
6615
6616  See also `tf.concat`, `tf.stack`, `tf.tile`.
6617
6618  Args:
6619    input: An `N`-dimensional Tensor.
6620    repeats: An 1-D `int` Tensor. The number of repetitions for each element.
6621      repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
6622      must equal `input.shape[axis]` if axis is not None.
6623    axis: An int. The axis along which to repeat values. By default (axis=None),
6624      use the flattened input array, and return a flat output array.
6625    name: A name for the operation.
6626
6627  Returns:
6628    A Tensor which has the same shape as `input`, except along the given axis.
6629      If axis is None then the output array is flattened to match the flattened
6630      input array.
6631
6632  Example usage:
6633
6634  >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6635  <tf.Tensor: shape=(5,), dtype=string,
6636  numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6637
6638  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6639  <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6640  array([[1, 2],
6641         [1, 2],
6642         [3, 4],
6643         [3, 4],
6644         [3, 4]], dtype=int32)>
6645
6646  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6647  <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6648  array([[1, 1, 2, 2, 2],
6649         [3, 3, 4, 4, 4]], dtype=int32)>
6650
6651  >>> repeat(3, repeats=4)
6652  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([3, 3, 3, 3], dtype=int32)>
6653
6654  >>> repeat([[1,2], [3,4]], repeats=2)
6655  <tf.Tensor: shape=(8,), dtype=int32,
6656  numpy=array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)>
6657
6658  """
6659  if axis is None:
6660    input = reshape(input, [-1])
6661    axis = 0
6662  return repeat_with_axis(input, repeats, axis, name)
6663
6664
6665@tf_export("guarantee_const")
6666@deprecation.deprecated(None, "Not for public use.")
6667def guarantee_const(input, name=None):    # pylint: disable=redefined-builtin
6668  """Promise to the TF runtime that the input tensor is a constant.
6669
6670  The runtime is then free to make optimizations based on this.
6671
6672  Returns the input tensor without modification.
6673
6674  Args:
6675    input: A `Tensor`.
6676    name: A name for this operation.
6677
6678  Returns:
6679    A `Tensor`. Has the same dtype as `input`.
6680  """
6681  return gen_array_ops.guarantee_const(input=input, name=name)
6682