• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# Tests for this file live in python/kernel_tests/array_ops_test.py
16"""Support for manipulating tensors."""
17
18import numbers
19import numpy as np
20
21from tensorflow.python.eager import context
22from tensorflow.python.eager import tape
23from tensorflow.python.framework import common_shapes
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import indexed_slices
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33# 'Constant' gets imported in the module 'array_ops'.
34from tensorflow.python.framework.constant_op import constant
35from tensorflow.python.ops import gen_array_ops
36from tensorflow.python.ops import gen_math_ops
37# go/tf-wildcard-import
38# pylint: disable=wildcard-import
39from tensorflow.python.ops.gen_array_ops import *
40from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse  # pylint: disable=unused-import
41from tensorflow.python.types import core
42from tensorflow.python.util import _pywrap_utils
43from tensorflow.python.util import deprecation
44from tensorflow.python.util import dispatch
45from tensorflow.python.util import lazy_loader
46from tensorflow.python.util import nest
47from tensorflow.python.util import tf_decorator
48from tensorflow.python.util.tf_export import tf_export
49# pylint: enable=wildcard-import
50
51math_ops = lazy_loader.LazyLoader(
52    "math_ops", globals(), "tensorflow.python.ops.math_ops")
53
54# Used for slicing to specify a new 1 size dimension
55newaxis = None
56tf_export("newaxis").export_constant(__name__, "newaxis")
57
58# We override the 'slice' for the "slice" op, so we keep Python's
59# existing 'slice' for later use in this module.
60_BaseSlice = slice
61
62
63@tf_export("reshape", v1=["reshape", "manip.reshape"])
64@dispatch.add_dispatch_support
65def reshape(tensor, shape, name=None):  # pylint: disable=redefined-outer-name
66  r"""Reshapes a tensor.
67
68  Given `tensor`, this operation returns a new `tf.Tensor` that has the same
69  values as `tensor` in the same order, except with a new shape given by
70  `shape`.
71
72  >>> t1 = [[1, 2, 3],
73  ...       [4, 5, 6]]
74  >>> print(tf.shape(t1).numpy())
75  [2 3]
76  >>> t2 = tf.reshape(t1, [6])
77  >>> t2
78  <tf.Tensor: shape=(6,), dtype=int32,
79    numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
80  >>> tf.reshape(t2, [3, 2])
81  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
82    array([[1, 2],
83           [3, 4],
84           [5, 6]], dtype=int32)>
85
86  The `tf.reshape` does not change the order of or the total number of elements
87  in the tensor, and so it can reuse the underlying data buffer. This makes it
88  a fast operation independent of how big of a tensor it is operating on.
89
90  >>> tf.reshape([1, 2, 3], [2, 2])
91  Traceback (most recent call last):
92  ...
93  InvalidArgumentError: Input to reshape is a tensor with 3 values, but the
94  requested shape has 4
95
96  To instead reorder the data to rearrange the dimensions of a tensor, see
97  `tf.transpose`.
98
99  >>> t = [[1, 2, 3],
100  ...      [4, 5, 6]]
101  >>> tf.reshape(t, [3, 2]).numpy()
102  array([[1, 2],
103         [3, 4],
104         [5, 6]], dtype=int32)
105  >>> tf.transpose(t, perm=[1, 0]).numpy()
106  array([[1, 4],
107         [2, 5],
108         [3, 6]], dtype=int32)
109
110  If one component of `shape` is the special value -1, the size of that
111  dimension is computed so that the total size remains constant.  In particular,
112  a `shape` of `[-1]` flattens into 1-D.  At most one component of `shape` can
113  be -1.
114
115  >>> t = [[1, 2, 3],
116  ...      [4, 5, 6]]
117  >>> tf.reshape(t, [-1])
118  <tf.Tensor: shape=(6,), dtype=int32,
119    numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
120  >>> tf.reshape(t, [3, -1])
121  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
122    array([[1, 2],
123           [3, 4],
124           [5, 6]], dtype=int32)>
125  >>> tf.reshape(t, [-1, 2])
126  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
127    array([[1, 2],
128           [3, 4],
129           [5, 6]], dtype=int32)>
130
131  `tf.reshape(t, [])` reshapes a tensor `t` with one element to a scalar.
132
133  >>> tf.reshape([7], []).numpy()
134  7
135
136  More examples:
137
138  >>> t = [1, 2, 3, 4, 5, 6, 7, 8, 9]
139  >>> print(tf.shape(t).numpy())
140  [9]
141  >>> tf.reshape(t, [3, 3])
142  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
143    array([[1, 2, 3],
144           [4, 5, 6],
145           [7, 8, 9]], dtype=int32)>
146
147  >>> t = [[[1, 1], [2, 2]],
148  ...      [[3, 3], [4, 4]]]
149  >>> print(tf.shape(t).numpy())
150  [2 2 2]
151  >>> tf.reshape(t, [2, 4])
152  <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
153    array([[1, 1, 2, 2],
154           [3, 3, 4, 4]], dtype=int32)>
155
156  >>> t = [[[1, 1, 1],
157  ...       [2, 2, 2]],
158  ...      [[3, 3, 3],
159  ...       [4, 4, 4]],
160  ...      [[5, 5, 5],
161  ...       [6, 6, 6]]]
162  >>> print(tf.shape(t).numpy())
163  [3 2 3]
164  >>> # Pass '[-1]' to flatten 't'.
165  >>> tf.reshape(t, [-1])
166  <tf.Tensor: shape=(18,), dtype=int32,
167    numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
168    dtype=int32)>
169  >>> # -- Using -1 to infer the shape --
170  >>> # Here -1 is inferred to be 9:
171  >>> tf.reshape(t, [2, -1])
172  <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
173    array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
174           [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
175  >>> # -1 is inferred to be 2:
176  >>> tf.reshape(t, [-1, 9])
177  <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
178    array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
179           [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
180  >>> # -1 is inferred to be 3:
181  >>> tf.reshape(t, [ 2, -1, 3])
182  <tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
183    array([[[1, 1, 1],
184            [2, 2, 2],
185            [3, 3, 3]],
186           [[4, 4, 4],
187            [5, 5, 5],
188            [6, 6, 6]]], dtype=int32)>
189
190  Args:
191    tensor: A `Tensor`.
192    shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
193      Defines the shape of the output tensor.
194    name: Optional string. A name for the operation.
195
196  Returns:
197    A `Tensor`. Has the same type as `tensor`.
198  """
199  result = gen_array_ops.reshape(tensor, shape, name)
200  tensor_util.maybe_set_static_shape(result, shape)
201  return result
202
203
204@tf_export("fill")
205@dispatch.add_dispatch_support
206def fill(dims, value, name=None):
207  r"""Creates a tensor filled with a scalar value.
208
209  See also `tf.ones`, `tf.zeros`, `tf.one_hot`, `tf.eye`.
210
211  This operation creates a tensor of shape `dims` and fills it with `value`.
212
213  For example:
214
215  >>> tf.fill([2, 3], 9)
216  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
217  array([[9, 9, 9],
218         [9, 9, 9]], dtype=int32)>
219
220  `tf.fill` evaluates at graph runtime and supports dynamic shapes based on
221  other runtime `tf.Tensors`, unlike `tf.constant(value, shape=dims)`, which
222  embeds the value as a `Const` node.
223
224  Args:
225    dims: A 1-D sequence of non-negative numbers. Represents the shape of the
226      output `tf.Tensor`. Entries should be of type: `int32`, `int64`.
227    value: A value to fill the returned `tf.Tensor`.
228    name: Optional string. The name of the output `tf.Tensor`.
229
230  Returns:
231    A `tf.Tensor` with shape `dims` and the same dtype as `value`.
232
233  Raises:
234    InvalidArgumentError: `dims` contains negative entries.
235    NotFoundError: `dims` contains non-integer entries.
236
237  @compatibility(numpy)
238  Similar to `np.full`. In `numpy`, more parameters are supported. Passing a
239  number argument as the shape (`np.full(5, value)`) is valid in `numpy` for
240  specifying a 1-D shaped result, while TensorFlow does not support this syntax.
241  @end_compatibility
242  """
243  result = gen_array_ops.fill(dims, value, name=name)
244  tensor_util.maybe_set_static_shape(result, dims)
245  return result
246
247
248@tf_export("identity")
249@dispatch.add_dispatch_support
250def identity(input, name=None):  # pylint: disable=redefined-builtin
251  r"""Return a Tensor with the same shape and contents as input.
252
253  The return value is not the same Tensor as the original, but contains the same
254  values.  This operation is fast when used on the same device.
255
256  For example:
257
258  >>> a = tf.constant([0.78])
259  >>> a_identity = tf.identity(a)
260  >>> a.numpy()
261  array([0.78], dtype=float32)
262  >>> a_identity.numpy()
263  array([0.78], dtype=float32)
264
265  Calling `tf.identity` on a variable will make a Tensor that represents the
266  value of that variable at the time it is called. This is equivalent to calling
267  `<variable>.read_value()`.
268
269  >>> a = tf.Variable(5)
270  >>> a_identity = tf.identity(a)
271  >>> a.assign_add(1)
272  <tf.Variable ... shape=() dtype=int32, numpy=6>
273  >>> a.numpy()
274  6
275  >>> a_identity.numpy()
276  5
277
278  Args:
279    input: A `Tensor`, a `Variable`, a `CompositeTensor` or anything that can be
280    converted to a tensor using `tf.convert_to_tensor`.
281    name: A name for the operation (optional).
282
283  Returns:
284    A `Tensor` or CompositeTensor. Has the same type and contents as `input`.
285  """
286  # Don't expand ResourceVariables, so identity(variable) will return a Tensor.
287  if (isinstance(input, composite_tensor.CompositeTensor) and
288      not _pywrap_utils.IsResourceVariable(input)):
289    return nest.map_structure(identity, input, expand_composites=True)
290  if context.executing_eagerly() and not hasattr(input, "graph"):
291    # Make sure we get an input with handle data attached from resource
292    # variables. Variables have correct handle data when graph building.
293    input = ops.convert_to_tensor(input)
294  ret = gen_array_ops.identity(input, name=name)
295  # Propagate handle data for happier shape inference for resource variables.
296  if hasattr(input, "_handle_data"):
297    ret._handle_data = input._handle_data  # pylint: disable=protected-access
298  return ret
299
300
301# pylint: disable=redefined-builtin,protected-access
302@tf_export(v1=["expand_dims"])
303@dispatch.add_dispatch_support
304@deprecation.deprecated_args(None, "Use the `axis` argument instead", "dim")
305def expand_dims(input, axis=None, name=None, dim=None):
306  """Returns a tensor with a length 1 axis inserted at index `axis`.
307
308  Given a tensor `input`, this operation inserts a dimension of length 1 at the
309  dimension index `axis` of `input`'s shape. The dimension index follows Python
310  indexing rules: It's zero-based, a negative index it is counted backward
311  from the end.
312
313  This operation is useful to:
314
315  * Add an outer "batch" dimension to a single element.
316  * Align axes for broadcasting.
317  * To add an inner vector length axis to a tensor of scalars.
318
319  For example:
320
321  If you have a single image of shape `[height, width, channels]`:
322
323  >>> image = tf.zeros([10,10,3])
324
325  You can add an outer `batch` axis by passing `axis=0`:
326
327  >>> tf.expand_dims(image, axis=0).shape.as_list()
328  [1, 10, 10, 3]
329
330  The new axis location matches Python `list.insert(axis, 1)`:
331
332  >>> tf.expand_dims(image, axis=1).shape.as_list()
333  [10, 1, 10, 3]
334
335  Following standard Python indexing rules, a negative `axis` counts from the
336  end so `axis=-1` adds an inner most dimension:
337
338  >>> tf.expand_dims(image, -1).shape.as_list()
339  [10, 10, 3, 1]
340
341  This operation requires that `axis` is a valid index for `input.shape`,
342  following Python indexing rules:
343
344  ```
345  -1-tf.rank(input) <= axis <= tf.rank(input)
346  ```
347
348  This operation is related to:
349
350  * `tf.squeeze`, which removes dimensions of size 1.
351  * `tf.reshape`, which provides more flexible reshaping capability.
352  * `tf.sparse.expand_dims`, which provides this functionality for
353    `tf.SparseTensor`
354
355  Args:
356    input: A `Tensor`.
357    axis: 0-D (scalar). Specifies the dimension index at which to expand the
358      shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`.
359    name: The name of the output `Tensor` (optional).
360    dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
361
362  Returns:
363    A `Tensor` with the same data as `input`, but its shape has an additional
364    dimension of size 1 added.
365
366  Raises:
367    ValueError: if either both or neither of `dim` and `axis` are specified.
368  """
369  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
370  if axis is None:
371    raise ValueError("Must specify an axis argument to tf.expand_dims()")
372  return expand_dims_v2(input, axis, name)
373
374
375@tf_export("expand_dims", v1=[])
376@dispatch.add_dispatch_support
377def expand_dims_v2(input, axis, name=None):
378  """Returns a tensor with a length 1 axis inserted at index `axis`.
379
380  Given a tensor `input`, this operation inserts a dimension of length 1 at the
381  dimension index `axis` of `input`'s shape. The dimension index follows Python
382  indexing rules: It's zero-based, a negative index it is counted backward
383  from the end.
384
385  This operation is useful to:
386
387  * Add an outer "batch" dimension to a single element.
388  * Align axes for broadcasting.
389  * To add an inner vector length axis to a tensor of scalars.
390
391  For example:
392
393  If you have a single image of shape `[height, width, channels]`:
394
395  >>> image = tf.zeros([10,10,3])
396
397  You can add an outer `batch` axis by passing `axis=0`:
398
399  >>> tf.expand_dims(image, axis=0).shape.as_list()
400  [1, 10, 10, 3]
401
402  The new axis location matches Python `list.insert(axis, 1)`:
403
404  >>> tf.expand_dims(image, axis=1).shape.as_list()
405  [10, 1, 10, 3]
406
407  Following standard Python indexing rules, a negative `axis` counts from the
408  end so `axis=-1` adds an inner most dimension:
409
410  >>> tf.expand_dims(image, -1).shape.as_list()
411  [10, 10, 3, 1]
412
413  This operation requires that `axis` is a valid index for `input.shape`,
414  following Python indexing rules:
415
416  ```
417  -1-tf.rank(input) <= axis <= tf.rank(input)
418  ```
419
420  This operation is related to:
421
422  * `tf.squeeze`, which removes dimensions of size 1.
423  * `tf.reshape`, which provides more flexible reshaping capability.
424  * `tf.sparse.expand_dims`, which provides this functionality for
425    `tf.SparseTensor`
426
427  Args:
428    input: A `Tensor`.
429    axis: Integer specifying the dimension index at which to expand the
430      shape of `input`. Given an input of D dimensions, `axis` must be in range
431      `[-(D+1), D]` (inclusive).
432    name: Optional string. The name of the output `Tensor`.
433
434  Returns:
435    A tensor with the same data as `input`, with an additional dimension
436    inserted at the index specified by `axis`.
437
438  Raises:
439    TypeError: If `axis` is not specified.
440    InvalidArgumentError: If `axis` is out of range `[-(D+1), D]`.
441  """
442  return gen_array_ops.expand_dims(input, axis, name)
443
444
445# pylint: enable=redefined-builtin,protected-access
446
447
448# Aliases for some automatically-generated names.
449# pylint: disable=protected-access
450@deprecation.deprecated("2016-11-30",
451                        "This op will be removed after the deprecation date. "
452                        "Please switch to tf.setdiff1d().")
453def listdiff(x, y, out_idx=None, name=None):
454  return gen_array_ops.list_diff(x, y, out_idx, name)
455
456
457listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__
458
459# pylint: enable=protected-access
460
461
462# pylint: disable=undefined-variable
463@deprecation.deprecated("2018-11-30",
464                        "This op will be removed after the deprecation date. "
465                        "Please switch to tf.sets.difference().")
466@tf_export(v1=["setdiff1d"])
467@dispatch.add_dispatch_support
468def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
469  """Computes the difference between two lists of numbers or strings.
470
471  Given a list x and a list y, this operation returns a list out that
472  represents all values that are in x but not in y. The returned list
473  out is sorted in the same order that the numbers appear in x
474  (duplicates are preserved). This operation also returns a list idx
475  that represents the position of each out element in x.
476
477  In other words:
478
479  ```python
480  out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]
481  ```
482
483  Example usage:
484
485  >>> x = [1, 2, 3, 4, 5, 6]
486  >>> y = [1, 3, 5]
487  >>> setdiff1d(x,y)
488  ListDiff(out=<tf.Tensor: id=2, shape=(3,), dtype=int32,
489  numpy=array([2, 4, 6], dtype=int32)>, idx=<tf.Tensor: id=3,
490  shape=(3,), dtype=int32, numpy=array([1, 3, 5], dtype=int32)>)
491
492  Args:
493    x: A Tensor. 1-D. Values to keep.
494    y: A Tensor. Must have the same type as x. 1-D. Values to remove.
495    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
496      tf.int32.
497    name: A name for the operation (optional).
498
499  Returns:
500    A tuple of Tensor objects (out, idx).
501    out: A Tensor. Has the same type as x.
502    idx: A Tensor of type out_idx.
503  """
504  return gen_array_ops.list_diff(x, y, index_dtype, name)
505
506
507setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__
508
509
510@tf_export("broadcast_dynamic_shape")
511@dispatch.add_dispatch_support
512def broadcast_dynamic_shape(shape_x, shape_y):
513  """Computes the shape of a broadcast given symbolic shapes.
514
515  When `shape_x` and `shape_y` are Tensors representing shapes (i.e. the result
516  of calling tf.shape on another Tensor) this computes a Tensor which is the
517  shape of the result of a broadcasting op applied in tensors of shapes
518  `shape_x` and `shape_y`.
519
520  This is useful when validating the result of a broadcasting operation when the
521  tensors do not have statically known shapes.
522
523  Example:
524
525  >>> shape_x = (1, 2, 3)
526  >>> shape_y = (5, 1, 3)
527  >>> tf.broadcast_dynamic_shape(shape_x, shape_y)
528  <tf.Tensor: shape=(3,), dtype=int32, numpy=array([5, 2, 3], ...>
529
530  Args:
531    shape_x: A rank 1 integer `Tensor`, representing the shape of x.
532    shape_y: A rank 1 integer `Tensor`, representing the shape of y.
533
534  Returns:
535    A rank 1 integer `Tensor` representing the broadcasted shape.
536
537  Raises:
538    InvalidArgumentError: If the two shapes are incompatible for
539    broadcasting.
540  """
541  return gen_array_ops.broadcast_args(shape_x, shape_y)
542
543
544@tf_export("broadcast_static_shape")
545@dispatch.add_dispatch_support
546def broadcast_static_shape(shape_x, shape_y):
547  """Computes the shape of a broadcast given known shapes.
548
549  When `shape_x` and `shape_y` are fully known `TensorShape`s this computes a
550  `TensorShape` which is the shape of the result of a broadcasting op applied in
551  tensors of shapes `shape_x` and `shape_y`.
552
553  For example, if shape_x is `TensorShape([1, 2, 3])` and shape_y is
554  `TensorShape([5, 1, 3])`, the result is a TensorShape whose value is
555  `TensorShape([5, 2, 3])`.
556
557  This is useful when validating the result of a broadcasting operation when the
558  tensors have statically known shapes.
559
560  Example:
561
562  >>> shape_x = tf.TensorShape([1, 2, 3])
563  >>> shape_y = tf.TensorShape([5, 1 ,3])
564  >>> tf.broadcast_static_shape(shape_x, shape_y)
565  TensorShape([5, 2, 3])
566
567  Args:
568    shape_x: A `TensorShape`
569    shape_y: A `TensorShape`
570
571  Returns:
572    A `TensorShape` representing the broadcasted shape.
573
574  Raises:
575    ValueError: If the two shapes can not be broadcasted.
576  """
577  return common_shapes.broadcast_shape(shape_x, shape_y)
578
579
580@tf_export("shape", v1=[])
581@dispatch.add_dispatch_support
582def shape_v2(input, out_type=dtypes.int32, name=None):
583  # pylint: disable=redefined-builtin
584  """Returns a tensor containing the shape of the input tensor.
585
586  See also `tf.size`, `tf.rank`.
587
588  `tf.shape` returns a 1-D integer tensor representing the shape of `input`.
589  For a scalar input, the tensor returned has a shape of (0,) and its value is
590  the empty vector (i.e. []).
591
592  For example:
593
594  >>> tf.shape(1.)
595  <tf.Tensor: shape=(0,), dtype=int32, numpy=array([], dtype=int32)>
596
597  >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
598  >>> tf.shape(t)
599  <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 2, 3], dtype=int32)>
600
601  Note: When using symbolic tensors, such as when using the Keras API,
602  tf.shape() will return the shape of the symbolic tensor.
603
604  >>> a = tf.keras.layers.Input((None, 10))
605  >>> tf.shape(a)
606  <... shape=(3,) dtype=int32...>
607
608  In these cases, using `tf.Tensor.shape` will return more informative results.
609
610  >>> a.shape
611  TensorShape([None, None, 10])
612
613  (The first `None` represents the as yet unknown batch size.)
614
615  `tf.shape` and `Tensor.shape` should be identical in eager mode.  Within
616  `tf.function` or within a `compat.v1` context, not all dimensions may be
617  known until execution time. Hence when defining custom layers and models
618  for graph mode, prefer the dynamic `tf.shape(x)` over the static `x.shape`.
619
620  Args:
621    input: A `Tensor` or `SparseTensor`.
622    out_type: (Optional) The specified output type of the operation (`int32` or
623      `int64`). Defaults to `tf.int32`.
624    name: A name for the operation (optional).
625
626  Returns:
627    A `Tensor` of type `out_type`.
628  """
629  return shape(input, name, out_type)
630
631
632@tf_export(v1=["shape"])
633@dispatch.add_dispatch_support
634def shape(input, name=None, out_type=dtypes.int32):
635  # pylint: disable=redefined-builtin
636  """Returns the shape of a tensor.
637
638  This operation returns a 1-D integer tensor representing the shape of `input`.
639
640  For example:
641
642  ```python
643  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
644  tf.shape(t)  # [2, 2, 3]
645  ```
646
647  Args:
648    input: A `Tensor` or `SparseTensor`.
649    name: A name for the operation (optional).
650    out_type: (Optional) The specified output type of the operation (`int32`
651    or `int64`). Defaults to `tf.int32`.
652
653  Returns:
654    A `Tensor` of type `out_type`.
655  """
656  return shape_internal(input, name, optimize=True, out_type=out_type)
657
658
659def shape_internal(input, name=None, optimize=True, out_type=None):
660  # pylint: disable=redefined-builtin
661  """Returns the shape of a tensor.
662
663  If `out_type` is not specified and the shape is fully known, then we look at
664  the dimension values to determine whether to return an int32 or int64 tensor.
665  If the shape is not fully known, we default to int32.
666
667  Args:
668    input: A `Tensor` or `SparseTensor`.
669    name: A name for the operation (optional).
670    optimize: if true, encode the shape as a constant when possible.
671    out_type: (Optional) The specified output type of the operation (`int32` or
672      `int64`). Defaults to tf.int32.
673
674  Returns:
675    A `Tensor` of type `out_type`.
676
677  """
678  with ops.name_scope(name, "Shape", [input]) as name:
679    if isinstance(
680        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
681      if not out_type:
682        out_type = dtypes.int32
683      return gen_math_ops.cast(input.dense_shape, out_type)
684    else:
685      if not context.executing_eagerly():
686        input = ops.convert_to_tensor(input)
687        input_shape = input.get_shape()
688        if optimize and input_shape.is_fully_defined():
689          # For fully defined shapes, if the out_type is not specified, we pick
690          # int32 / int64 based on the actual values.
691          if not out_type:
692            return constant_op._tensor_shape_tensor_conversion_function(  # pylint: disable=protected-access
693                input_shape)
694          return constant(input_shape.as_list(), out_type, name=name)
695      if not out_type:
696        out_type = dtypes.int32
697      return gen_array_ops.shape(input, name=name, out_type=out_type)
698
699
700@tf_export("shape_n")
701@dispatch.add_dispatch_support
702def shape_n(input, out_type=dtypes.int32, name=None):
703  # pylint: disable=redefined-builtin
704  """Returns shape of tensors.
705
706  Args:
707    input: A list of at least 1 `Tensor` object with the same type.
708    out_type: The specified output type of the operation (`int32` or `int64`).
709      Defaults to `tf.int32`(optional).
710    name: A name for the operation (optional).
711
712  Returns:
713    A list with the same length as `input` of `Tensor` objects with
714      type `out_type`.
715  """
716
717  return gen_array_ops.shape_n(input, out_type=out_type, name=name)
718
719
720@tf_export("size", v1=[])
721@dispatch.add_dispatch_support
722def size_v2(input, out_type=dtypes.int32, name=None):
723  # pylint: disable=redefined-builtin
724  """Returns the size of a tensor.
725
726  See also `tf.shape`.
727
728  Returns a 0-D `Tensor` representing the number of elements in `input`
729  of type `out_type`. Defaults to tf.int32.
730
731  For example:
732
733  >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
734  >>> tf.size(t)
735  <tf.Tensor: shape=(), dtype=int32, numpy=12>
736
737  Args:
738    input: A `Tensor` or `SparseTensor`.
739    name: A name for the operation (optional).
740    out_type: (Optional) The specified non-quantized numeric output type of the
741      operation. Defaults to `tf.int32`.
742
743  Returns:
744    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
745
746  @compatibility(numpy)
747  Equivalent to np.size()
748  @end_compatibility
749  """
750
751  return size(input, name, out_type)
752
753
754@tf_export(v1=["size"])
755@dispatch.add_dispatch_support
756def size(input, name=None, out_type=dtypes.int32):
757  # pylint: disable=redefined-builtin
758  """Returns the size of a tensor.
759
760  Returns a 0-D `Tensor` representing the number of elements in `input`
761  of type `out_type`. Defaults to tf.int32.
762
763  For example:
764
765  ```python
766  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
767  tf.size(t)  # 12
768  ```
769
770  Args:
771    input: A `Tensor` or `SparseTensor`.
772    name: A name for the operation (optional).
773    out_type: (Optional) The specified non-quantized numeric output type of the
774      operation. Defaults to `tf.int32`.
775
776  Returns:
777    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
778
779  @compatibility(numpy)
780  Equivalent to np.size()
781  @end_compatibility
782  """
783  return size_internal(input, name, optimize=True, out_type=out_type)
784
785
786def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
787  # pylint: disable=redefined-builtin,protected-access
788  """Returns the size of a tensor.
789
790  Args:
791    input: A `Tensor` or `SparseTensor`.
792    name: A name for the operation (optional).
793    optimize: if true, encode the size as a constant when possible.
794    out_type: (Optional) The specified non-quantized numeric output type of the
795      operation. Defaults to `tf.int32`.
796
797  Returns:
798    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
799  """
800  if (context.executing_eagerly() and not hasattr(input, "graph") and
801      not isinstance(
802          input,
803          (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))):
804    input = ops.convert_to_tensor(input)
805    np_out_type = out_type.as_numpy_dtype
806    num_elements = np.prod(input._shape_tuple(), dtype=np_out_type)  # pylint: disable=protected-access
807    return ops.convert_to_tensor(num_elements, dtype=out_type)
808  with ops.name_scope(name, "Size", [input]) as name:
809    if isinstance(
810        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
811      return gen_math_ops.prod(
812          gen_math_ops.cast(input.dense_shape, out_type), 0, name=name)
813    else:
814      input = ops.convert_to_tensor(input)
815      input_shape = input.get_shape()
816      if optimize:
817        if input_shape.is_fully_defined():
818          return constant(input_shape.num_elements(), out_type, name=name)
819        if input_shape.dims and any(dim == 0 for dim in input_shape.dims):
820          return constant(0, out_type, name=name)
821      return gen_array_ops.size(input, name=name, out_type=out_type)
822
823
824@tf_export("rank")
825@dispatch.add_dispatch_support
826def rank(input, name=None):
827  # pylint: disable=redefined-builtin
828  """Returns the rank of a tensor.
829
830  See also `tf.shape`.
831
832  Returns a 0-D `int32` `Tensor` representing the rank of `input`.
833
834  For example:
835
836  ```python
837  # shape of tensor 't' is [2, 2, 3]
838  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
839  tf.rank(t)  # 3
840  ```
841
842  **Note**: The rank of a tensor is not the same as the rank of a matrix. The
843  rank of a tensor is the number of indices required to uniquely select each
844  element of the tensor. Rank is also known as "order", "degree", or "ndims."
845
846  Args:
847    input: A `Tensor` or `SparseTensor`.
848    name: A name for the operation (optional).
849
850  Returns:
851    A `Tensor` of type `int32`.
852
853  @compatibility(numpy)
854  Equivalent to np.ndim
855  @end_compatibility
856  """
857  return rank_internal(input, name, optimize=True)
858
859
860def rank_internal(input, name=None, optimize=True):
861  # pylint: disable=redefined-builtin
862  """Returns the rank of a tensor.
863
864  Args:
865    input: A `Tensor` or `SparseTensor`.
866    name: A name for the operation (optional).
867    optimize: if true, encode the rank as a constant when possible.
868
869  Returns:
870    A `Tensor` of type `int32`.
871  """
872  with ops.name_scope(name, "Rank", [input]) as name:
873    if isinstance(
874        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
875      return gen_array_ops.size(input.dense_shape, name=name)
876    else:
877      input = ops.convert_to_tensor(input)
878      input_shape = input.get_shape()
879      if optimize and input_shape.ndims is not None:
880        return constant(input_shape.ndims, dtypes.int32, name=name)
881      return gen_array_ops.rank(input, name=name)
882
883
884_SLICE_TYPE_ERROR = (
885    "Only integers, slices (`:`), ellipsis (`...`), "
886    "tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid "
887    "indices")
888
889_SUPPORTED_SLICE_DTYPES = (dtypes.int16, dtypes.int32, dtypes.int32_ref,
890                           dtypes.int64, dtypes.int64_ref)
891
892
893def _check_index(idx):
894  """Check if a given value is a valid index into a tensor."""
895  if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
896    return
897
898  # Optimistic check. Assumptions:
899  # * any object with a dtype is supported
900  # * any object with a dtype has a sizeable shape attribute.
901  dtype = getattr(idx, "dtype", None)
902  if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or
903      idx.shape and len(idx.shape) == 1):
904    # TODO(slebedev): IndexError seems more appropriate here, but it
905    # will break `_slice_helper` contract.
906    raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))
907
908
909def _is_undefined_dimension(d):
910  return isinstance(d, tensor_shape.Dimension) and d.value is None
911
912
913@tf_export("__operators__.getitem", v1=[])
914@dispatch.add_dispatch_support
915def _slice_helper(tensor, slice_spec, var=None):
916  """Overload for Tensor.__getitem__.
917
918  This operation extracts the specified region from the tensor.
919  The notation is similar to NumPy with the restriction that
920  currently only support basic indexing. That means that
921  using a non-scalar tensor as input is not currently allowed.
922
923  Some useful examples:
924
925  ```python
926  # Strip leading and trailing 2 elements
927  foo = tf.constant([1,2,3,4,5,6])
928  print(foo[2:-2])  # => [3,4]
929
930  # Skip every other row and reverse the order of the columns
931  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
932  print(foo[::2,::-1])  # => [[3,2,1], [9,8,7]]
933
934  # Use scalar tensors as indices on both dimensions
935  print(foo[tf.constant(0), tf.constant(2)])  # => 3
936
937  # Insert another dimension
938  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
939  print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
940  print(foo[:, tf.newaxis, :]) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
941  print(foo[:, :, tf.newaxis]) # => [[[1],[2],[3]], [[4],[5],[6]],
942  [[7],[8],[9]]]
943
944  # Ellipses (3 equivalent operations)
945  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
946  print(foo[tf.newaxis, :, :])  # => [[[1,2,3], [4,5,6], [7,8,9]]]
947  print(foo[tf.newaxis, ...])  # => [[[1,2,3], [4,5,6], [7,8,9]]]
948  print(foo[tf.newaxis])  # => [[[1,2,3], [4,5,6], [7,8,9]]]
949
950  # Masks
951  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
952  print(foo[foo > 2])  # => [3, 4, 5, 6, 7, 8, 9]
953  ```
954
955  Notes:
956    - `tf.newaxis` is `None` as in NumPy.
957    - An implicit ellipsis is placed at the end of the `slice_spec`
958    - NumPy advanced indexing is currently not supported.
959
960  Purpose in the API:
961
962    This method is exposed in TensorFlow's API so that library developers
963    can register dispatching for `Tensor.__getitem__` to allow it to handle
964    custom composite tensors & other custom objects.
965
966    The API symbol is not intended to be called by users directly and does
967    appear in TensorFlow's generated documentation.
968
969  Args:
970    tensor: An ops.Tensor object.
971    slice_spec: The arguments to Tensor.__getitem__.
972    var: In the case of variable slice assignment, the Variable object to slice
973      (i.e. tensor is the read-only view of this variable).
974
975  Returns:
976    The appropriate slice of "tensor", based on "slice_spec".
977
978  Raises:
979    ValueError: If a slice range is negative size.
980    TypeError: If the slice indices aren't int, slice, ellipsis,
981      tf.newaxis or scalar int32/int64 tensors.
982  """
983  tensor = ops.convert_to_tensor(tensor)
984  # TODO(wangpeng): Consider supporting var
985  if var is None and ops._numpy_style_slicing:  # pylint: disable=protected-access
986    return tensor._numpy_style_getitem(slice_spec)  # pylint: disable=protected-access
987
988  if isinstance(slice_spec, bool) or \
989  (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
990  (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
991    return boolean_mask(tensor=tensor, mask=slice_spec)
992
993  if not isinstance(slice_spec, (list, tuple)):
994    slice_spec = [slice_spec]
995
996  begin, end, strides = [], [], []
997  index = 0
998
999  new_axis_mask, shrink_axis_mask = 0, 0
1000  begin_mask, end_mask = 0, 0
1001  ellipsis_mask = 0
1002  for s in slice_spec:
1003    if isinstance(s, _BaseSlice):
1004      # Finds the best dtype for begin, end, and strides.
1005      dtype = None
1006      for t in [s.start, s.stop, s.step]:
1007        if t is None or not isinstance(t, ops.Tensor):
1008          continue
1009        if t.dtype == dtypes.int64:
1010          dtype = dtypes.int64
1011        elif t.dtype == dtypes.int32 and dtype != dtypes.int64:
1012          dtype = dtypes.int32
1013        elif t.dtype == dtypes.int16 and dtype is None:
1014          dtype = dtypes.int16
1015
1016      if s.start is not None and not _is_undefined_dimension(s.start):
1017        _check_index(s.start)
1018        begin.append(s.start)
1019      else:
1020        if dtype is not None:
1021          begin.append(constant_op.constant(0, dtype=dtype))
1022        else:
1023          begin.append(0)
1024        begin_mask |= (1 << index)
1025      if s.stop is not None and not _is_undefined_dimension(s.stop):
1026        _check_index(s.stop)
1027        end.append(s.stop)
1028      else:
1029        if dtype is not None:
1030          end.append(constant_op.constant(0, dtype=dtype))
1031        else:
1032          end.append(0)
1033        end_mask |= (1 << index)
1034      if s.step is not None and not _is_undefined_dimension(s.step):
1035        _check_index(s.step)
1036        strides.append(s.step)
1037      else:
1038        if dtype is not None:
1039          strides.append(constant_op.constant(1, dtype=dtype))
1040        else:
1041          strides.append(1)
1042    elif s is Ellipsis:
1043      begin.append(0)
1044      end.append(0)
1045      strides.append(1)
1046      ellipsis_mask |= (1 << index)
1047    elif s is newaxis:
1048      begin.append(0)
1049      end.append(0)
1050      strides.append(1)
1051      new_axis_mask |= (1 << index)
1052    else:
1053      _check_index(s)
1054      begin.append(s)
1055      end.append(s + 1)
1056      # TODO(mdan): Investigate why we can't set int32 here.
1057      if isinstance(s, ops.Tensor) and (s.dtype == dtypes.int16 or
1058                                        s.dtype == dtypes.int64):
1059        strides.append(constant_op.constant(1, dtype=s.dtype))
1060      else:
1061        strides.append(1)
1062      shrink_axis_mask |= (1 << index)
1063    index += 1
1064
1065  # stack possibly involves no tensors, so we must use op_scope correct graph.
1066  with ops.name_scope(
1067      None,
1068      "strided_slice", [tensor] + begin + end + strides,
1069      skip_on_eager=False) as name:
1070    if begin:
1071      packed_begin, packed_end, packed_strides = (stack(begin), stack(end),
1072                                                  stack(strides))
1073      # TODO(mdan): Instead of implicitly casting, it's better to enforce the
1074      # same dtypes.
1075      if (packed_begin.dtype == dtypes.int64 or
1076          packed_end.dtype == dtypes.int64 or
1077          packed_strides.dtype == dtypes.int64):
1078        if packed_begin.dtype != dtypes.int64:
1079          packed_begin = gen_math_ops.cast(packed_begin, dtypes.int64)
1080        if packed_end.dtype != dtypes.int64:
1081          packed_end = gen_math_ops.cast(packed_end, dtypes.int64)
1082        if packed_strides.dtype != dtypes.int64:
1083          packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64)
1084      elif (packed_begin.dtype == dtypes.int16 and
1085            packed_end.dtype == dtypes.int16 and
1086            packed_strides.dtype == dtypes.int16):
1087        if packed_begin.dtype != dtypes.int16:
1088          packed_begin = gen_math_ops.cast(packed_begin, dtypes.int16)
1089        if packed_end.dtype != dtypes.int16:
1090          packed_end = gen_math_ops.cast(packed_end, dtypes.int16)
1091        if packed_strides.dtype != dtypes.int16:
1092          packed_strides = gen_math_ops.cast(packed_strides, dtypes.int16)
1093    else:
1094      var_empty = constant([], dtype=dtypes.int32)
1095      packed_begin = packed_end = packed_strides = var_empty
1096    return strided_slice(
1097        tensor,
1098        packed_begin,
1099        packed_end,
1100        packed_strides,
1101        begin_mask=begin_mask,
1102        end_mask=end_mask,
1103        shrink_axis_mask=shrink_axis_mask,
1104        new_axis_mask=new_axis_mask,
1105        ellipsis_mask=ellipsis_mask,
1106        var=var,
1107        name=name)
1108
1109
1110# pylint: disable=undefined-variable,protected-access,redefined-outer-name
1111@tf_export("slice")
1112@dispatch.add_dispatch_support
1113def slice(input_, begin, size, name=None):
1114  # pylint: disable=redefined-builtin
1115  """Extracts a slice from a tensor.
1116
1117  See also `tf.strided_slice`.
1118
1119  This operation extracts a slice of size `size` from a tensor `input_` starting
1120  at the location specified by `begin`. The slice `size` is represented as a
1121  tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
1122  of `input_` that you want to slice. The starting location (`begin`) for the
1123  slice is represented as an offset in each dimension of `input_`. In other
1124  words, `begin[i]` is the offset into the i'th dimension of `input_` that you
1125  want to slice from.
1126
1127  Note that `tf.Tensor.__getitem__` is typically a more pythonic way to
1128  perform slices, as it allows you to write `foo[3:7, :-2]` instead of
1129  `tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`.
1130
1131  `begin` is zero-based; `size` is one-based. If `size[i]` is -1,
1132  all remaining elements in dimension i are included in the
1133  slice. In other words, this is equivalent to setting:
1134
1135  `size[i] = input_.dim_size(i) - begin[i]`
1136
1137  This operation requires that:
1138
1139  `0 <= begin[i] <= begin[i] + size[i] <= Di  for i in [0, n]`
1140
1141  For example:
1142
1143  ```python
1144  t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1145                   [[3, 3, 3], [4, 4, 4]],
1146                   [[5, 5, 5], [6, 6, 6]]])
1147  tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]]
1148  tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
1149                                     #   [4, 4, 4]]]
1150  tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
1151                                     #  [[5, 5, 5]]]
1152  ```
1153
1154  Args:
1155    input_: A `Tensor`.
1156    begin: An `int32` or `int64` `Tensor`.
1157    size: An `int32` or `int64` `Tensor`.
1158    name: A name for the operation (optional).
1159
1160  Returns:
1161    A `Tensor` the same type as `input_`.
1162  """
1163  return gen_array_ops._slice(input_, begin, size, name=name)
1164
1165
1166# pylint: disable=invalid-name
1167@tf_export("strided_slice")
1168@dispatch.add_dispatch_support
1169def strided_slice(input_,
1170                  begin,
1171                  end,
1172                  strides=None,
1173                  begin_mask=0,
1174                  end_mask=0,
1175                  ellipsis_mask=0,
1176                  new_axis_mask=0,
1177                  shrink_axis_mask=0,
1178                  var=None,
1179                  name=None):
1180  """Extracts a strided slice of a tensor (generalized Python array indexing).
1181
1182  See also `tf.slice`.
1183
1184  **Instead of calling this op directly most users will want to use the
1185  NumPy-style slicing syntax (e.g. `tensor[..., 3:4:-1, tf.newaxis, 3]`), which
1186  is supported via `tf.Tensor.__getitem__` and `tf.Variable.__getitem__`.**
1187  The interface of this op is a low-level encoding of the slicing syntax.
1188
1189  Roughly speaking, this op extracts a slice of size `(end-begin)/stride`
1190  from the given `input_` tensor. Starting at the location specified by `begin`
1191  the slice continues by adding `stride` to the index until all dimensions are
1192  not less than `end`.
1193  Note that a stride can be negative, which causes a reverse slice.
1194
1195  Given a Python slice `input[spec0, spec1, ..., specn]`,
1196  this function will be called as follows.
1197
1198  `begin`, `end`, and `strides` will be vectors of length n.
1199  n in general is not equal to the rank of the `input_` tensor.
1200
1201  In each mask field (`begin_mask`, `end_mask`, `ellipsis_mask`,
1202  `new_axis_mask`, `shrink_axis_mask`) the ith bit will correspond to
1203  the ith spec.
1204
1205  If the ith bit of `begin_mask` is set, `begin[i]` is ignored and
1206  the fullest possible range in that dimension is used instead.
1207  `end_mask` works analogously, except with the end range.
1208
1209  `foo[5:,:,:3]` on a 7x8x9 tensor is equivalent to `foo[5:7,0:8,0:3]`.
1210  `foo[::-1]` reverses a tensor with shape 8.
1211
1212  If the ith bit of `ellipsis_mask` is set, as many unspecified dimensions
1213  as needed will be inserted between other dimensions. Only one
1214  non-zero bit is allowed in `ellipsis_mask`.
1215
1216  For example `foo[3:5,...,4:5]` on a shape 10x3x3x10 tensor is
1217  equivalent to `foo[3:5,:,:,4:5]` and
1218  `foo[3:5,...]` is equivalent to `foo[3:5,:,:,:]`.
1219
1220  If the ith bit of `new_axis_mask` is set, then `begin`,
1221  `end`, and `stride` are ignored and a new length 1 dimension is
1222  added at this point in the output tensor.
1223
1224  For example,
1225  `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor.
1226
1227  If the ith bit of `shrink_axis_mask` is set, it implies that the ith
1228  specification shrinks the dimensionality by 1, taking on the value at index
1229  `begin[i]`. `end[i]` and `strides[i]` are ignored in this case. For example in
1230  Python one might do `foo[:, 3, :]` which would result in `shrink_axis_mask`
1231  equal to 2.
1232
1233
1234  NOTE: `begin` and `end` are zero-indexed.
1235  `strides` entries must be non-zero.
1236
1237
1238  ```python
1239  t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1240                   [[3, 3, 3], [4, 4, 4]],
1241                   [[5, 5, 5], [6, 6, 6]]])
1242  tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1])  # [[[3, 3, 3]]]
1243  tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1])  # [[[3, 3, 3],
1244                                                        #   [4, 4, 4]]]
1245  tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1])  # [[[4, 4, 4],
1246                                                           #   [3, 3, 3]]]
1247  ```
1248
1249  Args:
1250    input_: A `Tensor`.
1251    begin: An `int32` or `int64` `Tensor`.
1252    end: An `int32` or `int64` `Tensor`.
1253    strides: An `int32` or `int64` `Tensor`.
1254    begin_mask: An `int32` mask.
1255    end_mask: An `int32` mask.
1256    ellipsis_mask: An `int32` mask.
1257    new_axis_mask: An `int32` mask.
1258    shrink_axis_mask: An `int32` mask.
1259    var: The variable corresponding to `input_` or None
1260    name: A name for the operation (optional).
1261
1262  Returns:
1263    A `Tensor` the same type as `input`.
1264  """
1265
1266  if strides is None:
1267    strides = ones_like(begin)
1268
1269  op = gen_array_ops.strided_slice(
1270      input=input_,
1271      begin=begin,
1272      end=end,
1273      strides=strides,
1274      name=name,
1275      begin_mask=begin_mask,
1276      end_mask=end_mask,
1277      ellipsis_mask=ellipsis_mask,
1278      new_axis_mask=new_axis_mask,
1279      shrink_axis_mask=shrink_axis_mask)
1280
1281  parent_name = name
1282
1283  if var is not None:
1284    def assign(val, name=None):
1285      """Closure that holds all the arguments to create an assignment."""
1286
1287      if name is None:
1288        name = parent_name + "_assign"
1289
1290      return var._strided_slice_assign(
1291          begin=begin,
1292          end=end,
1293          strides=strides,
1294          value=val,
1295          name=name,
1296          begin_mask=begin_mask,
1297          end_mask=end_mask,
1298          ellipsis_mask=ellipsis_mask,
1299          new_axis_mask=new_axis_mask,
1300          shrink_axis_mask=shrink_axis_mask)
1301
1302    op.assign = assign
1303
1304  return op
1305
1306
1307def _SliceHelperVar(var, slice_spec):
1308  """Creates a slice helper object given a variable.
1309
1310  This allows creating a sub-tensor from part of the current contents
1311  of a variable. See `tf.Tensor.__getitem__` for detailed examples
1312  of slicing.
1313
1314  This function in addition also allows assignment to a sliced range.
1315  This is similar to `__setitem__` functionality in Python. However,
1316  the syntax is different so that the user can capture the assignment
1317  operation for grouping or passing to `sess.run()` in TF1.
1318  For example,
1319
1320  ```python
1321  import tensorflow as tf
1322  A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
1323  print(A[:2, :2])  # => [[1,2], [4,5]]
1324
1325  A[:2,:2].assign(22. * tf.ones((2, 2))))
1326  print(A) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
1327  ```
1328
1329  Note that assignments currently do not support NumPy broadcasting
1330  semantics.
1331
1332  Args:
1333    var: An `ops.Variable` object.
1334    slice_spec: The arguments to `Tensor.__getitem__`.
1335
1336  Returns:
1337    The appropriate slice of "tensor", based on "slice_spec".
1338    As an operator. The operator also has a `assign()` method
1339    that can be used to generate an assignment operator.
1340
1341  Raises:
1342    ValueError: If a slice range is negative size.
1343    TypeError: TypeError: If the slice indices aren't int, slice,
1344      ellipsis, tf.newaxis or int32/int64 tensors.
1345
1346  """
1347
1348  return _slice_helper(var.value(), slice_spec, var)
1349
1350
1351ops.Tensor._override_operator("__getitem__", _slice_helper)
1352
1353
1354@tf_export("parallel_stack")
1355@dispatch.add_dispatch_support
1356def parallel_stack(values, name="parallel_stack"):
1357  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel.
1358
1359  Requires that the shape of inputs be known at graph construction time.
1360
1361  Packs the list of tensors in `values` into a tensor with rank one higher than
1362  each tensor in `values`, by packing them along the first dimension.
1363  Given a list of length `N` of tensors of shape `(A, B, C)`; the `output`
1364  tensor will have the shape `(N, A, B, C)`.
1365
1366  For example:
1367
1368  ```python
1369  x = tf.constant([1, 4])
1370  y = tf.constant([2, 5])
1371  z = tf.constant([3, 6])
1372  tf.parallel_stack([x, y, z])  # [[1, 4], [2, 5], [3, 6]]
1373  ```
1374
1375  The difference between `stack` and `parallel_stack` is that `stack` requires
1376  all the inputs be computed before the operation will begin but doesn't require
1377  that the input shapes be known during graph construction.
1378
1379  `parallel_stack` will copy pieces of the input into the output as they become
1380  available, in some situations this can provide a performance benefit.
1381
1382  Unlike `stack`, `parallel_stack` does NOT support backpropagation.
1383
1384  This is the opposite of unstack.  The numpy equivalent is
1385
1386      tf.parallel_stack([x, y, z]) = np.asarray([x, y, z])
1387
1388  @compatibility(eager)
1389  parallel_stack is not compatible with eager execution.
1390  @end_compatibility
1391
1392  Args:
1393    values: A list of `Tensor` objects with the same shape and type.
1394    name: A name for this operation (optional).
1395
1396  Returns:
1397    output: A stacked `Tensor` with the same type as `values`.
1398
1399  Raises:
1400    RuntimeError: if executed in eager mode.
1401  """
1402  if context.executing_eagerly():
1403    raise RuntimeError("tf.parallel_stack() is not compatible with "
1404                       "eager execution.")
1405  with ops.name_scope(name):
1406    value_t = ops.convert_to_tensor(values[0])
1407    value_shape = ops.convert_to_tensor(value_t).get_shape()
1408
1409    output_shape = tensor_shape.TensorShape([len(values)])
1410    output_shape = output_shape.concatenate(value_shape)
1411    # expand_dims converts concat to stack.
1412    return gen_array_ops.parallel_concat(
1413        [expand_dims(value, 0) for value in values], shape=output_shape)
1414
1415
1416@tf_export("stack")
1417@dispatch.add_dispatch_support
1418def stack(values, axis=0, name="stack"):
1419  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
1420
1421  See also `tf.concat`, `tf.tile`, `tf.repeat`.
1422
1423  Packs the list of tensors in `values` into a tensor with rank one higher than
1424  each tensor in `values`, by packing them along the `axis` dimension.
1425  Given a list of length `N` of tensors of shape `(A, B, C)`;
1426
1427  if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
1428  if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
1429  Etc.
1430
1431  For example:
1432
1433  >>> x = tf.constant([1, 4])
1434  >>> y = tf.constant([2, 5])
1435  >>> z = tf.constant([3, 6])
1436  >>> tf.stack([x, y, z])
1437  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
1438  array([[1, 4],
1439         [2, 5],
1440         [3, 6]], dtype=int32)>
1441  >>> tf.stack([x, y, z], axis=1)
1442  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
1443  array([[1, 2, 3],
1444         [4, 5, 6]], dtype=int32)>
1445
1446  This is the opposite of unstack.  The numpy equivalent is `np.stack`
1447
1448  >>> np.array_equal(np.stack([x, y, z]), tf.stack([x, y, z]))
1449  True
1450
1451  Args:
1452    values: A list of `Tensor` objects with the same shape and type.
1453    axis: An `int`. The axis to stack along. Defaults to the first dimension.
1454      Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
1455    name: A name for this operation (optional).
1456
1457  Returns:
1458    output: A stacked `Tensor` with the same type as `values`.
1459
1460  Raises:
1461    ValueError: If `axis` is out of the range [-(R+1), R+1).
1462  """
1463  if axis == 0:
1464    try:
1465      # If the input is a constant list, it can be converted to a constant op
1466      return ops.convert_to_tensor(values, name=name)
1467    except (TypeError, ValueError, NotImplementedError):
1468      pass  # Input list contains non-constant tensors
1469
1470  value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple()  # pylint: disable=protected-access
1471  if value_shape is not None:
1472    expanded_num_dims = len(value_shape) + 1
1473    if axis < -expanded_num_dims or axis >= expanded_num_dims:
1474      raise ValueError(f"Argument `axis` = {axis} not in range "
1475                       f"[{-expanded_num_dims}, {expanded_num_dims})")
1476
1477  return gen_array_ops.pack(values, axis=axis, name=name)
1478
1479
1480# pylint: disable=invalid-name
1481def _autopacking_helper(list_or_tuple, dtype, name):
1482  """Converts the given list or tuple to a tensor by packing.
1483
1484  Args:
1485    list_or_tuple: A (possibly nested) list or tuple containing a tensor.
1486    dtype: The element type of the returned tensor.
1487    name: A name for the returned tensor.
1488
1489  Returns:
1490    A `tf.Tensor` with value equivalent to `list_or_tuple`.
1491  """
1492  if context.executing_eagerly():
1493    # NOTE: Fast path when all the items are tensors, this doesn't do any type
1494    # checking.
1495    if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
1496      return gen_array_ops.pack(list_or_tuple, name=name)
1497  must_pack = False
1498  converted_elems = []
1499  with ops.name_scope(name) as scope:
1500    for i, elem in enumerate(list_or_tuple):
1501      if isinstance(elem, core.Tensor):
1502        if dtype is not None and elem.dtype.base_dtype != dtype:
1503          raise TypeError(f"Cannot convert a list containing a tensor of dtype "
1504                          f"{elem.dtype} to {dtype} (Tensor is: {elem!r})")
1505        converted_elems.append(elem)
1506        must_pack = True
1507      elif isinstance(elem, (list, tuple)):
1508        converted_elem = _autopacking_helper(elem, dtype, str(i))
1509        if isinstance(converted_elem, core.Tensor):
1510          must_pack = True
1511        converted_elems.append(converted_elem)
1512      else:
1513        converted_elems.append(elem)
1514    if must_pack:
1515      elems_as_tensors = []
1516      for i, elem in enumerate(converted_elems):
1517        if isinstance(elem, core.Tensor):
1518          elems_as_tensors.append(elem)
1519        else:
1520          # NOTE(mrry): This is inefficient, but it enables us to
1521          # handle the case where the list arguments are other
1522          # convertible-to-tensor types, such as numpy arrays.
1523          elems_as_tensors.append(
1524              constant_op.constant(elem, dtype=dtype, name=str(i)))
1525      return gen_array_ops.pack(elems_as_tensors, name=scope)
1526    else:
1527      return converted_elems
1528
1529
1530def _get_dtype_from_nested_lists(list_or_tuple):
1531  """Returns the dtype of any tensor-like object in `list_or_tuple`, if found.
1532
1533  Args:
1534    list_or_tuple: A list or tuple representing an object that can be converted
1535      to a `tf.Tensor`.
1536
1537  Returns:
1538    The dtype of any tensor-like object in `list_or_tuple`, or `None` if no
1539    such object exists.
1540  """
1541  for elem in list_or_tuple:
1542    if isinstance(elem, core.Tensor):
1543      return elem.dtype.base_dtype
1544    elif isinstance(elem, (list, tuple)):
1545      maybe_dtype = _get_dtype_from_nested_lists(elem)
1546      if maybe_dtype is not None:
1547        return maybe_dtype
1548  return None
1549
1550
1551def _cast_nested_seqs_to_dtype(dtype):
1552
1553  def _maybe_cast(elem):
1554    if isinstance(elem, core.Tensor):
1555      if dtype != elem.dtype.base_dtype:
1556        elem = gen_math_ops.cast(elem, dtype)
1557    return elem
1558
1559  return _maybe_cast
1560
1561
1562_NON_AUTOPACKABLE_TYPES = set(np.core.numerictypes.ScalarType)
1563_NON_AUTOPACKABLE_TYPES.add(np.ndarray)
1564
1565
1566def _should_not_autopack(v):
1567  # The condition we really want is
1568  #    any(isinstance(elem, core.Tensor))
1569  # but it is >5x slower due to abc.ABCMeta.__instancecheck__.
1570  # pylint: disable=unidiomatic-typecheck
1571  # TODO(slebedev): add nest.all?
1572  return all(type(elem) in _NON_AUTOPACKABLE_TYPES for elem in nest.flatten(v))
1573  # pylint: enable=unidiomatic-typecheck
1574
1575
1576def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
1577  """Tensor conversion function that automatically packs arguments."""
1578  if as_ref or _should_not_autopack(v):
1579    return NotImplemented
1580  inferred_dtype = _get_dtype_from_nested_lists(v)
1581  if inferred_dtype is None:
1582    # We did not find any tensor-like objects in the nested lists, so defer to
1583    # other conversion functions.
1584    return NotImplemented
1585  if dtype is None:
1586    dtype = inferred_dtype
1587  elif dtype != inferred_dtype:
1588    v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
1589  return _autopacking_helper(v, dtype, name or "packed")
1590
1591
1592# pylint: enable=invalid-name
1593
1594# NOTE: Register this conversion function to run *before* one that
1595# assumes every element is a value.
1596ops.register_tensor_conversion_function((list, tuple),
1597                                        _autopacking_conversion_function, 99)
1598
1599
1600@tf_export("unstack")
1601@dispatch.add_dispatch_support
1602def unstack(value, num=None, axis=0, name="unstack"):
1603  """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
1604
1605  Unpacks tensors from `value` by chipping it along the `axis` dimension.
1606
1607  >>> x = tf.reshape(tf.range(12), (3,4))
1608  >>>
1609  >>> p, q, r = tf.unstack(x)
1610  >>> p.shape.as_list()
1611  [4]
1612
1613  >>> i, j, k, l = tf.unstack(x, axis=1)
1614  >>> i.shape.as_list()
1615  [3]
1616
1617  This is the opposite of stack.
1618
1619  >>> x = tf.stack([i, j, k, l], axis=1)
1620
1621  More generally if you have a tensor of shape `(A, B, C, D)`:
1622
1623  >>> A, B, C, D = [2, 3, 4, 5]
1624  >>> t = tf.random.normal(shape=[A, B, C, D])
1625
1626  The number of tensor returned is equal to the length of the target `axis`:
1627
1628  >>> axis = 2
1629  >>> items = tf.unstack(t, axis=axis)
1630  >>> len(items) == t.shape[axis]
1631  True
1632
1633  The shape of each result tensor is equal to the shape of the input tensor,
1634  with the target `axis` removed.
1635
1636  >>> items[0].shape.as_list()  # [A, B, D]
1637  [2, 3, 5]
1638
1639  The value of each tensor `items[i]` is equal to the slice of `input` across
1640  `axis` at index `i`:
1641
1642  >>> for i in range(len(items)):
1643  ...   slice = t[:,:,i,:]
1644  ...   assert tf.reduce_all(slice == items[i])
1645
1646  #### Python iterable unpacking
1647
1648  With eager execution you _can_ unstack the 0th axis of a tensor using python's
1649  iterable unpacking:
1650
1651  >>> t = tf.constant([1,2,3])
1652  >>> a,b,c = t
1653
1654  `unstack` is still necessary because Iterable unpacking doesn't work in
1655  a `@tf.function`: Symbolic tensors are not iterable.
1656
1657  You need to use `tf.unstack` here:
1658
1659  >>> @tf.function
1660  ... def bad(t):
1661  ...   a,b,c = t
1662  ...   return a
1663  >>>
1664  >>> bad(t)
1665  Traceback (most recent call last):
1666  ...
1667  OperatorNotAllowedInGraphError: ...
1668
1669  >>> @tf.function
1670  ... def good(t):
1671  ...   a,b,c = tf.unstack(t)
1672  ...   return a
1673  >>>
1674  >>> good(t).numpy()
1675  1
1676
1677  #### Unknown shapes
1678
1679  Eager tensors have concrete values, so their shape is always known.
1680  Inside a `tf.function` the symbolic tensors may have unknown shapes.
1681  If the length of `axis` is unknown `tf.unstack` will fail because it cannot
1682  handle an unknown number of tensors:
1683
1684  >>> @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
1685  ... def bad(t):
1686  ...   tensors = tf.unstack(t)
1687  ...   return tensors[0]
1688  >>>
1689  >>> bad(tf.constant([1,2,3]))
1690  Traceback (most recent call last):
1691  ...
1692  ValueError: Cannot infer argument `num` from shape (None,)
1693
1694  If you know the `axis` length you can pass it as the `num` argument. But this
1695  must be a constant value.
1696
1697  If you actually need a variable number of tensors in a single `tf.function`
1698  trace, you will need to use exlicit loops and a `tf.TensorArray` instead.
1699
1700  Args:
1701    value: A rank `R > 0` `Tensor` to be unstacked.
1702    num: An `int`. The length of the dimension `axis`. Automatically inferred if
1703      `None` (the default).
1704    axis: An `int`. The axis to unstack along. Defaults to the first dimension.
1705      Negative values wrap around, so the valid range is `[-R, R)`.
1706    name: A name for the operation (optional).
1707
1708  Returns:
1709    The list of `Tensor` objects unstacked from `value`.
1710
1711  Raises:
1712    ValueError: If `axis` is out of the range `[-R, R)`.
1713    ValueError: If `num` is unspecified and cannot be inferred.
1714    InvalidArgumentError: If `num` does not match the shape of `value`.
1715  """
1716  if num is None:
1717    value = ops.convert_to_tensor(value)
1718    value_shape = value.get_shape()
1719    if value_shape.ndims is not None:
1720      if axis < -value_shape.ndims or axis >= value_shape.ndims:
1721        raise ValueError(f"Argument `axis` = {axis} not in range "
1722                         f"[{-value_shape.ndims}, {value_shape.ndims})")
1723      num = value_shape.dims[axis].value
1724    if num is None:
1725      raise ValueError(f"Cannot infer argument `num` from shape {value_shape}")
1726  return gen_array_ops.unpack(value, num=num, axis=axis, name=name)
1727
1728
1729@tf_export("concat")
1730@dispatch.add_dispatch_support
1731def concat(values, axis, name="concat"):
1732  """Concatenates tensors along one dimension.
1733
1734  See also `tf.tile`, `tf.stack`, `tf.repeat`.
1735
1736  Concatenates the list of tensors `values` along dimension `axis`.  If
1737  `values[i].shape = [D0, D1, ... Daxis(i), ...Dn]`, the concatenated
1738  result has shape
1739
1740      [D0, D1, ... Raxis, ...Dn]
1741
1742  where
1743
1744      Raxis = sum(Daxis(i))
1745
1746  That is, the data from the input tensors is joined along the `axis`
1747  dimension.
1748
1749  The number of dimensions of the input tensors must match, and all dimensions
1750  except `axis` must be equal.
1751
1752  For example:
1753
1754  >>> t1 = [[1, 2, 3], [4, 5, 6]]
1755  >>> t2 = [[7, 8, 9], [10, 11, 12]]
1756  >>> tf.concat([t1, t2], 0)
1757  <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
1758  array([[ 1,  2,  3],
1759         [ 4,  5,  6],
1760         [ 7,  8,  9],
1761         [10, 11, 12]], dtype=int32)>
1762
1763  >>> tf.concat([t1, t2], 1)
1764  <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
1765  array([[ 1,  2,  3,  7,  8,  9],
1766         [ 4,  5,  6, 10, 11, 12]], dtype=int32)>
1767
1768  As in Python, the `axis` could also be negative numbers. Negative `axis`
1769  are interpreted as counting from the end of the rank, i.e.,
1770   `axis + rank(values)`-th dimension.
1771
1772  For example:
1773
1774  >>> t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
1775  >>> t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
1776  >>> tf.concat([t1, t2], -1)
1777  <tf.Tensor: shape=(2, 2, 4), dtype=int32, numpy=
1778    array([[[ 1,  2,  7,  4],
1779            [ 2,  3,  8,  4]],
1780           [[ 4,  4,  2, 10],
1781            [ 5,  3, 15, 11]]], dtype=int32)>
1782
1783  Note: If you are concatenating along a new axis consider using stack.
1784  E.g.
1785
1786  ```python
1787  tf.concat([tf.expand_dims(t, axis) for t in tensors], axis)
1788  ```
1789
1790  can be rewritten as
1791
1792  ```python
1793  tf.stack(tensors, axis=axis)
1794  ```
1795
1796  Args:
1797    values: A list of `Tensor` objects or a single `Tensor`.
1798    axis: 0-D `int32` `Tensor`.  Dimension along which to concatenate. Must be
1799      in the range `[-rank(values), rank(values))`. As in Python, indexing for
1800      axis is 0-based. Positive axis in the rage of `[0, rank(values))` refers
1801      to `axis`-th dimension. And negative axis refers to `axis +
1802      rank(values)`-th dimension.
1803    name: A name for the operation (optional).
1804
1805  Returns:
1806    A `Tensor` resulting from concatenation of the input tensors.
1807  """
1808  if not isinstance(values, (list, tuple)):
1809    values = [values]
1810  # TODO(mrry): Change to return values?
1811  if len(values) == 1:  # Degenerate case of one tensor.
1812    # Make a throwaway call to convert_to_tensor to make sure
1813    # that axis is of the correct type, and make sure that
1814    # the returned tensor is a scalar.
1815    # TODO(keveman): Implement a standalone type and shape checker.
1816    with ops.name_scope(name) as scope:
1817      ops.convert_to_tensor(
1818          axis, name="concat_dim",
1819          dtype=dtypes.int32).get_shape().assert_has_rank(0)
1820      return identity(values[0], name=name)
1821  return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
1822
1823
1824@tf_export(v1=["boolean_mask"])
1825@dispatch.add_dispatch_support
1826def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
1827  """Apply boolean mask to tensor.
1828
1829  Numpy equivalent is `tensor[mask]`.
1830
1831  In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1832  the first K dimensions of `tensor`'s shape.  We then have:
1833    `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1834  where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1835  The `axis` could be used with `mask` to indicate the axis to mask from.
1836  In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1837  the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1838
1839  See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1840  ragged tensors, and can be used if you need to preserve the masked dimensions
1841  of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1842
1843  Examples:
1844
1845  ```python
1846  # 1-D example
1847  tensor = [0, 1, 2, 3]
1848  mask = np.array([True, False, True, False])
1849  tf.boolean_mask(tensor, mask)  # [0, 2]
1850
1851  # 2-D example
1852  tensor = [[1, 2], [3, 4], [5, 6]]
1853  mask = np.array([True, False, True])
1854  tf.boolean_mask(tensor, mask)  # [[1, 2], [5, 6]]
1855  ```
1856
1857  Args:
1858    tensor:  N-D Tensor.
1859    mask:  K-D boolean Tensor, K <= N and K must be known statically.
1860    name:  A name for this operation (optional).
1861    axis:  A 0-D int Tensor representing the axis in `tensor` to mask from. By
1862      default, axis is 0 which will mask from the first dimension. Otherwise K +
1863      axis <= N.
1864
1865  Returns:
1866    (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1867    to `True` values in `mask`.
1868
1869  Raises:
1870    ValueError:  If shapes do not conform.
1871  """
1872
1873  def _apply_mask_1d(reshaped_tensor, mask, axis=None):
1874    """Mask tensor along dimension 0 with a 1-D mask."""
1875    indices = squeeze(where_v2(mask), axis=[1])
1876    return gather(reshaped_tensor, indices, axis=axis)
1877
1878  with ops.name_scope(name, values=[tensor, mask]):
1879    tensor = ops.convert_to_tensor(tensor, name="tensor")
1880    mask = ops.convert_to_tensor(mask, name="mask")
1881
1882    shape_mask = mask.get_shape()
1883    ndims_mask = shape_mask.ndims
1884    shape_tensor = tensor.get_shape()
1885    if ndims_mask == 0:
1886      raise ValueError("mask cannot be scalar.")
1887    if ndims_mask is None:
1888      raise ValueError(
1889          "Number of mask dimensions must be specified, even if some dimensions"
1890          " are None.  E.g. shape=[None] is ok, but shape=None is not.")
1891    axis = 0 if axis is None else axis
1892    axis_value = tensor_util.constant_value(axis)
1893    if axis_value is not None:
1894      axis = axis_value
1895      shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask)
1896
1897    leading_size = gen_math_ops.prod(shape(tensor)[axis:axis + ndims_mask], [0])
1898    tensor = reshape(
1899        tensor,
1900        concat([
1901            shape(tensor)[:axis], [leading_size],
1902            shape(tensor)[axis + ndims_mask:]
1903        ], 0))
1904    # TODO(yongtang): tf.reshape in C++ kernel might have set the shape
1905    # correctly, so the following may not be needed? It still might be possible
1906    # that there are some edge case where tensor_util.constant_value resolves
1907    # more cases than ShapeInference of tf.reshape in C++ kernel.
1908    if axis_value is not None:
1909      first_dim = shape_tensor[axis:axis + ndims_mask].num_elements()
1910      tensor.set_shape(
1911          tensor_shape.as_shape(shape_tensor[:axis]).concatenate(
1912              [first_dim]).concatenate(shape_tensor[axis + ndims_mask:]))
1913
1914    mask = reshape(mask, [-1])
1915    return _apply_mask_1d(tensor, mask, axis)
1916
1917
1918@tf_export("boolean_mask", v1=[])
1919@dispatch.add_dispatch_support
1920def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
1921  """Apply boolean mask to tensor.
1922
1923  Numpy equivalent is `tensor[mask]`.
1924
1925  In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1926  the first K dimensions of `tensor`'s shape.  We then have:
1927    `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1928  where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1929  The `axis` could be used with `mask` to indicate the axis to mask from.
1930  In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1931  the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1932
1933  See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1934  ragged tensors, and can be used if you need to preserve the masked dimensions
1935  of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1936
1937  Examples:
1938
1939  >>> tensor = [0, 1, 2, 3]  # 1-D example
1940  >>> mask = np.array([True, False, True, False])
1941  >>> tf.boolean_mask(tensor, mask)
1942  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 2], dtype=int32)>
1943
1944  >>> tensor = [[1, 2], [3, 4], [5, 6]] # 2-D example
1945  >>> mask = np.array([True, False, True])
1946  >>> tf.boolean_mask(tensor, mask)
1947  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
1948  array([[1, 2],
1949         [5, 6]], dtype=int32)>
1950
1951  Args:
1952    tensor:  N-D Tensor.
1953    mask:  K-D boolean Tensor, K <= N and K must be known statically.
1954    axis:  A 0-D int Tensor representing the axis in `tensor` to mask from. By
1955      default, axis is 0 which will mask from the first dimension. Otherwise K +
1956      axis <= N.
1957    name:  A name for this operation (optional).
1958
1959  Returns:
1960    (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1961    to `True` values in `mask`.
1962
1963  Raises:
1964    ValueError:  If shapes do not conform.
1965
1966  Examples:
1967
1968  ```python
1969  # 2-D example
1970  tensor = [[1, 2], [3, 4], [5, 6]]
1971  mask = np.array([True, False, True])
1972  boolean_mask(tensor, mask)  # [[1, 2], [5, 6]]
1973  ```
1974  """
1975  return boolean_mask(tensor, mask, name, axis)
1976
1977
1978@tf_export("sparse.mask", v1=["sparse.mask", "sparse_mask"])
1979@deprecation.deprecated_endpoints("sparse_mask")
1980def sparse_mask(a, mask_indices, name=None):
1981  """Masks elements of `IndexedSlices`.
1982
1983  Given an `IndexedSlices` instance `a`, returns another `IndexedSlices` that
1984  contains a subset of the slices of `a`. Only the slices at indices not
1985  specified in `mask_indices` are returned.
1986
1987  This is useful when you need to extract a subset of slices in an
1988  `IndexedSlices` object.
1989
1990  For example:
1991
1992  ```python
1993  # `a` contains slices at indices [12, 26, 37, 45] from a large tensor
1994  # with shape [1000, 10]
1995  a.indices  # [12, 26, 37, 45]
1996  tf.shape(a.values)  # [4, 10]
1997
1998  # `b` will be the subset of `a` slices at its second and third indices, so
1999  # we want to mask its first and last indices (which are at absolute
2000  # indices 12, 45)
2001  b = tf.sparse.mask(a, [12, 45])
2002
2003  b.indices  # [26, 37]
2004  tf.shape(b.values)  # [2, 10]
2005  ```
2006
2007  Args:
2008    a: An `IndexedSlices` instance.
2009    mask_indices: Indices of elements to mask.
2010    name: A name for the operation (optional).
2011
2012  Returns:
2013    The masked `IndexedSlices` instance.
2014  """
2015  with ops.name_scope(name, "sparse_mask", [a, mask_indices]) as name:
2016    indices = a.indices
2017    out_indices, to_gather = gen_array_ops.list_diff(indices, mask_indices)
2018    out_values = gather(a.values, to_gather, name=name)
2019    return indexed_slices.IndexedSlices(out_values, out_indices, a.dense_shape)
2020
2021
2022@tf_export("unique")
2023@dispatch.add_dispatch_support
2024def unique(x, out_idx=dtypes.int32, name=None):
2025  """Finds unique elements in a 1-D tensor.
2026
2027  See also `tf.unique_with_counts`.
2028
2029  This operation returns a tensor `y` containing all of the unique elements
2030  of `x` sorted in the same order that they occur in `x`. This operation
2031  also returns a tensor `idx` the same size as `x` that contains the index
2032  of each value of `x` in the unique output `y`. In other words:
2033
2034
2035    y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
2036
2037  Example usage:
2038
2039  >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
2040  >>> y, idx = unique(x)
2041  >>> y
2042  <tf.Tensor: id=5, shape=(5,), dtype=int32,
2043  numpy=array([1, 2, 4, 7, 8], dtype=int32)>
2044  >>> idx
2045  <tf.Tensor: id=6, shape=(9,), dtype=int32,
2046  numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
2047
2048  Args:
2049    x: A Tensor. 1-D.
2050    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
2051      tf.int32.
2052    name: A name for the operation (optional).
2053
2054  Returns:
2055    A tuple of Tensor objects (y, idx).
2056      y: A Tensor. Has the same type as x.
2057      idx: A Tensor of type out_idx.
2058
2059  """
2060  # TODO(yongtang): switch to v2 once API deprecation
2061  # period (3 weeks) pass.
2062  # TODO(yongtang): The documentation should also
2063  # be updated when switch  to v2.
2064  return gen_array_ops.unique(x, out_idx, name)
2065
2066
2067unique.__doc__ = gen_array_ops.unique.__doc__
2068
2069
2070@tf_export("unique_with_counts")
2071@dispatch.add_dispatch_support
2072def unique_with_counts(x, out_idx=dtypes.int32, name=None):
2073  """Finds unique elements in a 1-D tensor.
2074
2075  See also `tf.unique`.
2076
2077  This operation returns a tensor `y` containing all of the unique elements
2078  of `x` sorted in the same order that they occur in `x`. This operation
2079  also returns a tensor `idx` the same size as `x` that contains the index
2080  of each value of `x` in the unique output `y`. Finally, it returns a
2081  third tensor `count` that contains the count of each element of `y`
2082  in `x`. In other words:
2083
2084    y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
2085
2086  Example usage:
2087
2088  >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
2089  >>> y, idx, count = unique_with_counts(x)
2090  >>> y
2091  <tf.Tensor: id=8, shape=(5,), dtype=int32,
2092  numpy=array([1, 2, 4, 7, 8], dtype=int32)>
2093  >>> idx
2094  <tf.Tensor: id=9, shape=(9,), dtype=int32,
2095  numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
2096  >>> count
2097  <tf.Tensor: id=10, shape=(5,), dtype=int32,
2098  numpy=array([2, 1, 3, 1, 2], dtype=int32)>
2099
2100  Args:
2101    x: A Tensor. 1-D.
2102    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
2103      tf.int32.
2104    name: A name for the operation (optional).
2105
2106  Returns:
2107    A tuple of Tensor objects (y, idx, count).
2108      y: A Tensor. Has the same type as x.
2109      idx: A Tensor of type out_idx.
2110      count: A Tensor of type out_idx.
2111
2112  """
2113  # TODO(yongtang): switch to v2 once API deprecation
2114  # period (3 weeks) pass.
2115  # TODO(yongtang): The documentation should also
2116  # be updated when switch  to v2.
2117  return gen_array_ops.unique_with_counts(x, out_idx, name)
2118
2119
2120unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
2121
2122
2123@tf_export("split")
2124@dispatch.add_dispatch_support
2125def split(value, num_or_size_splits, axis=0, num=None, name="split"):
2126  """Splits a tensor `value` into a list of sub tensors.
2127
2128  See also `tf.unstack`.
2129
2130  If `num_or_size_splits` is an `int`,  then it splits `value` along the
2131  dimension `axis` into `num_or_size_splits` smaller tensors. This requires that
2132  `value.shape[axis]` is divisible by `num_or_size_splits`.
2133
2134  If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
2135  `len(num_or_size_splits)` elements. The shape of the `i`-th
2136  element has the same size as the `value` except along dimension `axis` where
2137  the size is `num_or_size_splits[i]`.
2138
2139  For example:
2140
2141  >>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
2142  >>>
2143  >>> # Split `x` into 3 tensors along dimension 1
2144  >>> s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
2145  >>> tf.shape(s0).numpy()
2146  array([ 5, 10], dtype=int32)
2147  >>>
2148  >>> # Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
2149  >>> split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
2150  >>> tf.shape(split0).numpy()
2151  array([5, 4], dtype=int32)
2152  >>> tf.shape(split1).numpy()
2153  array([ 5, 15], dtype=int32)
2154  >>> tf.shape(split2).numpy()
2155  array([ 5, 11], dtype=int32)
2156
2157  Args:
2158    value: The `Tensor` to split.
2159    num_or_size_splits: Either an `int` indicating the number of splits
2160      along `axis` or a 1-D integer `Tensor` or Python list containing the sizes
2161      of each output tensor along `axis`. If an `int`, then it must evenly
2162      divide `value.shape[axis]`; otherwise the sum of sizes along the split
2163      axis must match that of the `value`.
2164    axis: An `int` or scalar `int32` `Tensor`. The dimension along which
2165      to split. Must be in the range `[-rank(value), rank(value))`. Defaults to
2166      0.
2167    num: Optional, an `int`, used to specify the number of outputs when it
2168      cannot be inferred from the shape of `size_splits`.
2169    name: A name for the operation (optional).
2170
2171  Returns:
2172    if `num_or_size_splits` is an `int` returns a list of
2173    `num_or_size_splits` `Tensor` objects; if `num_or_size_splits` is a 1-D
2174    list or 1-D `Tensor` returns `num_or_size_splits.get_shape[0]`
2175    `Tensor` objects resulting from splitting `value`.
2176
2177  Raises:
2178    ValueError: If `num` is unspecified and cannot be inferred.
2179    ValueError: If `num_or_size_splits` is a scalar `Tensor`.
2180  """
2181  if isinstance(num_or_size_splits,
2182                (numbers.Integral, tensor_shape.Dimension)):
2183    return gen_array_ops.split(
2184        axis=axis, num_split=num_or_size_splits, value=value, name=name)
2185
2186  size_splits = ops.convert_to_tensor(num_or_size_splits)
2187
2188  if size_splits._rank() == 0:
2189    raise ValueError(
2190        "Rank-0 tensors are not supported as the num_or_size_splits argument "
2191        "to split. Argument provided: %s" % (num_or_size_splits,))
2192
2193  if num is None:
2194    size_splits_shape = size_splits._shape_tuple()
2195    if size_splits_shape:
2196      num = size_splits_shape[0]
2197    if num is None:
2198      raise ValueError(
2199          f"Cannot infer argument `num` from shape {num_or_size_splits}")
2200
2201  return gen_array_ops.split_v(
2202      value=value, size_splits=size_splits, axis=axis, num_split=num, name=name)
2203
2204
2205@tf_export("transpose", v1=[])
2206@dispatch.add_dispatch_support
2207def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
2208  """Transposes `a`, where `a` is a Tensor.
2209
2210  Permutes the dimensions according to the value of `perm`.
2211
2212  The returned tensor's dimension `i` will correspond to the input dimension
2213  `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is the rank
2214  of the input tensor. Hence by default, this operation performs a regular
2215  matrix transpose on 2-D input Tensors.
2216
2217  If conjugate is `True` and `a.dtype` is either `complex64` or `complex128`
2218  then the values of `a` are conjugated and transposed.
2219
2220  @compatibility(numpy)
2221  In `numpy` transposes are memory-efficient constant time operations as they
2222  simply return a new view of the same data with adjusted `strides`.
2223
2224  TensorFlow does not support strides, so `transpose` returns a new tensor with
2225  the items permuted.
2226  @end_compatibility
2227
2228  For example:
2229
2230  >>> x = tf.constant([[1, 2, 3], [4, 5, 6]])
2231  >>> tf.transpose(x)
2232  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
2233  array([[1, 4],
2234         [2, 5],
2235         [3, 6]], dtype=int32)>
2236
2237  Equivalently, you could call `tf.transpose(x, perm=[1, 0])`.
2238
2239  If `x` is complex, setting conjugate=True gives the conjugate transpose:
2240
2241  >>> x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2242  ...                  [4 + 4j, 5 + 5j, 6 + 6j]])
2243  >>> tf.transpose(x, conjugate=True)
2244  <tf.Tensor: shape=(3, 2), dtype=complex128, numpy=
2245  array([[1.-1.j, 4.-4.j],
2246         [2.-2.j, 5.-5.j],
2247         [3.-3.j, 6.-6.j]])>
2248
2249  'perm' is more useful for n-dimensional tensors where n > 2:
2250
2251  >>> x = tf.constant([[[ 1,  2,  3],
2252  ...                   [ 4,  5,  6]],
2253  ...                  [[ 7,  8,  9],
2254  ...                   [10, 11, 12]]])
2255
2256  As above, simply calling `tf.transpose` will default to `perm=[2,1,0]`.
2257
2258  To take the transpose of the matrices in dimension-0 (such as when you are
2259  transposing matrices where 0 is the batch dimension), you would set
2260  `perm=[0,2,1]`.
2261
2262  >>> tf.transpose(x, perm=[0, 2, 1])
2263  <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
2264  array([[[ 1,  4],
2265          [ 2,  5],
2266          [ 3,  6]],
2267          [[ 7, 10],
2268          [ 8, 11],
2269          [ 9, 12]]], dtype=int32)>
2270
2271  Note: This has a shorthand `linalg.matrix_transpose`):
2272
2273  Args:
2274    a: A `Tensor`.
2275    perm: A permutation of the dimensions of `a`.  This should be a vector.
2276    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2277      to tf.math.conj(tf.transpose(input)).
2278    name: A name for the operation (optional).
2279
2280  Returns:
2281    A transposed `Tensor`.
2282  """
2283  return transpose(a=a, perm=perm, name=name, conjugate=conjugate)
2284
2285
2286@tf_export(v1=["transpose"])
2287@dispatch.add_dispatch_support
2288def transpose(a, perm=None, name="transpose", conjugate=False):
2289  """Transposes `a`.
2290
2291  Permutes the dimensions according to `perm`.
2292
2293  The returned tensor's dimension i will correspond to the input dimension
2294  `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
2295  the rank of the input tensor. Hence by default, this operation performs a
2296  regular matrix transpose on 2-D input Tensors. If conjugate is True and
2297  `a.dtype` is either `complex64` or `complex128` then the values of `a`
2298  are conjugated and transposed.
2299
2300  @compatibility(numpy)
2301  In `numpy` transposes are memory-efficient constant time operations as they
2302  simply return a new view of the same data with adjusted `strides`.
2303
2304  TensorFlow does not support strides, so `transpose` returns a new tensor with
2305  the items permuted.
2306  @end_compatibility
2307
2308  For example:
2309
2310  ```python
2311  x = tf.constant([[1, 2, 3], [4, 5, 6]])
2312  tf.transpose(x)  # [[1, 4]
2313                   #  [2, 5]
2314                   #  [3, 6]]
2315
2316  # Equivalently
2317  tf.transpose(x, perm=[1, 0])  # [[1, 4]
2318                                #  [2, 5]
2319                                #  [3, 6]]
2320
2321  # If x is complex, setting conjugate=True gives the conjugate transpose
2322  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2323                   [4 + 4j, 5 + 5j, 6 + 6j]])
2324  tf.transpose(x, conjugate=True)  # [[1 - 1j, 4 - 4j],
2325                                   #  [2 - 2j, 5 - 5j],
2326                                   #  [3 - 3j, 6 - 6j]]
2327
2328  # 'perm' is more useful for n-dimensional tensors, for n > 2
2329  x = tf.constant([[[ 1,  2,  3],
2330                    [ 4,  5,  6]],
2331                   [[ 7,  8,  9],
2332                    [10, 11, 12]]])
2333
2334  # Take the transpose of the matrices in dimension-0
2335  # (this common operation has a shorthand `linalg.matrix_transpose`)
2336  tf.transpose(x, perm=[0, 2, 1])  # [[[1,  4],
2337                                   #   [2,  5],
2338                                   #   [3,  6]],
2339                                   #  [[7, 10],
2340                                   #   [8, 11],
2341                                   #   [9, 12]]]
2342  ```
2343
2344  Args:
2345    a: A `Tensor`.
2346    perm: A permutation of the dimensions of `a`.
2347    name: A name for the operation (optional).
2348    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2349      to tf.math.conj(tf.transpose(input)).
2350
2351  Returns:
2352    A transposed `Tensor`.
2353  """
2354  with ops.name_scope(name, "transpose", [a]) as name:
2355    if not tensor_util.is_tf_type(a):
2356      a = ops.convert_to_tensor(a, name="a")
2357
2358    if conjugate and a.dtype.is_complex:
2359      transpose_fn = gen_array_ops.conjugate_transpose
2360    else:
2361      transpose_fn = gen_array_ops.transpose
2362
2363    if perm is not None:
2364      return transpose_fn(a, perm, name=name)
2365
2366    rank = a.shape.rank
2367    if rank is None:
2368      perm = gen_math_ops._range(gen_array_ops.rank(a) - 1, -1, -1)
2369    else:
2370      perm = np.arange(rank - 1, -1, -1, dtype=np.int32)
2371    return transpose_fn(a, perm, name=name)
2372
2373
2374# pylint: disable=invalid-name
2375@tf_export(
2376    "linalg.matrix_transpose",
2377    v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"])
2378@dispatch.add_dispatch_support
2379@deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose")
2380def matrix_transpose(a, name="matrix_transpose", conjugate=False):
2381  """Transposes last two dimensions of tensor `a`.
2382
2383  For example:
2384
2385  ```python
2386  x = tf.constant([[1, 2, 3], [4, 5, 6]])
2387  tf.linalg.matrix_transpose(x)  # [[1, 4],
2388                                 #  [2, 5],
2389                                 #  [3, 6]]
2390
2391  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2392                   [4 + 4j, 5 + 5j, 6 + 6j]])
2393  tf.linalg.matrix_transpose(x, conjugate=True)  # [[1 - 1j, 4 - 4j],
2394                                                 #  [2 - 2j, 5 - 5j],
2395                                                 #  [3 - 3j, 6 - 6j]]
2396
2397  # Matrix with two batch dimensions.
2398  # x.shape is [1, 2, 3, 4]
2399  # tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
2400  ```
2401
2402  Note that `tf.matmul` provides kwargs allowing for transpose of arguments.
2403  This is done with minimal cost, and is preferable to using this function. E.g.
2404
2405  ```python
2406  # Good!  Transpose is taken at minimal additional cost.
2407  tf.matmul(matrix, b, transpose_b=True)
2408
2409  # Inefficient!
2410  tf.matmul(matrix, tf.linalg.matrix_transpose(b))
2411  ```
2412
2413  @compatibility(numpy)
2414  In `numpy` transposes are memory-efficient constant time operations as they
2415  simply return a new view of the same data with adjusted `strides`.
2416
2417  TensorFlow does not support strides, `linalg.matrix_transpose` returns a new
2418  tensor with the items permuted.
2419  @end_compatibility
2420
2421  Args:
2422    a: A `Tensor` with `rank >= 2`.
2423    name: A name for the operation (optional).
2424    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2425      to tf.math.conj(tf.linalg.matrix_transpose(input)).
2426
2427  Returns:
2428    A transposed batch matrix `Tensor`.
2429
2430  Raises:
2431    ValueError:  If `a` is determined statically to have `rank < 2`.
2432  """
2433  with ops.name_scope(name, values=[a]):
2434    a = ops.convert_to_tensor(a, name="a")
2435
2436    # If we know the number of dimensions (statically), we can do two things:
2437    # 1. Check that `a` is a (batch) matrix.
2438    # 2. Use a Python list for perm.  This preserves static shape information
2439    #    and avoids extra computations.
2440    a_shape = a.get_shape()
2441    ndims = a_shape.ndims
2442    if ndims is not None:
2443      if ndims < 2:
2444        raise ValueError("Argument `a` should be a (batch) matrix with rank "
2445                         f">= 2.  Received `a` = {a} with shape: {a_shape}")
2446      perm = list(range(ndims - 2)) + [ndims - 1] + [ndims - 2]
2447    else:
2448      a_rank = rank(a)
2449      perm = concat(
2450          (gen_math_ops._range(0, a_rank - 2, 1), [a_rank - 1, a_rank - 2]), 0)
2451
2452    return transpose(a, perm=perm, conjugate=conjugate)
2453
2454
2455@tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"])
2456@dispatch.add_dispatch_support
2457@deprecation.deprecated_endpoints("matrix_diag")
2458def matrix_diag(diagonal,
2459                name="diag",
2460                k=0,
2461                num_rows=-1,
2462                num_cols=-1,
2463                padding_value=0,
2464                align="RIGHT_LEFT"):
2465  """Returns a batched diagonal tensor with given batched diagonal values.
2466
2467  Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
2468  diagonals of a matrix, with everything else padded with `padding`. `num_rows`
2469  and `num_cols` specify the dimension of the innermost matrix of the output. If
2470  both are not specified, the op assumes the innermost matrix is square and
2471  infers its size from `k` and the innermost dimension of `diagonal`. If only
2472  one of them is specified, the op assumes the unspecified value is the smallest
2473  possible based on other criteria.
2474
2475  Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor
2476  has rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only
2477  one diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has
2478  rank `r` with shape `[I, J, ..., L, num_rows, num_cols]`.
2479
2480  The second innermost dimension of `diagonal` has double meaning. When `k` is
2481  scalar or `k[0] == k[1]`, `M` is part of the batch size [I, J, ..., M], and
2482  the output tensor is:
2483
2484  ```
2485  output[i, j, ..., l, m, n]
2486    = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper
2487      padding_value                             ; otherwise
2488  ```
2489
2490  Otherwise, `M` is treated as the number of diagonals for the matrix in the
2491  same batch (`M = k[1]-k[0]+1`), and the output tensor is:
2492
2493  ```
2494  output[i, j, ..., l, m, n]
2495    = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2496      padding_value                                     ; otherwise
2497  ```
2498  where `d = n - m`, `diag_index = k[1] - d`, and
2499  `index_in_diag = n - max(d, 0) + offset`.
2500
2501  `offset` is zero except when the alignment of the diagonal is to the right.
2502  ```
2503  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2504                                             and `d >= 0`) or
2505                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2506                                             and `d <= 0`)
2507           0                          ; otherwise
2508  ```
2509  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2510
2511  For example:
2512
2513  ```
2514  # The main diagonal.
2515  diagonal = np.array([[1, 2, 3, 4],            # Input shape: (2, 4)
2516                       [5, 6, 7, 8]])
2517  tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0],  # Output shape: (2, 4, 4)
2518                                 [0, 2, 0, 0],
2519                                 [0, 0, 3, 0],
2520                                 [0, 0, 0, 4]],
2521                                [[5, 0, 0, 0],
2522                                 [0, 6, 0, 0],
2523                                 [0, 0, 7, 0],
2524                                 [0, 0, 0, 8]]]
2525
2526  # A superdiagonal (per batch).
2527  diagonal = np.array([[1, 2, 3],  # Input shape: (2, 3)
2528                       [4, 5, 6]])
2529  tf.matrix_diag(diagonal, k = 1)
2530    ==> [[[0, 1, 0, 0],  # Output shape: (2, 4, 4)
2531          [0, 0, 2, 0],
2532          [0, 0, 0, 3],
2533          [0, 0, 0, 0]],
2534         [[0, 4, 0, 0],
2535          [0, 0, 5, 0],
2536          [0, 0, 0, 6],
2537          [0, 0, 0, 0]]]
2538
2539  # A tridiagonal band (per batch).
2540  diagonals = np.array([[[8, 9, 0],  # Input shape: (2, 2, 3)
2541                         [1, 2, 3],
2542                         [0, 4, 5]],
2543                        [[2, 3, 0],
2544                         [6, 7, 9],
2545                         [0, 9, 1]]])
2546  tf.matrix_diag(diagonals, k = (-1, 1))
2547    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
2548          [4, 2, 9],
2549          [0, 5, 3]],
2550         [[6, 2, 0],
2551          [9, 7, 3],
2552          [0, 1, 9]]]
2553
2554  # RIGHT_LEFT alignment.
2555  diagonals = np.array([[[0, 8, 9],  # Input shape: (2, 2, 3)
2556                         [1, 2, 3],
2557                         [4, 5, 0]],
2558                        [[0, 2, 3],
2559                         [6, 7, 9],
2560                         [9, 1, 0]]])
2561  tf.matrix_diag(diagonals, k = (-1, 1), align="RIGHT_LEFT")
2562    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
2563          [4, 2, 9],
2564          [0, 5, 3]],
2565         [[6, 2, 0],
2566          [9, 7, 3],
2567          [0, 1, 9]]]
2568
2569  # Rectangular matrix.
2570  diagonal = np.array([1, 2])  # Input shape: (2)
2571  tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4)
2572    ==> [[0, 0, 0, 0],  # Output shape: (3, 4)
2573         [1, 0, 0, 0],
2574         [0, 2, 0, 0]]
2575
2576  # Rectangular matrix with inferred num_cols and padding_value = 9.
2577  tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
2578    ==> [[9, 9],  # Output shape: (3, 2)
2579         [1, 9],
2580         [9, 2]]
2581  ```
2582
2583  Args:
2584    diagonal: A `Tensor` with `rank k >= 1`.
2585    name: A name for the operation (optional).
2586    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2587      main diagonal, and negative value means subdiagonals. `k` can be a single
2588      integer (for a single diagonal) or a pair of integers specifying the low
2589      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2590    num_rows: The number of rows of the output matrix. If it is not provided,
2591      the op assumes the output matrix is a square matrix and infers the matrix
2592      size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2593    num_cols: The number of columns of the output matrix. If it is not provided,
2594      the op assumes the output matrix is a square matrix and infers the matrix
2595      size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2596    padding_value: The value to fill the area outside the specified diagonal
2597      band with. Default is 0.
2598    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2599      `align` is a string specifying how superdiagonals and subdiagonals should
2600      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2601      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2602      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2603      the left (right-pads the row). It is the packing format LAPACK uses.
2604      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2605
2606  Returns:
2607    A Tensor. Has the same type as `diagonal`.
2608  """
2609  # Special case to sidestep the tf.constant conversion error:
2610  # TypeError: Expected bool, got 0 of type 'int' instead.
2611  if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
2612    padding_value = bool(padding_value)
2613
2614  return gen_array_ops.matrix_diag_v3(
2615      diagonal=diagonal,
2616      k=k,
2617      num_rows=num_rows,
2618      num_cols=num_cols,
2619      padding_value=padding_value,
2620      align=align,
2621      name=name)
2622
2623
2624@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
2625@dispatch.add_dispatch_support
2626@deprecation.deprecated_endpoints("matrix_diag_part")
2627def matrix_diag_part(
2628    input,  # pylint:disable=redefined-builtin
2629    name="diag_part",
2630    k=0,
2631    padding_value=0,
2632    align="RIGHT_LEFT"):
2633  """Returns the batched diagonal part of a batched tensor.
2634
2635  Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
2636  `input`.
2637
2638  Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`.
2639  Let `max_diag_len` be the maximum length among all diagonals to be extracted,
2640  `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2641  Let `num_diags` be the number of diagonals to extract,
2642  `num_diags = k[1] - k[0] + 1`.
2643
2644  If `num_diags == 1`, the output tensor is of rank `r - 1` with shape
2645  `[I, J, ..., L, max_diag_len]` and values:
2646
2647  ```
2648  diagonal[i, j, ..., l, n]
2649    = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2650      padding_value                 ; otherwise.
2651  ```
2652  where `y = max(-k[1], 0)`, `x = max(k[1], 0)`.
2653
2654  Otherwise, the output tensor has rank `r` with dimensions
2655  `[I, J, ..., L, num_diags, max_diag_len]` with values:
2656
2657  ```
2658  diagonal[i, j, ..., l, m, n]
2659    = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2660      padding_value                 ; otherwise.
2661  ```
2662  where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
2663
2664  `offset` is zero except when the alignment of the diagonal is to the right.
2665  ```
2666  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2667                                             and `d >= 0`) or
2668                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2669                                             and `d <= 0`)
2670           0                          ; otherwise
2671  ```
2672  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2673
2674  The input must be at least a matrix.
2675
2676  For example:
2677
2678  ```
2679  input = np.array([[[1, 2, 3, 4],  # Input shape: (2, 3, 4)
2680                     [5, 6, 7, 8],
2681                     [9, 8, 7, 6]],
2682                    [[5, 4, 3, 2],
2683                     [1, 2, 3, 4],
2684                     [5, 6, 7, 8]]])
2685
2686  # A main diagonal from each batch.
2687  tf.linalg.diag_part(input) ==> [[1, 6, 7],  # Output shape: (2, 3)
2688                                  [5, 2, 7]]
2689
2690  # A superdiagonal from each batch.
2691  tf.linalg.diag_part(input, k = 1)
2692    ==> [[2, 7, 6],  # Output shape: (2, 3)
2693         [4, 3, 8]]
2694
2695  # A band from each batch.
2696  tf.linalg.diag_part(input, k = (-1, 2))
2697    ==> [[[3, 8, 0],  # Output shape: (2, 4, 3)
2698          [2, 7, 6],
2699          [1, 6, 7],
2700          [0, 5, 8]],
2701         [[3, 4, 0],
2702          [4, 3, 8],
2703          [5, 2, 7],
2704          [0, 1, 6]]]
2705
2706  # RIGHT_LEFT alignment.
2707  tf.linalg.diag_part(input, k = (-1, 2), align="RIGHT_LEFT")
2708    ==> [[[0, 3, 8],  # Output shape: (2, 4, 3)
2709          [2, 7, 6],
2710          [1, 6, 7],
2711          [5, 8, 0]],
2712         [[0, 3, 4],
2713          [4, 3, 8],
2714          [5, 2, 7],
2715          [1, 6, 0]]]
2716
2717  # max_diag_len can be shorter than the main diagonal.
2718  tf.linalg.diag_part(input, k = (-2, -1))
2719    ==> [[[5, 8],
2720          [0, 9]],
2721         [[1, 6],
2722          [0, 5]]]
2723
2724  # padding_value = 9
2725  tf.linalg.diag_part(input, k = (1, 3), padding_value = 9)
2726    ==> [[[4, 9, 9],  # Output shape: (2, 3, 3)
2727          [3, 8, 9],
2728          [2, 7, 6]],
2729         [[2, 9, 9],
2730          [3, 4, 9],
2731          [4, 3, 8]]]
2732
2733  ```
2734
2735  Args:
2736    input: A `Tensor` with `rank k >= 2`.
2737    name: A name for the operation (optional).
2738    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2739      main diagonal, and negative value means subdiagonals. `k` can be a single
2740      integer (for a single diagonal) or a pair of integers specifying the low
2741      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2742    padding_value: The value to fill the area outside the specified diagonal
2743      band with. Default is 0.
2744    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2745      `align` is a string specifying how superdiagonals and subdiagonals should
2746      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2747      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2748      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2749      the left (right-pads the row). It is the packing format LAPACK uses.
2750      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2751
2752  Returns:
2753    A Tensor containing diagonals of `input`. Has the same type as `input`.
2754
2755  Raises:
2756    InvalidArgumentError: When `k` is out of bound or when `k[0]>k[1:]`.
2757  """
2758  # Special case to sidestep the tf.constant conversion error:
2759  # TypeError: Expected bool, got 0 of type 'int' instead.
2760  if hasattr(input, "dtype") and input.dtype == "bool":
2761    padding_value = bool(padding_value)
2762
2763  return gen_array_ops.matrix_diag_part_v3(
2764      input=input, k=k, padding_value=padding_value, align=align, name=name)
2765
2766
2767@tf_export(
2768    "linalg.tensor_diag_part", v1=["linalg.tensor_diag_part", "diag_part"])
2769@dispatch.add_dispatch_support
2770@deprecation.deprecated_endpoints("diag_part")
2771def tensor_diag_part(
2772    input,  # pylint:disable=redefined-builtin
2773    name=None):
2774  """Returns the diagonal part of the tensor.
2775
2776  This operation returns a tensor with the `diagonal` part
2777  of the `input`. The `diagonal` part is computed as follows:
2778
2779  Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a
2780  tensor of rank `k` with dimensions `[D1,..., Dk]` where:
2781
2782  `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`.
2783
2784  For a rank 2 tensor, `linalg.diag_part` and `linalg.tensor_diag_part`
2785  produce the same result. For rank 3 and higher, linalg.diag_part extracts
2786  the diagonal of each inner-most matrix in the tensor. An example where
2787  they differ is given below.
2788
2789  >>> x = [[[[1111,1112],[1121,1122]],
2790  ...       [[1211,1212],[1221,1222]]],
2791  ...      [[[2111, 2112], [2121, 2122]],
2792  ...       [[2211, 2212], [2221, 2222]]]
2793  ...      ]
2794  >>> tf.linalg.tensor_diag_part(x)
2795  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2796  array([[1111, 1212],
2797         [2121, 2222]], dtype=int32)>
2798  >>> tf.linalg.diag_part(x).shape
2799  TensorShape([2, 2, 2])
2800
2801  Args:
2802    input: A `Tensor` with rank `2k`.
2803    name: A name for the operation (optional).
2804
2805  Returns:
2806    A Tensor containing diagonals of `input`. Has the same type as `input`, and
2807    rank `k`.
2808  """
2809  return gen_array_ops.diag_part(input=input, name=name)
2810
2811
2812@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
2813@dispatch.add_dispatch_support
2814@deprecation.deprecated_endpoints("matrix_set_diag")
2815def matrix_set_diag(
2816    input,  # pylint:disable=redefined-builtin
2817    diagonal,
2818    name="set_diag",
2819    k=0,
2820    align="RIGHT_LEFT"):
2821  """Returns a batched matrix tensor with new batched diagonal values.
2822
2823  Given `input` and `diagonal`, this operation returns a tensor with the
2824  same shape and values as `input`, except for the specified diagonals of the
2825  innermost matrices. These will be overwritten by the values in `diagonal`.
2826
2827  `input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
2828  `k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
2829  Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
2830  `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
2831  `max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
2832  `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2833
2834  The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
2835  If `k` is scalar or `k[0] == k[1]`:
2836
2837  ```
2838  output[i, j, ..., l, m, n]
2839    = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
2840      input[i, j, ..., l, m, n]              ; otherwise
2841  ```
2842
2843  Otherwise,
2844
2845  ```
2846  output[i, j, ..., l, m, n]
2847    = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2848      input[i, j, ..., l, m, n]                         ; otherwise
2849  ```
2850  where `d = n - m`, `diag_index = k[1] - d`, and
2851  `index_in_diag = n - max(d, 0) + offset`.
2852
2853  `offset` is zero except when the alignment of the diagonal is to the right.
2854  ```
2855  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2856                                             and `d >= 0`) or
2857                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2858                                             and `d <= 0`)
2859           0                          ; otherwise
2860  ```
2861  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2862
2863  For example:
2864
2865  ```
2866  # The main diagonal.
2867  input = np.array([[[7, 7, 7, 7],              # Input shape: (2, 3, 4)
2868                     [7, 7, 7, 7],
2869                     [7, 7, 7, 7]],
2870                    [[7, 7, 7, 7],
2871                     [7, 7, 7, 7],
2872                     [7, 7, 7, 7]]])
2873  diagonal = np.array([[1, 2, 3],               # Diagonal shape: (2, 3)
2874                       [4, 5, 6]])
2875  tf.matrix_set_diag(input, diagonal)
2876    ==> [[[1, 7, 7, 7],  # Output shape: (2, 3, 4)
2877          [7, 2, 7, 7],
2878          [7, 7, 3, 7]],
2879         [[4, 7, 7, 7],
2880          [7, 5, 7, 7],
2881          [7, 7, 6, 7]]]
2882
2883  # A superdiagonal (per batch).
2884  tf.matrix_set_diag(input, diagonal, k = 1)
2885    ==> [[[7, 1, 7, 7],  # Output shape: (2, 3, 4)
2886          [7, 7, 2, 7],
2887          [7, 7, 7, 3]],
2888         [[7, 4, 7, 7],
2889          [7, 7, 5, 7],
2890          [7, 7, 7, 6]]]
2891
2892  # A band of diagonals.
2893  diagonals = np.array([[[9, 1, 0],  # Diagonal shape: (2, 4, 3)
2894                         [6, 5, 8],
2895                         [1, 2, 3],
2896                         [0, 4, 5]],
2897                        [[1, 2, 0],
2898                         [5, 6, 4],
2899                         [6, 1, 2],
2900                         [0, 3, 4]]])
2901  tf.matrix_set_diag(input, diagonals, k = (-1, 2))
2902    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
2903          [4, 2, 5, 1],
2904          [7, 5, 3, 8]],
2905         [[6, 5, 1, 7],
2906          [3, 1, 6, 2],
2907          [7, 4, 2, 4]]]
2908
2909  # RIGHT_LEFT alignment.
2910  diagonals = np.array([[[0, 9, 1],  # Diagonal shape: (2, 4, 3)
2911                         [6, 5, 8],
2912                         [1, 2, 3],
2913                         [4, 5, 0]],
2914                        [[0, 1, 2],
2915                         [5, 6, 4],
2916                         [6, 1, 2],
2917                         [3, 4, 0]]])
2918  tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="RIGHT_LEFT")
2919    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
2920          [4, 2, 5, 1],
2921          [7, 5, 3, 8]],
2922         [[6, 5, 1, 7],
2923          [3, 1, 6, 2],
2924          [7, 4, 2, 4]]]
2925
2926  ```
2927
2928  Args:
2929    input: A `Tensor` with rank `k + 1`, where `k >= 1`.
2930    diagonal:  A `Tensor` with rank `k`, when `d_lower == d_upper`, or `k + 1`,
2931      otherwise. `k >= 1`.
2932    name: A name for the operation (optional).
2933    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2934      main diagonal, and negative value means subdiagonals. `k` can be a single
2935      integer (for a single diagonal) or a pair of integers specifying the low
2936      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2937    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2938      `align` is a string specifying how superdiagonals and subdiagonals should
2939      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2940      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2941      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2942      the left (right-pads the row). It is the packing format LAPACK uses.
2943      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2944  """
2945  return gen_array_ops.matrix_set_diag_v3(
2946      input=input, diagonal=diagonal, k=k, align=align, name=name)
2947
2948
2949# pylint: enable=invalid-name
2950
2951
2952def _constant_if_small(value, shape, dtype, name):
2953  try:
2954    if np.prod(shape) < 1000:
2955      return constant(value, shape=shape, dtype=dtype, name=name)
2956  except (NotImplementedError, TypeError):
2957    # Happens when shape is a Tensor, list with Tensor elements, etc.
2958    pass
2959  return None
2960
2961
2962def _tag_zeros_tensor(fun):
2963  """ Tags the result of function by setting _is_zeros_tensor attribute.
2964
2965  This is useful to compute Hessians of fused ops such as cross_entropy.
2966  """
2967
2968  def wrapped(*args, **kwargs):
2969    tensor = fun(*args, **kwargs)
2970    tensor._is_zeros_tensor = True
2971    return tensor
2972
2973  return tf_decorator.make_decorator(fun, wrapped)
2974
2975
2976@tf_export("zeros")
2977@dispatch.add_dispatch_support
2978@_tag_zeros_tensor
2979def zeros(shape, dtype=dtypes.float32, name=None):
2980  """Creates a tensor with all elements set to zero.
2981
2982  See also `tf.zeros_like`, `tf.ones`, `tf.fill`, `tf.eye`.
2983
2984  This operation returns a tensor of type `dtype` with shape `shape` and
2985  all elements set to zero.
2986
2987  >>> tf.zeros([3, 4], tf.int32)
2988  <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
2989  array([[0, 0, 0, 0],
2990         [0, 0, 0, 0],
2991         [0, 0, 0, 0]], dtype=int32)>
2992
2993  Args:
2994    shape: A `list` of integers, a `tuple` of integers, or
2995      a 1-D `Tensor` of type `int32`.
2996    dtype: The DType of an element in the resulting `Tensor`.
2997    name: Optional string. A name for the operation.
2998
2999  Returns:
3000    A `Tensor` with all elements set to zero.
3001  """
3002  dtype = dtypes.as_dtype(dtype).base_dtype
3003  with ops.name_scope(name, "zeros", [shape]) as name:
3004    if dtype == dtypes.bool:
3005      zero = False
3006    elif dtype == dtypes.string:
3007      zero = ""
3008    elif dtype.is_quantized:
3009      zero = np.zeros([]).astype(dtype.as_numpy_dtype)
3010    else:
3011      zero = 0
3012
3013    if not isinstance(shape, ops.Tensor):
3014      try:
3015        if not context.executing_eagerly():
3016          # Create a constant if it won't be very big. Otherwise create a fill
3017          # op to prevent serialized GraphDefs from becoming too large.
3018          output = _constant_if_small(zero, shape, dtype, name)
3019          if output is not None:
3020            return output
3021
3022        # Go through tensor shapes to get int64-if-needed semantics
3023        shape = constant_op._tensor_shape_tensor_conversion_function(
3024            tensor_shape.TensorShape(shape))
3025      except (TypeError, ValueError, errors.UnimplementedError):
3026        # Happens when shape is a list with tensor elements
3027        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
3028    if not shape._shape_tuple():
3029      shape = reshape(shape, [-1])  # Ensure it's a vector
3030    output = fill(shape, constant(zero, dtype=dtype), name=name)
3031  assert output.dtype.base_dtype == dtype
3032  return output
3033
3034
3035@tf_export(v1=["zeros_like"])
3036@dispatch.register_unary_elementwise_api
3037@dispatch.add_dispatch_support
3038def zeros_like(tensor, dtype=None, name=None, optimize=True):
3039  """Creates a tensor with all elements set to zero.
3040
3041  See also `tf.zeros`.
3042
3043  Given a single tensor (`tensor`), this operation returns a tensor of the
3044  same type and shape as `tensor` with all elements set to zero. Optionally,
3045  you can use `dtype` to specify a new type for the returned tensor.
3046
3047  Examples:
3048
3049    >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3050    >>> tf.zeros_like(tensor)
3051    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3052    array([[0, 0, 0],
3053           [0, 0, 0]], dtype=int32)>
3054
3055    >>> tf.zeros_like(tensor, dtype=tf.float32)
3056    <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
3057    array([[0., 0., 0.],
3058           [0., 0., 0.]], dtype=float32)>
3059
3060  Args:
3061    tensor: A `Tensor`.
3062    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3063      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3064      `complex64`, `complex128`, `bool` or `string`. (optional)
3065    name: A name for the operation (optional).
3066    optimize: if `True`, attempt to statically determine the shape of `tensor`
3067      and encode it as a constant. (optional, defaults to `True`)
3068
3069  Returns:
3070    A `Tensor` with all elements set to zero.
3071  """
3072  return zeros_like_impl(tensor, dtype, name, optimize)
3073
3074
3075@tf_export("zeros_like", v1=[])
3076@dispatch.register_unary_elementwise_api
3077@dispatch.add_dispatch_support
3078def zeros_like_v2(
3079    input,  # pylint: disable=redefined-builtin
3080    dtype=None,
3081    name=None):
3082  """Creates a tensor with all elements set to zero.
3083
3084  See also `tf.zeros`.
3085
3086  Given a single tensor or array-like object (`input`), this operation returns
3087  a tensor of the same type and shape as `input` with all elements set to zero.
3088  Optionally, you can use `dtype` to specify a new type for the returned tensor.
3089
3090  Examples:
3091
3092    >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3093    >>> tf.zeros_like(tensor)
3094    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3095    array([[0, 0, 0],
3096           [0, 0, 0]], dtype=int32)>
3097
3098    >>> tf.zeros_like(tensor, dtype=tf.float32)
3099    <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
3100    array([[0., 0., 0.],
3101           [0., 0., 0.]], dtype=float32)>
3102
3103    >>> tf.zeros_like([[1, 2, 3], [4, 5, 6]])
3104    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3105    array([[0, 0, 0],
3106           [0, 0, 0]], dtype=int32)>
3107
3108  Args:
3109    input: A `Tensor` or array-like object.
3110    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3111      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3112      `complex64`, `complex128`, `bool` or `string` (optional).
3113    name: A name for the operation (optional).
3114
3115  Returns:
3116    A `Tensor` with all elements set to zero.
3117  """
3118  return zeros_like_impl(input, dtype, name, optimize=True)
3119
3120
3121@_tag_zeros_tensor
3122def zeros_like_impl(tensor, dtype, name, optimize=True):
3123  """Internal implementation for the v1/v2 zeros_like API calls."""
3124  with ops.name_scope(name, "zeros_like", [tensor]) as name:
3125    if not tensor_util.is_tf_type(tensor):
3126      tensor = ops.convert_to_tensor(tensor, name="tensor")
3127    tensor_shape = tensor.shape
3128    tensor_dtype = tensor.dtype
3129
3130    if context.executing_eagerly():
3131      if dtype is not None and dtype != tensor_dtype:
3132        return zeros(
3133            shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
3134      return gen_array_ops.zeros_like(tensor, name=name)
3135
3136    # For now, variant types must be created via zeros_like; as we need to
3137    # pass the input variant object to the proper zeros callback.
3138
3139    if (optimize and tensor_shape.is_fully_defined() and
3140        tensor_dtype != dtypes.variant):
3141      # We can produce a zeros tensor independent of the value of 'tensor',
3142      # since the shape is known statically.
3143      return zeros(tensor_shape, dtype=dtype or tensor_dtype, name=name)
3144
3145    if dtype is not None and dtype != tensor_dtype and dtype != dtypes.variant:
3146      return zeros(
3147          shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
3148    else:
3149      return gen_array_ops.zeros_like(tensor, name=name)
3150
3151
3152@tf_export(v1=["ones_like"])
3153@dispatch.register_unary_elementwise_api
3154@dispatch.add_dispatch_support
3155def ones_like(tensor, dtype=None, name=None, optimize=True):
3156  """Creates a tensor with all elements set to 1.
3157
3158  See also `tf.ones`.
3159
3160  Given a single tensor (`tensor`), this operation returns a tensor of the same
3161  type and shape as `tensor` with all elements set to 1. Optionally, you can
3162  specify a new type (`dtype`) for the returned tensor.
3163
3164  For example:
3165
3166  ```python
3167  tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3168  tf.ones_like(tensor)  # [[1, 1, 1], [1, 1, 1]]
3169  ```
3170
3171  Args:
3172    tensor: A `Tensor`.
3173    dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
3174      `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, `complex64`,
3175      `complex128` or `bool`.
3176    name: A name for the operation (optional).
3177    optimize: if true, attempt to statically determine the shape of 'tensor' and
3178      encode it as a constant.
3179
3180  Returns:
3181    A `Tensor` with all elements set to 1.
3182  """
3183  return ones_like_impl(tensor, dtype, name, optimize)
3184
3185
3186@tf_export("ones_like", v1=[])
3187@dispatch.register_unary_elementwise_api
3188@dispatch.add_dispatch_support
3189def ones_like_v2(
3190    input,  # pylint: disable=redefined-builtin
3191    dtype=None,
3192    name=None):
3193  """Creates a tensor of all ones that has the same shape as the input.
3194
3195  See also `tf.ones`.
3196
3197  Given a single tensor (`tensor`), this operation returns a tensor of the
3198  same type and shape as `tensor` with all elements set to 1. Optionally,
3199  you can use `dtype` to specify a new type for the returned tensor.
3200
3201  For example:
3202
3203  >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3204  >>> tf.ones_like(tensor)
3205  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3206    array([[1, 1, 1],
3207           [1, 1, 1]], dtype=int32)>
3208
3209  Args:
3210    input: A `Tensor`.
3211    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3212      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3213      `complex64`, `complex128`, `bool` or `string`.
3214    name: A name for the operation (optional).
3215
3216  Returns:
3217    A `Tensor` with all elements set to one.
3218  """
3219  return ones_like_impl(input, dtype, name, optimize=True)
3220
3221
3222def ones_like_impl(tensor, dtype, name, optimize=True):
3223  """Internal implementation for the v1/v2 ones_like API calls."""
3224  with ops.name_scope(name, "ones_like", [tensor]) as name:
3225    tensor = ops.convert_to_tensor(tensor, name="tensor")
3226    ones_shape = shape_internal(tensor, optimize=optimize)
3227    if dtype is None:
3228      dtype = tensor.dtype
3229    ret = ones(ones_shape, dtype=dtype, name=name)
3230    if not context.executing_eagerly():
3231      ret.set_shape(tensor.get_shape())
3232    return ret
3233
3234
3235@tf_export("ones")
3236@dispatch.add_dispatch_support
3237def ones(shape, dtype=dtypes.float32, name=None):
3238  """Creates a tensor with all elements set to one (1).
3239
3240  See also `tf.ones_like`, `tf.zeros`, `tf.fill`, `tf.eye`.
3241
3242  This operation returns a tensor of type `dtype` with shape `shape` and
3243  all elements set to one.
3244
3245  >>> tf.ones([3, 4], tf.int32)
3246  <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3247  array([[1, 1, 1, 1],
3248         [1, 1, 1, 1],
3249         [1, 1, 1, 1]], dtype=int32)>
3250
3251  Args:
3252    shape: A `list` of integers, a `tuple` of integers, or
3253      a 1-D `Tensor` of type `int32`.
3254    dtype: Optional DType of an element in the resulting `Tensor`. Default is
3255      `tf.float32`.
3256    name: Optional string. A name for the operation.
3257
3258  Returns:
3259    A `Tensor` with all elements set to one (1).
3260  """
3261  dtype = dtypes.as_dtype(dtype).base_dtype
3262  with ops.name_scope(name, "ones", [shape]) as name:
3263    if dtype == dtypes.bool:
3264      one = True
3265    elif dtype.is_quantized:
3266      one = np.ones([]).astype(dtype.as_numpy_dtype)
3267    else:
3268      one = 1
3269    if not isinstance(shape, ops.Tensor):
3270      try:
3271        if not context.executing_eagerly():
3272          # Create a constant if it won't be very big. Otherwise create a fill
3273          # op to prevent serialized GraphDefs from becoming too large.
3274          output = _constant_if_small(one, shape, dtype, name)
3275          if output is not None:
3276            return output
3277
3278        # Go through tensor shapes to get int64-if-needed semantics
3279        shape = constant_op._tensor_shape_tensor_conversion_function(
3280            tensor_shape.TensorShape(shape))
3281      except (TypeError, ValueError):
3282        # Happens when shape is a list with tensor elements
3283        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
3284    if not shape._shape_tuple():
3285      shape = reshape(shape, [-1])  # Ensure it's a vector
3286    output = fill(shape, constant(one, dtype=dtype), name=name)
3287  assert output.dtype.base_dtype == dtype
3288  return output
3289
3290
3291@tf_export(v1=["placeholder"])
3292def placeholder(dtype, shape=None, name=None):
3293  """Inserts a placeholder for a tensor that will be always fed.
3294
3295  **Important**: This tensor will produce an error if evaluated. Its value must
3296  be fed using the `feed_dict` optional argument to `Session.run()`,
3297  `Tensor.eval()`, or `Operation.run()`.
3298
3299  For example:
3300
3301  ```python
3302  x = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024))
3303  y = tf.matmul(x, x)
3304
3305  with tf.compat.v1.Session() as sess:
3306    print(sess.run(y))  # ERROR: will fail because x was not fed.
3307
3308    rand_array = np.random.rand(1024, 1024)
3309    print(sess.run(y, feed_dict={x: rand_array}))  # Will succeed.
3310  ```
3311
3312  Args:
3313    dtype: The type of elements in the tensor to be fed.
3314    shape: The shape of the tensor to be fed (optional). If the shape is not
3315      specified, you can feed a tensor of any shape.
3316    name: A name for the operation (optional).
3317
3318  Returns:
3319    A `Tensor` that may be used as a handle for feeding a value, but not
3320    evaluated directly.
3321
3322  Raises:
3323    RuntimeError: if eager execution is enabled
3324
3325  @compatibility(TF2)
3326  This API is not compatible with eager execution and `tf.function`. To migrate
3327  to TF2, rewrite the code to be compatible with eager execution. Check the
3328  [migration
3329  guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
3330  on replacing `Session.run` calls. In TF2, you can just pass tensors directly
3331  into ops and layers. If you want to explicitly set up your inputs, also see
3332  [Keras functional API](https://www.tensorflow.org/guide/keras/functional) on
3333  how to use `tf.keras.Input` to replace `tf.compat.v1.placeholder`.
3334  `tf.function` arguments also do the job of `tf.compat.v1.placeholder`.
3335  For more details please read [Better
3336  performance with tf.function](https://www.tensorflow.org/guide/function).
3337  @end_compatibility
3338  """
3339  if context.executing_eagerly():
3340    raise RuntimeError("tf.placeholder() is not compatible with "
3341                       "eager execution.")
3342
3343  return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
3344
3345
3346@tf_export(v1=["placeholder_with_default"])
3347def placeholder_with_default(input, shape, name=None):  # pylint: disable=redefined-builtin
3348  """A placeholder op that passes through `input` when its output is not fed.
3349
3350  @compatibility(TF2)
3351  This API is strongly discouraged for use with eager execution and
3352  `tf.function`. The primary use of this API is for testing computation wrapped
3353  within a `tf.function` where the input tensors might not have statically known
3354  fully-defined shapes. The same can be achieved by creating a
3355  [concrete function](
3356  https://www.tensorflow.org/guide/function#obtaining_concrete_functions)
3357  from the `tf.function` with a `tf.TensorSpec` input which has partially
3358  defined shapes. For example, the code
3359
3360  >>> @tf.function
3361  ... def f():
3362  ...   x = tf.compat.v1.placeholder_with_default(
3363  ...       tf.constant([[1., 2., 3.], [4., 5., 6.]]), [None, 3])
3364  ...   y = tf.constant([[1.],[2.], [3.]])
3365  ...   z = tf.matmul(x, y)
3366  ...   assert z.shape[0] == None
3367  ...   assert z.shape[1] == 1
3368
3369  >>> f()
3370
3371  can easily be replaced by
3372
3373  >>> @tf.function
3374  ... def f(x):
3375  ...   y = tf.constant([[1.],[2.], [3.]])
3376  ...   z = tf.matmul(x, y)
3377  ...   assert z.shape[0] == None
3378  ...   assert z.shape[1] == 1
3379
3380  >>> g = f.get_concrete_function(tf.TensorSpec([None, 3]))
3381
3382  You can learn more about `tf.function` at [Better
3383  performance with tf.function](https://www.tensorflow.org/guide/function).
3384  @end_compatibility
3385
3386  Args:
3387    input: A `Tensor`. The default value to produce when output is not fed.
3388    shape: A `tf.TensorShape` or list of `int`s. The (possibly partial) shape of
3389      the tensor.
3390    name: A name for the operation (optional).
3391
3392  Returns:
3393    A `Tensor`. Has the same type as `input`.
3394  """
3395  return gen_array_ops.placeholder_with_default(input, shape, name)
3396
3397
3398@tf_export(v1=["sparse.placeholder", "sparse_placeholder"])
3399@deprecation.deprecated_endpoints("sparse_placeholder")
3400def sparse_placeholder(dtype, shape=None, name=None):
3401  """Inserts a placeholder for a sparse tensor that will be always fed.
3402
3403  **Important**: This sparse tensor will produce an error if evaluated.
3404  Its value must be fed using the `feed_dict` optional argument to
3405  `Session.run()`, `Tensor.eval()`, or `Operation.run()`.
3406
3407  For example:
3408
3409  ```python
3410  x = tf.compat.v1.sparse.placeholder(tf.float32)
3411  y = tf.sparse.reduce_sum(x)
3412
3413  with tf.compat.v1.Session() as sess:
3414    print(sess.run(y))  # ERROR: will fail because x was not fed.
3415
3416    indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)
3417    values = np.array([1.0, 2.0], dtype=np.float32)
3418    shape = np.array([7, 9, 2], dtype=np.int64)
3419    print(sess.run(y, feed_dict={
3420      x: tf.compat.v1.SparseTensorValue(indices, values, shape)}))  # Will
3421      succeed.
3422    print(sess.run(y, feed_dict={
3423      x: (indices, values, shape)}))  # Will succeed.
3424
3425    sp = tf.sparse.SparseTensor(indices=indices, values=values,
3426                                dense_shape=shape)
3427    sp_value = sp.eval(session=sess)
3428    print(sess.run(y, feed_dict={x: sp_value}))  # Will succeed.
3429  ```
3430
3431
3432  Args:
3433    dtype: The type of `values` elements in the tensor to be fed.
3434    shape: The shape of the tensor to be fed (optional). If the shape is not
3435      specified, you can feed a sparse tensor of any shape.
3436    name: A name for prefixing the operations (optional).
3437
3438  Returns:
3439    A `SparseTensor` that may be used as a handle for feeding a value, but not
3440    evaluated directly.
3441
3442  Raises:
3443    RuntimeError: if eager execution is enabled
3444
3445  @compatibility(TF2)
3446  This API is not compatible with eager execution and `tf.function`. To migrate
3447  to TF2, rewrite the code to be compatible with eager execution. Check the
3448  [migration
3449  guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
3450  on replacing `Session.run` calls. In TF2, you can just pass tensors directly
3451  into ops and layers. If you want to explicitly set up your inputs, also see
3452  [Keras functional API](https://www.tensorflow.org/guide/keras/functional) on
3453  how to use `tf.keras.Input` to replace `tf.compat.v1.sparse_placeholder`.
3454  `tf.function` arguments also do the job of `tf.compat.v1.sparse_placeholder`.
3455  For more details please read [Better
3456  performance with tf.function](https://www.tensorflow.org/guide/function).
3457  @end_compatibility
3458  """
3459  if context.executing_eagerly():
3460    raise RuntimeError("`sparse_placeholder` is not compatible with "
3461                       "eager execution.")
3462
3463  shape_name = (name + "/shape") if name is not None else None
3464  default_shape_name = (name + "/shape_default") if name is not None else None
3465  if shape is None:
3466    rank = None
3467    dense_shape = placeholder(dtypes.int64, shape=[rank], name=shape_name)
3468    dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
3469  else:
3470    if isinstance(shape, ops.Tensor):
3471      rank = shape.get_shape()[0]
3472      dense_shape_default = tensor_util.constant_value_as_shape(shape)
3473    else:
3474      rank = len(shape)
3475      # determine the shape, to override the `.shape` property of the
3476      # `SparseTensor`
3477      dense_shape_default = tensor_shape.TensorShape(
3478          tuple(None if dim == -1 else dim for dim in shape))
3479      shape = tuple(tensor_shape.dimension_value(dim) for dim in shape)
3480      shape = tuple(-1 if dim is None else dim for dim in shape)
3481      shape = ops.convert_to_tensor(
3482          shape, dtype=dtypes.int64, name=default_shape_name)
3483
3484    # `dense_shape` needs to be feedable (for users that treat this as an
3485    # actual placeholder). `constant_value_as_shape` sets constants to
3486    # not-feedable. placeholder_with_default works, but blocks `SparseTensor`
3487    # from reading the default value back out.
3488    dense_shape = placeholder_with_default(
3489        shape, shape=shape.shape, name=shape_name)
3490
3491  result = sparse_tensor.SparseTensor(
3492      values=placeholder(
3493          dtype,
3494          shape=[None],
3495          name=(name + "/values") if name is not None else None),
3496      indices=placeholder(
3497          dtypes.int64,
3498          shape=[None, rank],
3499          name=(name + "/indices") if name is not None else None),
3500      dense_shape=dense_shape)
3501
3502  # Now the SparseTensor.shape is a list of `None`s, since it couldn't read the
3503  # default shape out of the placeholder. Override that
3504  # shape to be the value determined here, so partial shapes can be
3505  # propagated.
3506  result._dense_shape_default = dense_shape_default
3507  return result
3508
3509# pylint: enable=redefined-outer-name
3510
3511
3512@tf_export("pad", v1=[])
3513@dispatch.add_dispatch_support
3514def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
3515  """Pads a tensor.
3516
3517  This operation pads a `tensor` according to the `paddings` you specify.
3518  `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3519  `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3520  many values to add before the contents of `tensor` in that dimension, and
3521  `paddings[D, 1]` indicates how many values to add after the contents of
3522  `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3523  and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3524  `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3525  no greater than `tensor.dim_size(D)`.
3526
3527  The padded size of each dimension D of the output is:
3528
3529  `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3530
3531  For example:
3532
3533  ```python
3534  t = tf.constant([[1, 2, 3], [4, 5, 6]])
3535  paddings = tf.constant([[1, 1,], [2, 2]])
3536  # 'constant_values' is 0.
3537  # rank of 't' is 2.
3538  tf.pad(t, paddings, "CONSTANT")  # [[0, 0, 0, 0, 0, 0, 0],
3539                                   #  [0, 0, 1, 2, 3, 0, 0],
3540                                   #  [0, 0, 4, 5, 6, 0, 0],
3541                                   #  [0, 0, 0, 0, 0, 0, 0]]
3542
3543  tf.pad(t, paddings, "REFLECT")  # [[6, 5, 4, 5, 6, 5, 4],
3544                                  #  [3, 2, 1, 2, 3, 2, 1],
3545                                  #  [6, 5, 4, 5, 6, 5, 4],
3546                                  #  [3, 2, 1, 2, 3, 2, 1]]
3547
3548  tf.pad(t, paddings, "SYMMETRIC")  # [[2, 1, 1, 2, 3, 3, 2],
3549                                    #  [2, 1, 1, 2, 3, 3, 2],
3550                                    #  [5, 4, 4, 5, 6, 6, 5],
3551                                    #  [5, 4, 4, 5, 6, 6, 5]]
3552  ```
3553
3554  Args:
3555    tensor: A `Tensor`.
3556    paddings: A `Tensor` of type `int32`.
3557    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3558    constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3559      same type as `tensor`.
3560    name: A name for the operation (optional).
3561
3562  Returns:
3563    A `Tensor`. Has the same type as `tensor`.
3564
3565  Raises:
3566    ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3567  """
3568  return pad(tensor, paddings, mode, name, constant_values)
3569
3570
3571@tf_export(v1=["pad"])
3572@dispatch.add_dispatch_support
3573def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0):  # pylint: disable=invalid-name
3574  """Pads a tensor.
3575
3576  This operation pads a `tensor` according to the `paddings` you specify.
3577  `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3578  `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3579  many values to add before the contents of `tensor` in that dimension, and
3580  `paddings[D, 1]` indicates how many values to add after the contents of
3581  `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3582  and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3583  `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3584  no greater than `tensor.dim_size(D)`.
3585
3586  The padded size of each dimension D of the output is:
3587
3588  `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3589
3590  For example:
3591
3592  ```python
3593  t = tf.constant([[1, 2, 3], [4, 5, 6]])
3594  paddings = tf.constant([[1, 1,], [2, 2]])
3595  # 'constant_values' is 0.
3596  # rank of 't' is 2.
3597  tf.pad(t, paddings, "CONSTANT")  # [[0, 0, 0, 0, 0, 0, 0],
3598                                   #  [0, 0, 1, 2, 3, 0, 0],
3599                                   #  [0, 0, 4, 5, 6, 0, 0],
3600                                   #  [0, 0, 0, 0, 0, 0, 0]]
3601
3602  tf.pad(t, paddings, "REFLECT")  # [[6, 5, 4, 5, 6, 5, 4],
3603                                  #  [3, 2, 1, 2, 3, 2, 1],
3604                                  #  [6, 5, 4, 5, 6, 5, 4],
3605                                  #  [3, 2, 1, 2, 3, 2, 1]]
3606
3607  tf.pad(t, paddings, "SYMMETRIC")  # [[2, 1, 1, 2, 3, 3, 2],
3608                                    #  [2, 1, 1, 2, 3, 3, 2],
3609                                    #  [5, 4, 4, 5, 6, 6, 5],
3610                                    #  [5, 4, 4, 5, 6, 6, 5]]
3611  ```
3612
3613  Args:
3614    tensor: A `Tensor`.
3615    paddings: A `Tensor` of type `int32`.
3616    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3617    name: A name for the operation (optional).
3618    constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3619      same type as `tensor`.
3620
3621  Returns:
3622    A `Tensor`. Has the same type as `tensor`.
3623
3624  Raises:
3625    ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3626  """
3627
3628  # Convert lower/mixed case to upper for NumPy compatibility
3629  # NumPy uses all lower-case modes.
3630  mode = mode.upper()
3631  if mode == "CONSTANT":
3632    # TODO(rjryan): Once the forward compatibility period (3 weeks) have passed
3633    # remove the "Pad" fallback here.
3634    if (not tensor_util.is_tf_type(constant_values) and
3635        np.ndim(constant_values) == 0 and
3636        constant_values == np.zeros_like(constant_values)):
3637      result = gen_array_ops.pad(tensor, paddings, name=name)
3638    else:
3639      result = gen_array_ops.pad_v2(
3640          tensor, paddings, constant_values, name=name)
3641  elif mode == "REFLECT":
3642    result = gen_array_ops.mirror_pad(
3643        tensor, paddings, mode="REFLECT", name=name)
3644  elif mode == "SYMMETRIC":
3645    result = gen_array_ops.mirror_pad(
3646        tensor, paddings, mode="SYMMETRIC", name=name)
3647  else:
3648    raise ValueError("Value of argument `mode` expected to be "
3649                     """one of "CONSTANT", "REFLECT", or "SYMMETRIC". """
3650                     f"Received `mode` = {mode}")
3651
3652  # Restore shape information where possible.
3653  if not context.executing_eagerly():
3654    paddings_constant = _get_paddings_constant(paddings)
3655    input_shape = (
3656        tensor_shape.TensorShape(tensor.shape)
3657        if isinstance(tensor, ops.Tensor) else result.op.inputs[0].shape)
3658    if (input_shape.ndims is not None and
3659        not result.shape.is_fully_defined() and paddings_constant is not None):
3660      new_shape = []
3661      for padding, dim in zip(paddings_constant, input_shape.as_list()):
3662        if padding is None or dim is None or any((x is None for x in padding)):
3663          new_shape.append(None)
3664        else:
3665          new_shape.append(sum(padding) + dim)
3666      result.set_shape(new_shape)
3667
3668  return result
3669
3670
3671def _get_paddings_constant(paddings):
3672  """Helper to get the constant values of the paddings arg to pad().
3673
3674  Used under V1 graph mode to facilitate computation of the shape of the output
3675  tensor of `pad()`.
3676
3677  Args:
3678    paddings: The same paddings arg as passed to pad(). Can be a Tensor, or
3679      a nested list or tuple of Tensor and/or numbers.
3680
3681  Returns:
3682    A nested list or numbers or `None`, in which `None` indicates unknown
3683    padding size.
3684  """
3685  if isinstance(paddings, ops.Tensor):
3686    return tensor_util.constant_value(paddings, partial=True)
3687  elif isinstance(paddings, (list, tuple)):
3688    return [_get_paddings_constant(x) for x in paddings]
3689  else:
3690    return paddings
3691
3692
3693@tf_export("meshgrid")
3694@dispatch.add_dispatch_support
3695def meshgrid(*args, **kwargs):
3696  """Broadcasts parameters for evaluation on an N-D grid.
3697
3698  Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
3699  of N-D coordinate arrays for evaluating expressions on an N-D grid.
3700
3701  Notes:
3702
3703  `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
3704  When the `indexing` argument is set to 'xy' (the default), the broadcasting
3705  instructions for the first two dimensions are swapped.
3706
3707  Examples:
3708
3709  Calling `X, Y = meshgrid(x, y)` with the tensors
3710
3711  ```python
3712  x = [1, 2, 3]
3713  y = [4, 5, 6]
3714  X, Y = tf.meshgrid(x, y)
3715  # X = [[1, 2, 3],
3716  #      [1, 2, 3],
3717  #      [1, 2, 3]]
3718  # Y = [[4, 4, 4],
3719  #      [5, 5, 5],
3720  #      [6, 6, 6]]
3721  ```
3722
3723  Args:
3724    *args: `Tensor`s with rank 1.
3725    **kwargs:
3726      - indexing: Either 'xy' or 'ij' (optional, default: 'xy').
3727      - name: A name for the operation (optional).
3728
3729  Returns:
3730    outputs: A list of N `Tensor`s with rank N.
3731
3732  Raises:
3733    TypeError: When no keyword arguments (kwargs) are passed.
3734    ValueError: When indexing keyword argument is not one of `xy` or `ij`.
3735  """
3736
3737  indexing = kwargs.pop("indexing", "xy")
3738  name = kwargs.pop("name", "meshgrid")
3739  if kwargs:
3740    key = list(kwargs.keys())[0]
3741    raise TypeError("'{}' is an invalid keyword argument "
3742                    "for this function".format(key))
3743
3744  if indexing not in ("xy", "ij"):
3745    raise ValueError("Argument `indexing` parameter must be either "
3746                     f"'xy' or 'ij', got '{indexing}'")
3747
3748  with ops.name_scope(name, "meshgrid", args) as name:
3749    ndim = len(args)
3750    s0 = (1,) * ndim
3751
3752    if not ndim:
3753      return []
3754
3755    # Prepare reshape by inserting dimensions with size 1 where needed
3756    output = []
3757    for i, x in enumerate(args):
3758      output.append(reshape(stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
3759    # Create parameters for broadcasting each tensor to the full size
3760    shapes = [size(x) for x in args]
3761
3762    output_dtype = ops.convert_to_tensor(args[0]).dtype.base_dtype
3763
3764    if indexing == "xy" and ndim > 1:
3765      output[0] = reshape(output[0], (1, -1) + (1,) * (ndim - 2))
3766      output[1] = reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
3767      shapes[0], shapes[1] = shapes[1], shapes[0]
3768
3769    # TODO(nolivia): improve performance with a broadcast
3770    mult_fact = ones(shapes, output_dtype)
3771    return [x * mult_fact for x in output]
3772
3773
3774NEW_AXIS = -1
3775SHRINK_AXIS = -2
3776
3777
3778# PEP-8 naming
3779# pylint: disable=invalid-name,redefined-outer-name
3780def _compute_size_of_strided_dim(shrink, spec, size):
3781  """Computes the size of a single strided slice dimension."""
3782
3783  unknown = None  # Document what None means here.
3784  use_full_range = None  # Document other use of None.
3785  # if this is a shrink axis (i.e. a non-range index)
3786  # it either will produce an error or return 1
3787  if shrink:
3788    return 1
3789  if size is unknown or size.value is unknown:
3790    return unknown
3791  size = size.value
3792  stride = spec.step
3793  if stride is not unknown:
3794    if stride == 0:
3795      return unknown
3796    stride = spec.step
3797    valid_range = [0, size] if stride > 0 else [-1, size - 1]
3798
3799    # PEP-8 naming
3800    # pylint: disable=invalid-name
3801    def canonical(x, c):
3802      if x is use_full_range:
3803        return valid_range[c] if stride > 0 else valid_range[(c + 1) & 1]
3804      else:
3805        x_fwd = size + x if x < 0 else x  # make negative indices positive
3806        return max(valid_range[0], min(valid_range[1], x_fwd))
3807
3808    begin = canonical(spec.start, 0)
3809    end = canonical(spec.stop, 1)
3810    interval_length = end - begin
3811    if interval_length == 0 or ((interval_length < 0) != (stride < 0)):
3812      return 0
3813    else:
3814      remainder = 1 if interval_length % stride != 0 else 0
3815      return interval_length // stride + remainder
3816  else:
3817    return unknown  # unknown because stride is unknown
3818
3819
3820def _TileGradShape(op):
3821  """Shape function for the TileGrad op."""
3822  multiples_shape = op.inputs[1].get_shape().with_rank(1)
3823  input_shape = op.inputs[0].get_shape().with_rank(multiples_shape[0])
3824  # NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
3825  # it is a vector of non-negative integers, and (ii) doing so allows
3826  # us to handle partially-known multiples.
3827  multiples = tensor_util.constant_value_as_shape(op.inputs[1]).with_rank(
3828      input_shape.ndims)
3829  if multiples.ndims is None:
3830    return [tensor_shape.unknown_shape()]
3831  else:
3832    output_dims = []
3833    for dim, multiple in zip(input_shape.dims, multiples.dims):
3834      output_dims.append(dim // multiple)
3835    return [tensor_shape.TensorShape(output_dims)]
3836
3837
3838@tf_export("edit_distance")
3839@dispatch.add_dispatch_support
3840def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
3841  """Computes the Levenshtein distance between sequences.
3842
3843  This operation takes variable-length sequences (`hypothesis` and `truth`),
3844  each provided as a `SparseTensor`, and computes the Levenshtein distance.
3845  You can normalize the edit distance by length of `truth` by setting
3846  `normalize` to true.
3847
3848  For example:
3849
3850  Given the following input,
3851  * `hypothesis` is a `tf.SparseTensor` of shape `[2, 1, 1]`
3852  * `truth` is a `tf.SparseTensor` of shape `[2, 2, 2]`
3853
3854  >>> hypothesis = tf.SparseTensor(
3855  ...   [[0, 0, 0],
3856  ...    [1, 0, 0]],
3857  ...   ["a", "b"],
3858  ...   (2, 1, 1))
3859  >>> truth = tf.SparseTensor(
3860  ...   [[0, 1, 0],
3861  ...    [1, 0, 0],
3862  ...    [1, 0, 1],
3863  ...    [1, 1, 0]],
3864  ...    ["a", "b", "c", "a"],
3865  ...    (2, 2, 2))
3866  >>> tf.edit_distance(hypothesis, truth, normalize=True)
3867  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3868  array([[inf, 1. ],
3869         [0.5, 1. ]], dtype=float32)>
3870
3871  The operation returns a dense Tensor of shape `[2, 2]` with
3872  edit distances normalized by `truth` lengths.
3873
3874  **Note**: It is possible to calculate edit distance between two
3875  sparse tensors with variable-length values. However, attempting to create
3876  them while eager execution is enabled will result in a `ValueError`.
3877
3878  For the following  inputs,
3879
3880  ```python
3881  # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values:
3882  #   (0,0) = ["a"]
3883  #   (1,0) = ["b"]
3884  hypothesis = tf.sparse.SparseTensor(
3885      [[0, 0, 0],
3886       [1, 0, 0]],
3887      ["a", "b"],
3888      (2, 1, 1))
3889
3890  # 'truth' is a tensor of shape `[2, 2]` with variable-length values:
3891  #   (0,0) = []
3892  #   (0,1) = ["a"]
3893  #   (1,0) = ["b", "c"]
3894  #   (1,1) = ["a"]
3895  truth = tf.sparse.SparseTensor(
3896      [[0, 1, 0],
3897       [1, 0, 0],
3898       [1, 0, 1],
3899       [1, 1, 0]],
3900      ["a", "b", "c", "a"],
3901      (2, 2, 2))
3902
3903  normalize = True
3904
3905  # The output would be a dense Tensor of shape `(2,)`, with edit distances
3906  normalized by 'truth' lengths.
3907  # output => array([0., 0.5], dtype=float32)
3908  ```
3909
3910  Args:
3911    hypothesis: A `SparseTensor` containing hypothesis sequences.
3912    truth: A `SparseTensor` containing truth sequences.
3913    normalize: A `bool`. If `True`, normalizes the Levenshtein distance by
3914      length of `truth.`
3915    name: A name for the operation (optional).
3916
3917  Returns:
3918    A dense `Tensor` with rank `R - 1`, where R is the rank of the
3919    `SparseTensor` inputs `hypothesis` and `truth`.
3920
3921  Raises:
3922    TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
3923  """
3924  if not isinstance(
3925      hypothesis,
3926      (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3927    raise TypeError("Hypothesis must be a SparseTensor.")
3928  if not isinstance(
3929      truth, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3930    raise TypeError("Truth must be a SparseTensor.")
3931
3932  return gen_array_ops.edit_distance(
3933      hypothesis.indices,
3934      hypothesis.values,
3935      hypothesis.dense_shape,
3936      truth.indices,
3937      truth.values,
3938      truth.dense_shape,
3939      normalize=normalize,
3940      name=name)
3941
3942
3943@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
3944def _FakeQuantWithMinMaxArgsGradient(op, grad):
3945  """Gradient for FakeQuantWithMinMaxArgs op."""
3946  return fake_quant_with_min_max_args_gradient(
3947      grad,
3948      op.inputs[0],
3949      min=op.get_attr("min"),
3950      max=op.get_attr("max"),
3951      num_bits=op.get_attr("num_bits"),
3952      narrow_range=op.get_attr("narrow_range"))
3953
3954
3955@ops.RegisterGradient("FakeQuantWithMinMaxVars")
3956def _FakeQuantWithMinMaxVarsGradient(op, grad):
3957  """Gradient for FakeQuantWithMinMaxVars op."""
3958  return fake_quant_with_min_max_vars_gradient(
3959      grad,
3960      op.inputs[0],
3961      op.inputs[1],
3962      op.inputs[2],
3963      num_bits=op.get_attr("num_bits"),
3964      narrow_range=op.get_attr("narrow_range"))
3965
3966
3967@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
3968def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
3969  """Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
3970  return fake_quant_with_min_max_vars_per_channel_gradient(
3971      grad,
3972      op.inputs[0],
3973      op.inputs[1],
3974      op.inputs[2],
3975      num_bits=op.get_attr("num_bits"),
3976      narrow_range=op.get_attr("narrow_range"))
3977
3978
3979@ops.RegisterGradient("QuantizeAndDequantizeV4")
3980def _QuantizeAndDequantizeV4Grad(op, grad):
3981  """Gradient for QuantizeAndDequantizeV4 op."""
3982  return quantize_and_dequantize_v4_grad(
3983      grad,
3984      op.inputs[0],
3985      op.inputs[1],
3986      op.inputs[2],
3987      axis=op.get_attr("axis"))
3988
3989
3990@ops.RegisterGradient("QuantizeAndDequantizeV4Grad")
3991def _QuantizeAndDequantizeV4GradGrad(op, grad):
3992  """Gradient for QuantizeAndDequantizeV4Grad op."""
3993  return _QuantizeAndDequantizeV4Grad(op, grad)
3994
3995
3996@tf_export("required_space_to_batch_paddings")
3997def required_space_to_batch_paddings(input_shape,
3998                                     block_shape,
3999                                     base_paddings=None,
4000                                     name=None):
4001  """Calculate padding required to make block_shape divide input_shape.
4002
4003  This function can be used to calculate a suitable paddings argument for use
4004  with space_to_batch_nd and batch_to_space_nd.
4005
4006  Args:
4007    input_shape: int32 Tensor of shape [N].
4008    block_shape: int32 Tensor of shape [N].
4009    base_paddings: Optional int32 Tensor of shape [N, 2].  Specifies the minimum
4010      amount of padding to use.  All elements must be >= 0.  If not specified,
4011      defaults to 0.
4012    name: string.  Optional name prefix.
4013
4014  Returns:
4015    (paddings, crops), where:
4016
4017    `paddings` and `crops` are int32 Tensors of rank 2 and shape [N, 2]
4018    satisfying:
4019
4020        paddings[i, 0] = base_paddings[i, 0].
4021        0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
4022        (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0
4023
4024        crops[i, 0] = 0
4025        crops[i, 1] = paddings[i, 1] - base_paddings[i, 1]
4026
4027  Raises: ValueError if called with incompatible shapes.
4028  """
4029  with ops.name_scope(name, "required_space_to_batch_paddings",
4030                      [input_shape, block_shape]):
4031    input_shape = ops.convert_to_tensor(
4032        input_shape, dtype=dtypes.int32, name="input_shape")
4033    block_shape = ops.convert_to_tensor(
4034        block_shape, dtype=dtypes.int32, name="block_shape")
4035
4036    block_shape.get_shape().assert_is_fully_defined()
4037    block_shape.get_shape().assert_has_rank(1)
4038    num_block_dims = block_shape.get_shape().dims[0].value
4039    if num_block_dims == 0:
4040      return zeros([0, 2], dtypes.int32), zeros([0, 2], dtypes.int32)
4041
4042    input_shape.get_shape().assert_is_compatible_with([num_block_dims])
4043
4044    if base_paddings is not None:
4045      base_paddings = ops.convert_to_tensor(
4046          base_paddings, dtype=dtypes.int32, name="base_paddings")
4047      base_paddings.get_shape().assert_is_compatible_with([num_block_dims, 2])
4048    else:
4049      base_paddings = zeros([num_block_dims, 2], dtypes.int32)
4050
4051    const_block_shape = tensor_util.constant_value(block_shape)
4052    const_input_shape = tensor_util.constant_value(input_shape)
4053    const_base_paddings = tensor_util.constant_value(base_paddings)
4054    if (const_block_shape is not None and const_input_shape is not None and
4055        const_base_paddings is not None):
4056      block_shape = const_block_shape
4057      input_shape = const_input_shape
4058      base_paddings = const_base_paddings
4059
4060    # Use same expression for both constant and non-constant case.
4061    pad_start = base_paddings[:, 0]
4062    orig_pad_end = base_paddings[:, 1]
4063    full_input_shape = input_shape + pad_start + orig_pad_end
4064    pad_end_extra = (block_shape - full_input_shape % block_shape) % block_shape
4065    pad_end = orig_pad_end + pad_end_extra
4066
4067    result_paddings = stack(
4068        [[pad_start[i], pad_end[i]] for i in range(num_block_dims)],
4069        name="paddings")
4070    result_crops = stack([[0, pad_end_extra[i]] for i in range(num_block_dims)],
4071                         name="crops")
4072    return result_paddings, result_crops
4073
4074
4075@tf_export(v1=["nn.space_to_batch", "space_to_batch"])
4076@dispatch.add_dispatch_support
4077@deprecation.deprecated_endpoints("space_to_batch")
4078def space_to_batch(  # pylint: disable=missing-docstring
4079    input,  # pylint: disable=redefined-builtin
4080    paddings,
4081    block_size=None,
4082    name=None,
4083    block_shape=None):  # pylint: disable=redefined-builtin
4084  block_size = deprecation.deprecated_argument_lookup("block_shape",
4085                                                      block_shape, "block_size",
4086                                                      block_size)
4087  result = space_to_batch_nd(
4088      input,
4089      paddings=paddings,
4090      block_shape=np.array([block_size, block_size], dtype=np.int64),
4091      name=name)
4092  result.set_shape(result.get_shape().with_rank(4))
4093  return result
4094
4095
4096space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
4097
4098
4099@tf_export("space_to_batch", "nn.space_to_batch", v1=[])
4100@dispatch.add_dispatch_support
4101def space_to_batch_v2(input, block_shape, paddings, name=None):  # pylint: disable=redefined-builtin
4102  return space_to_batch_nd(input, block_shape, paddings, name)
4103
4104
4105space_to_batch_v2.__doc__ = gen_array_ops.space_to_batch_nd.__doc__
4106
4107
4108@tf_export(v1=["nn.space_to_depth", "space_to_depth"])
4109@dispatch.add_dispatch_support
4110@deprecation.deprecated_endpoints("space_to_depth")
4111def space_to_depth(input, block_size, name=None, data_format="NHWC"):  # pylint: disable=redefined-builtin
4112  return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
4113
4114
4115space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
4116
4117
4118@tf_export("nn.space_to_depth", v1=[])
4119@dispatch.add_dispatch_support
4120def space_to_depth_v2(input, block_size, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4121  return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
4122
4123
4124space_to_depth_v2.__doc__ = gen_array_ops.space_to_depth.__doc__
4125
4126
4127@tf_export(v1=["nn.depth_to_space", "depth_to_space"])
4128@dispatch.add_dispatch_support
4129@deprecation.deprecated_endpoints("depth_to_space")
4130def depth_to_space(input, block_size, name=None, data_format="NHWC"):  # pylint: disable=redefined-builtin
4131  return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
4132
4133
4134depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
4135
4136
4137@tf_export("nn.depth_to_space", v1=[])
4138@dispatch.add_dispatch_support
4139def depth_to_space_v2(input, block_size, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4140  return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
4141
4142
4143depth_to_space_v2.__doc__ = gen_array_ops.depth_to_space.__doc__
4144
4145
4146@tf_export(v1=["batch_to_space"])
4147@dispatch.add_dispatch_support
4148def batch_to_space(input, crops, block_size, name=None, block_shape=None):  # pylint: disable=redefined-builtin,missing-docstring
4149  block_size = deprecation.deprecated_argument_lookup("block_shape",
4150                                                      block_shape, "block_size",
4151                                                      block_size)
4152  result = batch_to_space_nd(
4153      input,
4154      crops=crops,
4155      block_shape=np.array([block_size, block_size], dtype=np.int64),
4156      name=name)
4157  result.set_shape(result.get_shape().with_rank(4))
4158  return result
4159
4160
4161batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__
4162
4163
4164@tf_export("batch_to_space", v1=[])
4165@dispatch.add_dispatch_support
4166def batch_to_space_v2(input, block_shape, crops, name=None):  # pylint: disable=redefined-builtin
4167  """BatchToSpace for N-D tensors of type T.
4168
4169  This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
4170  shape `block_shape + [batch]`, interleaves these blocks back into the grid
4171  defined by the spatial dimensions `[1, ..., M]`, to obtain a result with the
4172  same rank as the input.  The spatial dimensions of this intermediate result
4173  are then optionally cropped according to `crops` to produce the output.  This
4174  is the reverse of SpaceToBatch (see `tf.space_to_batch`).
4175
4176  Args:
4177    input: A N-D `Tensor` with shape `input_shape = [batch] + spatial_shape +
4178      remaining_shape`, where `spatial_shape` has M dimensions.
4179    block_shape: A 1-D `Tensor` with shape [M]. Must be one of the following
4180      types: `int32`, `int64`. All values must be >= 1. For backwards
4181      compatibility with TF 1.0, this parameter may be an int, in which case it
4182      is converted to
4183      `numpy.array([block_shape, block_shape],
4184      dtype=numpy.int64)`.
4185    crops: A  2-D `Tensor` with shape `[M, 2]`. Must be one of the
4186      following types: `int32`, `int64`. All values must be >= 0.
4187      `crops[i] = [crop_start, crop_end]` specifies the amount to crop from
4188      input dimension `i + 1`, which corresponds to spatial dimension `i`.
4189      It is required that
4190      `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`.
4191      This operation is equivalent to the following steps:
4192      1. Reshape `input` to `reshaped` of shape: [block_shape[0], ...,
4193        block_shape[M-1], batch / prod(block_shape), input_shape[1], ...,
4194        input_shape[N-1]]
4195      2. Permute dimensions of `reshaped` to produce `permuted` of shape
4196         [batch / prod(block_shape),  input_shape[1], block_shape[0], ...,
4197         input_shape[M], block_shape[M-1], input_shape[M+1],
4198        ..., input_shape[N-1]]
4199      3. Reshape `permuted` to produce `reshaped_permuted` of shape
4200         [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
4201         input_shape[M] * block_shape[M-1], input_shape[M+1], ...,
4202         input_shape[N-1]]
4203      4. Crop the start and end of dimensions `[1, ..., M]` of
4204         `reshaped_permuted` according to `crops` to produce the output
4205         of shape:
4206         [batch / prod(block_shape),  input_shape[1] *
4207           block_shape[0] - crops[0,0] - crops[0,1], ..., input_shape[M] *
4208           block_shape[M-1] - crops[M-1,0] - crops[M-1,1],  input_shape[M+1],
4209           ..., input_shape[N-1]]
4210    name: A name for the operation (optional).
4211
4212  Examples:
4213
4214  1. For the following input of shape `[4, 1, 1, 1]`,
4215     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4216
4217     ```python
4218     [[[[1]]],
4219      [[[2]]],
4220      [[[3]]],
4221      [[[4]]]]
4222     ```
4223
4224    The output tensor has shape `[1, 2, 2, 1]` and value:
4225
4226     ```
4227     x = [[[[1], [2]],
4228         [[3], [4]]]]
4229     ```
4230
4231  2. For the following input of shape `[4, 1, 1, 3]`,
4232     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4233
4234     ```python
4235     [[[1,  2,   3]],
4236      [[4,  5,   6]],
4237      [[7,  8,   9]],
4238      [[10, 11, 12]]]
4239     ```
4240
4241    The output tensor has shape `[1, 2, 2, 3]` and value:
4242
4243    ```python
4244     x = [[[[1, 2, 3], [4,  5,  6 ]],
4245           [[7, 8, 9], [10, 11, 12]]]]
4246     ```
4247
4248  3. For the following
4249     input of shape `[4, 2, 2, 1]`,
4250     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4251
4252     ```python
4253     x = [[[[1], [3]], [[ 9], [11]]],
4254          [[[2], [4]], [[10], [12]]],
4255          [[[5], [7]], [[13], [15]]],
4256          [[[6], [8]], [[14], [16]]]]
4257     ```
4258
4259    The output tensor has shape `[1, 4, 4, 1]` and value:
4260
4261    ```python
4262     x = [[[1],  [2],  [ 3], [ 4]],
4263          [[5],  [6],  [ 7], [ 8]],
4264          [[9],  [10], [11], [12]],
4265          [[13], [14], [15], [16]]]
4266     ```
4267
4268  4. For the following input of shape
4269      `[8, 1, 3, 1]`,
4270      `block_shape = [2, 2]`, and `crops = [[0, 0], [2, 0]]`:
4271
4272      ```python
4273      x = [[[[0], [ 1], [ 3]]],
4274           [[[0], [ 9], [11]]],
4275           [[[0], [ 2], [ 4]]],
4276           [[[0], [10], [12]]],
4277           [[[0], [ 5], [ 7]]],
4278           [[[0], [13], [15]]],
4279           [[[0], [ 6], [ 8]]],
4280           [[[0], [14], [16]]]]
4281      ```
4282
4283      The output tensor has shape `[2, 2, 4, 1]` and value:
4284
4285      ```python
4286      x = [[[[ 1], [ 2], [ 3], [ 4]],
4287            [[ 5], [ 6], [ 7], [ 8]]],
4288           [[[ 9], [10], [11], [12]],
4289            [[13], [14], [15], [16]]]]
4290      ```
4291
4292  Returns:
4293    A `Tensor`. Has the same type as `input`.
4294  """
4295  if isinstance(block_shape, int):
4296    block_shape = np.array([block_shape, block_shape], dtype=np.int64)
4297
4298  return batch_to_space_nd(
4299      input=input, block_shape=block_shape, crops=crops, name=name)
4300
4301
4302@tf_export("one_hot")
4303@dispatch.add_dispatch_support
4304def one_hot(indices,
4305            depth,
4306            on_value=None,
4307            off_value=None,
4308            axis=None,
4309            dtype=None,
4310            name=None):
4311  """Returns a one-hot tensor.
4312
4313  See also `tf.fill`, `tf.eye`.
4314
4315  The locations represented by indices in `indices` take value `on_value`,
4316  while all other locations take value `off_value`.
4317
4318  `on_value` and `off_value` must have matching data types. If `dtype` is also
4319  provided, they must be the same data type as specified by `dtype`.
4320
4321  If `on_value` is not provided, it will default to the value `1` with type
4322  `dtype`
4323
4324  If `off_value` is not provided, it will default to the value `0` with type
4325  `dtype`
4326
4327  If the input `indices` is rank `N`, the output will have rank `N+1`. The
4328  new axis is created at dimension `axis` (default: the new axis is appended
4329  at the end).
4330
4331  If `indices` is a scalar the output shape will be a vector of length `depth`
4332
4333  If `indices` is a vector of length `features`, the output shape will be:
4334
4335  ```
4336    features x depth if axis == -1
4337    depth x features if axis == 0
4338  ```
4339
4340  If `indices` is a matrix (batch) with shape `[batch, features]`, the output
4341  shape will be:
4342
4343  ```
4344    batch x features x depth if axis == -1
4345    batch x depth x features if axis == 1
4346    depth x batch x features if axis == 0
4347  ```
4348
4349  If `indices` is a RaggedTensor, the 'axis' argument must be positive and refer
4350  to a non-ragged axis. The output will be equivalent to applying 'one_hot' on
4351  the values of the RaggedTensor, and creating a new RaggedTensor from the
4352  result.
4353
4354  If `dtype` is not provided, it will attempt to assume the data type of
4355  `on_value` or `off_value`, if one or both are passed in. If none of
4356  `on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
4357  value `tf.float32`.
4358
4359  Note: If a non-numeric data type output is desired (`tf.string`, `tf.bool`,
4360  etc.), both `on_value` and `off_value` _must_ be provided to `one_hot`.
4361
4362  For example:
4363
4364  ```python
4365  indices = [0, 1, 2]
4366  depth = 3
4367  tf.one_hot(indices, depth)  # output: [3 x 3]
4368  # [[1., 0., 0.],
4369  #  [0., 1., 0.],
4370  #  [0., 0., 1.]]
4371
4372  indices = [0, 2, -1, 1]
4373  depth = 3
4374  tf.one_hot(indices, depth,
4375             on_value=5.0, off_value=0.0,
4376             axis=-1)  # output: [4 x 3]
4377  # [[5.0, 0.0, 0.0],  # one_hot(0)
4378  #  [0.0, 0.0, 5.0],  # one_hot(2)
4379  #  [0.0, 0.0, 0.0],  # one_hot(-1)
4380  #  [0.0, 5.0, 0.0]]  # one_hot(1)
4381
4382  indices = [[0, 2], [1, -1]]
4383  depth = 3
4384  tf.one_hot(indices, depth,
4385             on_value=1.0, off_value=0.0,
4386             axis=-1)  # output: [2 x 2 x 3]
4387  # [[[1.0, 0.0, 0.0],   # one_hot(0)
4388  #   [0.0, 0.0, 1.0]],  # one_hot(2)
4389  #  [[0.0, 1.0, 0.0],   # one_hot(1)
4390  #   [0.0, 0.0, 0.0]]]  # one_hot(-1)
4391
4392  indices = tf.ragged.constant([[0, 1], [2]])
4393  depth = 3
4394  tf.one_hot(indices, depth)  # output: [2 x None x 3]
4395  # [[[1., 0., 0.],
4396  #   [0., 1., 0.]],
4397  #  [[0., 0., 1.]]]
4398  ```
4399
4400  Args:
4401    indices: A `Tensor` of indices.
4402    depth: A scalar defining the depth of the one hot dimension.
4403    on_value: A scalar defining the value to fill in output when `indices[j]
4404      = i`. (default: 1)
4405    off_value: A scalar defining the value to fill in output when `indices[j]
4406      != i`. (default: 0)
4407    axis: The axis to fill (default: -1, a new inner-most axis).
4408    dtype: The data type of the output tensor.
4409    name: A name for the operation (optional).
4410
4411  Returns:
4412    output: The one-hot tensor.
4413
4414  Raises:
4415    TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
4416    TypeError: If dtype of `on_value` and `off_value` don't match one another
4417  """
4418  with ops.name_scope(
4419      name, "one_hot",
4420      [indices, depth, on_value, off_value, axis, dtype]) as name:
4421    on_exists = on_value is not None
4422    off_exists = off_value is not None
4423
4424    if on_exists:
4425      on_value = ops.convert_to_tensor(on_value, dtype_hint=dtype)
4426    if off_exists:
4427      off_value = ops.convert_to_tensor(off_value, dtype_hint=dtype)
4428
4429    on_dtype = on_value.dtype.base_dtype if on_exists else None
4430    off_dtype = off_value.dtype.base_dtype if off_exists else None
4431
4432    if on_exists or off_exists:
4433      if dtype is not None:
4434        # Ensure provided on_value and/or off_value match dtype
4435        if on_exists and on_dtype != dtype:
4436          raise TypeError("dtype {0} of on_value does not match "
4437                          "dtype parameter {1}".format(on_dtype, dtype))
4438        if off_exists and off_dtype != dtype:
4439          raise TypeError("dtype {0} of off_value does not match "
4440                          "dtype parameter {1}".format(off_dtype, dtype))
4441      else:
4442        # dtype not provided: automatically assign it
4443        dtype = on_dtype if on_exists else off_dtype
4444    elif dtype is None:
4445      # None of on_value, off_value, or dtype provided. Default dtype to float32
4446      dtype = dtypes.float32
4447
4448    if not on_exists:
4449      # on_value not provided: assign to value 1 of type dtype
4450      on_value = ops.convert_to_tensor(1, dtype, name="on_value")
4451      on_dtype = dtype
4452    if not off_exists:
4453      # off_value not provided: assign to value 0 of type dtype
4454      off_value = ops.convert_to_tensor(0, dtype, name="off_value")
4455      off_dtype = dtype
4456
4457    if on_dtype != off_dtype:
4458      raise TypeError("dtype {0} of on_value does not match "
4459                      "dtype {1} of off_value".format(on_dtype, off_dtype))
4460
4461    return gen_array_ops.one_hot(indices, depth, on_value, off_value, axis,
4462                                 name)
4463
4464
4465def _all_dimensions(x):
4466  """Returns a 1D-tensor listing all dimensions in x."""
4467  # Fast path: avoid creating Rank and Range ops if ndims is known.
4468  if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
4469    return constant_op.constant(
4470        np.arange(x.get_shape().ndims), dtype=dtypes.int32)
4471  if (isinstance(x, sparse_tensor.SparseTensor) and
4472      x.dense_shape.get_shape().is_fully_defined()):
4473    r = x.dense_shape.get_shape().dims[0].value  # sparse.dense_shape is 1-D.
4474    return constant_op.constant(np.arange(r), dtype=dtypes.int32)
4475
4476  # Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
4477  return gen_math_ops._range(0, rank(x), 1)
4478
4479
4480@tf_export("sequence_mask")
4481@dispatch.add_dispatch_support
4482def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
4483  """Returns a mask tensor representing the first N positions of each cell.
4484
4485  If `lengths` has shape `[d_1, d_2, ..., d_n]` the resulting tensor `mask` has
4486  dtype `dtype` and shape `[d_1, d_2, ..., d_n, maxlen]`, with
4487
4488  ```
4489  mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
4490  ```
4491
4492  Examples:
4493
4494  ```python
4495  tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
4496                                  #  [True, True, True, False, False],
4497                                  #  [True, True, False, False, False]]
4498
4499  tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
4500                                    #   [True, True, True]],
4501                                    #  [[True, True, False],
4502                                    #   [False, False, False]]]
4503  ```
4504
4505  Args:
4506    lengths: integer tensor, all its values <= maxlen.
4507    maxlen: scalar integer tensor, size of last dimension of returned tensor.
4508      Default is the maximum value in `lengths`.
4509    dtype: output type of the resulting tensor.
4510    name: name of the op.
4511
4512  Returns:
4513    A mask tensor of shape `lengths.shape + (maxlen,)`, cast to specified dtype.
4514  Raises:
4515    ValueError: if `maxlen` is not a scalar.
4516  """
4517  with ops.name_scope(name, "SequenceMask", [lengths, maxlen]):
4518    lengths = ops.convert_to_tensor(lengths)
4519
4520    if maxlen is None:
4521      maxlen = gen_math_ops._max(lengths, _all_dimensions(lengths))
4522      maxlen = gen_math_ops.maximum(constant(0, maxlen.dtype), maxlen)
4523    else:
4524      maxlen = ops.convert_to_tensor(maxlen)
4525    if maxlen.get_shape().ndims is not None and maxlen.get_shape().ndims != 0:
4526      raise ValueError("Argument `maxlen` must be scalar for sequence_mask, "
4527                       f"received `maxlen` = {maxlen} "
4528                       f"with shape '{maxlen.get_shape()}' instead")
4529
4530    # The basic idea is to compare a range row vector of size maxlen:
4531    # [0, 1, 2, 3, 4]
4532    # to length as a matrix with 1 column: [[1], [3], [2]].
4533    # Because of broadcasting on both arguments this comparison results
4534    # in a matrix of size (len(lengths), maxlen)
4535    row_vector = gen_math_ops._range(
4536        constant(0, maxlen.dtype), maxlen, constant(1, maxlen.dtype))
4537    # Since maxlen >= max(lengths), it is safe to use maxlen as a cast
4538    # authoritative type. Whenever maxlen fits into tf.int32, so do the lengths.
4539    matrix = gen_math_ops.cast(expand_dims(lengths, -1), maxlen.dtype)
4540    result = row_vector < matrix
4541    if dtype is None or result.dtype.is_compatible_with(dtype):
4542      return result
4543    else:
4544      return gen_math_ops.cast(result, dtype)
4545
4546
4547@tf_export(v1=["squeeze"])
4548@dispatch.add_dispatch_support
4549@deprecation.deprecated_args(None, "Use the `axis` argument instead",
4550                             "squeeze_dims")
4551def squeeze(input, axis=None, name=None, squeeze_dims=None):
4552  # pylint: disable=redefined-builtin
4553  """Removes dimensions of size 1 from the shape of a tensor.
4554
4555  Given a tensor `input`, this operation returns a tensor of the same type with
4556  all dimensions of size 1 removed. If you don't want to remove all size 1
4557  dimensions, you can remove specific size 1 dimensions by specifying
4558  `axis`.
4559
4560  For example:
4561
4562  >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4563  >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4564  >>> print(tf.shape(tf.squeeze(t)).numpy())
4565  [2 3]
4566
4567  Or, to remove specific size 1 dimensions:
4568
4569  >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4570  >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4571  >>> print(tf.shape(tf.squeeze(t, [2, 4])).numpy())
4572  [1 2 3 1]
4573
4574  Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4575  time, where `N` is the number of elements in the squeezed dimensions.
4576
4577  Args:
4578    input: A `Tensor`. The `input` to squeeze.
4579    axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4580      squeezes the dimensions listed. The dimension index starts at 0. It is an
4581      error to squeeze a dimension that is not 1. Must be in the range
4582      `[-rank(input), rank(input))`. Must be specified if `input` is a
4583      `RaggedTensor`.
4584    name: A name for the operation (optional).
4585    squeeze_dims: Deprecated keyword argument that is now axis.
4586
4587  Returns:
4588    A `Tensor`. Has the same type as `input`.
4589    Contains the same data as `input`, but has one or more dimensions of
4590    size 1 removed.
4591
4592  Raises:
4593    ValueError: When both `squeeze_dims` and `axis` are specified.
4594  """
4595  axis = deprecation.deprecated_argument_lookup("axis", axis, "squeeze_dims",
4596                                                squeeze_dims)
4597  if np.isscalar(axis):
4598    axis = [axis]
4599  return gen_array_ops.squeeze(input, axis, name)
4600
4601
4602@tf_export("squeeze", v1=[])
4603@dispatch.add_dispatch_support
4604def squeeze_v2(input, axis=None, name=None):
4605  """Removes dimensions of size 1 from the shape of a tensor.
4606
4607  Given a tensor `input`, this operation returns a tensor of the same type with
4608  all dimensions of size 1 removed. If you don't want to remove all size 1
4609  dimensions, you can remove specific size 1 dimensions by specifying
4610  `axis`.
4611
4612  For example:
4613
4614  ```python
4615  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4616  tf.shape(tf.squeeze(t))  # [2, 3]
4617  ```
4618
4619  Or, to remove specific size 1 dimensions:
4620
4621  ```python
4622  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4623  tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]
4624  ```
4625
4626  Unlike the older op `tf.compat.v1.squeeze`, this op does not accept a
4627  deprecated `squeeze_dims` argument.
4628
4629  Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4630  time, where `N` is the number of elements in the squeezed dimensions.
4631
4632  Note: If squeeze is performed on dimensions of unknown sizes, then the
4633  returned Tensor will be of unknown shape. A common situation is when the
4634  first (batch) dimension is of size `None`, `tf.squeeze` returns
4635  `<unknown>` shape which may be a surprise. Specify the `axis=` argument
4636  to get the expected result, as illustrated in the following example:
4637
4638  ```python
4639  @tf.function
4640  def func(x):
4641    print('x.shape:', x.shape)
4642    known_axes = [i for i, size in enumerate(x.shape) if size == 1]
4643    y = tf.squeeze(x, axis=known_axes)
4644    print('shape of tf.squeeze(x, axis=known_axes):', y.shape)
4645    y = tf.squeeze(x)
4646    print('shape of tf.squeeze(x):', y.shape)
4647    return 0
4648
4649  _ = func.get_concrete_function(tf.TensorSpec([None, 1, 2], dtype=tf.int32))
4650  # Output is.
4651  # x.shape: (None, 1, 2)
4652  # shape of tf.squeeze(x, axis=known_axes): (None, 2)
4653  # shape of tf.squeeze(x): <unknown>
4654  ```
4655
4656  Args:
4657    input: A `Tensor`. The `input` to squeeze.
4658    axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4659      squeezes the dimensions listed. The dimension index starts at 0. It is an
4660      error to squeeze a dimension that is not 1. Must be in the range
4661      `[-rank(input), rank(input))`. Must be specified if `input` is a
4662      `RaggedTensor`.
4663    name: A name for the operation (optional).
4664
4665  Returns:
4666    A `Tensor`. Has the same type as `input`.
4667    Contains the same data as `input`, but has one or more dimensions of
4668    size 1 removed.
4669
4670  Raises:
4671    ValueError: The input cannot be converted to a tensor, or the specified
4672      axis cannot be squeezed.
4673  """
4674  # pylint: disable=redefined-builtin
4675  return squeeze(input, axis, name)
4676
4677
4678@tf_export(v1=["where"])
4679@dispatch.add_dispatch_support
4680def where(condition, x=None, y=None, name=None):
4681  """Return the elements, either from `x` or `y`, depending on the `condition`.
4682
4683  If both `x` and `y` are None, then this operation returns the coordinates of
4684  true elements of `condition`.  The coordinates are returned in a 2-D tensor
4685  where the first dimension (rows) represents the number of true elements, and
4686  the second dimension (columns) represents the coordinates of the true
4687  elements. Keep in mind, the shape of the output tensor can vary depending on
4688  how many true values there are in input. Indices are output in row-major
4689  order.
4690
4691  If both non-None, `x` and `y` must have the same shape.
4692  The `condition` tensor must be a scalar if `x` and `y` are scalar.
4693  If `x` and `y` are tensors of higher rank, then `condition` must be either a
4694  vector with size matching the first dimension of `x`, or must have the same
4695  shape as `x`.
4696
4697  The `condition` tensor acts as a mask that chooses, based on the value at each
4698  element, whether the corresponding element / row in the output should be taken
4699  from `x` (if true) or `y` (if false).
4700
4701  If `condition` is a vector and `x` and `y` are higher rank matrices, then it
4702  chooses which row (outer dimension) to copy from `x` and `y`. If `condition`
4703  has the same shape as `x` and `y`, then it chooses which element to copy from
4704  `x` and `y`.
4705
4706  Args:
4707    condition: A `Tensor` of type `bool`
4708    x: A Tensor which may have the same shape as `condition`. If `condition` is
4709      rank 1, `x` may have higher rank, but its first dimension must match the
4710      size of `condition`.
4711    y: A `tensor` with the same shape and type as `x`.
4712    name: A name of the operation (optional)
4713
4714  Returns:
4715    A `Tensor` with the same type and shape as `x`, `y` if they are non-None.
4716    Otherwise, a `Tensor` with shape `(num_true, rank(condition))`.
4717
4718  Raises:
4719    ValueError: When exactly one of `x` or `y` is non-None.
4720
4721  @compatibility(TF2)
4722
4723  This API is compatible with eager execution and `tf.function`. However, this
4724  is still a legacy API endpoint originally designed for TF1. To migrate to
4725  fully-native TF2, please replace its usage with `tf.where` instead, which is
4726  directly backwards compatible with `tf.compat.v1.where`.
4727
4728  However,`tf.compat.v1.where` is more restrictive than `tf.where`, requiring
4729  `x` and `y` to have the same shape, and returning a `Tensor` with the same
4730  type and shape as `x`, `y` (if they are both non-None).
4731
4732  `tf.where` will accept `x`, `y` that are not the same shape as long as they
4733  are broadcastable with one another and with `condition`, and will return a
4734  `Tensor` with shape broadcast from `condition`, `x`, and `y`.
4735
4736  For example, the following works with `tf.where` but not `tf.compat.v1.where`:
4737
4738  >>> tf.where([True, False, False, True], [1,2,3,4], [100])
4739  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
4740  dtype=int32)>
4741
4742  >>> tf.where(True, [1,2,3,4], 100)
4743  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
4744  dtype=int32)>
4745
4746  @end_compatibility
4747  """
4748  if x is None and y is None:
4749    with ops.name_scope(name, "Where", [condition]) as name:
4750      condition = ops.convert_to_tensor(
4751          condition, preferred_dtype=dtypes.bool, name="condition")
4752      return gen_array_ops.where(condition=condition, name=name)
4753  elif x is not None and y is not None:
4754    return gen_math_ops.select(condition=condition, x=x, y=y, name=name)
4755  else:
4756    raise ValueError("x and y must both be non-None or both be None.")
4757
4758
4759@tf_export("where", v1=["where_v2"])
4760@dispatch.add_dispatch_support
4761def where_v2(condition, x=None, y=None, name=None):
4762  """Returns the indices of non-zero elements, or multiplexes `x` and `y`.
4763
4764  This operation has two modes:
4765
4766  1. **Return the indices of non-zero elements** - When only
4767     `condition` is provided the result is an `int64` tensor where each row is
4768     the index of a non-zero element of `condition`. The result's shape
4769     is `[tf.math.count_nonzero(condition), tf.rank(condition)]`.
4770  2. **Multiplex `x` and `y`** - When both `x` and `y` are provided the
4771     result has the shape of `x`, `y`, and `condition` broadcast together. The
4772     result is taken from `x` where `condition` is non-zero
4773     or `y` where `condition` is zero.
4774
4775  #### 1. Return the indices of non-zero elements
4776
4777  Note: In this mode `condition` can have a dtype of `bool` or any numeric
4778  dtype.
4779
4780  If `x` and `y` are not provided (both are None):
4781
4782  `tf.where` will return the indices of `condition` that are non-zero,
4783  in the form of a 2-D tensor with shape `[n, d]`, where `n` is the number of
4784  non-zero elements in `condition` (`tf.count_nonzero(condition)`), and `d` is
4785  the number of axes of `condition` (`tf.rank(condition)`).
4786
4787  Indices are output in row-major order. The `condition` can have a `dtype` of
4788  `tf.bool`, or any numeric `dtype`.
4789
4790  Here `condition` is a 1-axis `bool` tensor with 2 `True` values. The result
4791  has a shape of `[2,1]`
4792
4793  >>> tf.where([True, False, False, True]).numpy()
4794  array([[0],
4795         [3]])
4796
4797  Here `condition` is a 2-axis integer tensor, with 3 non-zero values. The
4798  result has a shape of `[3, 2]`.
4799
4800  >>> tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
4801  array([[0, 0],
4802         [1, 0],
4803         [1, 2]])
4804
4805  Here `condition` is a 3-axis float tensor, with 5 non-zero values. The output
4806  shape is `[5, 3]`.
4807
4808  >>> float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
4809  ...                 [[0,   0], [0,   0], [99,    0]]]
4810  >>> tf.where(float_tensor).numpy()
4811  array([[0, 0, 0],
4812         [0, 1, 1],
4813         [0, 2, 0],
4814         [0, 2, 1],
4815         [1, 2, 0]])
4816
4817  These indices are the same that `tf.sparse.SparseTensor` would use to
4818  represent the condition tensor:
4819
4820  >>> sparse = tf.sparse.from_dense(float_tensor)
4821  >>> sparse.indices.numpy()
4822  array([[0, 0, 0],
4823         [0, 1, 1],
4824         [0, 2, 0],
4825         [0, 2, 1],
4826         [1, 2, 0]])
4827
4828  A complex number is considered non-zero if either the real or imaginary
4829  component is non-zero:
4830
4831  >>> tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
4832  array([[1],
4833         [2],
4834         [3]])
4835
4836  #### 2. Multiplex `x` and `y`
4837
4838  Note: In this mode `condition` must have a dtype of `bool`.
4839
4840  If `x` and `y` are also provided (both have non-None values) the `condition`
4841  tensor acts as a mask that chooses whether the corresponding
4842  element / row in the output should be taken from `x` (if the element in
4843  `condition` is `True`) or `y` (if it is `False`).
4844
4845  The shape of the result is formed by
4846  [broadcasting](https://docs.scipy.org/doc/numpy/reference/ufuncs.html)
4847  together the shapes of `condition`, `x`, and `y`.
4848
4849  When all three inputs have the same size, each is handled element-wise.
4850
4851  >>> tf.where([True, False, False, True],
4852  ...          [1, 2, 3, 4],
4853  ...          [100, 200, 300, 400]).numpy()
4854  array([  1, 200, 300,   4], dtype=int32)
4855
4856  There are two main rules for broadcasting:
4857
4858  1. If a tensor has fewer axes than the others, length-1 axes are added to the
4859     left of the shape.
4860  2. Axes with length-1 are streched to match the coresponding axes of the other
4861     tensors.
4862
4863  A length-1 vector is streched to match the other vectors:
4864
4865  >>> tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
4866  array([  1, 100, 100,   4], dtype=int32)
4867
4868  A scalar is expanded to match the other arguments:
4869
4870  >>> tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
4871  array([[  1, 100], [100,   4]], dtype=int32)
4872  >>> tf.where([[True, False], [False, True]], 1, 100).numpy()
4873  array([[  1, 100], [100,   1]], dtype=int32)
4874
4875  A scalar `condition` returns the complete `x` or `y` tensor, with
4876  broadcasting applied.
4877
4878  >>> tf.where(True, [1, 2, 3, 4], 100).numpy()
4879  array([1, 2, 3, 4], dtype=int32)
4880  >>> tf.where(False, [1, 2, 3, 4], 100).numpy()
4881  array([100, 100, 100, 100], dtype=int32)
4882
4883  For a non-trivial example of broadcasting, here `condition` has a shape of
4884  `[3]`, `x` has a shape of `[3,3]`, and `y` has a shape of `[3,1]`.
4885  Broadcasting first expands the shape of `condition` to `[1,3]`. The final
4886  broadcast shape is `[3,3]`. `condition` will select columns from `x` and `y`.
4887  Since `y` only has one column, all columns from `y` will be identical.
4888
4889  >>> tf.where([True, False, True],
4890  ...          x=[[1, 2, 3],
4891  ...             [4, 5, 6],
4892  ...             [7, 8, 9]],
4893  ...          y=[[100],
4894  ...             [200],
4895  ...             [300]]
4896  ... ).numpy()
4897  array([[ 1, 100, 3],
4898         [ 4, 200, 6],
4899         [ 7, 300, 9]], dtype=int32)
4900
4901  Note that if the gradient of either branch of the `tf.where` generates
4902  a `NaN`, then the gradient of the entire `tf.where` will be `NaN`. This is
4903  because the gradient calculation for `tf.where` combines the two branches, for
4904  performance reasons.
4905
4906  A workaround is to use an inner `tf.where` to ensure the function has
4907  no asymptote, and to avoid computing a value whose gradient is `NaN` by
4908  replacing dangerous inputs with safe inputs.
4909
4910  Instead of this,
4911
4912  >>> x = tf.constant(0., dtype=tf.float32)
4913  >>> with tf.GradientTape() as tape:
4914  ...   tape.watch(x)
4915  ...   y = tf.where(x < 1., 0., 1. / x)
4916  >>> print(tape.gradient(y, x))
4917  tf.Tensor(nan, shape=(), dtype=float32)
4918
4919  Although, the `1. / x` values are never used, its gradient is a `NaN` when
4920  `x = 0`. Instead, we should guard that with another `tf.where`
4921
4922  >>> x = tf.constant(0., dtype=tf.float32)
4923  >>> with tf.GradientTape() as tape:
4924  ...   tape.watch(x)
4925  ...   safe_x = tf.where(tf.equal(x, 0.), 1., x)
4926  ...   y = tf.where(x < 1., 0., 1. / safe_x)
4927  >>> print(tape.gradient(y, x))
4928  tf.Tensor(0.0, shape=(), dtype=float32)
4929
4930  See also:
4931
4932  * `tf.sparse` - The indices returned by the first form of `tf.where` can be
4933     useful in `tf.sparse.SparseTensor` objects.
4934  * `tf.gather_nd`, `tf.scatter_nd`, and related ops - Given the
4935    list of indices returned from `tf.where` the `scatter` and `gather` family
4936    of ops can be used fetch values or insert values at those indices.
4937  * `tf.strings.length` - `tf.string` is not an allowed dtype for the
4938    `condition`. Use the string length instead.
4939
4940  Args:
4941    condition: A `tf.Tensor` of dtype bool, or any numeric dtype. `condition`
4942      must have dtype `bool` when `x` and `y` are provided.
4943    x: If provided, a Tensor which is of the same type as `y`, and has a shape
4944      broadcastable with `condition` and `y`.
4945    y: If provided, a Tensor which is of the same type as `x`, and has a shape
4946      broadcastable with `condition` and `x`.
4947    name: A name of the operation (optional).
4948
4949  Returns:
4950    If `x` and `y` are provided:
4951      A `Tensor` with the same type as `x` and `y`, and shape that
4952      is broadcast from `condition`, `x`, and `y`.
4953    Otherwise, a `Tensor` with shape `[tf.math.count_nonzero(condition),
4954    tf.rank(condition)]`.
4955
4956  Raises:
4957    ValueError: When exactly one of `x` or `y` is non-None, or the shapes
4958      are not all broadcastable.
4959  """
4960  if x is None and y is None:
4961    with ops.name_scope(name, "Where", [condition]) as name:
4962      condition = ops.convert_to_tensor(
4963          condition, preferred_dtype=dtypes.bool, name="condition")
4964      return gen_array_ops.where(condition=condition, name=name)
4965  elif x is not None and y is not None:
4966    return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
4967  else:
4968    raise ValueError("x and y must both be non-None or both be None.")
4969
4970
4971# pylint: disable=redefined-builtin
4972@tf_export(v1=["reverse_sequence"])
4973@deprecation.deprecated_args(None,
4974                             "seq_dim is deprecated, use seq_axis instead",
4975                             "seq_dim")
4976@deprecation.deprecated_args(None,
4977                             "batch_dim is deprecated, use batch_axis instead",
4978                             "batch_dim")
4979def reverse_sequence(input,
4980                     seq_lengths,
4981                     seq_axis=None,
4982                     batch_axis=None,
4983                     name=None,
4984                     seq_dim=None,
4985                     batch_dim=None):
4986  """Reverses variable length slices.
4987
4988  This op first slices `input` along the dimension `batch_axis`, and for
4989  each slice `i`, reverses the first `seq_lengths[i]` elements along the
4990  dimension `seq_axis`.
4991
4992  The elements of `seq_lengths` must obey `seq_lengths[i] <=
4993  input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
4994  `input.dims[batch_axis]`.
4995
4996  The output slice `i` along dimension `batch_axis` is then given by
4997  input slice `i`, with the first `seq_lengths[i]` slices along
4998  dimension `seq_axis` reversed.
4999
5000  Example usage:
5001
5002  >>> seq_lengths = [7, 2, 3, 5]
5003  >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
5004  ...          [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
5005  >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
5006  >>> output
5007  <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
5008  array([[0, 0, 5, 4, 3, 2, 1, 0],
5009         [2, 1, 0, 0, 0, 0, 0, 0],
5010         [3, 2, 1, 4, 0, 0, 0, 0],
5011         [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
5012
5013  Args:
5014    input: A `Tensor`. The input to reverse.
5015    seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
5016      `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
5017      input.dims(seq_axis)`
5018    seq_axis: An `int`. The dimension which is partially reversed.
5019    batch_axis: An optional `int`. Defaults to `0`. The dimension along which
5020      reversal is performed.
5021    name: A name for the operation (optional).
5022
5023  Returns:
5024    A Tensor. Has the same type as input.
5025  """
5026  seq_axis = deprecation.deprecated_argument_lookup("seq_axis", seq_axis,
5027                                                    "seq_dim", seq_dim)
5028  batch_axis = deprecation.deprecated_argument_lookup("batch_axis", batch_axis,
5029                                                      "batch_dim", batch_dim)
5030  return gen_array_ops.reverse_sequence(
5031      input=input,
5032      seq_lengths=seq_lengths,
5033      seq_dim=seq_axis,
5034      batch_dim=batch_axis,
5035      name=name)
5036
5037
5038@tf_export("reverse_sequence", v1=[])
5039@dispatch.add_dispatch_support
5040def reverse_sequence_v2(input,
5041                        seq_lengths,
5042                        seq_axis=None,
5043                        batch_axis=None,
5044                        name=None):
5045  """Reverses variable length slices.
5046
5047  This op first slices `input` along the dimension `batch_axis`, and for
5048  each slice `i`, reverses the first `seq_lengths[i]` elements along the
5049  dimension `seq_axis`.
5050
5051  The elements of `seq_lengths` must obey `seq_lengths[i] <=
5052  input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
5053  `input.dims[batch_axis]`.
5054
5055  The output slice `i` along dimension `batch_axis` is then given by
5056  input slice `i`, with the first `seq_lengths[i]` slices along
5057  dimension `seq_axis` reversed.
5058
5059  Example usage:
5060
5061  >>> seq_lengths = [7, 2, 3, 5]
5062  >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
5063  ...          [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
5064  >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
5065  >>> output
5066  <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
5067  array([[0, 0, 5, 4, 3, 2, 1, 0],
5068         [2, 1, 0, 0, 0, 0, 0, 0],
5069         [3, 2, 1, 4, 0, 0, 0, 0],
5070         [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
5071
5072  Args:
5073    input: A `Tensor`. The input to reverse.
5074    seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
5075      `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
5076      input.dims(seq_axis)`
5077    seq_axis: An `int`. The dimension which is partially reversed.
5078    batch_axis: An optional `int`. Defaults to `0`. The dimension along which
5079      reversal is performed.
5080    name: A name for the operation (optional).
5081
5082  Returns:
5083    A Tensor. Has the same type as input.
5084  """
5085  return gen_array_ops.reverse_sequence(
5086      input=input,
5087      seq_lengths=seq_lengths,
5088      seq_dim=seq_axis,
5089      batch_dim=batch_axis,
5090      name=name)
5091
5092# pylint: enable=redefined-builtin
5093
5094
5095@tf_export(v1=["gather"])
5096@dispatch.add_dispatch_support
5097@deprecation.deprecated_args(None,
5098                             ("The `validate_indices` argument has no effect. "
5099                              "Indices are always validated on CPU and never "
5100                              "validated on GPU."),
5101                             ("validate_indices", None))
5102def gather(params,
5103           indices,
5104           validate_indices=None,
5105           name=None,
5106           axis=None,
5107           batch_dims=0):  # pylint: disable=g-doc-args
5108  r"""Gather slices from params axis `axis` according to indices.
5109
5110  Gather slices from `params` axis `axis` according to `indices`.  `indices`
5111  must be an integer tensor of any dimension (often 1-D).
5112
5113  `Tensor.__getitem__` works for scalars, `tf.newaxis`, and
5114  [python slices](https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing)
5115
5116  `tf.gather` extends indexing to handle tensors of indices.
5117
5118  In the simplest case it's identical to scalar indexing:
5119
5120  >>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
5121  >>> params[3].numpy()
5122  b'p3'
5123  >>> tf.gather(params, 3).numpy()
5124  b'p3'
5125
5126  The most common case is to pass a single axis tensor of indices (this
5127  can't be expressed as a python slice because the indices are not sequential):
5128
5129  >>> indices = [2, 0, 2, 5]
5130  >>> tf.gather(params, indices).numpy()
5131  array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
5132
5133  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
5134  <img style="width:100%" src="https://www.tensorflow.org/images/Gather.png"
5135  alt>
5136  </div>
5137
5138  The indices can have any shape. When the `params` has 1 axis, the
5139  output shape is equal to the input shape:
5140
5141  >>> tf.gather(params, [[2, 0], [2, 5]]).numpy()
5142  array([[b'p2', b'p0'],
5143         [b'p2', b'p5']], dtype=object)
5144
5145  The `params` may also have any shape. `gather` can select slices
5146  across any axis depending on the `axis` argument (which defaults to 0).
5147  Below it is used to gather first rows, then columns from a matrix:
5148
5149  >>> params = tf.constant([[0, 1.0, 2.0],
5150  ...                       [10.0, 11.0, 12.0],
5151  ...                       [20.0, 21.0, 22.0],
5152  ...                       [30.0, 31.0, 32.0]])
5153  >>> tf.gather(params, indices=[3,1]).numpy()
5154  array([[30., 31., 32.],
5155         [10., 11., 12.]], dtype=float32)
5156  >>> tf.gather(params, indices=[2,1], axis=1).numpy()
5157  array([[ 2.,  1.],
5158         [12., 11.],
5159         [22., 21.],
5160         [32., 31.]], dtype=float32)
5161
5162  More generally: The output shape has the same shape as the input, with the
5163  indexed-axis replaced by the shape of the indices.
5164
5165  >>> def result_shape(p_shape, i_shape, axis=0):
5166  ...   return p_shape[:axis] + i_shape + p_shape[axis+1:]
5167  >>>
5168  >>> result_shape([1, 2, 3], [], axis=1)
5169  [1, 3]
5170  >>> result_shape([1, 2, 3], [7], axis=1)
5171  [1, 7, 3]
5172  >>> result_shape([1, 2, 3], [7, 5], axis=1)
5173  [1, 7, 5, 3]
5174
5175  Here are some examples:
5176
5177  >>> params.shape.as_list()
5178  [4, 3]
5179  >>> indices = tf.constant([[0, 2]])
5180  >>> tf.gather(params, indices=indices, axis=0).shape.as_list()
5181  [1, 2, 3]
5182  >>> tf.gather(params, indices=indices, axis=1).shape.as_list()
5183  [4, 1, 2]
5184
5185  >>> params = tf.random.normal(shape=(5, 6, 7, 8))
5186  >>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
5187  >>> result = tf.gather(params, indices, axis=2)
5188  >>> result.shape.as_list()
5189  [5, 6, 10, 11, 8]
5190
5191  This is because each index takes a slice from `params`, and
5192  places it at the corresponding location in the output. For the above example
5193
5194  >>> # For any location in indices
5195  >>> a, b = 0, 1
5196  >>> tf.reduce_all(
5197  ...     # the corresponding slice of the result
5198  ...     result[:, :, a, b, :] ==
5199  ...     # is equal to the slice of `params` along `axis` at the index.
5200  ...     params[:, :, indices[a, b], :]
5201  ... ).numpy()
5202  True
5203
5204  ### Batching:
5205
5206  The `batch_dims` argument lets you gather different items from each element
5207  of a batch.
5208
5209  Using `batch_dims=1` is equivalent to having an outer loop over the first
5210  axis of `params` and `indices`:
5211
5212  >>> params = tf.constant([
5213  ...     [0, 0, 1, 0, 2],
5214  ...     [3, 0, 0, 0, 4],
5215  ...     [0, 5, 0, 6, 0]])
5216  >>> indices = tf.constant([
5217  ...     [2, 4],
5218  ...     [0, 4],
5219  ...     [1, 3]])
5220
5221  >>> tf.gather(params, indices, axis=1, batch_dims=1).numpy()
5222  array([[1, 2],
5223         [3, 4],
5224         [5, 6]], dtype=int32)
5225
5226  This is equivalent to:
5227
5228  >>> def manually_batched_gather(params, indices, axis):
5229  ...   batch_dims=1
5230  ...   result = []
5231  ...   for p,i in zip(params, indices):
5232  ...     r = tf.gather(p, i, axis=axis-batch_dims)
5233  ...     result.append(r)
5234  ...   return tf.stack(result)
5235  >>> manually_batched_gather(params, indices, axis=1).numpy()
5236  array([[1, 2],
5237         [3, 4],
5238         [5, 6]], dtype=int32)
5239
5240  Higher values of `batch_dims` are equivalent to multiple nested loops over
5241  the outer axes of `params` and `indices`. So the overall shape function is
5242
5243  >>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):
5244  ...   return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]
5245  >>>
5246  >>> batched_result_shape(
5247  ...     p_shape=params.shape.as_list(),
5248  ...     i_shape=indices.shape.as_list(),
5249  ...     axis=1,
5250  ...     batch_dims=1)
5251  [3, 2]
5252
5253  >>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()
5254  [3, 2]
5255
5256  This comes up naturally if you need to use the indices of an operation like
5257  `tf.argsort`, or `tf.math.top_k` where the last dimension of the indices
5258  indexes into the last dimension of input, at the corresponding location.
5259  In this case you can use `tf.gather(values, indices, batch_dims=-1)`.
5260
5261  See also:
5262
5263  * `tf.Tensor.__getitem__`: The direct tensor index operation (`t[]`), handles
5264    scalars and python-slices `tensor[..., 7, 1:-1]`
5265  * `tf.scatter`: A collection of operations similar to `__setitem__`
5266    (`t[i] = x`)
5267  * `tf.gather_nd`: An operation similar to `tf.gather` but gathers across
5268    multiple axis at once (it can gather elements of a matrix instead of rows
5269    or columns)
5270  * `tf.boolean_mask`, `tf.where`: Binary indexing.
5271  * `tf.slice` and `tf.strided_slice`: For lower level access to the
5272    implementation of `__getitem__`'s python-slice handling (`t[1:-1:2]`)
5273
5274  Args:
5275    params: The `Tensor` from which to gather values. Must be at least rank
5276      `axis + 1`.
5277    indices: The index `Tensor`.  Must be one of the following types: `int32`,
5278      `int64`. The values must be in range `[0, params.shape[axis])`.
5279    validate_indices: Deprecated, does nothing. Indices are always validated on
5280      CPU, never validated on GPU.
5281
5282      Caution: On CPU, if an out of bound index is found, an error is raised.
5283      On GPU, if an out of bound index is found, a 0 is stored in the
5284      corresponding output value.
5285    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5286      `axis` in `params` to gather `indices` from. Must be greater than or equal
5287      to `batch_dims`.  Defaults to the first non-batch dimension. Supports
5288      negative indexes.
5289    batch_dims: An `integer`.  The number of batch dimensions.  Must be less
5290      than or equal to `rank(indices)`.
5291    name: A name for the operation (optional).
5292
5293  Returns:
5294    A `Tensor`. Has the same type as `params`.
5295  """
5296  del validate_indices
5297
5298  if axis is None:
5299    axis = batch_dims
5300  if tensor_util.constant_value(axis) != 0:
5301    return gen_array_ops.gather_v2(
5302        params, indices, axis, batch_dims=batch_dims, name=name)
5303  try:
5304    # TODO(apassos) find a less bad way of detecting resource variables
5305    # without introducing a circular dependency.
5306    return params.sparse_read(indices, name=name)
5307  except AttributeError:
5308    return gen_array_ops.gather_v2(params, indices, axis, name=name)
5309
5310
5311@tf_export("gather", v1=[])
5312@dispatch.add_dispatch_support
5313def gather_v2(params,
5314              indices,
5315              validate_indices=None,
5316              axis=None,
5317              batch_dims=0,
5318              name=None):
5319  return gather(
5320      params,
5321      indices,
5322      validate_indices=validate_indices,
5323      name=name,
5324      axis=axis,
5325      batch_dims=batch_dims)
5326
5327
5328gather_v2.__doc__ = gather.__doc__
5329
5330
5331@tf_export(v1=["batch_gather"])
5332@dispatch.add_dispatch_support
5333@deprecation.deprecated(
5334    "2017-10-25", "`tf.batch_gather` is deprecated, please use `tf.gather` "
5335    "with `batch_dims=tf.rank(indices) - 1` instead.")  # pylint: disable=missing-docstring
5336def batch_gather(params, indices, name=None):
5337  """Gather slices from params according to indices with leading batch dims."""
5338  with ops.name_scope(name, "BatchGather", [params, indices]):
5339    indices = ops.convert_to_tensor(indices, name="indices")
5340    params = ops.convert_to_tensor(params, name="params")
5341    if indices.shape.ndims is None:
5342      raise ValueError(
5343          "batch_gather does not allow indices with unknown shape.")
5344    return _batch_gather(params, indices, batch_dims=indices.shape.ndims - 1)
5345
5346
5347def _batch_gather(params, indices, batch_dims, axis=None):
5348  r"""Gather slices from params according to indices with leading batch dims.
5349
5350  This operation assumes that the leading `batch_dims` dimensions of `indices`
5351  and `params` are batch dimensions; and performs a `tf.gather` operation within
5352  each batch. (If `batch_dims` is not specified, then it defaults to
5353  `rank(indices)-1`.)  In the case in which `batch_dims==0`, this operation
5354  is equivalent to `tf.gather`.
5355
5356  Args:
5357    params: A Tensor. The tensor from which to gather values.
5358    indices: A Tensor. Must be one of the following types: int32, int64. Index
5359      tensor. Must be in range `[0, params.shape[batch_dims]]`.
5360    batch_dims: An integer or none.  The number of batch dimensions.  Must be
5361      less than `rank(indices)`.  Defaults to `rank(indices) - 1` if None.
5362    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5363      `axis` in `params` to gather `indices` from. Must be greater than or equal
5364      to `batch_dims`.  Defaults to the first non-batch dimension. Supports
5365      negative indexes.
5366
5367  Returns:
5368    A Tensor. Has the same type as `params`.
5369
5370  Raises:
5371    ValueError: if `indices` has an unknown shape.
5372  """
5373  if batch_dims is not None and not isinstance(batch_dims, int):
5374    raise TypeError("Argument `batch_dims` must be an int. "
5375                    f"Received `batch_dims` = {batch_dims} instead")
5376  indices = ops.convert_to_tensor(indices, name="indices")
5377  params = ops.convert_to_tensor(params, name="params")
5378
5379  indices_ndims = indices.shape.ndims
5380  if indices_ndims is None:
5381    raise ValueError("tf.gather does not allow indices with unknown "
5382                     "rank when batch_dims is specified.")
5383  if batch_dims is None:
5384    batch_dims = indices_ndims - 1
5385  if batch_dims < 0:
5386    batch_dims += indices_ndims
5387  if batch_dims < 0 or batch_dims >= indices_ndims:
5388    raise ValueError(f"Argument `batch_dims` = {batch_dims} must be less than "
5389                     f"rank(`indices`) = {indices_ndims}")
5390  if params.shape.ndims is not None and batch_dims >= params.shape.ndims:
5391    raise ValueError(f"Argument `batch_dims` = {batch_dims} must be less than "
5392                     f"rank(`params`) = {params.shape.ndims}")
5393
5394  # Handle axis by transposing the axis dimension to be the first non-batch
5395  # dimension, recursively calling batch_gather with axis=0, and then
5396  # transposing the result to put the pre-axis dimensions before the indices
5397  # dimensions.
5398  if axis is not None and axis != batch_dims:
5399    # Adjust axis to be positive.
5400    if not isinstance(axis, int):
5401      axis = tf.where(axis < 0, axis + array_ops.rank(params), axis)
5402    elif axis < 0 and params.shape.ndims is None:
5403      axis = axis + array_ops.rank(params)
5404    else:
5405      if (axis < -params.shape.ndims) or (axis >= params.shape.ndims):
5406        raise ValueError(f"Argument `axis` = {axis} out of range "
5407                         f"[{-params.shape.ndims}, {params.shape.ndims})")
5408      if axis < 0:
5409        axis += params.shape.ndims
5410      if axis < batch_dims:
5411        raise ValueError(f"Argument `batch_dims` = {batch_dims} must be less "
5412                         f"than or equal to argument `axis` = {axis}")
5413
5414    # Move params[axis] up to params[batch_dims].
5415    perm = [
5416        list(range(batch_dims)), [axis],
5417        gen_math_ops._range(batch_dims, axis, 1),
5418        gen_math_ops._range(axis + 1, rank(params), 1)
5419    ]
5420    params = transpose(params, concat(perm, axis=0))
5421
5422    result = _batch_gather(params, indices, batch_dims=batch_dims)
5423
5424    # Move the result dimensions corresponding to params[batch_dims:axis]
5425    # to just before the dimensions corresponding to indices[batch_dims:].
5426    params_start = indices_ndims + axis - batch_dims
5427    perm = [
5428        list(range(batch_dims)),
5429        gen_math_ops._range(indices_ndims, params_start, 1),
5430        list(range(batch_dims, indices_ndims)),
5431        gen_math_ops._range(params_start, rank(result), 1)
5432    ]
5433    return transpose(result, perm=concat(perm, axis=0))
5434
5435  indices_shape = shape(indices)
5436  params_shape = shape(params)
5437  batch_indices = indices
5438  indices_dtype = indices.dtype.base_dtype
5439  accum_dim_value = ones((), dtype=indices_dtype)
5440  # Use correct type for offset index computation
5441  casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
5442  for dim in range(batch_dims, 0, -1):
5443    dim_value = casted_params_shape[dim - 1]
5444    accum_dim_value *= casted_params_shape[dim]
5445    start = zeros((), dtype=indices_dtype)
5446    step = ones((), dtype=indices_dtype)
5447    dim_indices = gen_math_ops._range(start, dim_value, step)
5448    dim_indices *= accum_dim_value
5449    dim_shape = stack(
5450        [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0)
5451    batch_indices += reshape(dim_indices, dim_shape)
5452
5453  flat_indices = reshape(batch_indices, [-1])
5454  outer_shape = params_shape[batch_dims + 1:]
5455  flat_inner_shape = gen_math_ops.prod(params_shape[:batch_dims + 1], [0],
5456                                       False)
5457
5458  flat_params = reshape(params, concat([[flat_inner_shape], outer_shape],
5459                                       axis=0))
5460  flat_result = gather(flat_params, flat_indices)
5461  result = reshape(flat_result, concat([indices_shape, outer_shape], axis=0))
5462  final_shape = indices.get_shape()[:batch_dims].merge_with(
5463      params.get_shape()[:batch_dims])
5464  final_shape = final_shape.concatenate(indices.get_shape().dims[batch_dims:])
5465  final_shape = final_shape.concatenate(params.get_shape()[batch_dims + 1:])
5466  result.set_shape(final_shape)
5467  return result
5468
5469
5470@tf_export(v1=["gather_nd", "manip.gather_nd"])
5471@dispatch.add_dispatch_support
5472@deprecated_endpoints("manip.gather_nd")
5473def gather_nd(params, indices, name=None, batch_dims=0):
5474  r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
5475
5476  `indices` is a `Tensor` of indices into `params`. The index vectors are
5477  arranged along the last axis of `indices`.
5478
5479  This is similar to `tf.gather`, in which `indices` defines slices into the
5480  first dimension of `params`. In `tf.gather_nd`, `indices` defines slices into
5481  the first `N` dimensions of `params`, where `N = indices.shape[-1]`.
5482
5483  Caution: On CPU, if an out of bound index is found, an error is returned.
5484  On GPU, if an out of bound index is found, a 0 is stored in the
5485  corresponding output value.
5486
5487  ## Gathering scalars
5488
5489  In the simplest case the vectors in `indices` index the full rank of `params`:
5490
5491  >>> tf.gather_nd(
5492  ...     indices=[[0, 0],
5493  ...              [1, 1]],
5494  ...     params = [['a', 'b'],
5495  ...               ['c', 'd']]).numpy()
5496  array([b'a', b'd'], dtype=object)
5497
5498  In this case the result has 1-axis fewer than `indices`, and each index vector
5499  is replaced by the scalar indexed from `params`.
5500
5501  In this case the shape relationship is:
5502
5503  ```
5504  index_depth = indices.shape[-1]
5505  assert index_depth == params.shape.rank
5506  result_shape = indices.shape[:-1]
5507  ```
5508
5509  If `indices` has a rank of `K`, it is helpful to think `indices` as a
5510  (K-1)-dimensional tensor of indices into `params`.
5511
5512  ## Gathering slices
5513
5514  If the index vectors do not index the full rank of `params` then each location
5515  in the result contains a slice of params. This example collects rows from a
5516  matrix:
5517
5518  >>> tf.gather_nd(
5519  ...     indices = [[1],
5520  ...                [0]],
5521  ...     params = [['a', 'b', 'c'],
5522  ...               ['d', 'e', 'f']]).numpy()
5523  array([[b'd', b'e', b'f'],
5524         [b'a', b'b', b'c']], dtype=object)
5525
5526  Here `indices` contains `[2]` index vectors, each with a length of `1`.
5527  The index vectors each refer to rows of the `params` matrix. Each
5528  row has a shape of `[3]` so the output shape is `[2, 3]`.
5529
5530  In this case, the relationship between the shapes is:
5531
5532  ```
5533  index_depth = indices.shape[-1]
5534  outer_shape = indices.shape[:-1]
5535  assert index_depth <= params.shape.rank
5536  inner_shape = params.shape[index_depth:]
5537  output_shape = outer_shape + inner_shape
5538  ```
5539
5540  It is helpful to think of the results in this case as tensors-of-tensors.
5541  The shape of the outer tensor is set by the leading dimensions of `indices`.
5542  While the shape of the inner tensors is the shape of a single slice.
5543
5544  ## Batches
5545
5546  Additionally both `params` and `indices` can have `M` leading batch
5547  dimensions that exactly match. In this case `batch_dims` must be set to `M`.
5548
5549  For example, to collect one row from each of a batch of matrices you could
5550  set the leading elements of the index vectors to be their location in the
5551  batch:
5552
5553  >>> tf.gather_nd(
5554  ...     indices = [[0, 1],
5555  ...                [1, 0],
5556  ...                [2, 4],
5557  ...                [3, 2],
5558  ...                [4, 1]],
5559  ...     params=tf.zeros([5, 7, 3])).shape.as_list()
5560  [5, 3]
5561
5562  The `batch_dims` argument lets you omit those leading location dimensions
5563  from the index:
5564
5565  >>> tf.gather_nd(
5566  ...     batch_dims=1,
5567  ...     indices = [[1],
5568  ...                [0],
5569  ...                [4],
5570  ...                [2],
5571  ...                [1]],
5572  ...     params=tf.zeros([5, 7, 3])).shape.as_list()
5573  [5, 3]
5574
5575  This is equivalent to caling a separate `gather_nd` for each location in the
5576  batch dimensions.
5577
5578
5579  >>> params=tf.zeros([5, 7, 3])
5580  >>> indices=tf.zeros([5, 1])
5581  >>> batch_dims = 1
5582  >>>
5583  >>> index_depth = indices.shape[-1]
5584  >>> batch_shape = indices.shape[:batch_dims]
5585  >>> assert params.shape[:batch_dims] == batch_shape
5586  >>> outer_shape = indices.shape[batch_dims:-1]
5587  >>> assert index_depth <= params.shape.rank
5588  >>> inner_shape = params.shape[batch_dims + index_depth:]
5589  >>> output_shape = batch_shape + outer_shape + inner_shape
5590  >>> output_shape.as_list()
5591  [5, 3]
5592
5593  ### More examples
5594
5595  Indexing into a 3-tensor:
5596
5597  >>> tf.gather_nd(
5598  ...     indices = [[1]],
5599  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5600  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5601  array([[[b'a1', b'b1'],
5602          [b'c1', b'd1']]], dtype=object)
5603
5604
5605
5606  >>> tf.gather_nd(
5607  ...     indices = [[0, 1], [1, 0]],
5608  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5609  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5610  array([[b'c0', b'd0'],
5611         [b'a1', b'b1']], dtype=object)
5612
5613
5614  >>> tf.gather_nd(
5615  ...     indices = [[0, 0, 1], [1, 0, 1]],
5616  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5617  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5618  array([b'b0', b'b1'], dtype=object)
5619
5620  The examples below are for the case when only indices have leading extra
5621  dimensions. If both 'params' and 'indices' have leading batch dimensions, use
5622  the 'batch_dims' parameter to run gather_nd in batch mode.
5623
5624  Batched indexing into a matrix:
5625
5626  >>> tf.gather_nd(
5627  ...     indices = [[[0, 0]], [[0, 1]]],
5628  ...     params = [['a', 'b'], ['c', 'd']]).numpy()
5629  array([[b'a'],
5630         [b'b']], dtype=object)
5631
5632
5633
5634  Batched slice indexing into a matrix:
5635
5636  >>> tf.gather_nd(
5637  ...     indices = [[[1]], [[0]]],
5638  ...     params = [['a', 'b'], ['c', 'd']]).numpy()
5639  array([[[b'c', b'd']],
5640         [[b'a', b'b']]], dtype=object)
5641
5642
5643  Batched indexing into a 3-tensor:
5644
5645  >>> tf.gather_nd(
5646  ...     indices = [[[1]], [[0]]],
5647  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5648  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5649  array([[[[b'a1', b'b1'],
5650           [b'c1', b'd1']]],
5651         [[[b'a0', b'b0'],
5652           [b'c0', b'd0']]]], dtype=object)
5653
5654
5655  >>> tf.gather_nd(
5656  ...     indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
5657  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5658  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5659  array([[[b'c0', b'd0'],
5660          [b'a1', b'b1']],
5661         [[b'a0', b'b0'],
5662          [b'c1', b'd1']]], dtype=object)
5663
5664  >>> tf.gather_nd(
5665  ...     indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]],
5666  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5667  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5668  array([[b'b0', b'b1'],
5669         [b'd0', b'c1']], dtype=object)
5670
5671
5672  Examples with batched 'params' and 'indices':
5673
5674  >>> tf.gather_nd(
5675  ...     batch_dims = 1,
5676  ...     indices = [[1],
5677  ...                [0]],
5678  ...     params = [[['a0', 'b0'],
5679  ...                ['c0', 'd0']],
5680  ...               [['a1', 'b1'],
5681  ...                ['c1', 'd1']]]).numpy()
5682  array([[b'c0', b'd0'],
5683         [b'a1', b'b1']], dtype=object)
5684
5685
5686  >>> tf.gather_nd(
5687  ...     batch_dims = 1,
5688  ...     indices = [[[1]], [[0]]],
5689  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5690  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5691  array([[[b'c0', b'd0']],
5692         [[b'a1', b'b1']]], dtype=object)
5693
5694  >>> tf.gather_nd(
5695  ...     batch_dims = 1,
5696  ...     indices = [[[1, 0]], [[0, 1]]],
5697  ...     params = [[['a0', 'b0'], ['c0', 'd0']],
5698  ...               [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5699  array([[b'c0'],
5700         [b'b1']], dtype=object)
5701
5702
5703  See also `tf.gather`.
5704
5705  Args:
5706    params: A `Tensor`. The tensor from which to gather values.
5707    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
5708      Index tensor.
5709    name: A name for the operation (optional).
5710    batch_dims: An integer or a scalar 'Tensor'. The number of batch dimensions.
5711
5712  Returns:
5713    A `Tensor`. Has the same type as `params`.
5714  """
5715  batch_dims_ = tensor_util.constant_value(batch_dims)
5716  if batch_dims_ is not None:
5717    batch_dims = int(batch_dims_)
5718  if batch_dims == 0:
5719    try:
5720      # TODO(apassos) find a less bad way of detecting resource variables
5721      # without introducing a circular dependency.
5722      return params.gather_nd(indices, name=name)
5723    except AttributeError:
5724      return gen_array_ops.gather_nd(params, indices, name=name)
5725  else:
5726    return batch_gather_nd(params, indices, batch_dims=batch_dims, name=name)
5727
5728
5729@tf_export("gather_nd", v1=[])
5730@dispatch.add_dispatch_support
5731def gather_nd_v2(params, indices, batch_dims=0, name=None):
5732  return gather_nd(params, indices, name=name, batch_dims=batch_dims)
5733
5734
5735gather_nd_v2.__doc__ = gather_nd.__doc__
5736
5737
5738def batch_gather_nd(params, indices, batch_dims, name=None):
5739  """gather_nd implementation with batch support."""
5740  with ops.name_scope(name, "BatchGatherND", [params, indices]):
5741    indices = ops.convert_to_tensor(indices, name="indices")
5742    params = ops.convert_to_tensor(params, name="params")
5743
5744    if not isinstance(batch_dims, int):
5745      raise TypeError(f"Argument `batch_dims` must be an int; got {batch_dims}")
5746    if batch_dims < 0:
5747      raise ValueError("tf.gather_nd does not allow negative batch_dims.")
5748    params_ndims = params.shape.ndims
5749    indices_ndims = indices.shape.ndims
5750    if indices_ndims is not None and batch_dims >= indices_ndims:
5751      raise ValueError(f"Argument `batch_dims` = {batch_dims} must be "
5752                       f"less than rank(`indices`) = {indices_ndims}")
5753    if params_ndims is not None and batch_dims >= params_ndims:
5754      raise ValueError(f"Argument `batch_dims` = {batch_dims} must be "
5755                       f"less than rank(`params`) = {params_ndims}")
5756
5757    expand = batch_dims == 0
5758    if expand:
5759      # Normally gather_nd will be called when batch_dims == 0.
5760      # But if this function is called with batch_dims = 0, e.g. for testing
5761      # purposes, this adds a dummy batch dimension to make batch_dims = 1.
5762      params = expand_dims(params, axis=0)
5763      indices = expand_dims(indices, axis=0)
5764      batch_dims = 1
5765
5766    params_shape = shape(params)
5767    indices_shape = shape(indices)
5768    batch_shape = params_shape[:batch_dims]
5769    batch_size = gen_math_ops.prod(batch_shape, [0])
5770    index_internal_ndims = rank(indices) - batch_dims - 1
5771    indices_internal_shape = indices_shape[batch_dims:-1]
5772
5773    # Assuming a 'params' with shape [b1, ..., bM, g1, ..., gN] and an 'indices'
5774    # with shape [b1, ..., bM, i1, ..., iK, C], where C <= N, we need to modify
5775    # 'indices' s.t. it has shape [i1, ..., iK, D], where D <= M + N and slices
5776    # to the entire 'params' tensor.
5777    # Assuming we have a batch of shape [B1, B2], we use meshgrid to create a
5778    # grid of size B1 x B2.
5779    batch_dim_list = unstack(batch_shape, axis=0)
5780    dim_ranges = [
5781        gen_math_ops.cast(gen_math_ops._range(0, x, 1), indices.dtype)
5782        for x in batch_dim_list
5783    ]
5784    mesh_list = meshgrid(*dim_ranges, indexing="ij") if dim_ranges else []
5785    # Then we flatten and stack the tensors to form a (B1.B2) by 2 matrix.
5786    flat_list = [reshape(x, shape=(-1,)) for x in mesh_list]
5787    index_grid = transpose(stack(flat_list, axis=0))
5788    # We need to concatenate these batch coordinates with the internal indices.
5789    # concat -> index_grid [B1.B2, 2] with indices [i1, ..., iK, C]
5790    # So we reshape them both to [(B1.B2), i1, ..., iK, *]
5791    index_grid_shape = shape(index_grid)
5792    index_grid = reshape(
5793        index_grid,
5794        concat([
5795            index_grid_shape[:1],
5796            ones(index_internal_ndims, dtype=dtypes.int32), index_grid_shape[1:]
5797        ],
5798               axis=0))
5799    tile_shape = concat(((1,), indices_internal_shape, (1,)), axis=0)
5800    index_grid = tile(index_grid, multiples=tile_shape)
5801    # index_grid now has shape [(B1.B2), i1, ..., iK, 2]
5802    flat_shape = concat(([batch_size], indices_shape[batch_dims:]), axis=0)
5803    flat_indices = reshape(indices, shape=flat_shape)
5804    # flat_indices now has shape [(B1.B2), i1, ..., iK, C]
5805    indices = concat((index_grid, flat_indices), axis=-1)
5806    # indices has shape [(B1.B2), i1, ..., iK, 2+C]
5807    out = gen_array_ops.gather_nd(params, indices)
5808    # out has shape [(B1.B2), i1, ..., iK, N-C]. Now we reshape batch to
5809    # its original form.
5810    out_shape = shape(out)
5811    out = reshape(out, shape=concat((batch_shape, out_shape[1:]), axis=0))
5812    if expand:
5813      out = squeeze(out, axis=0)
5814  return out
5815
5816
5817@deprecation.deprecated_endpoints("tensor_scatter_update")
5818@tf_export(
5819    "tensor_scatter_nd_update",
5820    v1=["tensor_scatter_nd_update", "tensor_scatter_update"])
5821@dispatch.add_dispatch_support
5822def tensor_scatter_nd_update(tensor, indices, updates, name=None):
5823  """Scatter `updates` into an existing tensor according to `indices`.
5824
5825  This operation creates a new tensor by applying sparse `updates` to the
5826  input `tensor`. This is similar to an index assignment.
5827
5828  ```
5829  # Not implemented: tensors cannot be updated inplace.
5830  tensor[indices] = updates
5831  ```
5832
5833  If an out of bound index is found on CPU, an error is returned.
5834
5835  > **WARNING**: There are some GPU specific semantics for this operation.
5836  >
5837  > - If an out of bound index is found, the index is ignored.
5838  > - The order in which updates are applied is nondeterministic, so the output
5839  >   will be nondeterministic if `indices` contains duplicates.
5840
5841  This operation is very similar to `tf.scatter_nd`, except that the updates are
5842  scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
5843  for the existing tensor cannot be re-used, a copy is made and updated.
5844
5845  In general:
5846
5847  * `indices` is an integer tensor - the indices to update in `tensor`.
5848  * `indices` has **at least two** axes, the last axis is the depth of the
5849    index vectors.
5850  * For each index vector in `indices` there is a corresponding entry in
5851    `updates`.
5852  * If the length of the index vectors matches the rank of the `tensor`, then
5853    the index vectors each point to scalars in `tensor` and each update is a
5854    scalar.
5855  * If the length of the index vectors is less than the rank of `tensor`, then
5856    the index vectors each point to slices of `tensor` and shape of the updates
5857    must match that slice.
5858
5859  Overall this leads to the following shape constraints:
5860
5861  ```
5862  assert tf.rank(indices) >= 2
5863  index_depth = indices.shape[-1]
5864  batch_shape = indices.shape[:-1]
5865  assert index_depth <= tf.rank(tensor)
5866  outer_shape = tensor.shape[:index_depth]
5867  inner_shape = tensor.shape[index_depth:]
5868  assert updates.shape == batch_shape + inner_shape
5869  ```
5870
5871  Typical usage is often much simpler than this general form, and it
5872  can be better understood starting with simple examples:
5873
5874  ### Scalar updates
5875
5876  The simplest usage inserts scalar elements into a tensor by index.
5877  In this case, the `index_depth` must equal the rank of the
5878  input `tensor`, slice each column of `indices` is an index into an axis of the
5879  input `tensor`.
5880
5881  In this simplest case the shape constraints are:
5882
5883  ```
5884  num_updates, index_depth = indices.shape.as_list()
5885  assert updates.shape == [num_updates]
5886  assert index_depth == tf.rank(tensor)`
5887  ```
5888
5889  For example, to insert 4 scattered elements in a rank-1 tensor with
5890  8 elements.
5891
5892  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
5893  <img style="width:100%"
5894    src="https://www.tensorflow.org/images/ScatterNd1.png">
5895  </div>
5896
5897  This scatter operation would look like this:
5898
5899  >>> tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
5900  >>> indices = [[1], [3], [4], [7]]       # num_updates == 4, index_depth == 1
5901  >>> updates = [9, 10, 11, 12]            # num_updates == 4
5902  >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5903  tf.Tensor([ 0 9  0 10  11  0  0 12], shape=(8,), dtype=int32)
5904
5905  The length (first axis) of `updates` must equal the length of the `indices`:
5906  `num_updates`. This is the number of updates being inserted. Each scalar
5907  update is inserted into `tensor` at the indexed location.
5908
5909  For a higher rank input `tensor` scalar updates can be inserted by using an
5910  `index_depth` that matches `tf.rank(tensor)`:
5911
5912  >>> tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
5913  >>> indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
5914  >>> updates = [5, 10]                    # num_updates == 2
5915  >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5916  tf.Tensor(
5917      [[ 1  5]
5918       [ 1  1]
5919       [10  1]], shape=(3, 2), dtype=int32)
5920
5921  ### Slice updates
5922
5923  When the input `tensor` has more than one axis scatter can be used to update
5924  entire slices.
5925
5926  In this case it's helpful to think of the input `tensor` as being a two level
5927  array-of-arrays. The shape of this two level array is split into the
5928  `outer_shape` and the `inner_shape`.
5929
5930  `indices` indexes into the outer level of the input tensor (`outer_shape`).
5931  and replaces the sub-array at that location with the corresponding item from
5932  the `updates` list. The shape of each update is `inner_shape`.
5933
5934  When updating a list of slices the shape constraints are:
5935
5936  ```
5937  num_updates, index_depth = indices.shape.as_list()
5938  outer_shape = tensor.shape[:index_depth]
5939  inner_shape = tensor.shape[index_depth:]
5940  assert updates.shape == [num_updates, inner_shape]
5941  ```
5942
5943  For example, to update rows of a `(6, 3)` `tensor`:
5944
5945  >>> tensor = tf.zeros([6, 3], dtype=tf.int32)
5946
5947  Use an index depth of one.
5948
5949  >>> indices = tf.constant([[2], [4]])     # num_updates == 2, index_depth == 1
5950  >>> num_updates, index_depth = indices.shape.as_list()
5951
5952  The `outer_shape` is `6`, the inner shape is `3`:
5953
5954  >>> outer_shape = tensor.shape[:index_depth]
5955  >>> inner_shape = tensor.shape[index_depth:]
5956
5957  2 rows are being indexed so 2 `updates` must be supplied.
5958  Each update must be shaped to match the `inner_shape`.
5959
5960  >>> # num_updates == 2, inner_shape==3
5961  >>> updates = tf.constant([[1, 2, 3],
5962  ...                        [4, 5, 6]])
5963
5964  Altogether this gives:
5965
5966  >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
5967  array([[0, 0, 0],
5968         [0, 0, 0],
5969         [1, 2, 3],
5970         [0, 0, 0],
5971         [4, 5, 6],
5972         [0, 0, 0]], dtype=int32)
5973
5974  #### More slice update examples
5975
5976  A tensor representing a batch of uniformly sized video clips naturally has 5
5977  axes: `[batch_size, time, width, height, channels]`.
5978
5979  For example:
5980
5981  >>> batch_size, time, width, height, channels = 13,11,7,5,3
5982  >>> video_batch = tf.zeros([batch_size, time, width, height, channels])
5983
5984  To replace a selection of video clips:
5985    * Use an `index_depth` of 1 (indexing the `outer_shape`: `[batch_size]`)
5986    * Provide updates each with a shape matching the `inner_shape`:
5987      `[time, width, height, channels]`.
5988
5989  To replace the first two clips with ones:
5990
5991  >>> indices = [[0],[1]]
5992  >>> new_clips = tf.ones([2, time, width, height, channels])
5993  >>> tf.tensor_scatter_nd_update(video_batch, indices, new_clips)
5994
5995  To replace a selection of frames in the videos:
5996
5997  * `indices` must have an `index_depth` of 2 for the `outer_shape`:
5998    `[batch_size, time]`.
5999  * `updates` must be shaped like a list of images.  Each update must have a
6000    shape, matching the `inner_shape`: `[width, height, channels]`.
6001
6002  To replace the first frame of the first three video clips:
6003
6004  >>> indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2
6005  >>> new_images = tf.ones([
6006  ...   # num_updates=3, inner_shape=(width, height, channels)
6007  ...   3, width, height, channels])
6008  >>> tf.tensor_scatter_nd_update(video_batch, indices, new_images)
6009
6010  ### Folded indices
6011
6012  In simple cases it's convenient to think of `indices` and `updates` as
6013  lists, but this is not a strict requirement. Instead of a flat `num_updates`,
6014  the `indices` and `updates` can be folded into a `batch_shape`. This
6015  `batch_shape` is all axes of the `indices`, except for the innermost
6016  `index_depth` axis.
6017
6018  ```
6019  index_depth = indices.shape[-1]
6020  batch_shape = indices.shape[:-1]
6021  ```
6022
6023  Note: The one exception is that the `batch_shape` cannot be `[]`. You can't
6024  update a single index by passing indices with shape `[index_depth]`.
6025
6026  `updates` must have a matching `batch_shape` (the axes before `inner_shape`).
6027
6028  ```
6029  assert updates.shape == batch_shape + inner_shape
6030  ```
6031
6032  Note: The result is equivalent to flattening the `batch_shape` axes of
6033  `indices` and `updates`. This generalization just avoids the need
6034  for reshapes when it is more natural to construct "folded" indices and
6035  updates.
6036
6037  With this generalization the full shape constraints are:
6038
6039  ```
6040  assert tf.rank(indices) >= 2
6041  index_depth = indices.shape[-1]
6042  batch_shape = indices.shape[:-1]
6043  assert index_depth <= tf.rank(tensor)
6044  outer_shape = tensor.shape[:index_depth]
6045  inner_shape = tensor.shape[index_depth:]
6046  assert updates.shape == batch_shape + inner_shape
6047  ```
6048
6049  For example, to draw an `X` on a `(5,5)` matrix start with these indices:
6050
6051  >>> tensor = tf.zeros([5,5])
6052  >>> indices = tf.constant([
6053  ...  [[0,0],
6054  ...   [1,1],
6055  ...   [2,2],
6056  ...   [3,3],
6057  ...   [4,4]],
6058  ...  [[0,4],
6059  ...   [1,3],
6060  ...   [2,2],
6061  ...   [3,1],
6062  ...   [4,0]],
6063  ... ])
6064  >>> indices.shape.as_list()  # batch_shape == [2, 5], index_depth == 2
6065  [2, 5, 2]
6066
6067  Here the `indices` do not have a shape of `[num_updates, index_depth]`, but a
6068  shape of `batch_shape+[index_depth]`.
6069
6070  Since the `index_depth` is equal to the rank of `tensor`:
6071
6072  * `outer_shape` is `(5,5)`
6073  * `inner_shape` is `()` - each update is scalar
6074  * `updates.shape` is `batch_shape + inner_shape == (5,2) + ()`
6075
6076  >>> updates = [
6077  ...   [1,1,1,1,1],
6078  ...   [1,1,1,1,1],
6079  ... ]
6080
6081  Putting this together gives:
6082
6083  >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
6084  array([[1., 0., 0., 0., 1.],
6085         [0., 1., 0., 1., 0.],
6086         [0., 0., 1., 0., 0.],
6087         [0., 1., 0., 1., 0.],
6088         [1., 0., 0., 0., 1.]], dtype=float32)
6089
6090  Args:
6091    tensor: Tensor to copy/update.
6092    indices: Indices to update.
6093    updates: Updates to apply at the indices.
6094    name: Optional name for the operation.
6095
6096  Returns:
6097    A new tensor with the given shape and updates applied according to the
6098    indices.
6099  """
6100  return gen_array_ops.tensor_scatter_update(
6101      tensor=tensor, indices=indices, updates=updates, name=name)
6102
6103
6104# Define quantize_v2 here in order to make name the second-to-last attribute,
6105# because round_mode was added later.
6106# (And also now because of 'axis' processing).
6107@tf_export(v1=["quantize_v2"])
6108@dispatch.add_dispatch_support
6109@deprecation.deprecated(
6110    "2017-10-25",
6111    "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
6112    "instead.")  # pylint: disable=missing-docstring
6113def quantize_v2(
6114    input,  # pylint: disable=redefined-builtin
6115    min_range,
6116    max_range,
6117    T,
6118    mode="MIN_COMBINED",
6119    name=None,
6120    round_mode="HALF_AWAY_FROM_ZERO",
6121    narrow_range=False,
6122    axis=None,
6123    ensure_minimum_range=0.01):
6124  if axis is None:
6125    axis = -1
6126  elif axis < 0:
6127    if input.shape.ndims is None:
6128      raise ValueError("input should have known rank to use negative axis.")
6129    axis %= input.shape.ndims
6130
6131  if ensure_minimum_range != 0.01:
6132    return gen_array_ops.quantize_v2(
6133        input,
6134        min_range,
6135        max_range,
6136        T=T,
6137        mode=mode,
6138        name=name,
6139        round_mode=round_mode,
6140        narrow_range=narrow_range,
6141        axis=axis,
6142        ensure_minimum_range=ensure_minimum_range)
6143  return gen_array_ops.quantize_v2(
6144      input,
6145      min_range,
6146      max_range,
6147      T=T,
6148      mode=mode,
6149      name=name,
6150      round_mode=round_mode,
6151      narrow_range=narrow_range,
6152      axis=axis)
6153
6154
6155quantize_v2.__doc__ = """Please use `tf.quantization.quantize` instead."""
6156
6157
6158# We want to expose tf.quantization.quantize instead of
6159# tf.quantization.quantize; we can deprecate tf.quantization.quantize in next
6160# version of TensorFlow.
6161@tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"])
6162@dispatch.add_dispatch_support
6163@deprecation.deprecated_endpoints("quantize")
6164def quantize(
6165    input,  # pylint: disable=redefined-builtin
6166    min_range,
6167    max_range,
6168    T,
6169    mode="MIN_COMBINED",
6170    round_mode="HALF_AWAY_FROM_ZERO",
6171    name=None,
6172    narrow_range=False,
6173    axis=None,
6174    ensure_minimum_range=0.01):
6175  """Quantize the input tensor."""
6176  if ensure_minimum_range != 0.01:
6177    return quantize_v2(
6178        input,
6179        min_range,
6180        max_range,
6181        T,
6182        mode=mode,
6183        round_mode=round_mode,
6184        name=name,
6185        narrow_range=narrow_range,
6186        axis=axis,
6187        ensure_minimum_range=ensure_minimum_range)
6188  return quantize_v2(
6189      input,
6190      min_range,
6191      max_range,
6192      T,
6193      mode=mode,
6194      round_mode=round_mode,
6195      name=name,
6196      narrow_range=narrow_range,
6197      axis=axis)
6198
6199
6200@tf_export("quantization.dequantize", v1=["quantization.dequantize",
6201                                          "dequantize"])
6202@dispatch.add_dispatch_support
6203@deprecation.deprecated_endpoints("dequantize")
6204def dequantize(  # pylint: disable=missing-docstring
6205    input,  # pylint: disable=redefined-builtin
6206    min_range,
6207    max_range,
6208    mode="MIN_COMBINED",
6209    name=None,
6210    axis=None,
6211    narrow_range=False,
6212    dtype=dtypes.float32):
6213  if axis is None:
6214    axis = -1
6215  elif axis < 0:
6216    if input.shape.ndims is None:
6217      raise ValueError("input should have known rank to use negative axis.")
6218    axis %= input.shape.ndims
6219
6220  if axis >= 0 or narrow_range:
6221    return gen_array_ops.dequantize(
6222        input,
6223        min_range,
6224        max_range,
6225        mode=mode,
6226        name=name,
6227        narrow_range=narrow_range,
6228        axis=axis,
6229        dtype=dtype)
6230  return gen_array_ops.dequantize(
6231      input, min_range, max_range, mode=mode, name=name, dtype=dtype)
6232
6233
6234dequantize.__doc__ = gen_array_ops.dequantize.__doc__
6235
6236
6237@tf_export("quantization.quantize_and_dequantize")
6238@dispatch.add_dispatch_support
6239@deprecation.deprecated(None,
6240                        "This Op has been deprecated, use" +
6241                        "`quantize_and_dequantize_v2` instead. To " +
6242                        "To simulate the V1 the behavior of " +
6243                        "tf.quantization.quantize_and_dequantize(...) use " +
6244                        "tf.grad_pass_through(" +
6245                        "tf.quantization.quantize_and_dequantize_v2)(...).")
6246def quantize_and_dequantize(
6247    input,  # pylint: disable=redefined-builtin
6248    input_min,
6249    input_max,
6250    signed_input=True,
6251    num_bits=8,
6252    range_given=False,
6253    round_mode="HALF_TO_EVEN",
6254    name=None,
6255    narrow_range=False,
6256    axis=None):
6257  """Quantizes then dequantizes a tensor.
6258
6259  Args:
6260    input: A `Tensor` to quantize and dequantize.
6261    input_min: If range_given=True, the minimum input value, that needs to be
6262      represented in the quantized representation. If axis is specified, this
6263      should be a vector of minimum values for each slice along axis.
6264    input_max: If range_given=True, the maximum input value that needs to be
6265      represented in the quantized representation. If axis is specified, this
6266      should be a vector of maximum values for each slice along axis.
6267    signed_input: True if the quantization is signed or unsigned.
6268    num_bits: The bitwidth of the quantization.
6269    range_given: If true use `input_min` and `input_max` for the range of the
6270      input, otherwise determine min and max from the input `Tensor`.
6271    round_mode: Rounding mode when rounding from float values to quantized ones.
6272      one of ['HALF_TO_EVEN', 'HALF_UP']
6273    name: Optional name for the operation.
6274    narrow_range: If true, then the absolute value of the quantized minimum
6275      value is the same as the quantized maximum value, instead of 1 greater.
6276      i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
6277    axis: Integer. If specified, refers to a dimension of the input tensor, such
6278      that quantization will be per slice along that dimension.
6279
6280  Returns:
6281    A `Tensor`. Each element is the result of quantizing and dequantizing the
6282    corresponding element of `input`.
6283  """
6284  if axis is None:
6285    axis = -1
6286  elif axis < 0:
6287    if input.shape.ndims is None:
6288      raise ValueError("input should have known rank to use negative axis.")
6289    axis %= input.shape.ndims
6290
6291  return gen_array_ops.quantize_and_dequantize_v2(
6292      input,
6293      input_min=input_min,
6294      input_max=input_max,
6295      signed_input=signed_input,
6296      num_bits=num_bits,
6297      range_given=range_given,
6298      round_mode=round_mode,
6299      narrow_range=narrow_range,
6300      axis=axis,
6301      name=name)
6302
6303
6304@tf_export("quantization.quantize_and_dequantize_v2")
6305@dispatch.add_dispatch_support
6306def quantize_and_dequantize_v2(
6307    input,  # pylint: disable=redefined-builtin
6308    input_min,
6309    input_max,
6310    signed_input=True,
6311    num_bits=8,
6312    range_given=False,
6313    round_mode="HALF_TO_EVEN",
6314    name=None,
6315    narrow_range=False,
6316    axis=None):
6317  """Quantizes then dequantizes a tensor.
6318
6319  Updates the gradient definition for quantization that is outside the range to
6320  be 0.To simulate the V1 the behavior of
6321  tf.quantization.quantize_and_dequantize(...) use
6322  tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
6323
6324  Example usage:
6325
6326  ```python
6327  def getQuantizeOp(input):
6328      input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
6329      net = tf.quantization.quantize_and_dequantize(input,
6330                                                    input_min=min_threshold,
6331                                                    input_max=max_threshold,
6332                                                    range_given=True)
6333
6334  To simulate v1 behavior:
6335
6336  def testDecomposeQuantizeDequantize(self):
6337      def f(input_tensor):
6338        return tf.quantization.quantize_and_dequantize_v2(input_tensor,
6339                                                          input_min = 5.0,
6340                                                          input_max= -10.0,
6341                                                          range_given=True)
6342      input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
6343      net = tf.grad_pass_through(f)(input_tensor)
6344  ```
6345
6346  Args:
6347    input: A `Tensor` to quantize and dequantize.
6348    input_min: If range_given=True, the minimum input value, that needs to be
6349      represented in the quantized representation. If axis is specified, this
6350      should be a vector of minimum values for each slice along axis.
6351    input_max: If range_given=True, the maximum input value that needs to be
6352      represented in the quantized representation. If axis is specified, this
6353      should be a vector of maximum values for each slice along axis.
6354    signed_input: True if the quantization is signed or unsigned.
6355    num_bits: The bitwidth of the quantization.
6356    range_given: If true use `input_min` and `input_max` for the range of the
6357      input, otherwise determine min and max from the input `Tensor`.
6358    round_mode: Rounding mode when rounding from float values to quantized ones.
6359      one of ['HALF_TO_EVEN', 'HALF_UP']
6360    name: Optional name for the operation.
6361    narrow_range: If true, then the absolute value of the quantized minimum
6362      value is the same as the quantized maximum value, instead of 1 greater.
6363      i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
6364    axis: Integer. If specified, refers to a dimension of the input tensor, such
6365      that quantization will be per slice along that dimension.
6366
6367  Returns:
6368    A `Tensor`. Each element is the result of quantizing and dequantizing the
6369    corresponding element of `input`.
6370  """
6371  if axis is None:
6372    axis = -1
6373  elif axis < 0:
6374    if input.shape.ndims is None:
6375      raise ValueError("input should have known rank to use negative axis.")
6376    axis %= input.shape.ndims
6377
6378  return gen_array_ops.quantize_and_dequantize_v4(
6379      input,
6380      input_min=input_min,
6381      input_max=input_max,
6382      signed_input=signed_input,
6383      num_bits=num_bits,
6384      range_given=range_given,
6385      round_mode=round_mode,
6386      narrow_range=narrow_range,
6387      axis=axis,
6388      name=name)
6389
6390
6391@tf_export("searchsorted")
6392@dispatch.add_dispatch_support
6393def searchsorted(sorted_sequence,
6394                 values,
6395                 side="left",
6396                 out_type=dtypes.int32,
6397                 name=None):
6398  """Searches for where a value would go in a sorted sequence.
6399
6400  This is not a method for checking containment (like python `in`).
6401
6402  The typical use case for this operation is "binning", "bucketing", or
6403  "discretizing". The `values` are assigned to bucket-indices based on the
6404  **edges** listed in `sorted_sequence`. This operation
6405  returns the bucket-index for each value.
6406
6407  >>> edges = [-1, 3.3, 9.1, 10.0]
6408  >>> values = [0.0, 4.1, 12.0]
6409  >>> tf.searchsorted(edges, values).numpy()
6410  array([1, 2, 4], dtype=int32)
6411
6412  The `side` argument controls which index is returned if a value lands exactly
6413  on an edge:
6414
6415  >>> seq = [0, 3, 9, 10, 10]
6416  >>> values = [0, 4, 10]
6417  >>> tf.searchsorted(seq, values).numpy()
6418  array([0, 2, 3], dtype=int32)
6419  >>> tf.searchsorted(seq, values, side="right").numpy()
6420  array([1, 2, 5], dtype=int32)
6421
6422  The `axis` is not settable for this operation. It always operates on the
6423  innermost dimension (`axis=-1`). The operation will accept any number of
6424  outer dimensions. Here it is applied to the rows of a matrix:
6425
6426  >>> sorted_sequence = [[0., 3., 8., 9., 10.],
6427  ...                    [1., 2., 3., 4., 5.]]
6428  >>> values = [[9.8, 2.1, 4.3],
6429  ...           [0.1, 6.6, 4.5, ]]
6430  >>> tf.searchsorted(sorted_sequence, values).numpy()
6431  array([[4, 1, 2],
6432         [0, 5, 4]], dtype=int32)
6433
6434  Note: This operation assumes that `sorted_sequence` **is sorted** along the
6435  innermost axis, maybe using `tf.sort(..., axis=-1)`. **If the sequence is not
6436  sorted no error is raised** and the content of the returned tensor is not well
6437  defined.
6438
6439  Args:
6440    sorted_sequence: N-D `Tensor` containing a sorted sequence.
6441    values: N-D `Tensor` containing the search values.
6442    side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
6443      upper_bound.
6444    out_type: The output type (`int32` or `int64`).  Default is `tf.int32`.
6445    name: Optional name for the operation.
6446
6447  Returns:
6448    An N-D `Tensor` the size of `values` containing the result of applying
6449    either lower_bound or upper_bound (depending on side) to each value.  The
6450    result is not a global index to the entire `Tensor`, but the index in the
6451    last dimension.
6452
6453  Raises:
6454    ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
6455                If the total size of `values` exceeds `2^31 - 1` elements.
6456                If the first `N-1` dimensions of the two tensors don't match.
6457  """
6458  sequence_size = shape_internal(sorted_sequence)[-1]
6459  values_size = shape_internal(values)[-1]
6460  sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
6461  values_2d = reshape(values, [-1, values_size])
6462  if side == "right":
6463    output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
6464                                       name)
6465  elif side == "left":
6466    output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
6467                                       name)
6468  else:
6469    raise ValueError("Argument `side` must be either 'right' or 'left'. "
6470                     f"Received: `side` = '{side}'.")
6471  return reshape(output, shape_internal(values))
6472
6473
6474quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
6475
6476
6477@tf_export("image.extract_patches")
6478@dispatch.add_dispatch_support
6479def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
6480  r"""Extract `patches` from `images`.
6481
6482  This op collects patches from the input image, as if applying a
6483  convolution. All extracted patches are stacked in the depth (last) dimension
6484  of the output.
6485
6486  Specifically, the op extracts patches of shape `sizes` which are `strides`
6487  apart in the input image. The output is subsampled using the `rates` argument,
6488  in the same manner as "atrous" or "dilated" convolutions.
6489
6490  The result is a 4D tensor which is indexed by batch, row, and column.
6491  `output[i, x, y]` contains a flattened patch of size `sizes[1], sizes[2]`
6492  which is taken from the input starting at
6493  `images[i, x*strides[1], y*strides[2]]`.
6494
6495  Each output patch can be reshaped to `sizes[1], sizes[2], depth`, where
6496  `depth` is `images.shape[3]`.
6497
6498  The output elements are taken from the input at intervals given by the `rate`
6499  argument, as in dilated convolutions.
6500
6501  The `padding` argument has no effect on the size of each patch, it determines
6502  how many patches are extracted. If `VALID`, only patches which are fully
6503  contained in the input image are included. If `SAME`, all patches whose
6504  starting point is inside the input are included, and areas outside the input
6505  default to zero.
6506
6507  Example:
6508
6509  ```
6510    n = 10
6511    # images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100
6512    images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]
6513
6514    # We generate two outputs as follows:
6515    # 1. 3x3 patches with stride length 5
6516    # 2. Same as above, but the rate is increased to 2
6517    tf.image.extract_patches(images=images,
6518                             sizes=[1, 3, 3, 1],
6519                             strides=[1, 5, 5, 1],
6520                             rates=[1, 1, 1, 1],
6521                             padding='VALID')
6522
6523    # Yields:
6524    [[[[ 1  2  3 11 12 13 21 22 23]
6525       [ 6  7  8 16 17 18 26 27 28]]
6526      [[51 52 53 61 62 63 71 72 73]
6527       [56 57 58 66 67 68 76 77 78]]]]
6528  ```
6529
6530  If we mark the pixels in the input image which are taken for the output with
6531  `*`, we see the pattern:
6532
6533  ```
6534     *  *  *  4  5  *  *  *  9 10
6535     *  *  * 14 15  *  *  * 19 20
6536     *  *  * 24 25  *  *  * 29 30
6537    31 32 33 34 35 36 37 38 39 40
6538    41 42 43 44 45 46 47 48 49 50
6539     *  *  * 54 55  *  *  * 59 60
6540     *  *  * 64 65  *  *  * 69 70
6541     *  *  * 74 75  *  *  * 79 80
6542    81 82 83 84 85 86 87 88 89 90
6543    91 92 93 94 95 96 97 98 99 100
6544  ```
6545
6546  ```
6547    tf.image.extract_patches(images=images,
6548                             sizes=[1, 3, 3, 1],
6549                             strides=[1, 5, 5, 1],
6550                             rates=[1, 2, 2, 1],
6551                             padding='VALID')
6552
6553    # Yields:
6554    [[[[  1   3   5  21  23  25  41  43  45]
6555       [  6   8  10  26  28  30  46  48  50]]
6556
6557      [[ 51  53  55  71  73  75  91  93  95]
6558       [ 56  58  60  76  78  80  96  98 100]]]]
6559  ```
6560
6561  We can again draw the effect, this time using the symbols `*`, `x`, `+` and
6562  `o` to distinguish the patches:
6563
6564  ```
6565     *  2  *  4  *  x  7  x  9  x
6566    11 12 13 14 15 16 17 18 19 20
6567     * 22  * 24  *  x 27  x 29  x
6568    31 32 33 34 35 36 37 38 39 40
6569     * 42  * 44  *  x 47  x 49  x
6570     + 52  + 54  +  o 57  o 59  o
6571    61 62 63 64 65 66 67 68 69 70
6572     + 72  + 74  +  o 77  o 79  o
6573    81 82 83 84 85 86 87 88 89 90
6574     + 92  + 94  +  o 97  o 99  o
6575  ```
6576
6577  Args:
6578    images: A 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`.
6579    sizes: The size of the extracted patches. Must be
6580      `[1, size_rows, size_cols, 1]`.
6581    strides: A 1-D Tensor of length 4. How far the centers of two consecutive
6582      patches are in the images. Must be: `[1, stride_rows, stride_cols, 1]`.
6583    rates: A 1-D Tensor of length 4. Must be: `[1, rate_rows, rate_cols, 1]`.
6584      This is the input stride, specifying how far two consecutive patch samples
6585      are in the input. Equivalent to extracting patches with `patch_sizes_eff =
6586      patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by subsampling
6587      them spatially by a factor of `rates`. This is equivalent to `rate` in
6588      dilated (a.k.a. Atrous) convolutions.
6589    padding: The type of padding algorithm to use.
6590    name: A name for the operation (optional).
6591
6592  Returns:
6593    A 4-D Tensor of the same type as the input.
6594  """
6595  return gen_array_ops.extract_image_patches(images, sizes, strides, rates,
6596                                             padding, name)
6597
6598
6599@tf_export(v1=["image.extract_image_patches", "extract_image_patches"])
6600@dispatch.add_dispatch_support
6601@deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead",
6602                             "ksizes")
6603def extract_image_patches(  # pylint: disable=missing-docstring
6604    images,
6605    ksizes=None,
6606    strides=None,
6607    rates=None,
6608    padding=None,
6609    name=None,
6610    sizes=None):
6611  """Extract patches from images and put them in the "depth" output dimension.
6612
6613  Args:
6614    `images`: A `Tensor`. Must be one of the following types: `float32`,
6615      `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`,
6616      `uint16`, `half`, `uint32`, `uint64`. 4-D Tensor with shape
6617    `[batch, in_rows, in_cols, depth]`. `ksizes`: A list of `ints` that has
6618      length `>= 4`. The size of the sliding window for each
6619    dimension of `images`. `strides`: A list of `ints` that has length `>= 4`.
6620      1-D of length 4. How far the centers of two consecutive
6621    patches are in the images. Must be:
6622    `[1, stride_rows, stride_cols, 1]`. `rates`: A list of `ints`
6623    that has length `>= 4`. 1-D of length 4. Must be: `[1, rate_rows, rate_cols,
6624      1]`. This is the input stride, specifying how far two consecutive patch
6625      samples are in the input. Equivalent to extracting patches with
6626      `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`,
6627      followed by subsampling them spatially by a factor of `rates`. This is
6628      equivalent to `rate` in dilated (a.k.a. Atrous) convolutions.
6629    `padding`: A `string` from: "SAME", "VALID". The type of padding algorithm
6630      to use.
6631    We specify the size-related attributes as:  ``` ksizes = [1, ksize_rows,
6632      ksize_cols, 1] strides = [1, strides_rows, strides_cols, 1] rates = [1,
6633      rates_rows, rates_cols, 1]
6634    name: A name for the operation (optional). ```
6635
6636  Returns:
6637    A Tensor. Has the same type as images.
6638  """
6639  ksizes = deprecation.deprecated_argument_lookup("sizes", sizes, "ksizes",
6640                                                  ksizes)
6641  return gen_array_ops.extract_image_patches(images, ksizes, strides, rates,
6642                                             padding, name)
6643
6644
6645extract_image_patches.__doc__ = gen_array_ops.extract_image_patches.__doc__
6646
6647
6648@tf_export("fingerprint")
6649@dispatch.add_dispatch_support
6650def fingerprint(data, method="farmhash64", name=None):
6651  r"""Generates fingerprint values.
6652
6653  Generates fingerprint values of `data`.
6654
6655  Fingerprint op considers the first dimension of `data` as the batch dimension,
6656  and `output[i]` contains the fingerprint value generated from contents in
6657  `data[i, ...]` for all `i`.
6658
6659  Fingerprint op writes fingerprint values as byte arrays. For example, the
6660  default method `farmhash64` generates a 64-bit fingerprint value at a time.
6661  This 8-byte value is written out as an `tf.uint8` array of size 8, in
6662  little-endian order.
6663
6664  For example, suppose that `data` has data type `tf.int32` and shape (2, 3, 4),
6665  and that the fingerprint method is `farmhash64`. In this case, the output
6666  shape is (2, 8), where 2 is the batch dimension size of `data`, and 8 is the
6667  size of each fingerprint value in bytes. `output[0, :]` is generated from
6668  12 integers in `data[0, :, :]` and similarly `output[1, :]` is generated from
6669  other 12 integers in `data[1, :, :]`.
6670
6671  Note that this op fingerprints the raw underlying buffer, and it does not
6672  fingerprint Tensor's metadata such as data type and/or shape. For example, the
6673  fingerprint values are invariant under reshapes and bitcasts as long as the
6674  batch dimension remain the same:
6675
6676  ```python
6677  tf.fingerprint(data) == tf.fingerprint(tf.reshape(data, ...))
6678  tf.fingerprint(data) == tf.fingerprint(tf.bitcast(data, ...))
6679  ```
6680
6681  For string data, one should expect `tf.fingerprint(data) !=
6682  tf.fingerprint(tf.string.reduce_join(data))` in general.
6683
6684  Args:
6685    data: A `Tensor`. Must have rank 1 or higher.
6686    method: A `Tensor` of type `tf.string`. Fingerprint method used by this op.
6687      Currently available method is `farmhash64`.
6688    name: A name for the operation (optional).
6689
6690  Returns:
6691    A two-dimensional `Tensor` of type `tf.uint8`. The first dimension equals to
6692    `data`'s first dimension, and the second dimension size depends on the
6693    fingerprint algorithm.
6694  """
6695  return gen_array_ops.fingerprint(data, method, name)
6696
6697
6698def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
6699  """Converts the given value to an integer Tensor."""
6700  tensor = ops.convert_to_tensor(
6701      tensor, name=name, preferred_dtype=dtype or dtypes.int32)
6702  if tensor.dtype.is_integer:
6703    if dtype is not None:
6704      tensor = gen_math_ops.cast(tensor, dtype)
6705  else:
6706    raise TypeError(f"Argument `tensor` (name: {name}) must be of type integer."
6707                    f" Received `tensor` = {tensor} of dtype: {tensor.dtype}")
6708  return tensor
6709
6710
6711def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"):
6712  """Validate an `axis` parameter, and normalize it to be positive.
6713
6714  If `ndims` is known (i.e., not `None`), then check that `axis` is in the
6715  range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
6716  `axis + ndims` (otherwise).
6717  If `ndims` is not known, and `axis` is positive, then return it as-is.
6718  If `ndims` is not known, and `axis` is negative, then report an error.
6719
6720  Args:
6721    axis: An integer constant
6722    ndims: An integer constant, or `None`
6723    axis_name: The name of `axis` (for error messages).
6724    ndims_name: The name of `ndims` (for error messages).
6725
6726  Returns:
6727    The normalized `axis` value.
6728
6729  Raises:
6730    ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
6731      `ndims is None`.
6732  """
6733  if not isinstance(axis, int):
6734    raise TypeError(f"{axis_name} must be an int; got {type(axis).__name__}")
6735  if ndims is not None:
6736    if 0 <= axis < ndims:
6737      return axis
6738    elif -ndims <= axis < 0:
6739      return axis + ndims
6740    else:
6741      raise ValueError(f"{axis_name}={axis} out of bounds: "
6742                       f"expected {-ndims}<={axis_name}<{ndims}")
6743  elif axis < 0:
6744    raise ValueError(f"{axis_name}={axis} may only be negative "
6745                     f"if {ndims_name} is statically known.")
6746  return axis
6747
6748
6749# This op is intended to exactly match the semantics of numpy.repeat, with
6750# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
6751# when axis is not specified.  Rather than implement that special behavior, we
6752# simply make `axis` be a required argument.
6753#
6754# External (OSS) `tf.repeat` feature request:
6755# https://github.com/tensorflow/tensorflow/issues/8246
6756def repeat_with_axis(data, repeats, axis, name=None):
6757  """Repeats elements of `data`.
6758
6759  Args:
6760    data: An `N`-dimensional tensor.
6761    repeats: A 1-D integer tensor specifying how many times each element in
6762      `axis` should be repeated.  `len(repeats)` must equal `data.shape[axis]`.
6763      Supports broadcasting from a scalar value.
6764    axis: `int`.  The axis along which to repeat values.  Must be less than
6765      `max(N, 1)`.
6766    name: A name for the operation.
6767
6768  Returns:
6769    A tensor with `max(N, 1)` dimensions.  Has the same shape as `data`,
6770    except that dimension `axis` has size `sum(repeats)`.
6771
6772  Example usage:
6773
6774  >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6775  <tf.Tensor: shape=(5,), dtype=string,
6776  numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6777  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6778  <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6779  array([[1, 2],
6780         [1, 2],
6781         [3, 4],
6782         [3, 4],
6783         [3, 4]], dtype=int32)>
6784  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6785  <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6786  array([[1, 1, 2, 2, 2],
6787         [3, 3, 4, 4, 4]], dtype=int32)>
6788
6789  """
6790  # Whether the execution uses the optimized non-XLA implementation below.
6791  # TODO(b/236387200): Separate the implementations at a lower level, so that
6792  # non-XLA path gets the performance benefits and the XLA path is not broken
6793  # after loading a saved model with the optimization.
6794  use_optimized_non_xla_implementation = False
6795
6796  if not isinstance(axis, int):
6797    raise TypeError("Argument `axis` must be an int. "
6798                    f"Received `axis` = {axis} of type {type(axis).__name__}")
6799
6800  with ops.name_scope(name, "Repeat", [data, repeats]):
6801    data = ops.convert_to_tensor(data, name="data")
6802    # Note: We want to pass dtype=None to convert_to_int_tensor so that the
6803    # existing type is maintained instead of force-casting to int32. However,
6804    # this is not compatible with the implementation used on the XLA path.
6805    if not use_optimized_non_xla_implementation:
6806      repeats = convert_to_int_tensor(repeats, name="repeats")
6807    else:
6808      repeats = convert_to_int_tensor(repeats, name="repeats", dtype=None)
6809
6810    repeats.shape.with_rank_at_most(1)
6811
6812    # If `data` is a scalar, then upgrade it to a vector.
6813    data = _with_nonzero_rank(data)
6814    data_shape = shape(data, out_type=repeats.dtype)
6815
6816    # If `axis` is negative, then convert it to a positive value.
6817    axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
6818
6819    # If we know that `repeats` is a scalar, then we can just tile & reshape.
6820    if repeats.shape.num_elements() == 1:
6821      repeats = reshape(repeats, [])
6822      expanded = expand_dims(data, axis + 1)
6823      tiled = tile_one_dimension(expanded, axis + 1, repeats)
6824      result_shape = concat([
6825          data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:]
6826      ],
6827                            axis=0)
6828      return reshape(tiled, result_shape)
6829
6830    # Check data Tensor shapes.
6831    if repeats.shape.ndims == 1:
6832      data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
6833
6834    repeats = broadcast_to(repeats, [data_shape[axis]])
6835
6836    # The implementation on the else branch has better performance. However, it
6837    # does not work on the XLA path since it relies on the range op with a
6838    # shape that is not a compile-time constant.
6839    if not use_optimized_non_xla_implementation:
6840      repeats_original = repeats
6841
6842      # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
6843      if repeats.shape.ndims != axis + 1:
6844        repeats_shape = shape(repeats)
6845        repeats_ndims = rank(repeats)
6846        broadcast_shape = concat(
6847            [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
6848        repeats = broadcast_to(repeats, broadcast_shape)
6849        repeats.set_shape([None] * (axis + 1))
6850
6851      # Create a "sequence mask" based on `repeats`, where slices across `axis`
6852      # contain one `True` value for each repetition.  E.g., if
6853      # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
6854      max_repeat = gen_math_ops._max(repeats, _all_dimensions(repeats))
6855      max_repeat = gen_math_ops.maximum(
6856          ops.convert_to_tensor(0, name="zero", dtype=max_repeat.dtype),
6857          max_repeat)
6858
6859      mask = sequence_mask(repeats, max_repeat)
6860
6861      # Add a new dimension around each value that needs to be repeated, and
6862      # then tile that new dimension to match the maximum number of repetitions.
6863      expanded = expand_dims(data, axis + 1)
6864      tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
6865
6866      # Use `boolean_mask` to discard the extra repeated values.  This also
6867      # flattens all dimensions up through `axis`.
6868      masked = boolean_mask(tiled, mask)
6869
6870      # Reshape the output tensor to add the outer dimensions back.
6871      if axis == 0:
6872        result = masked
6873      else:
6874        repeated_dim_size = gen_math_ops._sum(
6875            repeats_original,
6876            axis=gen_math_ops._range(0, rank(repeats_original), 1))
6877        result_shape = concat(
6878            [data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]],
6879            axis=0)
6880        result = reshape(masked, result_shape)
6881
6882      # Preserve shape information.
6883      if data.shape.ndims is not None:
6884        new_axis_size = 0 if repeats.shape[0] == 0 else None
6885        result.set_shape(data.shape[:axis].concatenate(
6886            [new_axis_size]).concatenate(data.shape[axis + 1:]))
6887
6888      return result
6889
6890    else:
6891      # Non-XLA path implementation
6892      # E.g., repeats = [3, 4, 0, 2, 1].
6893      # E.g., repeats_scan = [3, 7, 7, 9, 10].
6894      repeats_scan = math_ops.cumsum(repeats)
6895      # This concat just prepends 0 to handle the case when repeats is empty.
6896      # E.g., output_size = [0, 3, 7, 7, 9, 10][-1] = 10.
6897      output_size = concat([zeros(1, dtype=repeats_scan.dtype), repeats_scan],
6898                           axis=0)[-1]
6899      # E.g., output_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9].
6900      output_indices = math_ops.range(output_size, dtype=repeats.dtype)
6901      # E.g., gather_indices = [0, 0, 0, 1, 1, 1, 1, 3, 3, 4].
6902      gather_indices = searchsorted(
6903          repeats_scan, output_indices, side="right", out_type=repeats.dtype)
6904      return gather(data, gather_indices, axis=axis)
6905
6906
6907def tile_one_dimension(data, axis, multiple):
6908  """Tiles a single dimension of a tensor."""
6909  # Assumes axis is a nonnegative int.
6910  if data.shape.ndims is not None:
6911    multiples = [1] * data.shape.ndims
6912    multiples[axis] = multiple
6913  else:
6914    ones_value = ones(rank(data), dtypes.int32)
6915    multiples = concat([ones_value[:axis], [multiple], ones_value[axis + 1:]],
6916                       axis=0)
6917  return tile(data, multiples)
6918
6919
6920def _with_nonzero_rank(data):
6921  """If `data` is scalar, then add a dimension; otherwise return as-is."""
6922  if data.shape.ndims is not None:
6923    if data.shape.ndims == 0:
6924      return stack([data])
6925    else:
6926      return data
6927  else:
6928    data_shape = shape(data)
6929    data_ndims = rank(data)
6930    return reshape(data, concat([[1], data_shape], axis=0)[-data_ndims:])
6931
6932
6933@tf_export("repeat")
6934@dispatch.add_dispatch_support
6935def repeat(input, repeats, axis=None, name=None):  # pylint: disable=redefined-builtin
6936  """Repeat elements of `input`.
6937
6938  See also `tf.concat`, `tf.stack`, `tf.tile`.
6939
6940  Args:
6941    input: An `N`-dimensional Tensor.
6942    repeats: An 1-D `int` Tensor. The number of repetitions for each element.
6943      repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
6944      must equal `input.shape[axis]` if axis is not None.
6945    axis: An int. The axis along which to repeat values. By default (axis=None),
6946      use the flattened input array, and return a flat output array.
6947    name: A name for the operation.
6948
6949  Returns:
6950    A Tensor which has the same shape as `input`, except along the given axis.
6951      If axis is None then the output array is flattened to match the flattened
6952      input array.
6953
6954  Example usage:
6955
6956  >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6957  <tf.Tensor: shape=(5,), dtype=string,
6958  numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6959
6960  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6961  <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6962  array([[1, 2],
6963         [1, 2],
6964         [3, 4],
6965         [3, 4],
6966         [3, 4]], dtype=int32)>
6967
6968  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6969  <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6970  array([[1, 1, 2, 2, 2],
6971         [3, 3, 4, 4, 4]], dtype=int32)>
6972
6973  >>> repeat(3, repeats=4)
6974  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([3, 3, 3, 3], dtype=int32)>
6975
6976  >>> repeat([[1,2], [3,4]], repeats=2)
6977  <tf.Tensor: shape=(8,), dtype=int32,
6978  numpy=array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)>
6979
6980  """
6981  if axis is None:
6982    input = reshape(input, [-1])
6983    axis = 0
6984  return repeat_with_axis(input, repeats, axis, name)
6985
6986
6987@tf_export("guarantee_const")
6988@deprecation.deprecated(None, "Not for public use.")
6989def guarantee_const(input, name=None):    # pylint: disable=redefined-builtin
6990  """Promise to the TF runtime that the input tensor is a constant.
6991
6992  The runtime is then free to make optimizations based on this.
6993
6994  Returns the input tensor without modification.
6995
6996  Args:
6997    input: A `Tensor`.
6998    name: A name for this operation.
6999
7000  Returns:
7001    A `Tensor`. Has the same dtype as `input`.
7002  """
7003  return gen_array_ops.guarantee_const(input=input, name=name)
7004
7005
7006@tf_export("stop_gradient")
7007@dispatch.add_dispatch_support
7008def stop_gradient(input, name=None):  # pylint: disable=redefined-builtin
7009  """Stops gradient computation.
7010
7011  NOTE: This docstring is patched out below. See
7012  tensorflow/core/api_def/base_api/api_def_StopGradient.pbtxt for the full
7013  docstring. That file determines the public documentation page.
7014
7015  Args:
7016    input: A `Tensor`.
7017    name: A name for this operation.
7018
7019  Returns:
7020    A `Tensor`. Has the same dtype as `input`.
7021  """
7022  # Don't expand ResourceVariables, so stop_gradient(variable) will return a
7023  # Tensor.
7024  if (isinstance(input, composite_tensor.CompositeTensor) and
7025      not _pywrap_utils.IsResourceVariable(input)):
7026    return nest.map_structure(stop_gradient, input, expand_composites=True)
7027  # The StopGradient op has a gradient function registered which returns None
7028  # (meaning statically known to be zero). For correctness, that's all we
7029  # need. However, tf.GradientTape often makes decisions about what to keep in
7030  # memory based on which forward-pass tensors are currently being watched, and
7031  # returning None in a gradient is not sufficient to stop watching a tensor
7032  # since the backward function doesn't run in the forward pass. Pausing the
7033  # tape around this op instructs any tf.GradientTapes to ignore the
7034  # forward-pass output of StopGradient, which may be much more efficient.
7035  with tape.stop_recording():
7036    return gen_array_ops.stop_gradient(input, name=name)
7037
7038
7039stop_gradient.__doc__ = gen_array_ops.stop_gradient.__doc__
7040
7041
7042# Register elementwise ops that don't have Python wrappers.
7043dispatch.register_unary_elementwise_api(gen_array_ops.check_numerics)
7044