• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# Tests for this file live in python/kernel_tests/array_ops_test.py
16"""Support for manipulating tensors."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numbers
23import numpy as np
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import common_shapes
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_util
34# 'Constant' gets imported in the module 'array_ops'.
35from tensorflow.python.framework.constant_op import constant
36from tensorflow.python.ops import gen_array_ops
37from tensorflow.python.ops import gen_math_ops
38# go/tf-wildcard-import
39# pylint: disable=wildcard-import
40from tensorflow.python.ops.gen_array_ops import *
41from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse  # pylint: disable=unused-import
42from tensorflow.python.types import core
43from tensorflow.python.util import deprecation
44from tensorflow.python.util import dispatch
45from tensorflow.python.util import nest
46from tensorflow.python.util import tf_decorator
47from tensorflow.python.util.tf_export import tf_export
48# pylint: enable=wildcard-import
49
50# Used for slicing to specify a new 1 size dimension
51newaxis = None
52tf_export("newaxis").export_constant(__name__, "newaxis")
53
54# We override the 'slice' for the "slice" op, so we keep Python's
55# existing 'slice' for later use in this module.
56_BaseSlice = slice
57
58
59@tf_export("reshape", v1=["reshape", "manip.reshape"])
60@dispatch.add_dispatch_support
61def reshape(tensor, shape, name=None):  # pylint: disable=redefined-outer-name
62  r"""Reshapes a tensor.
63
64  Given `tensor`, this operation returns a new `tf.Tensor` that has the same
65  values as `tensor` in the same order, except with a new shape given by
66  `shape`.
67
68  >>> t1 = [[1, 2, 3],
69  ...       [4, 5, 6]]
70  >>> print(tf.shape(t1).numpy())
71  [2 3]
72  >>> t2 = tf.reshape(t1, [6])
73  >>> t2
74  <tf.Tensor: shape=(6,), dtype=int32,
75    numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
76  >>> tf.reshape(t2, [3, 2])
77  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
78    array([[1, 2],
79           [3, 4],
80           [5, 6]], dtype=int32)>
81
82  The `tf.reshape` does not change the order of or the total number of elements
83  in the tensor, and so it can reuse the underlying data buffer. This makes it
84  a fast operation independent of how big of a tensor it is operating on.
85
86  >>> tf.reshape([1, 2, 3], [2, 2])
87  Traceback (most recent call last):
88  ...
89  InvalidArgumentError: Input to reshape is a tensor with 3 values, but the
90  requested shape has 4
91
92  To instead reorder the data to rearrange the dimensions of a tensor, see
93  `tf.transpose`.
94
95  >>> t = [[1, 2, 3],
96  ...      [4, 5, 6]]
97  >>> tf.reshape(t, [3, 2]).numpy()
98  array([[1, 2],
99         [3, 4],
100         [5, 6]], dtype=int32)
101  >>> tf.transpose(t, perm=[1, 0]).numpy()
102  array([[1, 4],
103         [2, 5],
104         [3, 6]], dtype=int32)
105
106  If one component of `shape` is the special value -1, the size of that
107  dimension is computed so that the total size remains constant.  In particular,
108  a `shape` of `[-1]` flattens into 1-D.  At most one component of `shape` can
109  be -1.
110
111  >>> t = [[1, 2, 3],
112  ...      [4, 5, 6]]
113  >>> tf.reshape(t, [-1])
114  <tf.Tensor: shape=(6,), dtype=int32,
115    numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
116  >>> tf.reshape(t, [3, -1])
117  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
118    array([[1, 2],
119           [3, 4],
120           [5, 6]], dtype=int32)>
121  >>> tf.reshape(t, [-1, 2])
122  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
123    array([[1, 2],
124           [3, 4],
125           [5, 6]], dtype=int32)>
126
127  `tf.reshape(t, [])` reshapes a tensor `t` with one element to a scalar.
128
129  >>> tf.reshape([7], []).numpy()
130  7
131
132  More examples:
133
134  >>> t = [1, 2, 3, 4, 5, 6, 7, 8, 9]
135  >>> print(tf.shape(t).numpy())
136  [9]
137  >>> tf.reshape(t, [3, 3])
138  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
139    array([[1, 2, 3],
140           [4, 5, 6],
141           [7, 8, 9]], dtype=int32)>
142
143  >>> t = [[[1, 1], [2, 2]],
144  ...      [[3, 3], [4, 4]]]
145  >>> print(tf.shape(t).numpy())
146  [2 2 2]
147  >>> tf.reshape(t, [2, 4])
148  <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
149    array([[1, 1, 2, 2],
150           [3, 3, 4, 4]], dtype=int32)>
151
152  >>> t = [[[1, 1, 1],
153  ...       [2, 2, 2]],
154  ...      [[3, 3, 3],
155  ...       [4, 4, 4]],
156  ...      [[5, 5, 5],
157  ...       [6, 6, 6]]]
158  >>> print(tf.shape(t).numpy())
159  [3 2 3]
160  >>> # Pass '[-1]' to flatten 't'.
161  >>> tf.reshape(t, [-1])
162  <tf.Tensor: shape=(18,), dtype=int32,
163    numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
164    dtype=int32)>
165  >>> # -- Using -1 to infer the shape --
166  >>> # Here -1 is inferred to be 9:
167  >>> tf.reshape(t, [2, -1])
168  <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
169    array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
170           [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
171  >>> # -1 is inferred to be 2:
172  >>> tf.reshape(t, [-1, 9])
173  <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
174    array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
175           [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
176  >>> # -1 is inferred to be 3:
177  >>> tf.reshape(t, [ 2, -1, 3])
178  <tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
179    array([[[1, 1, 1],
180            [2, 2, 2],
181            [3, 3, 3]],
182           [[4, 4, 4],
183            [5, 5, 5],
184            [6, 6, 6]]], dtype=int32)>
185
186  Args:
187    tensor: A `Tensor`.
188    shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
189      Defines the shape of the output tensor.
190    name: Optional string. A name for the operation.
191
192  Returns:
193    A `Tensor`. Has the same type as `tensor`.
194  """
195  result = gen_array_ops.reshape(tensor, shape, name)
196  tensor_util.maybe_set_static_shape(result, shape)
197  return result
198
199
200@tf_export("fill")
201@dispatch.add_dispatch_support
202def fill(dims, value, name=None):
203  r"""Creates a tensor filled with a scalar value.
204
205  See also `tf.ones`, `tf.zeros`, `tf.one_hot`, `tf.eye`.
206
207  This operation creates a tensor of shape `dims` and fills it with `value`.
208
209  For example:
210
211  >>> tf.fill([2, 3], 9)
212  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
213  array([[9, 9, 9],
214         [9, 9, 9]], dtype=int32)>
215
216  `tf.fill` evaluates at graph runtime and supports dynamic shapes based on
217  other runtime `tf.Tensors`, unlike `tf.constant(value, shape=dims)`, which
218  embeds the value as a `Const` node.
219
220  Args:
221    dims: A 1-D sequence of non-negative numbers. Represents the shape of the
222      output `tf.Tensor`. Entries should be of type: `int32`, `int64`.
223    value: A value to fill the returned `tf.Tensor`.
224    name: Optional string. The name of the output `tf.Tensor`.
225
226  Returns:
227    A `tf.Tensor` with shape `dims` and the same dtype as `value`.
228
229  Raises:
230    InvalidArgumentError: `dims` contains negative entries.
231    NotFoundError: `dims` contains non-integer entries.
232
233  @compatibility(numpy)
234  Similar to `np.full`. In `numpy`, more parameters are supported. Passing a
235  number argument as the shape (`np.full(5, value)`) is valid in `numpy` for
236  specifying a 1-D shaped result, while TensorFlow does not support this syntax.
237  @end_compatibility
238  """
239  result = gen_array_ops.fill(dims, value, name=name)
240  tensor_util.maybe_set_static_shape(result, dims)
241  return result
242
243
244@tf_export("identity")
245@dispatch.add_dispatch_support
246def identity(input, name=None):  # pylint: disable=redefined-builtin
247  r"""Return a Tensor with the same shape and contents as input.
248
249  The return value is not the same Tensor as the original, but contains the same
250  values.  This operation is fast when used on the same device.
251
252  For example:
253
254  >>> a = tf.constant([0.78])
255  >>> a_identity = tf.identity(a)
256  >>> a.numpy()
257  array([0.78], dtype=float32)
258  >>> a_identity.numpy()
259  array([0.78], dtype=float32)
260
261  Calling `tf.identity` on a variable will make a Tensor that represents the
262  value of that variable at the time it is called. This is equivalent to calling
263  `<variable>.read_value()`.
264
265  >>> a = tf.Variable(5)
266  >>> a_identity = tf.identity(a)
267  >>> a.assign_add(1)
268  <tf.Variable ... shape=() dtype=int32, numpy=6>
269  >>> a.numpy()
270  6
271  >>> a_identity.numpy()
272  5
273
274  Args:
275    input: A `Tensor`, a `Variable`, a `CompositeTensor` or anything that can be
276    converted to a tensor using `tf.convert_to_tensor`.
277    name: A name for the operation (optional).
278
279  Returns:
280    A `Tensor` or CompositeTensor. Has the same type and contents as `input`.
281  """
282  if isinstance(input, composite_tensor.CompositeTensor):
283    return nest.map_structure(identity, input, expand_composites=True)
284  if context.executing_eagerly() and not hasattr(input, "graph"):
285    # Make sure we get an input with handle data attached from resource
286    # variables. Variables have correct handle data when graph building.
287    input = ops.convert_to_tensor(input)
288  ret = gen_array_ops.identity(input, name=name)
289  # Propagate handle data for happier shape inference for resource variables.
290  if hasattr(input, "_handle_data"):
291    ret._handle_data = input._handle_data  # pylint: disable=protected-access
292  return ret
293
294
295# pylint: disable=redefined-builtin,protected-access
296@tf_export(v1=["expand_dims"])
297@dispatch.add_dispatch_support
298@deprecation.deprecated_args(None, "Use the `axis` argument instead", "dim")
299def expand_dims(input, axis=None, name=None, dim=None):
300  """Returns a tensor with a length 1 axis inserted at index `axis`.
301
302  Given a tensor `input`, this operation inserts a dimension of length 1 at the
303  dimension index `axis` of `input`'s shape. The dimension index follows Python
304  indexing rules: It's zero-based, a negative index it is counted backward
305  from the end.
306
307  This operation is useful to:
308
309  * Add an outer "batch" dimension to a single element.
310  * Align axes for broadcasting.
311  * To add an inner vector length axis to a tensor of scalars.
312
313  For example:
314
315  If you have a single image of shape `[height, width, channels]`:
316
317  >>> image = tf.zeros([10,10,3])
318
319  You can add an outer `batch` axis by passing `axis=0`:
320
321  >>> tf.expand_dims(image, axis=0).shape.as_list()
322  [1, 10, 10, 3]
323
324  The new axis location matches Python `list.insert(axis, 1)`:
325
326  >>> tf.expand_dims(image, axis=1).shape.as_list()
327  [10, 1, 10, 3]
328
329  Following standard Python indexing rules, a negative `axis` counts from the
330  end so `axis=-1` adds an inner most dimension:
331
332  >>> tf.expand_dims(image, -1).shape.as_list()
333  [10, 10, 3, 1]
334
335  This operation requires that `axis` is a valid index for `input.shape`,
336  following Python indexing rules:
337
338  ```
339  -1-tf.rank(input) <= axis <= tf.rank(input)
340  ```
341
342  This operation is related to:
343
344  * `tf.squeeze`, which removes dimensions of size 1.
345  * `tf.reshape`, which provides more flexible reshaping capability.
346  * `tf.sparse.expand_dims`, which provides this functionality for
347    `tf.SparseTensor`
348
349  Args:
350    input: A `Tensor`.
351    axis: 0-D (scalar). Specifies the dimension index at which to expand the
352      shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`.
353    name: The name of the output `Tensor` (optional).
354    dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
355
356  Returns:
357    A `Tensor` with the same data as `input`, but its shape has an additional
358    dimension of size 1 added.
359
360  Raises:
361    ValueError: if either both or neither of `dim` and `axis` are specified.
362  """
363  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
364  if axis is None:
365    raise ValueError("Must specify an axis argument to tf.expand_dims()")
366  return expand_dims_v2(input, axis, name)
367
368
369@tf_export("expand_dims", v1=[])
370@dispatch.add_dispatch_support
371def expand_dims_v2(input, axis, name=None):
372  """Returns a tensor with a length 1 axis inserted at index `axis`.
373
374  Given a tensor `input`, this operation inserts a dimension of length 1 at the
375  dimension index `axis` of `input`'s shape. The dimension index follows Python
376  indexing rules: It's zero-based, a negative index it is counted backward
377  from the end.
378
379  This operation is useful to:
380
381  * Add an outer "batch" dimension to a single element.
382  * Align axes for broadcasting.
383  * To add an inner vector length axis to a tensor of scalars.
384
385  For example:
386
387  If you have a single image of shape `[height, width, channels]`:
388
389  >>> image = tf.zeros([10,10,3])
390
391  You can add an outer `batch` axis by passing `axis=0`:
392
393  >>> tf.expand_dims(image, axis=0).shape.as_list()
394  [1, 10, 10, 3]
395
396  The new axis location matches Python `list.insert(axis, 1)`:
397
398  >>> tf.expand_dims(image, axis=1).shape.as_list()
399  [10, 1, 10, 3]
400
401  Following standard Python indexing rules, a negative `axis` counts from the
402  end so `axis=-1` adds an inner most dimension:
403
404  >>> tf.expand_dims(image, -1).shape.as_list()
405  [10, 10, 3, 1]
406
407  This operation requires that `axis` is a valid index for `input.shape`,
408  following Python indexing rules:
409
410  ```
411  -1-tf.rank(input) <= axis <= tf.rank(input)
412  ```
413
414  This operation is related to:
415
416  * `tf.squeeze`, which removes dimensions of size 1.
417  * `tf.reshape`, which provides more flexible reshaping capability.
418  * `tf.sparse.expand_dims`, which provides this functionality for
419    `tf.SparseTensor`
420
421  Args:
422    input: A `Tensor`.
423    axis: Integer specifying the dimension index at which to expand the
424      shape of `input`. Given an input of D dimensions, `axis` must be in range
425      `[-(D+1), D]` (inclusive).
426    name: Optional string. The name of the output `Tensor`.
427
428  Returns:
429    A tensor with the same data as `input`, with an additional dimension
430    inserted at the index specified by `axis`.
431
432  Raises:
433    ValueError: If `axis` is not specified.
434    InvalidArgumentError: If `axis` is out of range `[-(D+1), D]`.
435  """
436  return gen_array_ops.expand_dims(input, axis, name)
437
438
439# pylint: enable=redefined-builtin,protected-access
440
441
442# Aliases for some automatically-generated names.
443# pylint: disable=protected-access
444@deprecation.deprecated("2016-11-30",
445                        "This op will be removed after the deprecation date. "
446                        "Please switch to tf.setdiff1d().")
447def listdiff(x, y, out_idx=None, name=None):
448  return gen_array_ops.list_diff(x, y, out_idx, name)
449
450
451listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__
452
453# pylint: enable=protected-access
454
455
456# pylint: disable=undefined-variable
457@deprecation.deprecated("2018-11-30",
458                        "This op will be removed after the deprecation date. "
459                        "Please switch to tf.sets.difference().")
460@tf_export(v1=["setdiff1d"])
461@dispatch.add_dispatch_support
462def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
463  """Computes the difference between two lists of numbers or strings.
464
465  Given a list x and a list y, this operation returns a list out that
466  represents all values that are in x but not in y. The returned list
467  out is sorted in the same order that the numbers appear in x
468  (duplicates are preserved). This operation also returns a list idx
469  that represents the position of each out element in x.
470
471  In other words:
472
473  ```python
474  out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]
475  ```
476
477  Example usage:
478
479  >>> x = [1, 2, 3, 4, 5, 6]
480  >>> y = [1, 3, 5]
481  >>> setdiff1d(x,y)
482  ListDiff(out=<tf.Tensor: id=2, shape=(3,), dtype=int32,
483  numpy=array([2, 4, 6], dtype=int32)>, idx=<tf.Tensor: id=3,
484  shape=(3,), dtype=int32, numpy=array([1, 3, 5], dtype=int32)>)
485
486  Args:
487    x: A Tensor. 1-D. Values to keep.
488    y: A Tensor. Must have the same type as x. 1-D. Values to remove.
489    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
490      tf.int32.
491    name: A name for the operation (optional).
492
493  Returns:
494    A tuple of Tensor objects (out, idx).
495    out: A Tensor. Has the same type as x.
496    idx: A Tensor of type out_idx.
497  """
498  return gen_array_ops.list_diff(x, y, index_dtype, name)
499
500
501setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__
502
503
504@tf_export("broadcast_dynamic_shape")
505@dispatch.add_dispatch_support
506def broadcast_dynamic_shape(shape_x, shape_y):
507  """Computes the shape of a broadcast given symbolic shapes.
508
509  When `shape_x` and `shape_y` are Tensors representing shapes (i.e. the result
510  of calling tf.shape on another Tensor) this computes a Tensor which is the
511  shape of the result of a broadcasting op applied in tensors of shapes
512  `shape_x` and `shape_y`.
513
514  This is useful when validating the result of a broadcasting operation when the
515  tensors do not have statically known shapes.
516
517  Example:
518
519  >>> shape_x = (1, 2, 3)
520  >>> shape_y = (5, 1, 3)
521  >>> tf.broadcast_dynamic_shape(shape_x, shape_y)
522  <tf.Tensor: shape=(3,), dtype=int32, numpy=array([5, 2, 3], ...>
523
524  Args:
525    shape_x: A rank 1 integer `Tensor`, representing the shape of x.
526    shape_y: A rank 1 integer `Tensor`, representing the shape of y.
527
528  Returns:
529    A rank 1 integer `Tensor` representing the broadcasted shape.
530
531  Raises:
532    InvalidArgumentError: If the two shapes are incompatible for
533    broadcasting.
534  """
535  return gen_array_ops.broadcast_args(shape_x, shape_y)
536
537
538@tf_export("broadcast_static_shape")
539@dispatch.add_dispatch_support
540def broadcast_static_shape(shape_x, shape_y):
541  """Computes the shape of a broadcast given known shapes.
542
543  When `shape_x` and `shape_y` are fully known `TensorShape`s this computes a
544  `TensorShape` which is the shape of the result of a broadcasting op applied in
545  tensors of shapes `shape_x` and `shape_y`.
546
547  For example, if shape_x is `TensorShape([1, 2, 3])` and shape_y is
548  `TensorShape([5, 1, 3])`, the result is a TensorShape whose value is
549  `TensorShape([5, 2, 3])`.
550
551  This is useful when validating the result of a broadcasting operation when the
552  tensors have statically known shapes.
553
554  Example:
555
556  >>> shape_x = tf.TensorShape([1, 2, 3])
557  >>> shape_y = tf.TensorShape([5, 1 ,3])
558  >>> tf.broadcast_static_shape(shape_x, shape_y)
559  TensorShape([5, 2, 3])
560
561  Args:
562    shape_x: A `TensorShape`
563    shape_y: A `TensorShape`
564
565  Returns:
566    A `TensorShape` representing the broadcasted shape.
567
568  Raises:
569    ValueError: If the two shapes can not be broadcasted.
570  """
571  return common_shapes.broadcast_shape(shape_x, shape_y)
572
573
574@tf_export("shape", v1=[])
575@dispatch.add_dispatch_support
576def shape_v2(input, out_type=dtypes.int32, name=None):
577  # pylint: disable=redefined-builtin
578  """Returns a tensor containing the shape of the input tensor.
579
580  See also `tf.size`, `tf.rank`.
581
582  `tf.shape` returns a 1-D integer tensor representing the shape of `input`.
583  For a scalar input, the tensor returned has a shape of (0,) and its value is
584  the empty vector (i.e. []).
585
586  For example:
587
588  >>> tf.shape(1.)
589  <tf.Tensor: shape=(0,), dtype=int32, numpy=array([], dtype=int32)>
590
591  >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
592  >>> tf.shape(t)
593  <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 2, 3], dtype=int32)>
594
595  Note: When using symbolic tensors, such as when using the Keras API,
596  tf.shape() will return the shape of the symbolic tensor.
597
598  >>> a = tf.keras.layers.Input((None, 10))
599  >>> tf.shape(a)
600  <... shape=(3,) dtype=int32...>
601
602  In these cases, using `tf.Tensor.shape` will return more informative results.
603
604  >>> a.shape
605  TensorShape([None, None, 10])
606
607  (The first `None` represents the as yet unknown batch size.)
608
609  `tf.shape` and `Tensor.shape` should be identical in eager mode.  Within
610  `tf.function` or within a `compat.v1` context, not all dimensions may be
611  known until execution time. Hence when defining custom layers and models
612  for graph mode, prefer the dynamic `tf.shape(x)` over the static `x.shape`.
613
614  Args:
615    input: A `Tensor` or `SparseTensor`.
616    out_type: (Optional) The specified output type of the operation (`int32` or
617      `int64`). Defaults to `tf.int32`.
618    name: A name for the operation (optional).
619
620  Returns:
621    A `Tensor` of type `out_type`.
622  """
623  return shape(input, name, out_type)
624
625
626@tf_export(v1=["shape"])
627@dispatch.add_dispatch_support
628def shape(input, name=None, out_type=dtypes.int32):
629  # pylint: disable=redefined-builtin
630  """Returns the shape of a tensor.
631
632  This operation returns a 1-D integer tensor representing the shape of `input`.
633
634  For example:
635
636  ```python
637  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
638  tf.shape(t)  # [2, 2, 3]
639  ```
640
641  Args:
642    input: A `Tensor` or `SparseTensor`.
643    name: A name for the operation (optional).
644    out_type: (Optional) The specified output type of the operation (`int32`
645    or `int64`). Defaults to `tf.int32`.
646
647  Returns:
648    A `Tensor` of type `out_type`.
649  """
650  return shape_internal(input, name, optimize=True, out_type=out_type)
651
652
653def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
654  # pylint: disable=redefined-builtin
655  """Returns the shape of a tensor.
656
657  Args:
658    input: A `Tensor` or `SparseTensor`.
659    name: A name for the operation (optional).
660    optimize: if true, encode the shape as a constant when possible.
661    out_type: (Optional) The specified output type of the operation (`int32` or
662      `int64`). Defaults to tf.int32.
663
664  Returns:
665    A `Tensor` of type `out_type`.
666
667  """
668  with ops.name_scope(name, "Shape", [input]) as name:
669    if isinstance(
670        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
671      return gen_math_ops.cast(input.dense_shape, out_type)
672    else:
673      if not context.executing_eagerly():
674        input = ops.convert_to_tensor(input)
675        input_shape = input.get_shape()
676        if optimize and input_shape.is_fully_defined():
677          return constant(input_shape.as_list(), out_type, name=name)
678      return gen_array_ops.shape(input, name=name, out_type=out_type)
679
680
681@tf_export("shape_n")
682@dispatch.add_dispatch_support
683def shape_n(input, out_type=dtypes.int32, name=None):
684  # pylint: disable=redefined-builtin
685  """Returns shape of tensors.
686
687  Args:
688    input: A list of at least 1 `Tensor` object with the same type.
689    out_type: The specified output type of the operation (`int32` or `int64`).
690      Defaults to `tf.int32`(optional).
691    name: A name for the operation (optional).
692
693  Returns:
694    A list with the same length as `input` of `Tensor` objects with
695      type `out_type`.
696  """
697
698  return gen_array_ops.shape_n(input, out_type=out_type, name=name)
699
700
701@tf_export("size", v1=[])
702@dispatch.add_dispatch_support
703def size_v2(input, out_type=dtypes.int32, name=None):
704  # pylint: disable=redefined-builtin
705  """Returns the size of a tensor.
706
707  See also `tf.shape`.
708
709  Returns a 0-D `Tensor` representing the number of elements in `input`
710  of type `out_type`. Defaults to tf.int32.
711
712  For example:
713
714  >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
715  >>> tf.size(t)
716  <tf.Tensor: shape=(), dtype=int32, numpy=12>
717
718  Args:
719    input: A `Tensor` or `SparseTensor`.
720    name: A name for the operation (optional).
721    out_type: (Optional) The specified non-quantized numeric output type of the
722      operation. Defaults to `tf.int32`.
723
724  Returns:
725    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
726
727  @compatibility(numpy)
728  Equivalent to np.size()
729  @end_compatibility
730  """
731
732  return size(input, name, out_type)
733
734
735@tf_export(v1=["size"])
736@dispatch.add_dispatch_support
737def size(input, name=None, out_type=dtypes.int32):
738  # pylint: disable=redefined-builtin
739  """Returns the size of a tensor.
740
741  Returns a 0-D `Tensor` representing the number of elements in `input`
742  of type `out_type`. Defaults to tf.int32.
743
744  For example:
745
746  ```python
747  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
748  tf.size(t)  # 12
749  ```
750
751  Args:
752    input: A `Tensor` or `SparseTensor`.
753    name: A name for the operation (optional).
754    out_type: (Optional) The specified non-quantized numeric output type of the
755      operation. Defaults to `tf.int32`.
756
757  Returns:
758    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
759
760  @compatibility(numpy)
761  Equivalent to np.size()
762  @end_compatibility
763  """
764  return size_internal(input, name, optimize=True, out_type=out_type)
765
766
767def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
768  # pylint: disable=redefined-builtin,protected-access
769  """Returns the size of a tensor.
770
771  Args:
772    input: A `Tensor` or `SparseTensor`.
773    name: A name for the operation (optional).
774    optimize: if true, encode the size as a constant when possible.
775    out_type: (Optional) The specified non-quantized numeric output type of the
776      operation. Defaults to `tf.int32`.
777
778  Returns:
779    A `Tensor` of type `out_type`. Defaults to `tf.int32`.
780  """
781  if (context.executing_eagerly() and not hasattr(input, "graph") and
782      not isinstance(
783          input,
784          (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))):
785    input = ops.convert_to_tensor(input)
786    np_out_type = out_type.as_numpy_dtype
787    num_elements = np.prod(input._shape_tuple(), dtype=np_out_type)  # pylint: disable=protected-access
788    return ops.convert_to_tensor(num_elements, dtype=out_type)
789  with ops.name_scope(name, "Size", [input]) as name:
790    if isinstance(
791        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
792      return gen_math_ops.prod(
793          gen_math_ops.cast(input.dense_shape, out_type), 0, name=name)
794    else:
795      input = ops.convert_to_tensor(input)
796      input_shape = input.get_shape()
797      if optimize:
798        if input_shape.is_fully_defined():
799          return constant(input_shape.num_elements(), out_type, name=name)
800        if input_shape.dims and any(dim == 0 for dim in input_shape.dims):
801          return constant(0, out_type, name=name)
802      return gen_array_ops.size(input, name=name, out_type=out_type)
803
804
805@tf_export("rank")
806@dispatch.add_dispatch_support
807def rank(input, name=None):
808  # pylint: disable=redefined-builtin
809  """Returns the rank of a tensor.
810
811  See also `tf.shape`.
812
813  Returns a 0-D `int32` `Tensor` representing the rank of `input`.
814
815  For example:
816
817  ```python
818  # shape of tensor 't' is [2, 2, 3]
819  t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
820  tf.rank(t)  # 3
821  ```
822
823  **Note**: The rank of a tensor is not the same as the rank of a matrix. The
824  rank of a tensor is the number of indices required to uniquely select each
825  element of the tensor. Rank is also known as "order", "degree", or "ndims."
826
827  Args:
828    input: A `Tensor` or `SparseTensor`.
829    name: A name for the operation (optional).
830
831  Returns:
832    A `Tensor` of type `int32`.
833
834  @compatibility(numpy)
835  Equivalent to np.ndim
836  @end_compatibility
837  """
838  return rank_internal(input, name, optimize=True)
839
840
841def rank_internal(input, name=None, optimize=True):
842  # pylint: disable=redefined-builtin
843  """Returns the rank of a tensor.
844
845  Args:
846    input: A `Tensor` or `SparseTensor`.
847    name: A name for the operation (optional).
848    optimize: if true, encode the rank as a constant when possible.
849
850  Returns:
851    A `Tensor` of type `int32`.
852  """
853  with ops.name_scope(name, "Rank", [input]) as name:
854    if isinstance(
855        input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
856      return gen_array_ops.size(input.dense_shape, name=name)
857    else:
858      input = ops.convert_to_tensor(input)
859      input_shape = input.get_shape()
860      if optimize and input_shape.ndims is not None:
861        return constant(input_shape.ndims, dtypes.int32, name=name)
862      return gen_array_ops.rank(input, name=name)
863
864
865_SLICE_TYPE_ERROR = (
866    "Only integers, slices (`:`), ellipsis (`...`), "
867    "tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid "
868    "indices")
869
870_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64,
871                           dtypes.int64_ref)
872
873
874def _check_index(idx):
875  """Check if a given value is a valid index into a tensor."""
876  if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
877    return
878
879  # Optimistic check. Assumptions:
880  # * any object with a dtype is supported
881  # * any object with a dtype has a sizeable shape attribute.
882  dtype = getattr(idx, "dtype", None)
883  if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or
884      idx.shape and len(idx.shape) == 1):
885    # TODO(slebedev): IndexError seems more appropriate here, but it
886    # will break `_slice_helper` contract.
887    raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))
888
889
890def _is_undefined_dimension(d):
891  return isinstance(d, tensor_shape.Dimension) and d.value is None
892
893
894@tf_export("__operators__.getitem", v1=[])
895@dispatch.add_dispatch_support
896def _slice_helper(tensor, slice_spec, var=None):
897  """Overload for Tensor.__getitem__.
898
899  This operation extracts the specified region from the tensor.
900  The notation is similar to NumPy with the restriction that
901  currently only support basic indexing. That means that
902  using a non-scalar tensor as input is not currently allowed.
903
904  Some useful examples:
905
906  ```python
907  # Strip leading and trailing 2 elements
908  foo = tf.constant([1,2,3,4,5,6])
909  print(foo[2:-2].eval())  # => [3,4]
910
911  # Skip every other row and reverse the order of the columns
912  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
913  print(foo[::2,::-1].eval())  # => [[3,2,1], [9,8,7]]
914
915  # Use scalar tensors as indices on both dimensions
916  print(foo[tf.constant(0), tf.constant(2)].eval())  # => 3
917
918  # Insert another dimension
919  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
920  print(foo[tf.newaxis, :, :].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
921  print(foo[:, tf.newaxis, :].eval()) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
922  print(foo[:, :, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]],
923  [[7],[8],[9]]]
924
925  # Ellipses (3 equivalent operations)
926  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
927  print(foo[tf.newaxis, :, :].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
928  print(foo[tf.newaxis, ...].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
929  print(foo[tf.newaxis].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
930
931  # Masks
932  foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
933  print(foo[foo > 2].eval())  # => [3, 4, 5, 6, 7, 8, 9]
934  ```
935
936  Notes:
937    - `tf.newaxis` is `None` as in NumPy.
938    - An implicit ellipsis is placed at the end of the `slice_spec`
939    - NumPy advanced indexing is currently not supported.
940
941  Purpose in the API:
942
943    This method is exposed in TensorFlow's API so that library developers
944    can register dispatching for `Tensor.__getitem__` to allow it to handle
945    custom composite tensors & other custom objects.
946
947    The API symbol is not intended to be called by users directly and does
948    appear in TensorFlow's generated documentation.
949
950  Args:
951    tensor: An ops.Tensor object.
952    slice_spec: The arguments to Tensor.__getitem__.
953    var: In the case of variable slice assignment, the Variable object to slice
954      (i.e. tensor is the read-only view of this variable).
955
956  Returns:
957    The appropriate slice of "tensor", based on "slice_spec".
958
959  Raises:
960    ValueError: If a slice range is negative size.
961    TypeError: If the slice indices aren't int, slice, ellipsis,
962      tf.newaxis or scalar int32/int64 tensors.
963  """
964  tensor = ops.convert_to_tensor(tensor)
965  # TODO(wangpeng): Consider supporting var
966  if var is None and ops._numpy_style_slicing:  # pylint: disable=protected-access
967    return tensor._numpy_style_getitem(slice_spec)  # pylint: disable=protected-access
968
969  if isinstance(slice_spec, bool) or \
970  (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
971  (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
972    return boolean_mask(tensor=tensor, mask=slice_spec)
973
974  if not isinstance(slice_spec, (list, tuple)):
975    slice_spec = [slice_spec]
976
977  begin, end, strides = [], [], []
978  index = 0
979
980  new_axis_mask, shrink_axis_mask = 0, 0
981  begin_mask, end_mask = 0, 0
982  ellipsis_mask = 0
983  for s in slice_spec:
984    if isinstance(s, _BaseSlice):
985      if s.start is not None and not _is_undefined_dimension(s.start):
986        _check_index(s.start)
987        begin.append(s.start)
988      else:
989        begin.append(0)
990        begin_mask |= (1 << index)
991      if s.stop is not None and not _is_undefined_dimension(s.stop):
992        _check_index(s.stop)
993        end.append(s.stop)
994      else:
995        end.append(0)
996        end_mask |= (1 << index)
997      if s.step is not None and not _is_undefined_dimension(s.step):
998        _check_index(s.step)
999        strides.append(s.step)
1000      else:
1001        strides.append(1)
1002    elif s is Ellipsis:
1003      begin.append(0)
1004      end.append(0)
1005      strides.append(1)
1006      ellipsis_mask |= (1 << index)
1007    elif s is newaxis:
1008      begin.append(0)
1009      end.append(0)
1010      strides.append(1)
1011      new_axis_mask |= (1 << index)
1012    else:
1013      _check_index(s)
1014      begin.append(s)
1015      end.append(s + 1)
1016      strides.append(1)
1017      shrink_axis_mask |= (1 << index)
1018    index += 1
1019
1020  # stack possibly involves no tensors, so we must use op_scope correct graph.
1021  with ops.name_scope(
1022      None,
1023      "strided_slice", [tensor] + begin + end + strides,
1024      skip_on_eager=False) as name:
1025    if begin:
1026      packed_begin, packed_end, packed_strides = (stack(begin), stack(end),
1027                                                  stack(strides))
1028      if (packed_begin.dtype == dtypes.int64 or
1029          packed_end.dtype == dtypes.int64 or
1030          packed_strides.dtype == dtypes.int64):
1031        if packed_begin.dtype != dtypes.int64:
1032          packed_begin = gen_math_ops.cast(packed_begin, dtypes.int64)
1033        if packed_end.dtype != dtypes.int64:
1034          packed_end = gen_math_ops.cast(packed_end, dtypes.int64)
1035        if packed_strides.dtype != dtypes.int64:
1036          packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64)
1037    else:
1038      var_empty = constant([], dtype=dtypes.int32)
1039      packed_begin = packed_end = packed_strides = var_empty
1040    return strided_slice(
1041        tensor,
1042        packed_begin,
1043        packed_end,
1044        packed_strides,
1045        begin_mask=begin_mask,
1046        end_mask=end_mask,
1047        shrink_axis_mask=shrink_axis_mask,
1048        new_axis_mask=new_axis_mask,
1049        ellipsis_mask=ellipsis_mask,
1050        var=var,
1051        name=name)
1052
1053
1054# pylint: disable=undefined-variable,protected-access,redefined-outer-name
1055@tf_export("slice")
1056@dispatch.add_dispatch_support
1057def slice(input_, begin, size, name=None):
1058  # pylint: disable=redefined-builtin
1059  """Extracts a slice from a tensor.
1060
1061  See also `tf.strided_slice`.
1062
1063  This operation extracts a slice of size `size` from a tensor `input_` starting
1064  at the location specified by `begin`. The slice `size` is represented as a
1065  tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
1066  of `input_` that you want to slice. The starting location (`begin`) for the
1067  slice is represented as an offset in each dimension of `input_`. In other
1068  words, `begin[i]` is the offset into the i'th dimension of `input_` that you
1069  want to slice from.
1070
1071  Note that `tf.Tensor.__getitem__` is typically a more pythonic way to
1072  perform slices, as it allows you to write `foo[3:7, :-2]` instead of
1073  `tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`.
1074
1075  `begin` is zero-based; `size` is one-based. If `size[i]` is -1,
1076  all remaining elements in dimension i are included in the
1077  slice. In other words, this is equivalent to setting:
1078
1079  `size[i] = input_.dim_size(i) - begin[i]`
1080
1081  This operation requires that:
1082
1083  `0 <= begin[i] <= begin[i] + size[i] <= Di  for i in [0, n]`
1084
1085  For example:
1086
1087  ```python
1088  t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1089                   [[3, 3, 3], [4, 4, 4]],
1090                   [[5, 5, 5], [6, 6, 6]]])
1091  tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]]
1092  tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
1093                                     #   [4, 4, 4]]]
1094  tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
1095                                     #  [[5, 5, 5]]]
1096  ```
1097
1098  Args:
1099    input_: A `Tensor`.
1100    begin: An `int32` or `int64` `Tensor`.
1101    size: An `int32` or `int64` `Tensor`.
1102    name: A name for the operation (optional).
1103
1104  Returns:
1105    A `Tensor` the same type as `input_`.
1106  """
1107  return gen_array_ops._slice(input_, begin, size, name=name)
1108
1109
1110# pylint: disable=invalid-name
1111@tf_export("strided_slice")
1112@dispatch.add_dispatch_support
1113def strided_slice(input_,
1114                  begin,
1115                  end,
1116                  strides=None,
1117                  begin_mask=0,
1118                  end_mask=0,
1119                  ellipsis_mask=0,
1120                  new_axis_mask=0,
1121                  shrink_axis_mask=0,
1122                  var=None,
1123                  name=None):
1124  """Extracts a strided slice of a tensor (generalized Python array indexing).
1125
1126  See also `tf.slice`.
1127
1128  **Instead of calling this op directly most users will want to use the
1129  NumPy-style slicing syntax (e.g. `tensor[..., 3:4:-1, tf.newaxis, 3]`), which
1130  is supported via `tf.Tensor.__getitem__` and `tf.Variable.__getitem__`.**
1131  The interface of this op is a low-level encoding of the slicing syntax.
1132
1133  Roughly speaking, this op extracts a slice of size `(end-begin)/stride`
1134  from the given `input_` tensor. Starting at the location specified by `begin`
1135  the slice continues by adding `stride` to the index until all dimensions are
1136  not less than `end`.
1137  Note that a stride can be negative, which causes a reverse slice.
1138
1139  Given a Python slice `input[spec0, spec1, ..., specn]`,
1140  this function will be called as follows.
1141
1142  `begin`, `end`, and `strides` will be vectors of length n.
1143  n in general is not equal to the rank of the `input_` tensor.
1144
1145  In each mask field (`begin_mask`, `end_mask`, `ellipsis_mask`,
1146  `new_axis_mask`, `shrink_axis_mask`) the ith bit will correspond to
1147  the ith spec.
1148
1149  If the ith bit of `begin_mask` is set, `begin[i]` is ignored and
1150  the fullest possible range in that dimension is used instead.
1151  `end_mask` works analogously, except with the end range.
1152
1153  `foo[5:,:,:3]` on a 7x8x9 tensor is equivalent to `foo[5:7,0:8,0:3]`.
1154  `foo[::-1]` reverses a tensor with shape 8.
1155
1156  If the ith bit of `ellipsis_mask` is set, as many unspecified dimensions
1157  as needed will be inserted between other dimensions. Only one
1158  non-zero bit is allowed in `ellipsis_mask`.
1159
1160  For example `foo[3:5,...,4:5]` on a shape 10x3x3x10 tensor is
1161  equivalent to `foo[3:5,:,:,4:5]` and
1162  `foo[3:5,...]` is equivalent to `foo[3:5,:,:,:]`.
1163
1164  If the ith bit of `new_axis_mask` is set, then `begin`,
1165  `end`, and `stride` are ignored and a new length 1 dimension is
1166  added at this point in the output tensor.
1167
1168  For example,
1169  `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor.
1170
1171  If the ith bit of `shrink_axis_mask` is set, it implies that the ith
1172  specification shrinks the dimensionality by 1, taking on the value at index
1173  `begin[i]`. `end[i]` and `strides[i]` are ignored in this case. For example in
1174  Python one might do `foo[:, 3, :]` which would result in `shrink_axis_mask`
1175  equal to 2.
1176
1177
1178  NOTE: `begin` and `end` are zero-indexed.
1179  `strides` entries must be non-zero.
1180
1181
1182  ```python
1183  t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1184                   [[3, 3, 3], [4, 4, 4]],
1185                   [[5, 5, 5], [6, 6, 6]]])
1186  tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1])  # [[[3, 3, 3]]]
1187  tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1])  # [[[3, 3, 3],
1188                                                        #   [4, 4, 4]]]
1189  tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1])  # [[[4, 4, 4],
1190                                                           #   [3, 3, 3]]]
1191  ```
1192
1193  Args:
1194    input_: A `Tensor`.
1195    begin: An `int32` or `int64` `Tensor`.
1196    end: An `int32` or `int64` `Tensor`.
1197    strides: An `int32` or `int64` `Tensor`.
1198    begin_mask: An `int32` mask.
1199    end_mask: An `int32` mask.
1200    ellipsis_mask: An `int32` mask.
1201    new_axis_mask: An `int32` mask.
1202    shrink_axis_mask: An `int32` mask.
1203    var: The variable corresponding to `input_` or None
1204    name: A name for the operation (optional).
1205
1206  Returns:
1207    A `Tensor` the same type as `input`.
1208  """
1209
1210  if strides is None:
1211    strides = ones_like(begin)
1212
1213  op = gen_array_ops.strided_slice(
1214      input=input_,
1215      begin=begin,
1216      end=end,
1217      strides=strides,
1218      name=name,
1219      begin_mask=begin_mask,
1220      end_mask=end_mask,
1221      ellipsis_mask=ellipsis_mask,
1222      new_axis_mask=new_axis_mask,
1223      shrink_axis_mask=shrink_axis_mask)
1224
1225  parent_name = name
1226
1227  if var is not None:
1228    def assign(val, name=None):
1229      """Closure that holds all the arguments to create an assignment."""
1230
1231      if name is None:
1232        name = parent_name + "_assign"
1233
1234      return var._strided_slice_assign(
1235          begin=begin,
1236          end=end,
1237          strides=strides,
1238          value=val,
1239          name=name,
1240          begin_mask=begin_mask,
1241          end_mask=end_mask,
1242          ellipsis_mask=ellipsis_mask,
1243          new_axis_mask=new_axis_mask,
1244          shrink_axis_mask=shrink_axis_mask)
1245
1246    op.assign = assign
1247
1248  return op
1249
1250
1251def _SliceHelperVar(var, slice_spec):
1252  """Creates a slice helper object given a variable.
1253
1254  This allows creating a sub-tensor from part of the current contents
1255  of a variable. See `tf.Tensor.__getitem__` for detailed examples
1256  of slicing.
1257
1258  This function in addition also allows assignment to a sliced range.
1259  This is similar to `__setitem__` functionality in Python. However,
1260  the syntax is different so that the user can capture the assignment
1261  operation for grouping or passing to `sess.run()`.
1262  For example,
1263
1264  ```python
1265  import tensorflow as tf
1266  A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
1267  with tf.compat.v1.Session() as sess:
1268    sess.run(tf.compat.v1.global_variables_initializer())
1269    print(sess.run(A[:2, :2]))  # => [[1,2], [4,5]]
1270
1271    op = A[:2,:2].assign(22. * tf.ones((2, 2)))
1272    print(sess.run(op))  # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
1273  ```
1274
1275  Note that assignments currently do not support NumPy broadcasting
1276  semantics.
1277
1278  Args:
1279    var: An `ops.Variable` object.
1280    slice_spec: The arguments to `Tensor.__getitem__`.
1281
1282  Returns:
1283    The appropriate slice of "tensor", based on "slice_spec".
1284    As an operator. The operator also has a `assign()` method
1285    that can be used to generate an assignment operator.
1286
1287  Raises:
1288    ValueError: If a slice range is negative size.
1289    TypeError: TypeError: If the slice indices aren't int, slice,
1290      ellipsis, tf.newaxis or int32/int64 tensors.
1291
1292  """
1293
1294  return _slice_helper(var.value(), slice_spec, var)
1295
1296
1297ops.Tensor._override_operator("__getitem__", _slice_helper)
1298
1299
1300@tf_export("parallel_stack")
1301@dispatch.add_dispatch_support
1302def parallel_stack(values, name="parallel_stack"):
1303  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel.
1304
1305  Requires that the shape of inputs be known at graph construction time.
1306
1307  Packs the list of tensors in `values` into a tensor with rank one higher than
1308  each tensor in `values`, by packing them along the first dimension.
1309  Given a list of length `N` of tensors of shape `(A, B, C)`; the `output`
1310  tensor will have the shape `(N, A, B, C)`.
1311
1312  For example:
1313
1314  ```python
1315  x = tf.constant([1, 4])
1316  y = tf.constant([2, 5])
1317  z = tf.constant([3, 6])
1318  tf.parallel_stack([x, y, z])  # [[1, 4], [2, 5], [3, 6]]
1319  ```
1320
1321  The difference between `stack` and `parallel_stack` is that `stack` requires
1322  all the inputs be computed before the operation will begin but doesn't require
1323  that the input shapes be known during graph construction.
1324
1325  `parallel_stack` will copy pieces of the input into the output as they become
1326  available, in some situations this can provide a performance benefit.
1327
1328  Unlike `stack`, `parallel_stack` does NOT support backpropagation.
1329
1330  This is the opposite of unstack.  The numpy equivalent is
1331
1332      tf.parallel_stack([x, y, z]) = np.asarray([x, y, z])
1333
1334  @compatibility(eager)
1335  parallel_stack is not compatible with eager execution.
1336  @end_compatibility
1337
1338  Args:
1339    values: A list of `Tensor` objects with the same shape and type.
1340    name: A name for this operation (optional).
1341
1342  Returns:
1343    output: A stacked `Tensor` with the same type as `values`.
1344
1345  Raises:
1346    RuntimeError: if executed in eager mode.
1347  """
1348  if context.executing_eagerly():
1349    raise RuntimeError("tf.parallel_stack() is not compatible with "
1350                       "eager execution.")
1351  with ops.name_scope(name):
1352    value_t = ops.convert_to_tensor(values[0])
1353    value_shape = ops.convert_to_tensor(value_t).get_shape()
1354
1355    output_shape = tensor_shape.TensorShape([len(values)])
1356    output_shape = output_shape.concatenate(value_shape)
1357    # expand_dims converts concat to stack.
1358    return gen_array_ops.parallel_concat(
1359        [expand_dims(value, 0) for value in values], shape=output_shape)
1360
1361
1362@tf_export("stack")
1363@dispatch.add_dispatch_support
1364def stack(values, axis=0, name="stack"):
1365  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
1366
1367  See also `tf.concat`, `tf.tile`, `tf.repeat`.
1368
1369  Packs the list of tensors in `values` into a tensor with rank one higher than
1370  each tensor in `values`, by packing them along the `axis` dimension.
1371  Given a list of length `N` of tensors of shape `(A, B, C)`;
1372
1373  if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
1374  if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
1375  Etc.
1376
1377  For example:
1378
1379  >>> x = tf.constant([1, 4])
1380  >>> y = tf.constant([2, 5])
1381  >>> z = tf.constant([3, 6])
1382  >>> tf.stack([x, y, z])
1383  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
1384  array([[1, 4],
1385         [2, 5],
1386         [3, 6]], dtype=int32)>
1387  >>> tf.stack([x, y, z], axis=1)
1388  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
1389  array([[1, 2, 3],
1390         [4, 5, 6]], dtype=int32)>
1391
1392  This is the opposite of unstack.  The numpy equivalent is `np.stack`
1393
1394  >>> np.array_equal(np.stack([x, y, z]), tf.stack([x, y, z]))
1395  True
1396
1397  Args:
1398    values: A list of `Tensor` objects with the same shape and type.
1399    axis: An `int`. The axis to stack along. Defaults to the first dimension.
1400      Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
1401    name: A name for this operation (optional).
1402
1403  Returns:
1404    output: A stacked `Tensor` with the same type as `values`.
1405
1406  Raises:
1407    ValueError: If `axis` is out of the range [-(R+1), R+1).
1408  """
1409  if axis == 0:
1410    try:
1411      # If the input is a constant list, it can be converted to a constant op
1412      return ops.convert_to_tensor(values, name=name)
1413    except (TypeError, ValueError):
1414      pass  # Input list contains non-constant tensors
1415
1416  value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple()  # pylint: disable=protected-access
1417  if value_shape is not None:
1418    expanded_num_dims = len(value_shape) + 1
1419    if axis < -expanded_num_dims or axis >= expanded_num_dims:
1420      raise ValueError("axis = %d not in [%d, %d)" %
1421                       (axis, -expanded_num_dims, expanded_num_dims))
1422
1423  return gen_array_ops.pack(values, axis=axis, name=name)
1424
1425
1426# pylint: disable=invalid-name
1427def _autopacking_helper(list_or_tuple, dtype, name):
1428  """Converts the given list or tuple to a tensor by packing.
1429
1430  Args:
1431    list_or_tuple: A (possibly nested) list or tuple containing a tensor.
1432    dtype: The element type of the returned tensor.
1433    name: A name for the returned tensor.
1434
1435  Returns:
1436    A `tf.Tensor` with value equivalent to `list_or_tuple`.
1437  """
1438  if context.executing_eagerly():
1439    # NOTE: Fast path when all the items are tensors, this doesn't do any type
1440    # checking.
1441    if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
1442      return gen_array_ops.pack(list_or_tuple, name=name)
1443  must_pack = False
1444  converted_elems = []
1445  with ops.name_scope(name) as scope:
1446    for i, elem in enumerate(list_or_tuple):
1447      if isinstance(elem, core.Tensor):
1448        if dtype is not None and elem.dtype.base_dtype != dtype:
1449          raise TypeError("Cannot convert a list containing a tensor of dtype "
1450                          "%s to %s (Tensor is: %r)" %
1451                          (elem.dtype, dtype, elem))
1452        converted_elems.append(elem)
1453        must_pack = True
1454      elif isinstance(elem, (list, tuple)):
1455        converted_elem = _autopacking_helper(elem, dtype, str(i))
1456        if isinstance(converted_elem, core.Tensor):
1457          must_pack = True
1458        converted_elems.append(converted_elem)
1459      else:
1460        converted_elems.append(elem)
1461    if must_pack:
1462      elems_as_tensors = []
1463      for i, elem in enumerate(converted_elems):
1464        if isinstance(elem, core.Tensor):
1465          elems_as_tensors.append(elem)
1466        else:
1467          # NOTE(mrry): This is inefficient, but it enables us to
1468          # handle the case where the list arguments are other
1469          # convertible-to-tensor types, such as numpy arrays.
1470          elems_as_tensors.append(
1471              constant_op.constant(elem, dtype=dtype, name=str(i)))
1472      return gen_array_ops.pack(elems_as_tensors, name=scope)
1473    else:
1474      return converted_elems
1475
1476
1477def _get_dtype_from_nested_lists(list_or_tuple):
1478  """Returns the dtype of any tensor-like object in `list_or_tuple`, if found.
1479
1480  Args:
1481    list_or_tuple: A list or tuple representing an object that can be converted
1482      to a `tf.Tensor`.
1483
1484  Returns:
1485    The dtype of any tensor-like object in `list_or_tuple`, or `None` if no
1486    such object exists.
1487  """
1488  for elem in list_or_tuple:
1489    if isinstance(elem, core.Tensor):
1490      return elem.dtype.base_dtype
1491    elif isinstance(elem, (list, tuple)):
1492      maybe_dtype = _get_dtype_from_nested_lists(elem)
1493      if maybe_dtype is not None:
1494        return maybe_dtype
1495  return None
1496
1497
1498def _cast_nested_seqs_to_dtype(dtype):
1499
1500  def _maybe_cast(elem):
1501    if isinstance(elem, core.Tensor):
1502      if dtype != elem.dtype.base_dtype:
1503        elem = gen_math_ops.cast(elem, dtype)
1504    return elem
1505
1506  return _maybe_cast
1507
1508
1509_NON_AUTOPACKABLE_TYPES = set(np.core.numerictypes.ScalarType)
1510_NON_AUTOPACKABLE_TYPES.add(np.ndarray)
1511
1512
1513def _should_not_autopack(v):
1514  # The condition we really want is
1515  #    any(isinstance(elem, core.Tensor))
1516  # but it is >5x slower due to abc.ABCMeta.__instancecheck__.
1517  # pylint: disable=unidiomatic-typecheck
1518  # TODO(slebedev): add nest.all?
1519  return all(type(elem) in _NON_AUTOPACKABLE_TYPES for elem in nest.flatten(v))
1520  # pylint: enable=unidiomatic-typecheck
1521
1522
1523def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
1524  """Tensor conversion function that automatically packs arguments."""
1525  if as_ref or _should_not_autopack(v):
1526    return NotImplemented
1527  inferred_dtype = _get_dtype_from_nested_lists(v)
1528  if inferred_dtype is None:
1529    # We did not find any tensor-like objects in the nested lists, so defer to
1530    # other conversion functions.
1531    return NotImplemented
1532  if dtype is None:
1533    dtype = inferred_dtype
1534  elif dtype != inferred_dtype:
1535    v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
1536  return _autopacking_helper(v, dtype, name or "packed")
1537
1538
1539# pylint: enable=invalid-name
1540
1541# NOTE: Register this conversion function to run *before* one that
1542# assumes every element is a value.
1543ops.register_tensor_conversion_function((list, tuple),
1544                                        _autopacking_conversion_function, 99)
1545
1546
1547@tf_export("unstack")
1548@dispatch.add_dispatch_support
1549def unstack(value, num=None, axis=0, name="unstack"):
1550  """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
1551
1552  Unpacks tensors from `value` by chipping it along the `axis` dimension.
1553
1554  >>> x = tf.reshape(tf.range(12), (3,4))
1555  >>>
1556  >>> p, q, r = tf.unstack(x)
1557  >>> p.shape.as_list()
1558  [4]
1559
1560  >>> i, j, k, l = tf.unstack(x, axis=1)
1561  >>> i.shape.as_list()
1562  [3]
1563
1564  This is the opposite of stack.
1565
1566  >>> x = tf.stack([i, j, k, l], axis=1)
1567
1568  More generally if you have a tensor of shape `(A, B, C, D)`:
1569
1570  >>> A, B, C, D = [2, 3, 4, 5]
1571  >>> t = tf.random.normal(shape=[A, B, C, D])
1572
1573  The number of tensor returned is equal to the length of the target `axis`:
1574
1575  >>> axis = 2
1576  >>> items = tf.unstack(t, axis=axis)
1577  >>> len(items) == t.shape[axis]
1578  True
1579
1580  The shape of each result tensor is equal to the shape of the input tensor,
1581  with the target `axis` removed.
1582
1583  >>> items[0].shape.as_list()  # [A, B, D]
1584  [2, 3, 5]
1585
1586  The value of each tensor `items[i]` is equal to the slice of `input` across
1587  `axis` at index `i`:
1588
1589  >>> for i in range(len(items)):
1590  ...   slice = t[:,:,i,:]
1591  ...   assert tf.reduce_all(slice == items[i])
1592
1593  #### Python iterable unpacking
1594
1595  With eager execution you _can_ unstack the 0th axis of a tensor using python's
1596  iterable unpacking:
1597
1598  >>> t = tf.constant([1,2,3])
1599  >>> a,b,c = t
1600
1601  `unstack` is still necessary because Iterable unpacking doesn't work in
1602  a `@tf.function`: Symbolic tensors are not iterable.
1603
1604  You need to use `tf.unstack` here:
1605
1606  >>> @tf.function
1607  ... def bad(t):
1608  ...   a,b,c = t
1609  ...   return a
1610  >>>
1611  >>> bad(t)
1612  Traceback (most recent call last):
1613  ...
1614  OperatorNotAllowedInGraphError: ...
1615
1616  >>> @tf.function
1617  ... def good(t):
1618  ...   a,b,c = tf.unstack(t)
1619  ...   return a
1620  >>>
1621  >>> good(t).numpy()
1622  1
1623
1624  #### Unknown shapes
1625
1626  Eager tensors have concrete values, so their shape is always known.
1627  Inside a `tf.function` the symbolic tensors may have unknown shapes.
1628  If the length of `axis` is unknown `tf.unstack` will fail because it cannot
1629  handle an unknown number of tensors:
1630
1631  >>> @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
1632  ... def bad(t):
1633  ...   tensors = tf.unstack(t)
1634  ...   return tensors[0]
1635  >>>
1636  >>> bad(tf.constant([1,2,3]))
1637  Traceback (most recent call last):
1638  ...
1639  ValueError: Cannot infer num from shape (None,)
1640
1641  If you know the `axis` length you can pass it as the `num` argument. But this
1642  must be a constant value.
1643
1644  If you actually need a variable number of tensors in a single `tf.function`
1645  trace, you will need to use exlicit loops and a `tf.TensorArray` instead.
1646
1647  Args:
1648    value: A rank `R > 0` `Tensor` to be unstacked.
1649    num: An `int`. The length of the dimension `axis`. Automatically inferred if
1650      `None` (the default).
1651    axis: An `int`. The axis to unstack along. Defaults to the first dimension.
1652      Negative values wrap around, so the valid range is `[-R, R)`.
1653    name: A name for the operation (optional).
1654
1655  Returns:
1656    The list of `Tensor` objects unstacked from `value`.
1657
1658  Raises:
1659    ValueError: If `axis` is out of the range `[-R, R)`.
1660    ValueError: If `num` is unspecified and cannot be inferred.
1661    InvalidArgumentError: If `num` does not match the shape of `value`.
1662  """
1663  if num is None:
1664    value = ops.convert_to_tensor(value)
1665    value_shape = value.get_shape()
1666    if value_shape.ndims is not None:
1667      if axis < -value_shape.ndims or axis >= value_shape.ndims:
1668        raise ValueError("axis = %d not in [%d, %d)" %
1669                         (axis, -value_shape.ndims, value_shape.ndims))
1670      num = value_shape.dims[axis].value
1671  if num is None:
1672    raise ValueError("Cannot infer num from shape %s" % value_shape)
1673  return gen_array_ops.unpack(value, num=num, axis=axis, name=name)
1674
1675
1676@tf_export("concat")
1677@dispatch.add_dispatch_support
1678def concat(values, axis, name="concat"):
1679  """Concatenates tensors along one dimension.
1680
1681  See also `tf.tile`, `tf.stack`, `tf.repeat`.
1682
1683  Concatenates the list of tensors `values` along dimension `axis`.  If
1684  `values[i].shape = [D0, D1, ... Daxis(i), ...Dn]`, the concatenated
1685  result has shape
1686
1687      [D0, D1, ... Raxis, ...Dn]
1688
1689  where
1690
1691      Raxis = sum(Daxis(i))
1692
1693  That is, the data from the input tensors is joined along the `axis`
1694  dimension.
1695
1696  The number of dimensions of the input tensors must match, and all dimensions
1697  except `axis` must be equal.
1698
1699  For example:
1700
1701  >>> t1 = [[1, 2, 3], [4, 5, 6]]
1702  >>> t2 = [[7, 8, 9], [10, 11, 12]]
1703  >>> tf.concat([t1, t2], 0)
1704  <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
1705  array([[ 1,  2,  3],
1706         [ 4,  5,  6],
1707         [ 7,  8,  9],
1708         [10, 11, 12]], dtype=int32)>
1709
1710  >>> tf.concat([t1, t2], 1)
1711  <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
1712  array([[ 1,  2,  3,  7,  8,  9],
1713         [ 4,  5,  6, 10, 11, 12]], dtype=int32)>
1714
1715  As in Python, the `axis` could also be negative numbers. Negative `axis`
1716  are interpreted as counting from the end of the rank, i.e.,
1717   `axis + rank(values)`-th dimension.
1718
1719  For example:
1720
1721  >>> t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
1722  >>> t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
1723  >>> tf.concat([t1, t2], -1)
1724  <tf.Tensor: shape=(2, 2, 4), dtype=int32, numpy=
1725    array([[[ 1,  2,  7,  4],
1726            [ 2,  3,  8,  4]],
1727           [[ 4,  4,  2, 10],
1728            [ 5,  3, 15, 11]]], dtype=int32)>
1729
1730  Note: If you are concatenating along a new axis consider using stack.
1731  E.g.
1732
1733  ```python
1734  tf.concat([tf.expand_dims(t, axis) for t in tensors], axis)
1735  ```
1736
1737  can be rewritten as
1738
1739  ```python
1740  tf.stack(tensors, axis=axis)
1741  ```
1742
1743  Args:
1744    values: A list of `Tensor` objects or a single `Tensor`.
1745    axis: 0-D `int32` `Tensor`.  Dimension along which to concatenate. Must be
1746      in the range `[-rank(values), rank(values))`. As in Python, indexing for
1747      axis is 0-based. Positive axis in the rage of `[0, rank(values))` refers
1748      to `axis`-th dimension. And negative axis refers to `axis +
1749      rank(values)`-th dimension.
1750    name: A name for the operation (optional).
1751
1752  Returns:
1753    A `Tensor` resulting from concatenation of the input tensors.
1754  """
1755  if not isinstance(values, (list, tuple)):
1756    values = [values]
1757  # TODO(mrry): Change to return values?
1758  if len(values) == 1:  # Degenerate case of one tensor.
1759    # Make a throwaway call to convert_to_tensor to make sure
1760    # that axis is of the correct type, and make sure that
1761    # the returned tensor is a scalar.
1762    # TODO(keveman): Implement a standalone type and shape checker.
1763    with ops.name_scope(name) as scope:
1764      ops.convert_to_tensor(
1765          axis, name="concat_dim",
1766          dtype=dtypes.int32).get_shape().assert_has_rank(0)
1767      return identity(values[0], name=name)
1768  return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
1769
1770
1771@tf_export(v1=["boolean_mask"])
1772@dispatch.add_dispatch_support
1773def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
1774  """Apply boolean mask to tensor.
1775
1776  Numpy equivalent is `tensor[mask]`.
1777
1778  In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1779  the first K dimensions of `tensor`'s shape.  We then have:
1780    `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1781  where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1782  The `axis` could be used with `mask` to indicate the axis to mask from.
1783  In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1784  the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1785
1786  See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1787  ragged tensors, and can be used if you need to preserve the masked dimensions
1788  of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1789
1790  Examples:
1791
1792  ```python
1793  # 1-D example
1794  tensor = [0, 1, 2, 3]
1795  mask = np.array([True, False, True, False])
1796  tf.boolean_mask(tensor, mask)  # [0, 2]
1797
1798  # 2-D example
1799  tensor = [[1, 2], [3, 4], [5, 6]]
1800  mask = np.array([True, False, True])
1801  tf.boolean_mask(tensor, mask)  # [[1, 2], [5, 6]]
1802  ```
1803
1804  Args:
1805    tensor:  N-D Tensor.
1806    mask:  K-D boolean Tensor, K <= N and K must be known statically.
1807    name:  A name for this operation (optional).
1808    axis:  A 0-D int Tensor representing the axis in `tensor` to mask from. By
1809      default, axis is 0 which will mask from the first dimension. Otherwise K +
1810      axis <= N.
1811
1812  Returns:
1813    (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1814    to `True` values in `mask`.
1815
1816  Raises:
1817    ValueError:  If shapes do not conform.
1818  """
1819
1820  def _apply_mask_1d(reshaped_tensor, mask, axis=None):
1821    """Mask tensor along dimension 0 with a 1-D mask."""
1822    indices = squeeze(where_v2(mask), axis=[1])
1823    return gather(reshaped_tensor, indices, axis=axis)
1824
1825  with ops.name_scope(name, values=[tensor, mask]):
1826    tensor = ops.convert_to_tensor(tensor, name="tensor")
1827    mask = ops.convert_to_tensor(mask, name="mask")
1828
1829    shape_mask = mask.get_shape()
1830    ndims_mask = shape_mask.ndims
1831    shape_tensor = tensor.get_shape()
1832    if ndims_mask == 0:
1833      raise ValueError("mask cannot be scalar.")
1834    if ndims_mask is None:
1835      raise ValueError(
1836          "Number of mask dimensions must be specified, even if some dimensions"
1837          " are None.  E.g. shape=[None] is ok, but shape=None is not.")
1838    axis = 0 if axis is None else axis
1839    axis_value = tensor_util.constant_value(axis)
1840    if axis_value is not None:
1841      axis = axis_value
1842      shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask)
1843
1844    leading_size = gen_math_ops.prod(shape(tensor)[axis:axis + ndims_mask], [0])
1845    tensor = reshape(
1846        tensor,
1847        concat([
1848            shape(tensor)[:axis], [leading_size],
1849            shape(tensor)[axis + ndims_mask:]
1850        ], 0))
1851    # TODO(yongtang): tf.reshape in C++ kernel might have set the shape
1852    # correctly, so the following may not be needed? It still might be possible
1853    # that there are some edge case where tensor_util.constant_value resolves
1854    # more cases than ShapeInference of tf.reshape in C++ kernel.
1855    if axis_value is not None:
1856      first_dim = shape_tensor[axis:axis + ndims_mask].num_elements()
1857      tensor.set_shape(
1858          tensor_shape.as_shape(shape_tensor[:axis]).concatenate(
1859              [first_dim]).concatenate(shape_tensor[axis + ndims_mask:]))
1860
1861    mask = reshape(mask, [-1])
1862    return _apply_mask_1d(tensor, mask, axis)
1863
1864
1865@tf_export("boolean_mask", v1=[])
1866@dispatch.add_dispatch_support
1867def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
1868  """Apply boolean mask to tensor.
1869
1870  Numpy equivalent is `tensor[mask]`.
1871
1872  In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1873  the first K dimensions of `tensor`'s shape.  We then have:
1874    `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1875  where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1876  The `axis` could be used with `mask` to indicate the axis to mask from.
1877  In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1878  the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1879
1880  See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1881  ragged tensors, and can be used if you need to preserve the masked dimensions
1882  of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1883
1884  Examples:
1885
1886  >>> tensor = [0, 1, 2, 3]  # 1-D example
1887  >>> mask = np.array([True, False, True, False])
1888  >>> tf.boolean_mask(tensor, mask)
1889  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 2], dtype=int32)>
1890
1891  >>> tensor = [[1, 2], [3, 4], [5, 6]] # 2-D example
1892  >>> mask = np.array([True, False, True])
1893  >>> tf.boolean_mask(tensor, mask)
1894  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
1895  array([[1, 2],
1896         [5, 6]], dtype=int32)>
1897
1898  Args:
1899    tensor:  N-D Tensor.
1900    mask:  K-D boolean Tensor, K <= N and K must be known statically.
1901    axis:  A 0-D int Tensor representing the axis in `tensor` to mask from. By
1902      default, axis is 0 which will mask from the first dimension. Otherwise K +
1903      axis <= N.
1904    name:  A name for this operation (optional).
1905
1906  Returns:
1907    (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1908    to `True` values in `mask`.
1909
1910  Raises:
1911    ValueError:  If shapes do not conform.
1912
1913  Examples:
1914
1915  ```python
1916  # 2-D example
1917  tensor = [[1, 2], [3, 4], [5, 6]]
1918  mask = np.array([True, False, True])
1919  boolean_mask(tensor, mask)  # [[1, 2], [5, 6]]
1920  ```
1921  """
1922  return boolean_mask(tensor, mask, name, axis)
1923
1924
1925@tf_export("sparse.mask", v1=["sparse.mask", "sparse_mask"])
1926@deprecation.deprecated_endpoints("sparse_mask")
1927def sparse_mask(a, mask_indices, name=None):
1928  """Masks elements of `IndexedSlices`.
1929
1930  Given an `IndexedSlices` instance `a`, returns another `IndexedSlices` that
1931  contains a subset of the slices of `a`. Only the slices at indices not
1932  specified in `mask_indices` are returned.
1933
1934  This is useful when you need to extract a subset of slices in an
1935  `IndexedSlices` object.
1936
1937  For example:
1938
1939  ```python
1940  # `a` contains slices at indices [12, 26, 37, 45] from a large tensor
1941  # with shape [1000, 10]
1942  a.indices  # [12, 26, 37, 45]
1943  tf.shape(a.values)  # [4, 10]
1944
1945  # `b` will be the subset of `a` slices at its second and third indices, so
1946  # we want to mask its first and last indices (which are at absolute
1947  # indices 12, 45)
1948  b = tf.sparse.mask(a, [12, 45])
1949
1950  b.indices  # [26, 37]
1951  tf.shape(b.values)  # [2, 10]
1952  ```
1953
1954  Args:
1955    a: An `IndexedSlices` instance.
1956    mask_indices: Indices of elements to mask.
1957    name: A name for the operation (optional).
1958
1959  Returns:
1960    The masked `IndexedSlices` instance.
1961  """
1962  with ops.name_scope(name, "sparse_mask", [a, mask_indices]) as name:
1963    indices = a.indices
1964    out_indices, to_gather = gen_array_ops.list_diff(indices, mask_indices)
1965    out_values = gather(a.values, to_gather, name=name)
1966    return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
1967
1968
1969@tf_export("unique")
1970@dispatch.add_dispatch_support
1971def unique(x, out_idx=dtypes.int32, name=None):
1972  """Finds unique elements in a 1-D tensor.
1973
1974  See also `tf.unique_with_counts`.
1975
1976  This operation returns a tensor `y` containing all of the unique elements
1977  of `x` sorted in the same order that they occur in `x`. This operation
1978  also returns a tensor `idx` the same size as `x` that contains the index
1979  of each value of `x` in the unique output `y`. In other words:
1980
1981
1982    y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
1983
1984  Example usage:
1985
1986  >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
1987  >>> y, idx = unique(x)
1988  >>> y
1989  <tf.Tensor: id=5, shape=(5,), dtype=int32,
1990  numpy=array([1, 2, 4, 7, 8], dtype=int32)>
1991  >>> idx
1992  <tf.Tensor: id=6, shape=(9,), dtype=int32,
1993  numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
1994
1995  Args:
1996    x: A Tensor. 1-D.
1997    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
1998      tf.int32.
1999    name: A name for the operation (optional).
2000
2001  Returns:
2002    A tuple of Tensor objects (y, idx).
2003      y: A Tensor. Has the same type as x.
2004      idx: A Tensor of type out_idx.
2005
2006  """
2007  # TODO(yongtang): switch to v2 once API deprecation
2008  # period (3 weeks) pass.
2009  # TODO(yongtang): The documentation should also
2010  # be updated when switch  to v2.
2011  return gen_array_ops.unique(x, out_idx, name)
2012
2013
2014unique.__doc__ = gen_array_ops.unique.__doc__
2015
2016
2017@tf_export("unique_with_counts")
2018@dispatch.add_dispatch_support
2019def unique_with_counts(x, out_idx=dtypes.int32, name=None):
2020  """Finds unique elements in a 1-D tensor.
2021
2022  See also `tf.unique`.
2023
2024  This operation returns a tensor `y` containing all of the unique elements
2025  of `x` sorted in the same order that they occur in `x`. This operation
2026  also returns a tensor `idx` the same size as `x` that contains the index
2027  of each value of `x` in the unique output `y`. Finally, it returns a
2028  third tensor `count` that contains the count of each element of `y`
2029  in `x`. In other words:
2030
2031    y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
2032
2033  Example usage:
2034
2035  >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
2036  >>> y, idx, count = unique_with_counts(x)
2037  >>> y
2038  <tf.Tensor: id=8, shape=(5,), dtype=int32,
2039  numpy=array([1, 2, 4, 7, 8], dtype=int32)>
2040  >>> idx
2041  <tf.Tensor: id=9, shape=(9,), dtype=int32,
2042  numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
2043  >>> count
2044  <tf.Tensor: id=10, shape=(5,), dtype=int32,
2045  numpy=array([2, 1, 3, 1, 2], dtype=int32)>
2046
2047  Args:
2048    x: A Tensor. 1-D.
2049    out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
2050      tf.int32.
2051    name: A name for the operation (optional).
2052
2053  Returns:
2054    A tuple of Tensor objects (y, idx, count).
2055      y: A Tensor. Has the same type as x.
2056      idx: A Tensor of type out_idx.
2057      count: A Tensor of type out_idx.
2058
2059  """
2060  # TODO(yongtang): switch to v2 once API deprecation
2061  # period (3 weeks) pass.
2062  # TODO(yongtang): The documentation should also
2063  # be updated when switch  to v2.
2064  return gen_array_ops.unique_with_counts(x, out_idx, name)
2065
2066
2067unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
2068
2069
2070@tf_export("split")
2071@dispatch.add_dispatch_support
2072def split(value, num_or_size_splits, axis=0, num=None, name="split"):
2073  """Splits a tensor `value` into a list of sub tensors.
2074
2075  See also `tf.unstack`.
2076
2077  If `num_or_size_splits` is an integer,  then `value` is split along the
2078  dimension `axis` into `num_or_size_splits` smaller tensors. This requires that
2079  `value.shape[axis]` is divisible by `num_or_size_splits`.
2080
2081  If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
2082  `len(num_or_size_splits)` elements. The shape of the `i`-th
2083  element has the same size as the `value` except along dimension `axis` where
2084  the size is `num_or_size_splits[i]`.
2085
2086  For example:
2087
2088  >>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
2089  >>>
2090  >>> # Split `x` into 3 tensors along dimension 1
2091  >>> s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
2092  >>> tf.shape(s0).numpy()
2093  array([ 5, 10], dtype=int32)
2094  >>>
2095  >>> # Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
2096  >>> split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
2097  >>> tf.shape(split0).numpy()
2098  array([5, 4], dtype=int32)
2099  >>> tf.shape(split1).numpy()
2100  array([ 5, 15], dtype=int32)
2101  >>> tf.shape(split2).numpy()
2102  array([ 5, 11], dtype=int32)
2103
2104  Args:
2105    value: The `Tensor` to split.
2106    num_or_size_splits: Either an integer indicating the number of splits along
2107      `axis` or a 1-D integer `Tensor` or Python list containing the sizes of
2108      each output tensor along `axis`. If a scalar, then it must evenly divide
2109      `value.shape[axis]`; otherwise the sum of sizes along the split axis
2110      must match that of the `value`.
2111    axis: An integer or scalar `int32` `Tensor`. The dimension along which to
2112      split. Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
2113    num: Optional, used to specify the number of outputs when it cannot be
2114      inferred from the shape of `size_splits`.
2115    name: A name for the operation (optional).
2116
2117  Returns:
2118    if `num_or_size_splits` is a scalar returns a list of `num_or_size_splits`
2119    `Tensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
2120    `num_or_size_splits.get_shape[0]` `Tensor` objects resulting from splitting
2121    `value`.
2122
2123  Raises:
2124    ValueError: If `num` is unspecified and cannot be inferred.
2125  """
2126  if isinstance(num_or_size_splits,
2127                (numbers.Integral, tensor_shape.Dimension)):
2128    return gen_array_ops.split(
2129        axis=axis, num_split=num_or_size_splits, value=value, name=name)
2130
2131  size_splits = ops.convert_to_tensor(num_or_size_splits)
2132
2133  if size_splits._rank() == 0:
2134    raise ValueError(
2135        "Rank-0 tensors are not supported as the num_or_size_splits argument "
2136        "to split. Argument provided: %s" % (num_or_size_splits,))
2137
2138  if num is None:
2139    size_splits_shape = size_splits._shape_tuple()
2140    if size_splits_shape:
2141      num = size_splits_shape[0]
2142    if num is None:
2143      raise ValueError("Cannot infer num from shape %s" % num_or_size_splits)
2144
2145  return gen_array_ops.split_v(
2146      value=value, size_splits=size_splits, axis=axis, num_split=num, name=name)
2147
2148
2149@tf_export("transpose", v1=[])
2150@dispatch.add_dispatch_support
2151def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
2152  """Transposes `a`, where `a` is a Tensor.
2153
2154  Permutes the dimensions according to the value of `perm`.
2155
2156  The returned tensor's dimension `i` will correspond to the input dimension
2157  `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is the rank
2158  of the input tensor. Hence by default, this operation performs a regular
2159  matrix transpose on 2-D input Tensors.
2160
2161  If conjugate is `True` and `a.dtype` is either `complex64` or `complex128`
2162  then the values of `a` are conjugated and transposed.
2163
2164  @compatibility(numpy)
2165  In `numpy` transposes are memory-efficient constant time operations as they
2166  simply return a new view of the same data with adjusted `strides`.
2167
2168  TensorFlow does not support strides, so `transpose` returns a new tensor with
2169  the items permuted.
2170  @end_compatibility
2171
2172  For example:
2173
2174  >>> x = tf.constant([[1, 2, 3], [4, 5, 6]])
2175  >>> tf.transpose(x)
2176  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
2177  array([[1, 4],
2178         [2, 5],
2179         [3, 6]], dtype=int32)>
2180
2181  Equivalently, you could call `tf.transpose(x, perm=[1, 0])`.
2182
2183  If `x` is complex, setting conjugate=True gives the conjugate transpose:
2184
2185  >>> x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2186  ...                  [4 + 4j, 5 + 5j, 6 + 6j]])
2187  >>> tf.transpose(x, conjugate=True)
2188  <tf.Tensor: shape=(3, 2), dtype=complex128, numpy=
2189  array([[1.-1.j, 4.-4.j],
2190         [2.-2.j, 5.-5.j],
2191         [3.-3.j, 6.-6.j]])>
2192
2193  'perm' is more useful for n-dimensional tensors where n > 2:
2194
2195  >>> x = tf.constant([[[ 1,  2,  3],
2196  ...                   [ 4,  5,  6]],
2197  ...                  [[ 7,  8,  9],
2198  ...                   [10, 11, 12]]])
2199
2200  As above, simply calling `tf.transpose` will default to `perm=[2,1,0]`.
2201
2202  To take the transpose of the matrices in dimension-0 (such as when you are
2203  transposing matrices where 0 is the batch dimension), you would set
2204  `perm=[0,2,1]`.
2205
2206  >>> tf.transpose(x, perm=[0, 2, 1])
2207  <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
2208  array([[[ 1,  4],
2209          [ 2,  5],
2210          [ 3,  6]],
2211          [[ 7, 10],
2212          [ 8, 11],
2213          [ 9, 12]]], dtype=int32)>
2214
2215  Note: This has a shorthand `linalg.matrix_transpose`):
2216
2217  Args:
2218    a: A `Tensor`.
2219    perm: A permutation of the dimensions of `a`.  This should be a vector.
2220    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2221      to tf.math.conj(tf.transpose(input)).
2222    name: A name for the operation (optional).
2223
2224  Returns:
2225    A transposed `Tensor`.
2226  """
2227  return transpose(a=a, perm=perm, name=name, conjugate=conjugate)
2228
2229
2230@tf_export(v1=["transpose"])
2231@dispatch.add_dispatch_support
2232def transpose(a, perm=None, name="transpose", conjugate=False):
2233  """Transposes `a`.
2234
2235  Permutes the dimensions according to `perm`.
2236
2237  The returned tensor's dimension i will correspond to the input dimension
2238  `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
2239  the rank of the input tensor. Hence by default, this operation performs a
2240  regular matrix transpose on 2-D input Tensors. If conjugate is True and
2241  `a.dtype` is either `complex64` or `complex128` then the values of `a`
2242  are conjugated and transposed.
2243
2244  @compatibility(numpy)
2245  In `numpy` transposes are memory-efficient constant time operations as they
2246  simply return a new view of the same data with adjusted `strides`.
2247
2248  TensorFlow does not support strides, so `transpose` returns a new tensor with
2249  the items permuted.
2250  @end_compatibility
2251
2252  For example:
2253
2254  ```python
2255  x = tf.constant([[1, 2, 3], [4, 5, 6]])
2256  tf.transpose(x)  # [[1, 4]
2257                   #  [2, 5]
2258                   #  [3, 6]]
2259
2260  # Equivalently
2261  tf.transpose(x, perm=[1, 0])  # [[1, 4]
2262                                #  [2, 5]
2263                                #  [3, 6]]
2264
2265  # If x is complex, setting conjugate=True gives the conjugate transpose
2266  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2267                   [4 + 4j, 5 + 5j, 6 + 6j]])
2268  tf.transpose(x, conjugate=True)  # [[1 - 1j, 4 - 4j],
2269                                   #  [2 - 2j, 5 - 5j],
2270                                   #  [3 - 3j, 6 - 6j]]
2271
2272  # 'perm' is more useful for n-dimensional tensors, for n > 2
2273  x = tf.constant([[[ 1,  2,  3],
2274                    [ 4,  5,  6]],
2275                   [[ 7,  8,  9],
2276                    [10, 11, 12]]])
2277
2278  # Take the transpose of the matrices in dimension-0
2279  # (this common operation has a shorthand `linalg.matrix_transpose`)
2280  tf.transpose(x, perm=[0, 2, 1])  # [[[1,  4],
2281                                   #   [2,  5],
2282                                   #   [3,  6]],
2283                                   #  [[7, 10],
2284                                   #   [8, 11],
2285                                   #   [9, 12]]]
2286  ```
2287
2288  Args:
2289    a: A `Tensor`.
2290    perm: A permutation of the dimensions of `a`.
2291    name: A name for the operation (optional).
2292    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2293      to tf.math.conj(tf.transpose(input)).
2294
2295  Returns:
2296    A transposed `Tensor`.
2297  """
2298  with ops.name_scope(name, "transpose", [a]) as name:
2299    if not tensor_util.is_tf_type(a):
2300      a = ops.convert_to_tensor(a, name="a")
2301
2302    if conjugate and a.dtype.is_complex:
2303      transpose_fn = gen_array_ops.conjugate_transpose
2304    else:
2305      transpose_fn = gen_array_ops.transpose
2306
2307    if perm is not None:
2308      return transpose_fn(a, perm, name=name)
2309
2310    rank = a.shape.rank
2311    if rank is None:
2312      perm = gen_math_ops._range(gen_array_ops.rank(a) - 1, -1, -1)
2313    else:
2314      perm = np.arange(rank - 1, -1, -1, dtype=np.int32)
2315    return transpose_fn(a, perm, name=name)
2316
2317
2318# pylint: disable=invalid-name
2319@tf_export(
2320    "linalg.matrix_transpose",
2321    v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"])
2322@dispatch.add_dispatch_support
2323@deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose")
2324def matrix_transpose(a, name="matrix_transpose", conjugate=False):
2325  """Transposes last two dimensions of tensor `a`.
2326
2327  For example:
2328
2329  ```python
2330  x = tf.constant([[1, 2, 3], [4, 5, 6]])
2331  tf.linalg.matrix_transpose(x)  # [[1, 4],
2332                                 #  [2, 5],
2333                                 #  [3, 6]]
2334
2335  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2336                   [4 + 4j, 5 + 5j, 6 + 6j]])
2337  tf.linalg.matrix_transpose(x, conjugate=True)  # [[1 - 1j, 4 - 4j],
2338                                                 #  [2 - 2j, 5 - 5j],
2339                                                 #  [3 - 3j, 6 - 6j]]
2340
2341  # Matrix with two batch dimensions.
2342  # x.shape is [1, 2, 3, 4]
2343  # tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
2344  ```
2345
2346  Note that `tf.matmul` provides kwargs allowing for transpose of arguments.
2347  This is done with minimal cost, and is preferable to using this function. E.g.
2348
2349  ```python
2350  # Good!  Transpose is taken at minimal additional cost.
2351  tf.matmul(matrix, b, transpose_b=True)
2352
2353  # Inefficient!
2354  tf.matmul(matrix, tf.linalg.matrix_transpose(b))
2355  ```
2356
2357  @compatibility(numpy)
2358  In `numpy` transposes are memory-efficient constant time operations as they
2359  simply return a new view of the same data with adjusted `strides`.
2360
2361  TensorFlow does not support strides, `linalg.matrix_transpose` returns a new
2362  tensor with the items permuted.
2363  @end_compatibility
2364
2365  Args:
2366    a: A `Tensor` with `rank >= 2`.
2367    name: A name for the operation (optional).
2368    conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2369      to tf.math.conj(tf.linalg.matrix_transpose(input)).
2370
2371  Returns:
2372    A transposed batch matrix `Tensor`.
2373
2374  Raises:
2375    ValueError:  If `a` is determined statically to have `rank < 2`.
2376  """
2377  with ops.name_scope(name, values=[a]):
2378    a = ops.convert_to_tensor(a, name="a")
2379
2380    # If we know the number of dimensions (statically), we can do two things:
2381    # 1. Check that `a` is a (batch) matrix.
2382    # 2. Use a Python list for perm.  This preserves static shape information
2383    #    and avoids extra computations.
2384    a_shape = a.get_shape()
2385    ndims = a_shape.ndims
2386    if ndims is not None:
2387      if ndims < 2:
2388        raise ValueError(
2389            "Argument 'a' should be a (batch) matrix, with rank >= 2.  Found: "
2390            "%s" % a_shape)
2391      perm = list(range(ndims - 2)) + [ndims - 1] + [ndims - 2]
2392    else:
2393      a_rank = rank(a)
2394      perm = concat(
2395          (gen_math_ops._range(0, a_rank - 2, 1), [a_rank - 1, a_rank - 2]), 0)
2396
2397    return transpose(a, perm=perm, conjugate=conjugate)
2398
2399
2400@tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"])
2401@dispatch.add_dispatch_support
2402@deprecation.deprecated_endpoints("matrix_diag")
2403def matrix_diag(diagonal,
2404                name="diag",
2405                k=0,
2406                num_rows=-1,
2407                num_cols=-1,
2408                padding_value=0,
2409                align="RIGHT_LEFT"):
2410  """Returns a batched diagonal tensor with given batched diagonal values.
2411
2412  Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
2413  diagonals of a matrix, with everything else padded with `padding`. `num_rows`
2414  and `num_cols` specify the dimension of the innermost matrix of the output. If
2415  both are not specified, the op assumes the innermost matrix is square and
2416  infers its size from `k` and the innermost dimension of `diagonal`. If only
2417  one of them is specified, the op assumes the unspecified value is the smallest
2418  possible based on other criteria.
2419
2420  Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor
2421  has rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only
2422  one diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has
2423  rank `r` with shape `[I, J, ..., L, num_rows, num_cols]`.
2424
2425  The second innermost dimension of `diagonal` has double meaning. When `k` is
2426  scalar or `k[0] == k[1]`, `M` is part of the batch size [I, J, ..., M], and
2427  the output tensor is:
2428
2429  ```
2430  output[i, j, ..., l, m, n]
2431    = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper
2432      padding_value                             ; otherwise
2433  ```
2434
2435  Otherwise, `M` is treated as the number of diagonals for the matrix in the
2436  same batch (`M = k[1]-k[0]+1`), and the output tensor is:
2437
2438  ```
2439  output[i, j, ..., l, m, n]
2440    = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2441      padding_value                                     ; otherwise
2442  ```
2443  where `d = n - m`, `diag_index = k[1] - d`, and
2444  `index_in_diag = n - max(d, 0) + offset`.
2445
2446  `offset` is zero except when the alignment of the diagonal is to the right.
2447  ```
2448  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2449                                             and `d >= 0`) or
2450                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2451                                             and `d <= 0`)
2452           0                          ; otherwise
2453  ```
2454  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2455
2456  For example:
2457
2458  ```
2459  # The main diagonal.
2460  diagonal = np.array([[1, 2, 3, 4],            # Input shape: (2, 4)
2461                       [5, 6, 7, 8]])
2462  tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0],  # Output shape: (2, 4, 4)
2463                                 [0, 2, 0, 0],
2464                                 [0, 0, 3, 0],
2465                                 [0, 0, 0, 4]],
2466                                [[5, 0, 0, 0],
2467                                 [0, 6, 0, 0],
2468                                 [0, 0, 7, 0],
2469                                 [0, 0, 0, 8]]]
2470
2471  # A superdiagonal (per batch).
2472  diagonal = np.array([[1, 2, 3],  # Input shape: (2, 3)
2473                       [4, 5, 6]])
2474  tf.matrix_diag(diagonal, k = 1)
2475    ==> [[[0, 1, 0, 0],  # Output shape: (2, 4, 4)
2476          [0, 0, 2, 0],
2477          [0, 0, 0, 3],
2478          [0, 0, 0, 0]],
2479         [[0, 4, 0, 0],
2480          [0, 0, 5, 0],
2481          [0, 0, 0, 6],
2482          [0, 0, 0, 0]]]
2483
2484  # A tridiagonal band (per batch).
2485  diagonals = np.array([[[8, 9, 0],  # Input shape: (2, 2, 3)
2486                         [1, 2, 3],
2487                         [0, 4, 5]],
2488                        [[2, 3, 0],
2489                         [6, 7, 9],
2490                         [0, 9, 1]]])
2491  tf.matrix_diag(diagonals, k = (-1, 1))
2492    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
2493          [4, 2, 9],
2494          [0, 5, 3]],
2495         [[6, 2, 0],
2496          [9, 7, 3],
2497          [0, 1, 9]]]
2498
2499  # RIGHT_LEFT alignment.
2500  diagonals = np.array([[[0, 8, 9],  # Input shape: (2, 2, 3)
2501                         [1, 2, 3],
2502                         [4, 5, 0]],
2503                        [[0, 2, 3],
2504                         [6, 7, 9],
2505                         [9, 1, 0]]])
2506  tf.matrix_diag(diagonals, k = (-1, 1), align="RIGHT_LEFT")
2507    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
2508          [4, 2, 9],
2509          [0, 5, 3]],
2510         [[6, 2, 0],
2511          [9, 7, 3],
2512          [0, 1, 9]]]
2513
2514  # Rectangular matrix.
2515  diagonal = np.array([1, 2])  # Input shape: (2)
2516  tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4)
2517    ==> [[0, 0, 0, 0],  # Output shape: (3, 4)
2518         [1, 0, 0, 0],
2519         [0, 2, 0, 0]]
2520
2521  # Rectangular matrix with inferred num_cols and padding_value = 9.
2522  tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
2523    ==> [[9, 9],  # Output shape: (3, 2)
2524         [1, 9],
2525         [9, 2]]
2526  ```
2527
2528  Args:
2529    diagonal: A `Tensor` with `rank k >= 1`.
2530    name: A name for the operation (optional).
2531    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2532      main diagonal, and negative value means subdiagonals. `k` can be a single
2533      integer (for a single diagonal) or a pair of integers specifying the low
2534      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2535    num_rows: The number of rows of the output matrix. If it is not provided,
2536      the op assumes the output matrix is a square matrix and infers the matrix
2537      size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2538    num_cols: The number of columns of the output matrix. If it is not provided,
2539      the op assumes the output matrix is a square matrix and infers the matrix
2540      size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2541    padding_value: The value to fill the area outside the specified diagonal
2542      band with. Default is 0.
2543    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2544      `align` is a string specifying how superdiagonals and subdiagonals should
2545      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2546      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2547      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2548      the left (right-pads the row). It is the packing format LAPACK uses.
2549      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2550
2551  Returns:
2552    A Tensor. Has the same type as `diagonal`.
2553  """
2554  # Special case to sidestep the tf.constant conversion error:
2555  # TypeError: Expected bool, got 0 of type 'int' instead.
2556  if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
2557    padding_value = bool(padding_value)
2558
2559  return gen_array_ops.matrix_diag_v3(
2560      diagonal=diagonal,
2561      k=k,
2562      num_rows=num_rows,
2563      num_cols=num_cols,
2564      padding_value=padding_value,
2565      align=align,
2566      name=name)
2567
2568
2569@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
2570@dispatch.add_dispatch_support
2571@deprecation.deprecated_endpoints("matrix_diag_part")
2572def matrix_diag_part(
2573    input,  # pylint:disable=redefined-builtin
2574    name="diag_part",
2575    k=0,
2576    padding_value=0,
2577    align="RIGHT_LEFT"):
2578  """Returns the batched diagonal part of a batched tensor.
2579
2580  Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
2581  `input`.
2582
2583  Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`.
2584  Let `max_diag_len` be the maximum length among all diagonals to be extracted,
2585  `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2586  Let `num_diags` be the number of diagonals to extract,
2587  `num_diags = k[1] - k[0] + 1`.
2588
2589  If `num_diags == 1`, the output tensor is of rank `r - 1` with shape
2590  `[I, J, ..., L, max_diag_len]` and values:
2591
2592  ```
2593  diagonal[i, j, ..., l, n]
2594    = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2595      padding_value                 ; otherwise.
2596  ```
2597  where `y = max(-k[1], 0)`, `x = max(k[1], 0)`.
2598
2599  Otherwise, the output tensor has rank `r` with dimensions
2600  `[I, J, ..., L, num_diags, max_diag_len]` with values:
2601
2602  ```
2603  diagonal[i, j, ..., l, m, n]
2604    = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2605      padding_value                 ; otherwise.
2606  ```
2607  where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
2608
2609  `offset` is zero except when the alignment of the diagonal is to the right.
2610  ```
2611  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2612                                             and `d >= 0`) or
2613                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2614                                             and `d <= 0`)
2615           0                          ; otherwise
2616  ```
2617  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2618
2619  The input must be at least a matrix.
2620
2621  For example:
2622
2623  ```
2624  input = np.array([[[1, 2, 3, 4],  # Input shape: (2, 3, 4)
2625                     [5, 6, 7, 8],
2626                     [9, 8, 7, 6]],
2627                    [[5, 4, 3, 2],
2628                     [1, 2, 3, 4],
2629                     [5, 6, 7, 8]]])
2630
2631  # A main diagonal from each batch.
2632  tf.linalg.diag_part(input) ==> [[1, 6, 7],  # Output shape: (2, 3)
2633                                  [5, 2, 7]]
2634
2635  # A superdiagonal from each batch.
2636  tf.linalg.diag_part(input, k = 1)
2637    ==> [[2, 7, 6],  # Output shape: (2, 3)
2638         [4, 3, 8]]
2639
2640  # A band from each batch.
2641  tf.linalg.diag_part(input, k = (-1, 2))
2642    ==> [[[3, 8, 0],  # Output shape: (2, 4, 3)
2643          [2, 7, 6],
2644          [1, 6, 7],
2645          [0, 5, 8]],
2646         [[3, 4, 0],
2647          [4, 3, 8],
2648          [5, 2, 7],
2649          [0, 1, 6]]]
2650
2651  # RIGHT_LEFT alignment.
2652  tf.linalg.diag_part(input, k = (-1, 2), align="RIGHT_LEFT")
2653    ==> [[[0, 3, 8],  # Output shape: (2, 4, 3)
2654          [2, 7, 6],
2655          [1, 6, 7],
2656          [5, 8, 0]],
2657         [[0, 3, 4],
2658          [4, 3, 8],
2659          [5, 2, 7],
2660          [1, 6, 0]]]
2661
2662  # max_diag_len can be shorter than the main diagonal.
2663  tf.linalg.diag_part(input, k = (-2, -1))
2664    ==> [[[5, 8],
2665          [0, 9]],
2666         [[1, 6],
2667          [0, 5]]]
2668
2669  # padding_value = 9
2670  tf.linalg.diag_part(input, k = (1, 3), padding_value = 9)
2671    ==> [[[4, 9, 9],  # Output shape: (2, 3, 3)
2672          [3, 8, 9],
2673          [2, 7, 6]],
2674         [[2, 9, 9],
2675          [3, 4, 9],
2676          [4, 3, 8]]]
2677
2678  ```
2679
2680  Args:
2681    input: A `Tensor` with `rank k >= 2`.
2682    name: A name for the operation (optional).
2683    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2684      main diagonal, and negative value means subdiagonals. `k` can be a single
2685      integer (for a single diagonal) or a pair of integers specifying the low
2686      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2687    padding_value: The value to fill the area outside the specified diagonal
2688      band with. Default is 0.
2689    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2690      `align` is a string specifying how superdiagonals and subdiagonals should
2691      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2692      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2693      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2694      the left (right-pads the row). It is the packing format LAPACK uses.
2695      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2696
2697  Returns:
2698    A Tensor containing diagonals of `input`. Has the same type as `input`.
2699  """
2700  # Special case to sidestep the tf.constant conversion error:
2701  # TypeError: Expected bool, got 0 of type 'int' instead.
2702  if hasattr(input, "dtype") and input.dtype == "bool":
2703    padding_value = bool(padding_value)
2704
2705  return gen_array_ops.matrix_diag_part_v3(
2706      input=input, k=k, padding_value=padding_value, align=align, name=name)
2707
2708
2709@tf_export(
2710    "linalg.tensor_diag_part", v1=["linalg.tensor_diag_part", "diag_part"])
2711@dispatch.add_dispatch_support
2712@deprecation.deprecated_endpoints("diag_part")
2713def tensor_diag_part(
2714    input,  # pylint:disable=redefined-builtin
2715    name=None):
2716  """Returns the diagonal part of the tensor.
2717
2718  This operation returns a tensor with the `diagonal` part
2719  of the `input`. The `diagonal` part is computed as follows:
2720
2721  Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a
2722  tensor of rank `k` with dimensions `[D1,..., Dk]` where:
2723
2724  `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`.
2725
2726  For a rank 2 tensor, `linalg.diag_part` and `linalg.tensor_diag_part`
2727  produce the same result. For rank 3 and higher, linalg.diag_part extracts
2728  the diagonal of each inner-most matrix in the tensor. An example where
2729  they differ is given below.
2730
2731  >>> x = [[[[1111,1112],[1121,1122]],
2732  ...       [[1211,1212],[1221,1222]]],
2733  ...      [[[2111, 2112], [2121, 2122]],
2734  ...       [[2211, 2212], [2221, 2222]]]
2735  ...      ]
2736  >>> tf.linalg.tensor_diag_part(x)
2737  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2738  array([[1111, 1212],
2739         [2121, 2222]], dtype=int32)>
2740  >>> tf.linalg.diag_part(x).shape
2741  TensorShape([2, 2, 2])
2742
2743  Args:
2744    input: A `Tensor` with rank `2k`.
2745    name: A name for the operation (optional).
2746
2747  Returns:
2748    A Tensor containing diagonals of `input`. Has the same type as `input`, and
2749    rank `k`.
2750  """
2751  return gen_array_ops.diag_part(input=input, name=name)
2752
2753
2754@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
2755@dispatch.add_dispatch_support
2756@deprecation.deprecated_endpoints("matrix_set_diag")
2757def matrix_set_diag(
2758    input,  # pylint:disable=redefined-builtin
2759    diagonal,
2760    name="set_diag",
2761    k=0,
2762    align="RIGHT_LEFT"):
2763  """Returns a batched matrix tensor with new batched diagonal values.
2764
2765  Given `input` and `diagonal`, this operation returns a tensor with the
2766  same shape and values as `input`, except for the specified diagonals of the
2767  innermost matrices. These will be overwritten by the values in `diagonal`.
2768
2769  `input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
2770  `k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
2771  Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
2772  `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
2773  `max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
2774  `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2775
2776  The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
2777  If `k` is scalar or `k[0] == k[1]`:
2778
2779  ```
2780  output[i, j, ..., l, m, n]
2781    = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
2782      input[i, j, ..., l, m, n]              ; otherwise
2783  ```
2784
2785  Otherwise,
2786
2787  ```
2788  output[i, j, ..., l, m, n]
2789    = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2790      input[i, j, ..., l, m, n]                         ; otherwise
2791  ```
2792  where `d = n - m`, `diag_index = k[1] - d`, and
2793  `index_in_diag = n - max(d, 0) + offset`.
2794
2795  `offset` is zero except when the alignment of the diagonal is to the right.
2796  ```
2797  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2798                                             and `d >= 0`) or
2799                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2800                                             and `d <= 0`)
2801           0                          ; otherwise
2802  ```
2803  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2804
2805  For example:
2806
2807  ```
2808  # The main diagonal.
2809  input = np.array([[[7, 7, 7, 7],              # Input shape: (2, 3, 4)
2810                     [7, 7, 7, 7],
2811                     [7, 7, 7, 7]],
2812                    [[7, 7, 7, 7],
2813                     [7, 7, 7, 7],
2814                     [7, 7, 7, 7]]])
2815  diagonal = np.array([[1, 2, 3],               # Diagonal shape: (2, 3)
2816                       [4, 5, 6]])
2817  tf.matrix_set_diag(input, diagonal)
2818    ==> [[[1, 7, 7, 7],  # Output shape: (2, 3, 4)
2819          [7, 2, 7, 7],
2820          [7, 7, 3, 7]],
2821         [[4, 7, 7, 7],
2822          [7, 5, 7, 7],
2823          [7, 7, 6, 7]]]
2824
2825  # A superdiagonal (per batch).
2826  tf.matrix_set_diag(input, diagonal, k = 1)
2827    ==> [[[7, 1, 7, 7],  # Output shape: (2, 3, 4)
2828          [7, 7, 2, 7],
2829          [7, 7, 7, 3]],
2830         [[7, 4, 7, 7],
2831          [7, 7, 5, 7],
2832          [7, 7, 7, 6]]]
2833
2834  # A band of diagonals.
2835  diagonals = np.array([[[9, 1, 0],  # Diagonal shape: (2, 4, 3)
2836                         [6, 5, 8],
2837                         [1, 2, 3],
2838                         [0, 4, 5]],
2839                        [[1, 2, 0],
2840                         [5, 6, 4],
2841                         [6, 1, 2],
2842                         [0, 3, 4]]])
2843  tf.matrix_set_diag(input, diagonals, k = (-1, 2))
2844    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
2845          [4, 2, 5, 1],
2846          [7, 5, 3, 8]],
2847         [[6, 5, 1, 7],
2848          [3, 1, 6, 2],
2849          [7, 4, 2, 4]]]
2850
2851  # RIGHT_LEFT alignment.
2852  diagonals = np.array([[[0, 9, 1],  # Diagonal shape: (2, 4, 3)
2853                         [6, 5, 8],
2854                         [1, 2, 3],
2855                         [4, 5, 0]],
2856                        [[0, 1, 2],
2857                         [5, 6, 4],
2858                         [6, 1, 2],
2859                         [3, 4, 0]]])
2860  tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="RIGHT_LEFT")
2861    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
2862          [4, 2, 5, 1],
2863          [7, 5, 3, 8]],
2864         [[6, 5, 1, 7],
2865          [3, 1, 6, 2],
2866          [7, 4, 2, 4]]]
2867
2868  ```
2869
2870  Args:
2871    input: A `Tensor` with rank `k + 1`, where `k >= 1`.
2872    diagonal:  A `Tensor` with rank `k`, when `d_lower == d_upper`, or `k + 1`,
2873      otherwise. `k >= 1`.
2874    name: A name for the operation (optional).
2875    k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2876      main diagonal, and negative value means subdiagonals. `k` can be a single
2877      integer (for a single diagonal) or a pair of integers specifying the low
2878      and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2879    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2880      `align` is a string specifying how superdiagonals and subdiagonals should
2881      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2882      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2883      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2884      the left (right-pads the row). It is the packing format LAPACK uses.
2885      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2886  """
2887  return gen_array_ops.matrix_set_diag_v3(
2888      input=input, diagonal=diagonal, k=k, align=align, name=name)
2889
2890
2891# pylint: enable=invalid-name
2892
2893
2894def _constant_if_small(value, shape, dtype, name):
2895  try:
2896    if np.prod(shape) < 1000:
2897      return constant(value, shape=shape, dtype=dtype, name=name)
2898  except TypeError:
2899    # Happens when shape is a Tensor, list with Tensor elements, etc.
2900    pass
2901  return None
2902
2903
2904def _tag_zeros_tensor(fun):
2905  """ Tags the result of function by setting _is_zeros_tensor attribute.
2906
2907  This is useful to compute Hessians of fused ops such as cross_entropy.
2908  """
2909
2910  def wrapped(*args, **kwargs):
2911    tensor = fun(*args, **kwargs)
2912    tensor._is_zeros_tensor = True
2913    return tensor
2914
2915  return tf_decorator.make_decorator(fun, wrapped)
2916
2917
2918@tf_export("zeros")
2919@dispatch.add_dispatch_support
2920@_tag_zeros_tensor
2921def zeros(shape, dtype=dtypes.float32, name=None):
2922  """Creates a tensor with all elements set to zero.
2923
2924  See also `tf.zeros_like`, `tf.ones`, `tf.fill`, `tf.eye`.
2925
2926  This operation returns a tensor of type `dtype` with shape `shape` and
2927  all elements set to zero.
2928
2929  >>> tf.zeros([3, 4], tf.int32)
2930  <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
2931  array([[0, 0, 0, 0],
2932         [0, 0, 0, 0],
2933         [0, 0, 0, 0]], dtype=int32)>
2934
2935  Args:
2936    shape: A `list` of integers, a `tuple` of integers, or
2937      a 1-D `Tensor` of type `int32`.
2938    dtype: The DType of an element in the resulting `Tensor`.
2939    name: Optional string. A name for the operation.
2940
2941  Returns:
2942    A `Tensor` with all elements set to zero.
2943  """
2944  dtype = dtypes.as_dtype(dtype).base_dtype
2945  with ops.name_scope(name, "zeros", [shape]) as name:
2946    if dtype == dtypes.bool:
2947      zero = False
2948    elif dtype == dtypes.string:
2949      zero = ""
2950    elif dtype.is_quantized:
2951      zero = np.zeros([]).astype(dtype.as_numpy_dtype)
2952    else:
2953      zero = 0
2954
2955    if not isinstance(shape, ops.Tensor):
2956      try:
2957        if not context.executing_eagerly():
2958          # Create a constant if it won't be very big. Otherwise create a fill
2959          # op to prevent serialized GraphDefs from becoming too large.
2960          output = _constant_if_small(zero, shape, dtype, name)
2961          if output is not None:
2962            return output
2963
2964        # Go through tensor shapes to get int64-if-needed semantics
2965        shape = constant_op._tensor_shape_tensor_conversion_function(
2966            tensor_shape.TensorShape(shape))
2967      except (TypeError, ValueError):
2968        # Happens when shape is a list with tensor elements
2969        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
2970    if not shape._shape_tuple():
2971      shape = reshape(shape, [-1])  # Ensure it's a vector
2972    output = fill(shape, constant(zero, dtype=dtype), name=name)
2973  assert output.dtype.base_dtype == dtype
2974  return output
2975
2976
2977@tf_export(v1=["zeros_like"])
2978@dispatch.add_dispatch_support
2979def zeros_like(tensor, dtype=None, name=None, optimize=True):
2980  """Creates a tensor with all elements set to zero.
2981
2982  See also `tf.zeros`.
2983
2984  Given a single tensor (`tensor`), this operation returns a tensor of the
2985  same type and shape as `tensor` with all elements set to zero. Optionally,
2986  you can use `dtype` to specify a new type for the returned tensor.
2987
2988  Examples:
2989
2990    >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
2991    >>> tf.zeros_like(tensor)
2992    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
2993    array([[0, 0, 0],
2994           [0, 0, 0]], dtype=int32)>
2995
2996    >>> tf.zeros_like(tensor, dtype=tf.float32)
2997    <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
2998    array([[0., 0., 0.],
2999           [0., 0., 0.]], dtype=float32)>
3000
3001  Args:
3002    tensor: A `Tensor`.
3003    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3004      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3005      `complex64`, `complex128`, `bool` or `string`. (optional)
3006    name: A name for the operation (optional).
3007    optimize: if `True`, attempt to statically determine the shape of `tensor`
3008      and encode it as a constant. (optional, defaults to `True`)
3009
3010  Returns:
3011    A `Tensor` with all elements set to zero.
3012  """
3013  return zeros_like_impl(tensor, dtype, name, optimize)
3014
3015
3016@tf_export("zeros_like", v1=[])
3017@dispatch.add_dispatch_support
3018def zeros_like_v2(
3019    input,  # pylint: disable=redefined-builtin
3020    dtype=None,
3021    name=None):
3022  """Creates a tensor with all elements set to zero.
3023
3024  See also `tf.zeros`.
3025
3026  Given a single tensor or array-like object (`input`), this operation returns
3027  a tensor of the same type and shape as `input` with all elements set to zero.
3028  Optionally, you can use `dtype` to specify a new type for the returned tensor.
3029
3030  Examples:
3031
3032    >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3033    >>> tf.zeros_like(tensor)
3034    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3035    array([[0, 0, 0],
3036           [0, 0, 0]], dtype=int32)>
3037
3038    >>> tf.zeros_like(tensor, dtype=tf.float32)
3039    <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
3040    array([[0., 0., 0.],
3041           [0., 0., 0.]], dtype=float32)>
3042
3043    >>> tf.zeros_like([[1, 2, 3], [4, 5, 6]])
3044    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3045    array([[0, 0, 0],
3046           [0, 0, 0]], dtype=int32)>
3047
3048  Args:
3049    input: A `Tensor` or array-like object.
3050    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3051      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3052      `complex64`, `complex128`, `bool` or `string` (optional).
3053    name: A name for the operation (optional).
3054
3055  Returns:
3056    A `Tensor` with all elements set to zero.
3057  """
3058  return zeros_like_impl(input, dtype, name, optimize=True)
3059
3060
3061@_tag_zeros_tensor
3062def zeros_like_impl(tensor, dtype, name, optimize=True):
3063  """Internal implementation for the v1/v2 zeros_like API calls."""
3064  with ops.name_scope(name, "zeros_like", [tensor]) as name:
3065    if not tensor_util.is_tf_type(tensor):
3066      tensor = ops.convert_to_tensor(tensor, name="tensor")
3067    tensor_shape = tensor.shape
3068    tensor_dtype = tensor.dtype
3069
3070    if context.executing_eagerly():
3071      if dtype is not None and dtype != tensor_dtype:
3072        return zeros(
3073            shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
3074      return gen_array_ops.zeros_like(tensor, name=name)
3075
3076    # For now, variant types must be created via zeros_like; as we need to
3077    # pass the input variant object to the proper zeros callback.
3078
3079    if (optimize and tensor_shape.is_fully_defined() and
3080        tensor_dtype != dtypes.variant):
3081      # We can produce a zeros tensor independent of the value of 'tensor',
3082      # since the shape is known statically.
3083      return zeros(tensor_shape, dtype=dtype or tensor_dtype, name=name)
3084
3085    if dtype is not None and dtype != tensor_dtype and dtype != dtypes.variant:
3086      return zeros(
3087          shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
3088    else:
3089      return gen_array_ops.zeros_like(tensor, name=name)
3090
3091
3092@tf_export(v1=["ones_like"])
3093@dispatch.add_dispatch_support
3094def ones_like(tensor, dtype=None, name=None, optimize=True):
3095  """Creates a tensor with all elements set to 1.
3096
3097  See also `tf.ones`.
3098
3099  Given a single tensor (`tensor`), this operation returns a tensor of the same
3100  type and shape as `tensor` with all elements set to 1. Optionally, you can
3101  specify a new type (`dtype`) for the returned tensor.
3102
3103  For example:
3104
3105  ```python
3106  tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3107  tf.ones_like(tensor)  # [[1, 1, 1], [1, 1, 1]]
3108  ```
3109
3110  Args:
3111    tensor: A `Tensor`.
3112    dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
3113      `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, `complex64`,
3114      `complex128` or `bool`.
3115    name: A name for the operation (optional).
3116    optimize: if true, attempt to statically determine the shape of 'tensor' and
3117      encode it as a constant.
3118
3119  Returns:
3120    A `Tensor` with all elements set to 1.
3121  """
3122  return ones_like_impl(tensor, dtype, name, optimize)
3123
3124
3125@tf_export("ones_like", v1=[])
3126@dispatch.add_dispatch_support
3127def ones_like_v2(
3128    input,  # pylint: disable=redefined-builtin
3129    dtype=None,
3130    name=None):
3131  """Creates a tensor of all ones that has the same shape as the input.
3132
3133  See also `tf.ones`.
3134
3135  Given a single tensor (`tensor`), this operation returns a tensor of the
3136  same type and shape as `tensor` with all elements set to 1. Optionally,
3137  you can use `dtype` to specify a new type for the returned tensor.
3138
3139  For example:
3140
3141  >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3142  >>> tf.ones_like(tensor)
3143  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3144    array([[1, 1, 1],
3145           [1, 1, 1]], dtype=int32)>
3146
3147  Args:
3148    input: A `Tensor`.
3149    dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3150      `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3151      `complex64`, `complex128`, `bool` or `string`.
3152    name: A name for the operation (optional).
3153
3154  Returns:
3155    A `Tensor` with all elements set to one.
3156  """
3157  return ones_like_impl(input, dtype, name, optimize=True)
3158
3159
3160def ones_like_impl(tensor, dtype, name, optimize=True):
3161  """Internal implementation for the v1/v2 ones_like API calls."""
3162  with ops.name_scope(name, "ones_like", [tensor]) as name:
3163    tensor = ops.convert_to_tensor(tensor, name="tensor")
3164    ones_shape = shape_internal(tensor, optimize=optimize)
3165    if dtype is None:
3166      dtype = tensor.dtype
3167    ret = ones(ones_shape, dtype=dtype, name=name)
3168    if not context.executing_eagerly():
3169      ret.set_shape(tensor.get_shape())
3170    return ret
3171
3172
3173@tf_export("ones")
3174@dispatch.add_dispatch_support
3175def ones(shape, dtype=dtypes.float32, name=None):
3176  """Creates a tensor with all elements set to one (1).
3177
3178  See also `tf.ones_like`, `tf.zeros`, `tf.fill`, `tf.eye`.
3179
3180  This operation returns a tensor of type `dtype` with shape `shape` and
3181  all elements set to one.
3182
3183  >>> tf.ones([3, 4], tf.int32)
3184  <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3185  array([[1, 1, 1, 1],
3186         [1, 1, 1, 1],
3187         [1, 1, 1, 1]], dtype=int32)>
3188
3189  Args:
3190    shape: A `list` of integers, a `tuple` of integers, or
3191      a 1-D `Tensor` of type `int32`.
3192    dtype: Optional DType of an element in the resulting `Tensor`. Default is
3193      `tf.float32`.
3194    name: Optional string. A name for the operation.
3195
3196  Returns:
3197    A `Tensor` with all elements set to one (1).
3198  """
3199  dtype = dtypes.as_dtype(dtype).base_dtype
3200  with ops.name_scope(name, "ones", [shape]) as name:
3201    if dtype == dtypes.bool:
3202      one = True
3203    elif dtype.is_quantized:
3204      one = np.ones([]).astype(dtype.as_numpy_dtype)
3205    else:
3206      one = 1
3207    if not isinstance(shape, ops.Tensor):
3208      try:
3209        if not context.executing_eagerly():
3210          # Create a constant if it won't be very big. Otherwise create a fill
3211          # op to prevent serialized GraphDefs from becoming too large.
3212          output = _constant_if_small(one, shape, dtype, name)
3213          if output is not None:
3214            return output
3215
3216        # Go through tensor shapes to get int64-if-needed semantics
3217        shape = constant_op._tensor_shape_tensor_conversion_function(
3218            tensor_shape.TensorShape(shape))
3219      except (TypeError, ValueError):
3220        # Happens when shape is a list with tensor elements
3221        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
3222    if not shape._shape_tuple():
3223      shape = reshape(shape, [-1])  # Ensure it's a vector
3224    output = fill(shape, constant(one, dtype=dtype), name=name)
3225  assert output.dtype.base_dtype == dtype
3226  return output
3227
3228
3229@tf_export(v1=["placeholder"])
3230def placeholder(dtype, shape=None, name=None):
3231  """Inserts a placeholder for a tensor that will be always fed.
3232
3233  **Important**: This tensor will produce an error if evaluated. Its value must
3234  be fed using the `feed_dict` optional argument to `Session.run()`,
3235  `Tensor.eval()`, or `Operation.run()`.
3236
3237  For example:
3238
3239  ```python
3240  x = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024))
3241  y = tf.matmul(x, x)
3242
3243  with tf.compat.v1.Session() as sess:
3244    print(sess.run(y))  # ERROR: will fail because x was not fed.
3245
3246    rand_array = np.random.rand(1024, 1024)
3247    print(sess.run(y, feed_dict={x: rand_array}))  # Will succeed.
3248  ```
3249
3250  @compatibility(eager)
3251  Placeholders are not compatible with eager execution.
3252  @end_compatibility
3253
3254  Args:
3255    dtype: The type of elements in the tensor to be fed.
3256    shape: The shape of the tensor to be fed (optional). If the shape is not
3257      specified, you can feed a tensor of any shape.
3258    name: A name for the operation (optional).
3259
3260  Returns:
3261    A `Tensor` that may be used as a handle for feeding a value, but not
3262    evaluated directly.
3263
3264  Raises:
3265    RuntimeError: if eager execution is enabled
3266  """
3267  if context.executing_eagerly():
3268    raise RuntimeError("tf.placeholder() is not compatible with "
3269                       "eager execution.")
3270
3271  return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
3272
3273
3274@tf_export(v1=["placeholder_with_default"])
3275def placeholder_with_default(input, shape, name=None):  # pylint: disable=redefined-builtin
3276  """A placeholder op that passes through `input` when its output is not fed.
3277
3278  Args:
3279    input: A `Tensor`. The default value to produce when output is not fed.
3280    shape: A `tf.TensorShape` or list of `int`s. The (possibly partial) shape of
3281      the tensor.
3282    name: A name for the operation (optional).
3283
3284  Returns:
3285    A `Tensor`. Has the same type as `input`.
3286  """
3287  return gen_array_ops.placeholder_with_default(input, shape, name)
3288
3289
3290@tf_export(v1=["sparse.placeholder", "sparse_placeholder"])
3291@deprecation.deprecated_endpoints("sparse_placeholder")
3292def sparse_placeholder(dtype, shape=None, name=None):
3293  """Inserts a placeholder for a sparse tensor that will be always fed.
3294
3295  **Important**: This sparse tensor will produce an error if evaluated.
3296  Its value must be fed using the `feed_dict` optional argument to
3297  `Session.run()`, `Tensor.eval()`, or `Operation.run()`.
3298
3299  For example:
3300
3301  ```python
3302  x = tf.compat.v1.sparse.placeholder(tf.float32)
3303  y = tf.sparse.reduce_sum(x)
3304
3305  with tf.compat.v1.Session() as sess:
3306    print(sess.run(y))  # ERROR: will fail because x was not fed.
3307
3308    indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)
3309    values = np.array([1.0, 2.0], dtype=np.float32)
3310    shape = np.array([7, 9, 2], dtype=np.int64)
3311    print(sess.run(y, feed_dict={
3312      x: tf.compat.v1.SparseTensorValue(indices, values, shape)}))  # Will
3313      succeed.
3314    print(sess.run(y, feed_dict={
3315      x: (indices, values, shape)}))  # Will succeed.
3316
3317    sp = tf.sparse.SparseTensor(indices=indices, values=values,
3318                                dense_shape=shape)
3319    sp_value = sp.eval(session=sess)
3320    print(sess.run(y, feed_dict={x: sp_value}))  # Will succeed.
3321  ```
3322
3323  @compatibility{eager} Placeholders are not compatible with eager execution.
3324
3325  Args:
3326    dtype: The type of `values` elements in the tensor to be fed.
3327    shape: The shape of the tensor to be fed (optional). If the shape is not
3328      specified, you can feed a sparse tensor of any shape.
3329    name: A name for prefixing the operations (optional).
3330
3331  Returns:
3332    A `SparseTensor` that may be used as a handle for feeding a value, but not
3333    evaluated directly.
3334
3335  Raises:
3336    RuntimeError: if eager execution is enabled
3337  """
3338  if context.executing_eagerly():
3339    raise RuntimeError("`sparse_placeholder` is not compatible with "
3340                       "eager execution.")
3341
3342  shape_name = (name + "/shape") if name is not None else None
3343  default_shape_name = (name + "/shape_default") if name is not None else None
3344  if shape is None:
3345    rank = None
3346    dense_shape = placeholder(dtypes.int64, shape=[rank], name=shape_name)
3347    dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
3348  else:
3349    if isinstance(shape, ops.Tensor):
3350      rank = shape.get_shape()[0]
3351      dense_shape_default = tensor_util.constant_value_as_shape(shape)
3352    else:
3353      rank = len(shape)
3354      # determine the shape, to override the `.shape` property of the
3355      # `SparseTensor`
3356      dense_shape_default = tensor_shape.TensorShape(
3357          tuple(None if dim == -1 else dim for dim in shape))
3358      shape = tuple(tensor_shape.dimension_value(dim) for dim in shape)
3359      shape = tuple(-1 if dim is None else dim for dim in shape)
3360      shape = ops.convert_to_tensor(
3361          shape, dtype=dtypes.int64, name=default_shape_name)
3362
3363    # `dense_shape` needs to be feedable (for users that treat this as an
3364    # actual placeholder). `constant_value_as_shape` sets constants to
3365    # not-feedable. placeholder_with_default works, but blocks `SparseTensor`
3366    # from reading the default value back out.
3367    dense_shape = placeholder_with_default(
3368        shape, shape=shape.shape, name=shape_name)
3369
3370  result = sparse_tensor.SparseTensor(
3371      values=placeholder(
3372          dtype,
3373          shape=[None],
3374          name=(name + "/values") if name is not None else None),
3375      indices=placeholder(
3376          dtypes.int64,
3377          shape=[None, rank],
3378          name=(name + "/indices") if name is not None else None),
3379      dense_shape=dense_shape)
3380
3381  # Now the SparseTensor.shape is a list of `None`s, since it couldn't read the
3382  # default shape out of the placeholder. Override that
3383  # shape to be the value determined here, so partial shapes can be
3384  # propagated.
3385  result._dense_shape_default = dense_shape_default
3386  return result
3387
3388# pylint: enable=redefined-outer-name
3389
3390
3391@tf_export("pad", v1=[])
3392@dispatch.add_dispatch_support
3393def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
3394  """Pads a tensor.
3395
3396  This operation pads a `tensor` according to the `paddings` you specify.
3397  `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3398  `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3399  many values to add before the contents of `tensor` in that dimension, and
3400  `paddings[D, 1]` indicates how many values to add after the contents of
3401  `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3402  and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3403  `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3404  no greater than `tensor.dim_size(D)`.
3405
3406  The padded size of each dimension D of the output is:
3407
3408  `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3409
3410  For example:
3411
3412  ```python
3413  t = tf.constant([[1, 2, 3], [4, 5, 6]])
3414  paddings = tf.constant([[1, 1,], [2, 2]])
3415  # 'constant_values' is 0.
3416  # rank of 't' is 2.
3417  tf.pad(t, paddings, "CONSTANT")  # [[0, 0, 0, 0, 0, 0, 0],
3418                                   #  [0, 0, 1, 2, 3, 0, 0],
3419                                   #  [0, 0, 4, 5, 6, 0, 0],
3420                                   #  [0, 0, 0, 0, 0, 0, 0]]
3421
3422  tf.pad(t, paddings, "REFLECT")  # [[6, 5, 4, 5, 6, 5, 4],
3423                                  #  [3, 2, 1, 2, 3, 2, 1],
3424                                  #  [6, 5, 4, 5, 6, 5, 4],
3425                                  #  [3, 2, 1, 2, 3, 2, 1]]
3426
3427  tf.pad(t, paddings, "SYMMETRIC")  # [[2, 1, 1, 2, 3, 3, 2],
3428                                    #  [2, 1, 1, 2, 3, 3, 2],
3429                                    #  [5, 4, 4, 5, 6, 6, 5],
3430                                    #  [5, 4, 4, 5, 6, 6, 5]]
3431  ```
3432
3433  Args:
3434    tensor: A `Tensor`.
3435    paddings: A `Tensor` of type `int32`.
3436    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3437    constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3438      same type as `tensor`.
3439    name: A name for the operation (optional).
3440
3441  Returns:
3442    A `Tensor`. Has the same type as `tensor`.
3443
3444  Raises:
3445    ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3446  """
3447  return pad(tensor, paddings, mode, name, constant_values)
3448
3449
3450@tf_export(v1=["pad"])
3451@dispatch.add_dispatch_support
3452def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0):  # pylint: disable=invalid-name
3453  """Pads a tensor.
3454
3455  This operation pads a `tensor` according to the `paddings` you specify.
3456  `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3457  `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3458  many values to add before the contents of `tensor` in that dimension, and
3459  `paddings[D, 1]` indicates how many values to add after the contents of
3460  `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3461  and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3462  `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3463  no greater than `tensor.dim_size(D)`.
3464
3465  The padded size of each dimension D of the output is:
3466
3467  `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3468
3469  For example:
3470
3471  ```python
3472  t = tf.constant([[1, 2, 3], [4, 5, 6]])
3473  paddings = tf.constant([[1, 1,], [2, 2]])
3474  # 'constant_values' is 0.
3475  # rank of 't' is 2.
3476  tf.pad(t, paddings, "CONSTANT")  # [[0, 0, 0, 0, 0, 0, 0],
3477                                   #  [0, 0, 1, 2, 3, 0, 0],
3478                                   #  [0, 0, 4, 5, 6, 0, 0],
3479                                   #  [0, 0, 0, 0, 0, 0, 0]]
3480
3481  tf.pad(t, paddings, "REFLECT")  # [[6, 5, 4, 5, 6, 5, 4],
3482                                  #  [3, 2, 1, 2, 3, 2, 1],
3483                                  #  [6, 5, 4, 5, 6, 5, 4],
3484                                  #  [3, 2, 1, 2, 3, 2, 1]]
3485
3486  tf.pad(t, paddings, "SYMMETRIC")  # [[2, 1, 1, 2, 3, 3, 2],
3487                                    #  [2, 1, 1, 2, 3, 3, 2],
3488                                    #  [5, 4, 4, 5, 6, 6, 5],
3489                                    #  [5, 4, 4, 5, 6, 6, 5]]
3490  ```
3491
3492  Args:
3493    tensor: A `Tensor`.
3494    paddings: A `Tensor` of type `int32`.
3495    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3496    name: A name for the operation (optional).
3497    constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3498      same type as `tensor`.
3499
3500  Returns:
3501    A `Tensor`. Has the same type as `tensor`.
3502
3503  Raises:
3504    ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3505  """
3506
3507  # Convert lower/mixed case to upper for NumPy compatibility
3508  # NumPy uses all lower-case modes.
3509  mode = mode.upper()
3510  if mode == "CONSTANT":
3511    # TODO(rjryan): Once the forward compatibility period (3 weeks) have passed
3512    # remove the "Pad" fallback here.
3513    if not tensor_util.is_tf_type(constant_values) and constant_values == 0:
3514      result = gen_array_ops.pad(tensor, paddings, name=name)
3515    else:
3516      result = gen_array_ops.pad_v2(
3517          tensor, paddings, constant_values, name=name)
3518  elif mode == "REFLECT":
3519    result = gen_array_ops.mirror_pad(
3520        tensor, paddings, mode="REFLECT", name=name)
3521  elif mode == "SYMMETRIC":
3522    result = gen_array_ops.mirror_pad(
3523        tensor, paddings, mode="SYMMETRIC", name=name)
3524  else:
3525    raise ValueError("Unknown padding mode: %s" % mode)
3526
3527  # Restore shape information where possible.
3528  if not context.executing_eagerly():
3529    paddings_constant = _get_paddings_constant(paddings)
3530    input_shape = (
3531        tensor_shape.TensorShape(tensor.shape)
3532        if isinstance(tensor, ops.Tensor) else result.op.inputs[0].shape)
3533    if (input_shape.ndims is not None and
3534        not result.shape.is_fully_defined() and paddings_constant is not None):
3535      new_shape = []
3536      for padding, dim in zip(paddings_constant, input_shape.as_list()):
3537        if padding is None or dim is None or any((x is None for x in padding)):
3538          new_shape.append(None)
3539        else:
3540          new_shape.append(sum(padding) + dim)
3541      result.set_shape(new_shape)
3542
3543  return result
3544
3545
3546def _get_paddings_constant(paddings):
3547  """Helper to get the constant values of the paddings arg to pad().
3548
3549  Used under V1 graph mode to facilitate computation of the shape of the output
3550  tensor of `pad()`.
3551
3552  Args:
3553    paddings: The same paddings arg as passed to pad(). Can be a Tensor, or
3554      a nested list or tuple of Tensor and/or numbers.
3555
3556  Returns:
3557    A nested list or numbers or `None`, in which `None` indicates unknown
3558    padding size.
3559  """
3560  if isinstance(paddings, ops.Tensor):
3561    return tensor_util.constant_value(paddings, partial=True)
3562  elif isinstance(paddings, (list, tuple)):
3563    return [_get_paddings_constant(x) for x in paddings]
3564  else:
3565    return paddings
3566
3567
3568@tf_export("meshgrid")
3569@dispatch.add_dispatch_support
3570def meshgrid(*args, **kwargs):
3571  """Broadcasts parameters for evaluation on an N-D grid.
3572
3573  Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
3574  of N-D coordinate arrays for evaluating expressions on an N-D grid.
3575
3576  Notes:
3577
3578  `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
3579  When the `indexing` argument is set to 'xy' (the default), the broadcasting
3580  instructions for the first two dimensions are swapped.
3581
3582  Examples:
3583
3584  Calling `X, Y = meshgrid(x, y)` with the tensors
3585
3586  ```python
3587  x = [1, 2, 3]
3588  y = [4, 5, 6]
3589  X, Y = tf.meshgrid(x, y)
3590  # X = [[1, 2, 3],
3591  #      [1, 2, 3],
3592  #      [1, 2, 3]]
3593  # Y = [[4, 4, 4],
3594  #      [5, 5, 5],
3595  #      [6, 6, 6]]
3596  ```
3597
3598  Args:
3599    *args: `Tensor`s with rank 1.
3600    **kwargs:
3601      - indexing: Either 'xy' or 'ij' (optional, default: 'xy').
3602      - name: A name for the operation (optional).
3603
3604  Returns:
3605    outputs: A list of N `Tensor`s with rank N.
3606
3607  Raises:
3608    TypeError: When no keyword arguments (kwargs) are passed.
3609    ValueError: When indexing keyword argument is not one of `xy` or `ij`.
3610  """
3611
3612  indexing = kwargs.pop("indexing", "xy")
3613  name = kwargs.pop("name", "meshgrid")
3614  if kwargs:
3615    key = list(kwargs.keys())[0]
3616    raise TypeError("'{}' is an invalid keyword argument "
3617                    "for this function".format(key))
3618
3619  if indexing not in ("xy", "ij"):
3620    raise ValueError("indexing parameter must be either 'xy' or 'ij'")
3621
3622  with ops.name_scope(name, "meshgrid", args) as name:
3623    ndim = len(args)
3624    s0 = (1,) * ndim
3625
3626    if not ndim:
3627      return []
3628
3629    # Prepare reshape by inserting dimensions with size 1 where needed
3630    output = []
3631    for i, x in enumerate(args):
3632      output.append(reshape(stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
3633    # Create parameters for broadcasting each tensor to the full size
3634    shapes = [size(x) for x in args]
3635
3636    output_dtype = ops.convert_to_tensor(args[0]).dtype.base_dtype
3637
3638    if indexing == "xy" and ndim > 1:
3639      output[0] = reshape(output[0], (1, -1) + (1,) * (ndim - 2))
3640      output[1] = reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
3641      shapes[0], shapes[1] = shapes[1], shapes[0]
3642
3643    # TODO(nolivia): improve performance with a broadcast
3644    mult_fact = ones(shapes, output_dtype)
3645    return [x * mult_fact for x in output]
3646
3647
3648NEW_AXIS = -1
3649SHRINK_AXIS = -2
3650
3651
3652# PEP-8 naming
3653# pylint: disable=invalid-name,redefined-outer-name
3654def _compute_size_of_strided_dim(shrink, spec, size):
3655  """Computes the size of a single strided slice dimension."""
3656
3657  unknown = None  # Document what None means here.
3658  use_full_range = None  # Document other use of None.
3659  # if this is a shrink axis (i.e. a non-range index)
3660  # it either will produce an error or return 1
3661  if shrink:
3662    return 1
3663  if size is unknown or size.value is unknown:
3664    return unknown
3665  size = size.value
3666  stride = spec.step
3667  if stride is not unknown:
3668    if stride == 0:
3669      return unknown
3670    stride = spec.step
3671    valid_range = [0, size] if stride > 0 else [-1, size - 1]
3672
3673    # PEP-8 naming
3674    # pylint: disable=invalid-name
3675    def canonical(x, c):
3676      if x is use_full_range:
3677        return valid_range[c] if stride > 0 else valid_range[(c + 1) & 1]
3678      else:
3679        x_fwd = size + x if x < 0 else x  # make negative indices positive
3680        return max(valid_range[0], min(valid_range[1], x_fwd))
3681
3682    begin = canonical(spec.start, 0)
3683    end = canonical(spec.stop, 1)
3684    interval_length = end - begin
3685    if interval_length == 0 or ((interval_length < 0) != (stride < 0)):
3686      return 0
3687    else:
3688      remainder = 1 if interval_length % stride != 0 else 0
3689      return interval_length // stride + remainder
3690  else:
3691    return unknown  # unknown because stride is unknown
3692
3693
3694def _TileGradShape(op):
3695  """Shape function for the TileGrad op."""
3696  multiples_shape = op.inputs[1].get_shape().with_rank(1)
3697  input_shape = op.inputs[0].get_shape().with_rank(multiples_shape[0])
3698  # NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
3699  # it is a vector of non-negative integers, and (ii) doing so allows
3700  # us to handle partially-known multiples.
3701  multiples = tensor_util.constant_value_as_shape(op.inputs[1]).with_rank(
3702      input_shape.ndims)
3703  if multiples.ndims is None:
3704    return [tensor_shape.unknown_shape()]
3705  else:
3706    output_dims = []
3707    for dim, multiple in zip(input_shape.dims, multiples.dims):
3708      output_dims.append(dim // multiple)
3709    return [tensor_shape.TensorShape(output_dims)]
3710
3711
3712@tf_export("edit_distance")
3713@dispatch.add_dispatch_support
3714def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
3715  """Computes the Levenshtein distance between sequences.
3716
3717  This operation takes variable-length sequences (`hypothesis` and `truth`),
3718  each provided as a `SparseTensor`, and computes the Levenshtein distance.
3719  You can normalize the edit distance by length of `truth` by setting
3720  `normalize` to true.
3721
3722  For example:
3723
3724  Given the following input,
3725  * `hypothesis` is a `tf.SparseTensor` of shape `[2, 1, 1]`
3726  * `truth` is a `tf.SparseTensor` of shape `[2, 2, 2]`
3727
3728  >>> hypothesis = tf.SparseTensor(
3729  ...   [[0, 0, 0],
3730  ...    [1, 0, 0]],
3731  ...   ["a", "b"],
3732  ...   (2, 1, 1))
3733  >>> truth = tf.SparseTensor(
3734  ...   [[0, 1, 0],
3735  ...    [1, 0, 0],
3736  ...    [1, 0, 1],
3737  ...    [1, 1, 0]],
3738  ...    ["a", "b", "c", "a"],
3739  ...    (2, 2, 2))
3740  >>> tf.edit_distance(hypothesis, truth, normalize=True)
3741  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3742  array([[inf, 1. ],
3743         [0.5, 1. ]], dtype=float32)>
3744
3745  The operation returns a dense Tensor of shape `[2, 2]` with
3746  edit distances normalized by `truth` lengths.
3747
3748  **Note**: It is possible to calculate edit distance between two
3749  sparse tensors with variable-length values. However, attempting to create
3750  them while eager execution is enabled will result in a `ValueError`.
3751
3752  For the following  inputs,
3753
3754  ```python
3755  # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values:
3756  #   (0,0) = ["a"]
3757  #   (1,0) = ["b"]
3758  hypothesis = tf.sparse.SparseTensor(
3759      [[0, 0, 0],
3760       [1, 0, 0]],
3761      ["a", "b"],
3762      (2, 1, 1))
3763
3764  # 'truth' is a tensor of shape `[2, 2]` with variable-length values:
3765  #   (0,0) = []
3766  #   (0,1) = ["a"]
3767  #   (1,0) = ["b", "c"]
3768  #   (1,1) = ["a"]
3769  truth = tf.sparse.SparseTensor(
3770      [[0, 1, 0],
3771       [1, 0, 0],
3772       [1, 0, 1],
3773       [1, 1, 0]],
3774      ["a", "b", "c", "a"],
3775      (2, 2, 2))
3776
3777  normalize = True
3778
3779  # The output would be a dense Tensor of shape `(2,)`, with edit distances
3780  normalized by 'truth' lengths.
3781  # output => array([0., 0.5], dtype=float32)
3782  ```
3783
3784  Args:
3785    hypothesis: A `SparseTensor` containing hypothesis sequences.
3786    truth: A `SparseTensor` containing truth sequences.
3787    normalize: A `bool`. If `True`, normalizes the Levenshtein distance by
3788      length of `truth.`
3789    name: A name for the operation (optional).
3790
3791  Returns:
3792    A dense `Tensor` with rank `R - 1`, where R is the rank of the
3793    `SparseTensor` inputs `hypothesis` and `truth`.
3794
3795  Raises:
3796    TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
3797  """
3798  if not isinstance(
3799      hypothesis,
3800      (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3801    raise TypeError("Hypothesis must be a SparseTensor.")
3802  if not isinstance(
3803      truth, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3804    raise TypeError("Truth must be a SparseTensor.")
3805
3806  return gen_array_ops.edit_distance(
3807      hypothesis.indices,
3808      hypothesis.values,
3809      hypothesis.dense_shape,
3810      truth.indices,
3811      truth.values,
3812      truth.dense_shape,
3813      normalize=normalize,
3814      name=name)
3815
3816
3817@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
3818def _FakeQuantWithMinMaxArgsGradient(op, grad):
3819  """Gradient for FakeQuantWithMinMaxArgs op."""
3820  return fake_quant_with_min_max_args_gradient(
3821      grad,
3822      op.inputs[0],
3823      min=op.get_attr("min"),
3824      max=op.get_attr("max"),
3825      num_bits=op.get_attr("num_bits"),
3826      narrow_range=op.get_attr("narrow_range"))
3827
3828
3829@ops.RegisterGradient("FakeQuantWithMinMaxVars")
3830def _FakeQuantWithMinMaxVarsGradient(op, grad):
3831  """Gradient for FakeQuantWithMinMaxVars op."""
3832  return fake_quant_with_min_max_vars_gradient(
3833      grad,
3834      op.inputs[0],
3835      op.inputs[1],
3836      op.inputs[2],
3837      num_bits=op.get_attr("num_bits"),
3838      narrow_range=op.get_attr("narrow_range"))
3839
3840
3841@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
3842def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
3843  """Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
3844  return fake_quant_with_min_max_vars_per_channel_gradient(
3845      grad,
3846      op.inputs[0],
3847      op.inputs[1],
3848      op.inputs[2],
3849      num_bits=op.get_attr("num_bits"),
3850      narrow_range=op.get_attr("narrow_range"))
3851
3852
3853@ops.RegisterGradient("QuantizeAndDequantizeV4")
3854def _QuantizeAndDequantizeV4Grad(op, grad):
3855  """Gradient for QuantizeAndDequantizeV4 op."""
3856  return quantize_and_dequantize_v4_grad(
3857      grad,
3858      op.inputs[0],
3859      op.inputs[1],
3860      op.inputs[2],
3861      axis=op.get_attr("axis"))
3862
3863
3864@ops.RegisterGradient("QuantizeAndDequantizeV4Grad")
3865def _QuantizeAndDequantizeV4GradGrad(op, grad):
3866  """Gradient for QuantizeAndDequantizeV4Grad op."""
3867  return _QuantizeAndDequantizeV4Grad(op, grad)
3868
3869
3870@tf_export("required_space_to_batch_paddings")
3871def required_space_to_batch_paddings(input_shape,
3872                                     block_shape,
3873                                     base_paddings=None,
3874                                     name=None):
3875  """Calculate padding required to make block_shape divide input_shape.
3876
3877  This function can be used to calculate a suitable paddings argument for use
3878  with space_to_batch_nd and batch_to_space_nd.
3879
3880  Args:
3881    input_shape: int32 Tensor of shape [N].
3882    block_shape: int32 Tensor of shape [N].
3883    base_paddings: Optional int32 Tensor of shape [N, 2].  Specifies the minimum
3884      amount of padding to use.  All elements must be >= 0.  If not specified,
3885      defaults to 0.
3886    name: string.  Optional name prefix.
3887
3888  Returns:
3889    (paddings, crops), where:
3890
3891    `paddings` and `crops` are int32 Tensors of rank 2 and shape [N, 2]
3892    satisfying:
3893
3894        paddings[i, 0] = base_paddings[i, 0].
3895        0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
3896        (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0
3897
3898        crops[i, 0] = 0
3899        crops[i, 1] = paddings[i, 1] - base_paddings[i, 1]
3900
3901  Raises: ValueError if called with incompatible shapes.
3902  """
3903  with ops.name_scope(name, "required_space_to_batch_paddings",
3904                      [input_shape, block_shape]):
3905    input_shape = ops.convert_to_tensor(
3906        input_shape, dtype=dtypes.int32, name="input_shape")
3907    block_shape = ops.convert_to_tensor(
3908        block_shape, dtype=dtypes.int32, name="block_shape")
3909
3910    block_shape.get_shape().assert_is_fully_defined()
3911    block_shape.get_shape().assert_has_rank(1)
3912    num_block_dims = block_shape.get_shape().dims[0].value
3913    if num_block_dims == 0:
3914      return zeros([0, 2], dtypes.int32), zeros([0, 2], dtypes.int32)
3915
3916    input_shape.get_shape().assert_is_compatible_with([num_block_dims])
3917
3918    if base_paddings is not None:
3919      base_paddings = ops.convert_to_tensor(
3920          base_paddings, dtype=dtypes.int32, name="base_paddings")
3921      base_paddings.get_shape().assert_is_compatible_with([num_block_dims, 2])
3922    else:
3923      base_paddings = zeros([num_block_dims, 2], dtypes.int32)
3924
3925    const_block_shape = tensor_util.constant_value(block_shape)
3926    const_input_shape = tensor_util.constant_value(input_shape)
3927    const_base_paddings = tensor_util.constant_value(base_paddings)
3928    if (const_block_shape is not None and const_input_shape is not None and
3929        const_base_paddings is not None):
3930      block_shape = const_block_shape
3931      input_shape = const_input_shape
3932      base_paddings = const_base_paddings
3933
3934    # Use same expression for both constant and non-constant case.
3935    pad_start = base_paddings[:, 0]
3936    orig_pad_end = base_paddings[:, 1]
3937    full_input_shape = input_shape + pad_start + orig_pad_end
3938    pad_end_extra = (block_shape - full_input_shape % block_shape) % block_shape
3939    pad_end = orig_pad_end + pad_end_extra
3940
3941    result_paddings = stack(
3942        [[pad_start[i], pad_end[i]] for i in range(num_block_dims)],
3943        name="paddings")
3944    result_crops = stack([[0, pad_end_extra[i]] for i in range(num_block_dims)],
3945                         name="crops")
3946    return result_paddings, result_crops
3947
3948
3949@tf_export(v1=["nn.space_to_batch", "space_to_batch"])
3950@dispatch.add_dispatch_support
3951@deprecation.deprecated_endpoints("space_to_batch")
3952def space_to_batch(  # pylint: disable=missing-docstring
3953    input,  # pylint: disable=redefined-builtin
3954    paddings,
3955    block_size=None,
3956    name=None,
3957    block_shape=None):  # pylint: disable=redefined-builtin
3958  block_size = deprecation.deprecated_argument_lookup("block_shape",
3959                                                      block_shape, "block_size",
3960                                                      block_size)
3961  result = space_to_batch_nd(
3962      input,
3963      paddings=paddings,
3964      block_shape=np.array([block_size, block_size], dtype=np.int64),
3965      name=name)
3966  result.set_shape(result.get_shape().with_rank(4))
3967  return result
3968
3969
3970space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
3971
3972
3973@tf_export("space_to_batch", "nn.space_to_batch", v1=[])
3974@dispatch.add_dispatch_support
3975def space_to_batch_v2(input, block_shape, paddings, name=None):  # pylint: disable=redefined-builtin
3976  return space_to_batch_nd(input, block_shape, paddings, name)
3977
3978
3979space_to_batch_v2.__doc__ = gen_array_ops.space_to_batch_nd.__doc__
3980
3981
3982@tf_export(v1=["nn.space_to_depth", "space_to_depth"])
3983@dispatch.add_dispatch_support
3984@deprecation.deprecated_endpoints("space_to_depth")
3985def space_to_depth(input, block_size, name=None, data_format="NHWC"):  # pylint: disable=redefined-builtin
3986  return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
3987
3988
3989space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
3990
3991
3992@tf_export("nn.space_to_depth", v1=[])
3993@dispatch.add_dispatch_support
3994def space_to_depth_v2(input, block_size, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
3995  return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
3996
3997
3998space_to_depth_v2.__doc__ = gen_array_ops.space_to_depth.__doc__
3999
4000
4001@tf_export(v1=["nn.depth_to_space", "depth_to_space"])
4002@dispatch.add_dispatch_support
4003@deprecation.deprecated_endpoints("depth_to_space")
4004def depth_to_space(input, block_size, name=None, data_format="NHWC"):  # pylint: disable=redefined-builtin
4005  return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
4006
4007
4008depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
4009
4010
4011@tf_export("nn.depth_to_space", v1=[])
4012@dispatch.add_dispatch_support
4013def depth_to_space_v2(input, block_size, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4014  return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
4015
4016
4017depth_to_space_v2.__doc__ = gen_array_ops.depth_to_space.__doc__
4018
4019
4020@tf_export(v1=["batch_to_space"])
4021@dispatch.add_dispatch_support
4022def batch_to_space(input, crops, block_size, name=None, block_shape=None):  # pylint: disable=redefined-builtin,missing-docstring
4023  block_size = deprecation.deprecated_argument_lookup("block_shape",
4024                                                      block_shape, "block_size",
4025                                                      block_size)
4026  result = batch_to_space_nd(
4027      input,
4028      crops=crops,
4029      block_shape=np.array([block_size, block_size], dtype=np.int64),
4030      name=name)
4031  result.set_shape(result.get_shape().with_rank(4))
4032  return result
4033
4034
4035batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__
4036
4037
4038@tf_export("batch_to_space", v1=[])
4039@dispatch.add_dispatch_support
4040def batch_to_space_v2(input, block_shape, crops, name=None):  # pylint: disable=redefined-builtin
4041  """BatchToSpace for N-D tensors of type T.
4042
4043  This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
4044  shape `block_shape + [batch]`, interleaves these blocks back into the grid
4045  defined by the spatial dimensions `[1, ..., M]`, to obtain a result with the
4046  same rank as the input.  The spatial dimensions of this intermediate result
4047  are then optionally cropped according to `crops` to produce the output.  This
4048  is the reverse of SpaceToBatch (see `tf.space_to_batch`).
4049
4050  Args:
4051    input: A N-D `Tensor` with shape `input_shape = [batch] + spatial_shape +
4052      remaining_shape`, where `spatial_shape` has M dimensions.
4053    block_shape: A 1-D `Tensor` with shape [M]. Must be one of the following
4054      types: `int32`, `int64`. All values must be >= 1. For backwards
4055      compatibility with TF 1.0, this parameter may be an int, in which case it
4056      is converted to
4057      `numpy.array([block_shape, block_shape],
4058      dtype=numpy.int64)`.
4059    crops: A  2-D `Tensor` with shape `[M, 2]`. Must be one of the
4060      following types: `int32`, `int64`. All values must be >= 0.
4061      `crops[i] = [crop_start, crop_end]` specifies the amount to crop from
4062      input dimension `i + 1`, which corresponds to spatial dimension `i`.
4063      It is required that
4064      `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`.
4065      This operation is equivalent to the following steps:
4066      1. Reshape `input` to `reshaped` of shape: [block_shape[0], ...,
4067        block_shape[M-1], batch / prod(block_shape), input_shape[1], ...,
4068        input_shape[N-1]]
4069      2. Permute dimensions of `reshaped` to produce `permuted` of shape
4070         [batch / prod(block_shape),  input_shape[1], block_shape[0], ...,
4071         input_shape[M], block_shape[M-1], input_shape[M+1],
4072        ..., input_shape[N-1]]
4073      3. Reshape `permuted` to produce `reshaped_permuted` of shape
4074         [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
4075         input_shape[M] * block_shape[M-1], input_shape[M+1], ...,
4076         input_shape[N-1]]
4077      4. Crop the start and end of dimensions `[1, ..., M]` of
4078         `reshaped_permuted` according to `crops` to produce the output
4079         of shape:
4080         [batch / prod(block_shape),  input_shape[1] *
4081           block_shape[0] - crops[0,0] - crops[0,1], ..., input_shape[M] *
4082           block_shape[M-1] - crops[M-1,0] - crops[M-1,1],  input_shape[M+1],
4083           ..., input_shape[N-1]]
4084    name: A name for the operation (optional).
4085
4086  Examples:
4087
4088  1. For the following input of shape `[4, 1, 1, 1]`,
4089     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4090
4091     ```python
4092     [[[[1]]],
4093      [[[2]]],
4094      [[[3]]],
4095      [[[4]]]]
4096     ```
4097
4098    The output tensor has shape `[1, 2, 2, 1]` and value:
4099
4100     ```
4101     x = [[[[1], [2]],
4102         [[3], [4]]]]
4103     ```
4104
4105  2. For the following input of shape `[4, 1, 1, 3]`,
4106     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4107
4108     ```python
4109     [[[1,  2,   3]],
4110      [[4,  5,   6]],
4111      [[7,  8,   9]],
4112      [[10, 11, 12]]]
4113     ```
4114
4115    The output tensor has shape `[1, 2, 2, 3]` and value:
4116
4117    ```python
4118     x = [[[[1, 2, 3], [4,  5,  6 ]],
4119           [[7, 8, 9], [10, 11, 12]]]]
4120     ```
4121
4122  3. For the following
4123     input of shape `[4, 2, 2, 1]`,
4124     `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4125
4126     ```python
4127     x = [[[[1], [3]], [[ 9], [11]]],
4128          [[[2], [4]], [[10], [12]]],
4129          [[[5], [7]], [[13], [15]]],
4130          [[[6], [8]], [[14], [16]]]]
4131     ```
4132
4133    The output tensor has shape `[1, 4, 4, 1]` and value:
4134
4135    ```python
4136     x = [[[1],  [2],  [ 3], [ 4]],
4137          [[5],  [6],  [ 7], [ 8]],
4138          [[9],  [10], [11], [12]],
4139          [[13], [14], [15], [16]]]
4140     ```
4141
4142  4. For the following input of shape
4143      `[8, 1, 3, 1]`,
4144      `block_shape = [2, 2]`, and `crops = [[0, 0], [2, 0]]`:
4145
4146      ```python
4147      x = [[[[0], [ 1], [ 3]]],
4148           [[[0], [ 9], [11]]],
4149           [[[0], [ 2], [ 4]]],
4150           [[[0], [10], [12]]],
4151           [[[0], [ 5], [ 7]]],
4152           [[[0], [13], [15]]],
4153           [[[0], [ 6], [ 8]]],
4154           [[[0], [14], [16]]]]
4155      ```
4156
4157      The output tensor has shape `[2, 2, 4, 1]` and value:
4158
4159      ```python
4160      x = [[[[ 1], [ 2], [ 3], [ 4]],
4161            [[ 5], [ 6], [ 7], [ 8]]],
4162           [[[ 9], [10], [11], [12]],
4163            [[13], [14], [15], [16]]]]
4164      ```
4165
4166  Returns:
4167    A `Tensor`. Has the same type as `input`.
4168  """
4169  if isinstance(block_shape, int):
4170    block_shape = np.array([block_shape, block_shape], dtype=np.int64)
4171
4172  return batch_to_space_nd(
4173      input=input, block_shape=block_shape, crops=crops, name=name)
4174
4175
4176@tf_export("one_hot")
4177@dispatch.add_dispatch_support
4178def one_hot(indices,
4179            depth,
4180            on_value=None,
4181            off_value=None,
4182            axis=None,
4183            dtype=None,
4184            name=None):
4185  """Returns a one-hot tensor.
4186
4187  See also `tf.fill`, `tf.eye`.
4188
4189  The locations represented by indices in `indices` take value `on_value`,
4190  while all other locations take value `off_value`.
4191
4192  `on_value` and `off_value` must have matching data types. If `dtype` is also
4193  provided, they must be the same data type as specified by `dtype`.
4194
4195  If `on_value` is not provided, it will default to the value `1` with type
4196  `dtype`
4197
4198  If `off_value` is not provided, it will default to the value `0` with type
4199  `dtype`
4200
4201  If the input `indices` is rank `N`, the output will have rank `N+1`. The
4202  new axis is created at dimension `axis` (default: the new axis is appended
4203  at the end).
4204
4205  If `indices` is a scalar the output shape will be a vector of length `depth`
4206
4207  If `indices` is a vector of length `features`, the output shape will be:
4208
4209  ```
4210    features x depth if axis == -1
4211    depth x features if axis == 0
4212  ```
4213
4214  If `indices` is a matrix (batch) with shape `[batch, features]`, the output
4215  shape will be:
4216
4217  ```
4218    batch x features x depth if axis == -1
4219    batch x depth x features if axis == 1
4220    depth x batch x features if axis == 0
4221  ```
4222
4223  If `indices` is a RaggedTensor, the 'axis' argument must be positive and refer
4224  to a non-ragged axis. The output will be equivalent to applying 'one_hot' on
4225  the values of the RaggedTensor, and creating a new RaggedTensor from the
4226  result.
4227
4228  If `dtype` is not provided, it will attempt to assume the data type of
4229  `on_value` or `off_value`, if one or both are passed in. If none of
4230  `on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
4231  value `tf.float32`.
4232
4233  Note: If a non-numeric data type output is desired (`tf.string`, `tf.bool`,
4234  etc.), both `on_value` and `off_value` _must_ be provided to `one_hot`.
4235
4236  For example:
4237
4238  ```python
4239  indices = [0, 1, 2]
4240  depth = 3
4241  tf.one_hot(indices, depth)  # output: [3 x 3]
4242  # [[1., 0., 0.],
4243  #  [0., 1., 0.],
4244  #  [0., 0., 1.]]
4245
4246  indices = [0, 2, -1, 1]
4247  depth = 3
4248  tf.one_hot(indices, depth,
4249             on_value=5.0, off_value=0.0,
4250             axis=-1)  # output: [4 x 3]
4251  # [[5.0, 0.0, 0.0],  # one_hot(0)
4252  #  [0.0, 0.0, 5.0],  # one_hot(2)
4253  #  [0.0, 0.0, 0.0],  # one_hot(-1)
4254  #  [0.0, 5.0, 0.0]]  # one_hot(1)
4255
4256  indices = [[0, 2], [1, -1]]
4257  depth = 3
4258  tf.one_hot(indices, depth,
4259             on_value=1.0, off_value=0.0,
4260             axis=-1)  # output: [2 x 2 x 3]
4261  # [[[1.0, 0.0, 0.0],   # one_hot(0)
4262  #   [0.0, 0.0, 1.0]],  # one_hot(2)
4263  #  [[0.0, 1.0, 0.0],   # one_hot(1)
4264  #   [0.0, 0.0, 0.0]]]  # one_hot(-1)
4265
4266  indices = tf.ragged.constant([[0, 1], [2]])
4267  depth = 3
4268  tf.one_hot(indices, depth)  # output: [2 x None x 3]
4269  # [[[1., 0., 0.],
4270  #   [0., 1., 0.]],
4271  #  [[0., 0., 1.]]]
4272  ```
4273
4274  Args:
4275    indices: A `Tensor` of indices.
4276    depth: A scalar defining the depth of the one hot dimension.
4277    on_value: A scalar defining the value to fill in output when `indices[j]
4278      = i`. (default: 1)
4279    off_value: A scalar defining the value to fill in output when `indices[j]
4280      != i`. (default: 0)
4281    axis: The axis to fill (default: -1, a new inner-most axis).
4282    dtype: The data type of the output tensor.
4283    name: A name for the operation (optional).
4284
4285  Returns:
4286    output: The one-hot tensor.
4287
4288  Raises:
4289    TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
4290    TypeError: If dtype of `on_value` and `off_value` don't match one another
4291  """
4292  with ops.name_scope(
4293      name, "one_hot",
4294      [indices, depth, on_value, off_value, axis, dtype]) as name:
4295    on_exists = on_value is not None
4296    off_exists = off_value is not None
4297
4298    if on_exists:
4299      on_value = ops.convert_to_tensor(on_value, dtype_hint=dtype)
4300    if off_exists:
4301      off_value = ops.convert_to_tensor(off_value, dtype_hint=dtype)
4302
4303    on_dtype = on_value.dtype.base_dtype if on_exists else None
4304    off_dtype = off_value.dtype.base_dtype if off_exists else None
4305
4306    if on_exists or off_exists:
4307      if dtype is not None:
4308        # Ensure provided on_value and/or off_value match dtype
4309        if on_exists and on_dtype != dtype:
4310          raise TypeError("dtype {0} of on_value does not match "
4311                          "dtype parameter {1}".format(on_dtype, dtype))
4312        if off_exists and off_dtype != dtype:
4313          raise TypeError("dtype {0} of off_value does not match "
4314                          "dtype parameter {1}".format(off_dtype, dtype))
4315      else:
4316        # dtype not provided: automatically assign it
4317        dtype = on_dtype if on_exists else off_dtype
4318    elif dtype is None:
4319      # None of on_value, off_value, or dtype provided. Default dtype to float32
4320      dtype = dtypes.float32
4321
4322    if not on_exists:
4323      # on_value not provided: assign to value 1 of type dtype
4324      on_value = ops.convert_to_tensor(1, dtype, name="on_value")
4325      on_dtype = dtype
4326    if not off_exists:
4327      # off_value not provided: assign to value 0 of type dtype
4328      off_value = ops.convert_to_tensor(0, dtype, name="off_value")
4329      off_dtype = dtype
4330
4331    if on_dtype != off_dtype:
4332      raise TypeError("dtype {0} of on_value does not match "
4333                      "dtype {1} of off_value".format(on_dtype, off_dtype))
4334
4335    return gen_array_ops.one_hot(indices, depth, on_value, off_value, axis,
4336                                 name)
4337
4338
4339def _all_dimensions(x):
4340  """Returns a 1D-tensor listing all dimensions in x."""
4341  # Fast path: avoid creating Rank and Range ops if ndims is known.
4342  if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
4343    return constant_op.constant(
4344        np.arange(x.get_shape().ndims), dtype=dtypes.int32)
4345  if (isinstance(x, sparse_tensor.SparseTensor) and
4346      x.dense_shape.get_shape().is_fully_defined()):
4347    r = x.dense_shape.get_shape().dims[0].value  # sparse.dense_shape is 1-D.
4348    return constant_op.constant(np.arange(r), dtype=dtypes.int32)
4349
4350  # Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
4351  return gen_math_ops._range(0, rank(x), 1)
4352
4353
4354@tf_export("sequence_mask")
4355@dispatch.add_dispatch_support
4356def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
4357  """Returns a mask tensor representing the first N positions of each cell.
4358
4359  If `lengths` has shape `[d_1, d_2, ..., d_n]` the resulting tensor `mask` has
4360  dtype `dtype` and shape `[d_1, d_2, ..., d_n, maxlen]`, with
4361
4362  ```
4363  mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
4364  ```
4365
4366  Examples:
4367
4368  ```python
4369  tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
4370                                  #  [True, True, True, False, False],
4371                                  #  [True, True, False, False, False]]
4372
4373  tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
4374                                    #   [True, True, True]],
4375                                    #  [[True, True, False],
4376                                    #   [False, False, False]]]
4377  ```
4378
4379  Args:
4380    lengths: integer tensor, all its values <= maxlen.
4381    maxlen: scalar integer tensor, size of last dimension of returned tensor.
4382      Default is the maximum value in `lengths`.
4383    dtype: output type of the resulting tensor.
4384    name: name of the op.
4385
4386  Returns:
4387    A mask tensor of shape `lengths.shape + (maxlen,)`, cast to specified dtype.
4388  Raises:
4389    ValueError: if `maxlen` is not a scalar.
4390  """
4391  with ops.name_scope(name, "SequenceMask", [lengths, maxlen]):
4392    lengths = ops.convert_to_tensor(lengths)
4393
4394    if maxlen is None:
4395      maxlen = gen_math_ops._max(lengths, _all_dimensions(lengths))
4396      maxlen = gen_math_ops.maximum(constant(0, maxlen.dtype), maxlen)
4397    else:
4398      maxlen = ops.convert_to_tensor(maxlen)
4399    if maxlen.get_shape().ndims is not None and maxlen.get_shape().ndims != 0:
4400      raise ValueError("maxlen must be scalar for sequence_mask")
4401
4402    # The basic idea is to compare a range row vector of size maxlen:
4403    # [0, 1, 2, 3, 4]
4404    # to length as a matrix with 1 column: [[1], [3], [2]].
4405    # Because of broadcasting on both arguments this comparison results
4406    # in a matrix of size (len(lengths), maxlen)
4407    row_vector = gen_math_ops._range(
4408        constant(0, maxlen.dtype), maxlen, constant(1, maxlen.dtype))
4409    # Since maxlen >= max(lengths), it is safe to use maxlen as a cast
4410    # authoritative type. Whenever maxlen fits into tf.int32, so do the lengths.
4411    matrix = gen_math_ops.cast(expand_dims(lengths, -1), maxlen.dtype)
4412    result = row_vector < matrix
4413    if dtype is None or result.dtype.is_compatible_with(dtype):
4414      return result
4415    else:
4416      return gen_math_ops.cast(result, dtype)
4417
4418
4419@tf_export(v1=["squeeze"])
4420@dispatch.add_dispatch_support
4421@deprecation.deprecated_args(None, "Use the `axis` argument instead",
4422                             "squeeze_dims")
4423def squeeze(input, axis=None, name=None, squeeze_dims=None):
4424  # pylint: disable=redefined-builtin
4425  """Removes dimensions of size 1 from the shape of a tensor.
4426
4427  Given a tensor `input`, this operation returns a tensor of the same type with
4428  all dimensions of size 1 removed. If you don't want to remove all size 1
4429  dimensions, you can remove specific size 1 dimensions by specifying
4430  `axis`.
4431
4432  For example:
4433
4434  >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4435  >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4436  >>> print(tf.shape(tf.squeeze(t)).numpy())
4437  [2 3]
4438
4439  Or, to remove specific size 1 dimensions:
4440
4441  >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4442  >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4443  >>> print(tf.shape(tf.squeeze(t, [2, 4])).numpy())
4444  [1 2 3 1]
4445
4446  Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4447  time, where `N` is the number of elements in the squeezed dimensions.
4448
4449  Args:
4450    input: A `Tensor`. The `input` to squeeze.
4451    axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4452      squeezes the dimensions listed. The dimension index starts at 0. It is an
4453      error to squeeze a dimension that is not 1. Must be in the range
4454      `[-rank(input), rank(input))`. Must be specified if `input` is a
4455      `RaggedTensor`.
4456    name: A name for the operation (optional).
4457    squeeze_dims: Deprecated keyword argument that is now axis.
4458
4459  Returns:
4460    A `Tensor`. Has the same type as `input`.
4461    Contains the same data as `input`, but has one or more dimensions of
4462    size 1 removed.
4463
4464  Raises:
4465    ValueError: When both `squeeze_dims` and `axis` are specified.
4466  """
4467  axis = deprecation.deprecated_argument_lookup("axis", axis, "squeeze_dims",
4468                                                squeeze_dims)
4469  if np.isscalar(axis):
4470    axis = [axis]
4471  return gen_array_ops.squeeze(input, axis, name)
4472
4473
4474@tf_export("squeeze", v1=[])
4475@dispatch.add_dispatch_support
4476def squeeze_v2(input, axis=None, name=None):
4477  """Removes dimensions of size 1 from the shape of a tensor.
4478
4479  Given a tensor `input`, this operation returns a tensor of the same type with
4480  all dimensions of size 1 removed. If you don't want to remove all size 1
4481  dimensions, you can remove specific size 1 dimensions by specifying
4482  `axis`.
4483
4484  For example:
4485
4486  ```python
4487  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4488  tf.shape(tf.squeeze(t))  # [2, 3]
4489  ```
4490
4491  Or, to remove specific size 1 dimensions:
4492
4493  ```python
4494  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4495  tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]
4496  ```
4497
4498  Unlike the older op `tf.compat.v1.squeeze`, this op does not accept a
4499  deprecated `squeeze_dims` argument.
4500
4501  Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4502  time, where `N` is the number of elements in the squeezed dimensions.
4503
4504  Args:
4505    input: A `Tensor`. The `input` to squeeze.
4506    axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4507      squeezes the dimensions listed. The dimension index starts at 0. It is an
4508      error to squeeze a dimension that is not 1. Must be in the range
4509      `[-rank(input), rank(input))`. Must be specified if `input` is a
4510      `RaggedTensor`.
4511    name: A name for the operation (optional).
4512
4513  Returns:
4514    A `Tensor`. Has the same type as `input`.
4515    Contains the same data as `input`, but has one or more dimensions of
4516    size 1 removed.
4517
4518  Raises:
4519    ValueError: The input cannot be converted to a tensor, or the specified
4520      axis cannot be squeezed.
4521  """
4522  # pylint: disable=redefined-builtin
4523  return squeeze(input, axis, name)
4524
4525
4526@tf_export(v1=["where"])
4527@dispatch.add_dispatch_support
4528def where(condition, x=None, y=None, name=None):
4529  """Return the elements, either from `x` or `y`, depending on the `condition`.
4530
4531  If both `x` and `y` are None, then this operation returns the coordinates of
4532  true elements of `condition`.  The coordinates are returned in a 2-D tensor
4533  where the first dimension (rows) represents the number of true elements, and
4534  the second dimension (columns) represents the coordinates of the true
4535  elements. Keep in mind, the shape of the output tensor can vary depending on
4536  how many true values there are in input. Indices are output in row-major
4537  order.
4538
4539  If both non-None, `x` and `y` must have the same shape.
4540  The `condition` tensor must be a scalar if `x` and `y` are scalar.
4541  If `x` and `y` are tensors of higher rank, then `condition` must be either a
4542  vector with size matching the first dimension of `x`, or must have the same
4543  shape as `x`.
4544
4545  The `condition` tensor acts as a mask that chooses, based on the value at each
4546  element, whether the corresponding element / row in the output should be taken
4547  from `x` (if true) or `y` (if false).
4548
4549  If `condition` is a vector and `x` and `y` are higher rank matrices, then it
4550  chooses which row (outer dimension) to copy from `x` and `y`. If `condition`
4551  has the same shape as `x` and `y`, then it chooses which element to copy from
4552  `x` and `y`.
4553
4554  Args:
4555    condition: A `Tensor` of type `bool`
4556    x: A Tensor which may have the same shape as `condition`. If `condition` is
4557      rank 1, `x` may have higher rank, but its first dimension must match the
4558      size of `condition`.
4559    y: A `tensor` with the same shape and type as `x`.
4560    name: A name of the operation (optional)
4561
4562  Returns:
4563    A `Tensor` with the same type and shape as `x`, `y` if they are non-None.
4564    Otherwise, a `Tensor` with shape `(num_true, rank(condition))`.
4565
4566  Raises:
4567    ValueError: When exactly one of `x` or `y` is non-None.
4568  """
4569  if x is None and y is None:
4570    with ops.name_scope(name, "Where", [condition]) as name:
4571      condition = ops.convert_to_tensor(
4572          condition, preferred_dtype=dtypes.bool, name="condition")
4573      return gen_array_ops.where(condition=condition, name=name)
4574  elif x is not None and y is not None:
4575    return gen_math_ops.select(condition=condition, x=x, y=y, name=name)
4576  else:
4577    raise ValueError("x and y must both be non-None or both be None.")
4578
4579
4580@tf_export("where", v1=["where_v2"])
4581@dispatch.add_dispatch_support
4582def where_v2(condition, x=None, y=None, name=None):
4583  """Return the elements where `condition` is `True` (multiplexing `x` and `y`).
4584
4585  This operator has two modes: in one mode both `x` and `y` are provided, in
4586  another mode neither are provided. `condition` is always expected to be a
4587  `tf.Tensor` of type `bool`.
4588
4589  #### Retrieving indices of `True` elements
4590
4591  If `x` and `y` are not provided (both are None):
4592
4593  `tf.where` will return the indices of `condition` that are `True`, in
4594  the form of a 2-D tensor with shape (n, d).
4595  (Where n is the number of matching indices in `condition`,
4596  and d is the number of dimensions in `condition`).
4597
4598  Indices are output in row-major order.
4599
4600  >>> tf.where([True, False, False, True])
4601  <tf.Tensor: shape=(2, 1), dtype=int64, numpy=
4602  array([[0],
4603         [3]])>
4604
4605  >>> tf.where([[True, False], [False, True]])
4606  <tf.Tensor: shape=(2, 2), dtype=int64, numpy=
4607  array([[0, 0],
4608         [1, 1]])>
4609
4610  >>> tf.where([[[True, False], [False, True], [True, True]]])
4611  <tf.Tensor: shape=(4, 3), dtype=int64, numpy=
4612  array([[0, 0, 0],
4613         [0, 1, 1],
4614         [0, 2, 0],
4615         [0, 2, 1]])>
4616
4617  #### Multiplexing between `x` and `y`
4618
4619  If `x` and `y` are provided (both have non-None values):
4620
4621  `tf.where` will choose an output shape from the shapes of `condition`, `x`,
4622  and `y` that all three shapes are
4623  [broadcastable](https://docs.scipy.org/doc/numpy/reference/ufuncs.html) to.
4624
4625  The `condition` tensor acts as a mask that chooses whether the corresponding
4626  element / row in the output should be taken from `x`
4627  (if the element in `condition` is True) or `y` (if it is false).
4628
4629  >>> tf.where([True, False, False, True], [1,2,3,4], [100,200,300,400])
4630  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 200, 300,   4],
4631  dtype=int32)>
4632  >>> tf.where([True, False, False, True], [1,2,3,4], [100])
4633  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
4634  dtype=int32)>
4635  >>> tf.where([True, False, False, True], [1,2,3,4], 100)
4636  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
4637  dtype=int32)>
4638  >>> tf.where([True, False, False, True], 1, 100)
4639  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   1],
4640  dtype=int32)>
4641
4642  >>> tf.where(True, [1,2,3,4], 100)
4643  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
4644  dtype=int32)>
4645  >>> tf.where(False, [1,2,3,4], 100)
4646  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([100, 100, 100, 100],
4647  dtype=int32)>
4648
4649  Note that if the gradient of either branch of the tf.where generates
4650  a NaN, then the gradient of the entire tf.where will be NaN.
4651  A workaround is to use an inner tf.where to ensure the function has
4652  no asymptote, and to avoid computing a value whose gradient is NaN by
4653  replacing dangerous inputs with safe inputs.
4654
4655  Instead of this,
4656
4657  >>> y = tf.constant(-1, dtype=tf.float32)
4658  >>> tf.where(y > 0, tf.sqrt(y), y)
4659  <tf.Tensor: shape=(), dtype=float32, numpy=-1.0>
4660
4661  Use this
4662
4663  >>> tf.where(y > 0, tf.sqrt(tf.where(y > 0, y, 1)), y)
4664  <tf.Tensor: shape=(), dtype=float32, numpy=-1.0>
4665
4666  Args:
4667    condition: A `tf.Tensor` of type `bool`
4668    x: If provided, a Tensor which is of the same type as `y`, and has a shape
4669      broadcastable with `condition` and `y`.
4670    y: If provided, a Tensor which is of the same type as `x`, and has a shape
4671      broadcastable with `condition` and `x`.
4672    name: A name of the operation (optional).
4673
4674  Returns:
4675    If `x` and `y` are provided:
4676      A `Tensor` with the same type as `x` and `y`, and shape that
4677      is broadcast from `condition`, `x`, and `y`.
4678    Otherwise, a `Tensor` with shape `(num_true, dim_size(condition))`.
4679
4680  Raises:
4681    ValueError: When exactly one of `x` or `y` is non-None, or the shapes
4682      are not all broadcastable.
4683  """
4684  if x is None and y is None:
4685    with ops.name_scope(name, "Where", [condition]) as name:
4686      condition = ops.convert_to_tensor(
4687          condition, preferred_dtype=dtypes.bool, name="condition")
4688      return gen_array_ops.where(condition=condition, name=name)
4689  elif x is not None and y is not None:
4690    return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
4691  else:
4692    raise ValueError("x and y must both be non-None or both be None.")
4693
4694
4695# pylint: disable=redefined-builtin
4696@tf_export(v1=["reverse_sequence"])
4697@deprecation.deprecated_args(None,
4698                             "seq_dim is deprecated, use seq_axis instead",
4699                             "seq_dim")
4700@deprecation.deprecated_args(None,
4701                             "batch_dim is deprecated, use batch_axis instead",
4702                             "batch_dim")
4703def reverse_sequence(input,
4704                     seq_lengths,
4705                     seq_axis=None,
4706                     batch_axis=None,
4707                     name=None,
4708                     seq_dim=None,
4709                     batch_dim=None):
4710  """Reverses variable length slices.
4711
4712  This op first slices `input` along the dimension `batch_axis`, and for
4713  each slice `i`, reverses the first `seq_lengths[i]` elements along the
4714  dimension `seq_axis`.
4715
4716  The elements of `seq_lengths` must obey `seq_lengths[i] <=
4717  input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
4718  `input.dims[batch_axis]`.
4719
4720  The output slice `i` along dimension `batch_axis` is then given by
4721  input slice `i`, with the first `seq_lengths[i]` slices along
4722  dimension `seq_axis` reversed.
4723
4724  Example usage:
4725
4726  >>> seq_lengths = [7, 2, 3, 5]
4727  >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
4728  ...          [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
4729  >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
4730  >>> output
4731  <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
4732  array([[0, 0, 5, 4, 3, 2, 1, 0],
4733         [2, 1, 0, 0, 0, 0, 0, 0],
4734         [3, 2, 1, 4, 0, 0, 0, 0],
4735         [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
4736
4737  Args:
4738    input: A `Tensor`. The input to reverse.
4739    seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
4740      `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
4741      input.dims(seq_axis)`
4742    seq_axis: An `int`. The dimension which is partially reversed.
4743    batch_axis: An optional `int`. Defaults to `0`. The dimension along which
4744      reversal is performed.
4745    name: A name for the operation (optional).
4746
4747  Returns:
4748    A Tensor. Has the same type as input.
4749  """
4750  seq_axis = deprecation.deprecated_argument_lookup("seq_axis", seq_axis,
4751                                                    "seq_dim", seq_dim)
4752  batch_axis = deprecation.deprecated_argument_lookup("batch_axis", batch_axis,
4753                                                      "batch_dim", batch_dim)
4754  return gen_array_ops.reverse_sequence(
4755      input=input,
4756      seq_lengths=seq_lengths,
4757      seq_dim=seq_axis,
4758      batch_dim=batch_axis,
4759      name=name)
4760
4761
4762@tf_export("reverse_sequence", v1=[])
4763@dispatch.add_dispatch_support
4764def reverse_sequence_v2(input,
4765                        seq_lengths,
4766                        seq_axis=None,
4767                        batch_axis=None,
4768                        name=None):
4769  """Reverses variable length slices.
4770
4771  This op first slices `input` along the dimension `batch_axis`, and for
4772  each slice `i`, reverses the first `seq_lengths[i]` elements along the
4773  dimension `seq_axis`.
4774
4775  The elements of `seq_lengths` must obey `seq_lengths[i] <=
4776  input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
4777  `input.dims[batch_axis]`.
4778
4779  The output slice `i` along dimension `batch_axis` is then given by
4780  input slice `i`, with the first `seq_lengths[i]` slices along
4781  dimension `seq_axis` reversed.
4782
4783  Example usage:
4784
4785  >>> seq_lengths = [7, 2, 3, 5]
4786  >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
4787  ...          [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
4788  >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
4789  >>> output
4790  <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
4791  array([[0, 0, 5, 4, 3, 2, 1, 0],
4792         [2, 1, 0, 0, 0, 0, 0, 0],
4793         [3, 2, 1, 4, 0, 0, 0, 0],
4794         [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
4795
4796  Args:
4797    input: A `Tensor`. The input to reverse.
4798    seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
4799      `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
4800      input.dims(seq_axis)`
4801    seq_axis: An `int`. The dimension which is partially reversed.
4802    batch_axis: An optional `int`. Defaults to `0`. The dimension along which
4803      reversal is performed.
4804    name: A name for the operation (optional).
4805
4806  Returns:
4807    A Tensor. Has the same type as input.
4808  """
4809  return gen_array_ops.reverse_sequence(
4810      input=input,
4811      seq_lengths=seq_lengths,
4812      seq_dim=seq_axis,
4813      batch_dim=batch_axis,
4814      name=name)
4815
4816# pylint: enable=redefined-builtin
4817
4818
4819@tf_export(v1=["gather"])
4820@deprecation.deprecated_args(None,
4821                             ("The `validate_indices` argument has no effect. "
4822                              "Indices are always validated on CPU and never "
4823                              "validated on GPU."),
4824                             "validate_indices")
4825@dispatch.add_dispatch_support
4826def gather(params,
4827           indices,
4828           validate_indices=None,
4829           name=None,
4830           axis=None,
4831           batch_dims=0):  # pylint: disable=g-doc-args
4832  r"""Gather slices from params axis `axis` according to indices.
4833
4834  Gather slices from `params` axis `axis` according to `indices`.  `indices`
4835  must be an integer tensor of any dimension (often 1-D).
4836
4837  `Tensor.__getitem__` works for scalars, `tf.newaxis`, and
4838  [python slices](https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing)
4839
4840  `tf.gather` extends indexing to handle tensors of indices.
4841
4842  In the simplest case it's identical to scalar indexing:
4843
4844  >>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
4845  >>> params[3].numpy()
4846  b'p3'
4847  >>> tf.gather(params, 3).numpy()
4848  b'p3'
4849
4850  The most common case is to pass a single axis tensor of indices (this
4851  can't be expressed as a python slice because the indices are not sequential):
4852
4853  >>> indices = [2, 0, 2, 5]
4854  >>> tf.gather(params, indices).numpy()
4855  array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
4856
4857  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
4858  <img style="width:100%" src="https://www.tensorflow.org/images/Gather.png"
4859  alt>
4860  </div>
4861
4862  The indices can have any shape. When the `params` has 1 axis, the
4863  output shape is equal to the input shape:
4864
4865  >>> tf.gather(params, [[2, 0], [2, 5]]).numpy()
4866  array([[b'p2', b'p0'],
4867         [b'p2', b'p5']], dtype=object)
4868
4869  The `params` may also have any shape. `gather` can select slices
4870  across any axis depending on the `axis` argument (which defaults to 0).
4871  Below it is used to gather first rows, then columns from a matrix:
4872
4873  >>> params = tf.constant([[0, 1.0, 2.0],
4874  ...                       [10.0, 11.0, 12.0],
4875  ...                       [20.0, 21.0, 22.0],
4876  ...                       [30.0, 31.0, 32.0]])
4877  >>> tf.gather(params, indices=[3,1]).numpy()
4878  array([[30., 31., 32.],
4879         [10., 11., 12.]], dtype=float32)
4880  >>> tf.gather(params, indices=[2,1], axis=1).numpy()
4881  array([[ 2.,  1.],
4882         [12., 11.],
4883         [22., 21.],
4884         [32., 31.]], dtype=float32)
4885
4886  More generally: The output shape has the same shape as the input, with the
4887  indexed-axis replaced by the shape of the indices.
4888
4889  >>> def result_shape(p_shape, i_shape, axis=0):
4890  ...   return p_shape[:axis] + i_shape + p_shape[axis+1:]
4891  >>>
4892  >>> result_shape([1, 2, 3], [], axis=1)
4893  [1, 3]
4894  >>> result_shape([1, 2, 3], [7], axis=1)
4895  [1, 7, 3]
4896  >>> result_shape([1, 2, 3], [7, 5], axis=1)
4897  [1, 7, 5, 3]
4898
4899  Here are some examples:
4900
4901  >>> params.shape.as_list()
4902  [4, 3]
4903  >>> indices = tf.constant([[0, 2]])
4904  >>> tf.gather(params, indices=indices, axis=0).shape.as_list()
4905  [1, 2, 3]
4906  >>> tf.gather(params, indices=indices, axis=1).shape.as_list()
4907  [4, 1, 2]
4908
4909  >>> params = tf.random.normal(shape=(5, 6, 7, 8))
4910  >>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
4911  >>> result = tf.gather(params, indices, axis=2)
4912  >>> result.shape.as_list()
4913  [5, 6, 10, 11, 8]
4914
4915  This is because each index takes a slice from `params`, and
4916  places it at the corresponding location in the output. For the above example
4917
4918  >>> # For any location in indices
4919  >>> a, b = 0, 1
4920  >>> tf.reduce_all(
4921  ...     # the corresponding slice of the result
4922  ...     result[:, :, a, b, :] ==
4923  ...     # is equal to the slice of `params` along `axis` at the index.
4924  ...     params[:, :, indices[a, b], :]
4925  ... ).numpy()
4926  True
4927
4928  ### Batching:
4929
4930  The `batch_dims` argument lets you gather different items from each element
4931  of a batch.
4932
4933  Using `batch_dims=1` is equivalent to having an outer loop over the first
4934  axis of `params` and `indices`:
4935
4936  >>> params = tf.constant([
4937  ...     [0, 0, 1, 0, 2],
4938  ...     [3, 0, 0, 0, 4],
4939  ...     [0, 5, 0, 6, 0]])
4940  >>> indices = tf.constant([
4941  ...     [2, 4],
4942  ...     [0, 4],
4943  ...     [1, 3]])
4944
4945  >>> tf.gather(params, indices, axis=1, batch_dims=1).numpy()
4946  array([[1, 2],
4947         [3, 4],
4948         [5, 6]], dtype=int32)
4949
4950  This is is equivalent to:
4951
4952  >>> def manually_batched_gather(params, indices, axis):
4953  ...   batch_dims=1
4954  ...   result = []
4955  ...   for p,i in zip(params, indices):
4956  ...     r = tf.gather(p, i, axis=axis-batch_dims)
4957  ...     result.append(r)
4958  ...   return tf.stack(result)
4959  >>> manually_batched_gather(params, indices, axis=1).numpy()
4960  array([[1, 2],
4961         [3, 4],
4962         [5, 6]], dtype=int32)
4963
4964  Higher values of `batch_dims` are equivalent to multiple nested loops over
4965  the outer axes of `params` and `indices`. So the overall shape function is
4966
4967  >>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):
4968  ...   return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]
4969  >>>
4970  >>> batched_result_shape(
4971  ...     p_shape=params.shape.as_list(),
4972  ...     i_shape=indices.shape.as_list(),
4973  ...     axis=1,
4974  ...     batch_dims=1)
4975  [3, 2]
4976
4977  >>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()
4978  [3, 2]
4979
4980  This comes up naturally if you need to use the indices of an operation like
4981  `tf.argsort`, or `tf.math.top_k` where the last dimension of the indices
4982  indexes into the last dimension of input, at the corresponding location.
4983  In this case you can use `tf.gather(values, indices, batch_dims=-1)`.
4984
4985  See also:
4986
4987  * `tf.Tensor.__getitem__`: The direct tensor index operation (`t[]`), handles
4988    scalars and python-slices `tensor[..., 7, 1:-1]`
4989  * `tf.scatter`: A collection of operations similar to `__setitem__`
4990    (`t[i] = x`)
4991  * `tf.gather_nd`: An operation similar to `tf.gather` but gathers across
4992    multiple axis at once (it can gather elements of a matrix instead of rows
4993    or columns)
4994  * `tf.boolean_mask`, `tf.where`: Binary indexing.
4995  * `tf.slice` and `tf.strided_slice`: For lower level access to the
4996    implementation of `__getitem__`'s python-slice handling (`t[1:-1:2]`)
4997
4998  Args:
4999    params: The `Tensor` from which to gather values. Must be at least rank
5000      `axis + 1`.
5001    indices: The index `Tensor`.  Must be one of the following types: `int32`,
5002      `int64`. The values must be in range `[0, params.shape[axis])`.
5003    validate_indices: Deprecated, does nothing. Indices are always validated on
5004      CPU, never validated on GPU.
5005
5006      Caution: On CPU, if an out of bound index is found, an error is raised.
5007      On GPU, if an out of bound index is found, a 0 is stored in the
5008      corresponding output value.
5009    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5010      `axis` in `params` to gather `indices` from. Must be greater than or equal
5011      to `batch_dims`.  Defaults to the first non-batch dimension. Supports
5012      negative indexes.
5013    batch_dims: An `integer`.  The number of batch dimensions.  Must be less
5014      than or equal to `rank(indices)`.
5015    name: A name for the operation (optional).
5016
5017  Returns:
5018    A `Tensor`. Has the same type as `params`.
5019  """
5020  del validate_indices
5021
5022  if axis is None:
5023    axis = batch_dims
5024  if tensor_util.constant_value(axis) != 0:
5025    return gen_array_ops.gather_v2(
5026        params, indices, axis, batch_dims=batch_dims, name=name)
5027  try:
5028    # TODO(apassos) find a less bad way of detecting resource variables
5029    # without introducing a circular dependency.
5030    return params.sparse_read(indices, name=name)
5031  except AttributeError:
5032    return gen_array_ops.gather_v2(params, indices, axis, name=name)
5033
5034
5035@tf_export("gather", v1=[])
5036@dispatch.add_dispatch_support
5037def gather_v2(params,
5038              indices,
5039              validate_indices=None,
5040              axis=None,
5041              batch_dims=0,
5042              name=None):
5043  return gather(
5044      params,
5045      indices,
5046      validate_indices=validate_indices,
5047      name=name,
5048      axis=axis,
5049      batch_dims=batch_dims)
5050
5051
5052gather_v2.__doc__ = gather.__doc__
5053
5054
5055@tf_export(v1=["batch_gather"])
5056@dispatch.add_dispatch_support
5057@deprecation.deprecated(
5058    "2017-10-25", "`tf.batch_gather` is deprecated, please use `tf.gather` "
5059    "with `batch_dims=-1` instead.")  # pylint: disable=missing-docstring
5060def batch_gather(params, indices, name=None):
5061  """Gather slices from params according to indices with leading batch dims."""
5062  with ops.name_scope(name, "BatchGather", [params, indices]):
5063    indices = ops.convert_to_tensor(indices, name="indices")
5064    params = ops.convert_to_tensor(params, name="params")
5065    if indices.shape.ndims is None:
5066      raise ValueError(
5067          "batch_gather does not allow indices with unknown shape.")
5068    return _batch_gather(params, indices, batch_dims=indices.shape.ndims - 1)
5069
5070
5071def _batch_gather(params, indices, batch_dims, axis=None):
5072  r"""Gather slices from params according to indices with leading batch dims.
5073
5074  This operation assumes that the leading `batch_dims` dimensions of `indices`
5075  and `params` are batch dimensions; and performs a `tf.gather` operation within
5076  each batch. (If `batch_dims` is not specified, then it defaults to
5077  `rank(indices)-1`.)  In the case in which `batch_dims==0`, this operation
5078  is equivalent to `tf.gather`.
5079
5080  Args:
5081    params: A Tensor. The tensor from which to gather values.
5082    indices: A Tensor. Must be one of the following types: int32, int64. Index
5083      tensor. Must be in range `[0, params.shape[batch_dims]]`.
5084    batch_dims: An integer or none.  The number of batch dimensions.  Must be
5085      less than `rank(indices)`.  Defaults to `rank(indices) - 1` if None.
5086    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5087      `axis` in `params` to gather `indices` from. Must be greater than or equal
5088      to `batch_dims`.  Defaults to the first non-batch dimension. Supports
5089      negative indexes.
5090
5091  Returns:
5092    A Tensor. Has the same type as `params`.
5093
5094  Raises:
5095    ValueError: if `indices` has an unknown shape.
5096  """
5097  if batch_dims is not None and not isinstance(batch_dims, int):
5098    raise TypeError("batch_dims must be an int; got %r" % (batch_dims,))
5099  indices = ops.convert_to_tensor(indices, name="indices")
5100  params = ops.convert_to_tensor(params, name="params")
5101
5102  indices_ndims = indices.shape.ndims
5103  if indices_ndims is None:
5104    raise ValueError("tf.gather does not allow indices with unknown "
5105                     "rank when batch_dims is specified.")
5106  if batch_dims is None:
5107    batch_dims = indices_ndims - 1
5108  if batch_dims < 0:
5109    batch_dims += indices_ndims
5110  if batch_dims < 0 or batch_dims >= indices_ndims:
5111    raise ValueError("batch_dims = %d must be less than rank(indices) = %d" %
5112                     (batch_dims, indices_ndims))
5113  if params.shape.ndims is not None and batch_dims >= params.shape.ndims:
5114    raise ValueError("batch_dims = %d must be less than rank(params) = %d" %
5115                     (batch_dims, params.shape.ndims))
5116
5117  # Handle axis by transposing the axis dimension to be the first non-batch
5118  # dimension, recursively calling batch_gather with axis=0, and then
5119  # transposing the result to put the pre-axis dimensions before the indices
5120  # dimensions.
5121  if axis is not None and axis != batch_dims:
5122    # Adjust axis to be positive.
5123    if not isinstance(axis, int):
5124      axis = tf.where(axis < 0, axis + array_ops.rank(params), axis)
5125    elif axis < 0 and params.shape.ndims is None:
5126      axis = axis + array_ops.rank(params)
5127    else:
5128      if (axis < -params.shape.ndims) or (axis >= params.shape.ndims):
5129        raise ValueError("axis (%d) out of range [%d, %d)" %
5130                         (axis, -params.shape.ndims, params.shape.ndims))
5131      if axis < 0:
5132        axis += params.shape.ndims
5133      if axis < batch_dims:
5134        raise ValueError("batch_dims = %d must be less than or equal to "
5135                         "axis = %d" % (batch_dims, axis))
5136
5137    # Move params[axis] up to params[batch_dims].
5138    perm = [
5139        list(range(batch_dims)), [axis],
5140        gen_math_ops._range(batch_dims, axis, 1),
5141        gen_math_ops._range(axis + 1, rank(params), 1)
5142    ]
5143    params = transpose(params, concat(perm, axis=0))
5144
5145    result = _batch_gather(params, indices, batch_dims=batch_dims)
5146
5147    # Move the result dimensions corresponding to params[batch_dims:axis]
5148    # to just before the dimensions corresponding to indices[batch_dims:].
5149    params_start = indices_ndims + axis - batch_dims
5150    perm = [
5151        list(range(batch_dims)),
5152        gen_math_ops._range(indices_ndims, params_start, 1),
5153        list(range(batch_dims, indices_ndims)),
5154        gen_math_ops._range(params_start, rank(result), 1)
5155    ]
5156    return transpose(result, perm=concat(perm, axis=0))
5157
5158  indices_shape = shape(indices)
5159  params_shape = shape(params)
5160  batch_indices = indices
5161  indices_dtype = indices.dtype.base_dtype
5162  accum_dim_value = ones((), dtype=indices_dtype)
5163  # Use correct type for offset index computation
5164  casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
5165  for dim in range(batch_dims, 0, -1):
5166    dim_value = casted_params_shape[dim - 1]
5167    accum_dim_value *= casted_params_shape[dim]
5168    start = zeros((), dtype=indices_dtype)
5169    step = ones((), dtype=indices_dtype)
5170    dim_indices = gen_math_ops._range(start, dim_value, step)
5171    dim_indices *= accum_dim_value
5172    dim_shape = stack(
5173        [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0)
5174    batch_indices += reshape(dim_indices, dim_shape)
5175
5176  flat_indices = reshape(batch_indices, [-1])
5177  outer_shape = params_shape[batch_dims + 1:]
5178  flat_inner_shape = gen_math_ops.prod(params_shape[:batch_dims + 1], [0],
5179                                       False)
5180
5181  flat_params = reshape(params, concat([[flat_inner_shape], outer_shape],
5182                                       axis=0))
5183  flat_result = gather(flat_params, flat_indices)
5184  result = reshape(flat_result, concat([indices_shape, outer_shape], axis=0))
5185  final_shape = indices.get_shape()[:batch_dims].merge_with(
5186      params.get_shape()[:batch_dims])
5187  final_shape = final_shape.concatenate(indices.get_shape().dims[batch_dims:])
5188  final_shape = final_shape.concatenate(params.get_shape()[batch_dims + 1:])
5189  result.set_shape(final_shape)
5190  return result
5191
5192
5193@tf_export(v1=["gather_nd", "manip.gather_nd"])
5194@dispatch.add_dispatch_support
5195@deprecated_endpoints("manip.gather_nd")
5196def gather_nd(params, indices, name=None, batch_dims=0):
5197  r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
5198
5199  `indices` is an K-dimensional integer tensor, best thought of as a
5200  (K-1)-dimensional tensor of indices into `params`, where each element defines
5201  a slice of `params`:
5202
5203      output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
5204
5205  Whereas in `tf.gather` `indices` defines slices into the first
5206  dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
5207  first `N` dimensions of `params`, where `N = indices.shape[-1]`.
5208
5209  The last dimension of `indices` can be at most the rank of
5210  `params`:
5211
5212      indices.shape[-1] <= params.rank
5213
5214  The last dimension of `indices` corresponds to elements
5215  (if `indices.shape[-1] == params.rank`) or slices
5216  (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]`
5217  of `params`.  The output tensor has shape
5218
5219      indices.shape[:-1] + params.shape[indices.shape[-1]:]
5220
5221  Additionally both 'params' and 'indices' can have M leading batch
5222  dimensions that exactly match. In this case 'batch_dims' must be M.
5223
5224  Note that on CPU, if an out of bound index is found, an error is returned.
5225  On GPU, if an out of bound index is found, a 0 is stored in the
5226  corresponding output value.
5227
5228  Some examples below.
5229
5230  Simple indexing into a matrix:
5231
5232  ```python
5233      indices = [[0, 0], [1, 1]]
5234      params = [['a', 'b'], ['c', 'd']]
5235      output = ['a', 'd']
5236  ```
5237
5238  Slice indexing into a matrix:
5239
5240  ```python
5241      indices = [[1], [0]]
5242      params = [['a', 'b'], ['c', 'd']]
5243      output = [['c', 'd'], ['a', 'b']]
5244  ```
5245
5246  Indexing into a 3-tensor:
5247
5248  ```python
5249      indices = [[1]]
5250      params = [[['a0', 'b0'], ['c0', 'd0']],
5251                [['a1', 'b1'], ['c1', 'd1']]]
5252      output = [[['a1', 'b1'], ['c1', 'd1']]]
5253
5254
5255      indices = [[0, 1], [1, 0]]
5256      params = [[['a0', 'b0'], ['c0', 'd0']],
5257                [['a1', 'b1'], ['c1', 'd1']]]
5258      output = [['c0', 'd0'], ['a1', 'b1']]
5259
5260
5261      indices = [[0, 0, 1], [1, 0, 1]]
5262      params = [[['a0', 'b0'], ['c0', 'd0']],
5263                [['a1', 'b1'], ['c1', 'd1']]]
5264      output = ['b0', 'b1']
5265  ```
5266
5267  The examples below are for the case when only indices have leading extra
5268  dimensions. If both 'params' and 'indices' have leading batch dimensions, use
5269  the 'batch_dims' parameter to run gather_nd in batch mode.
5270
5271  Batched indexing into a matrix:
5272
5273  ```python
5274      indices = [[[0, 0]], [[0, 1]]]
5275      params = [['a', 'b'], ['c', 'd']]
5276      output = [['a'], ['b']]
5277  ```
5278
5279  Batched slice indexing into a matrix:
5280
5281  ```python
5282      indices = [[[1]], [[0]]]
5283      params = [['a', 'b'], ['c', 'd']]
5284      output = [[['c', 'd']], [['a', 'b']]]
5285  ```
5286
5287  Batched indexing into a 3-tensor:
5288
5289  ```python
5290      indices = [[[1]], [[0]]]
5291      params = [[['a0', 'b0'], ['c0', 'd0']],
5292                [['a1', 'b1'], ['c1', 'd1']]]
5293      output = [[[['a1', 'b1'], ['c1', 'd1']]],
5294                [[['a0', 'b0'], ['c0', 'd0']]]]
5295
5296      indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
5297      params = [[['a0', 'b0'], ['c0', 'd0']],
5298                [['a1', 'b1'], ['c1', 'd1']]]
5299      output = [[['c0', 'd0'], ['a1', 'b1']],
5300                [['a0', 'b0'], ['c1', 'd1']]]
5301
5302
5303      indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
5304      params = [[['a0', 'b0'], ['c0', 'd0']],
5305                [['a1', 'b1'], ['c1', 'd1']]]
5306      output = [['b0', 'b1'], ['d0', 'c1']]
5307  ```
5308
5309  Examples with batched 'params' and 'indices':
5310
5311  ```python
5312      batch_dims = 1
5313      indices = [[1], [0]]
5314      params = [[['a0', 'b0'], ['c0', 'd0']],
5315                [['a1', 'b1'], ['c1', 'd1']]]
5316      output = [['c0', 'd0'], ['a1', 'b1']]
5317
5318      batch_dims = 1
5319      indices = [[[1]], [[0]]]
5320      params = [[['a0', 'b0'], ['c0', 'd0']],
5321                [['a1', 'b1'], ['c1', 'd1']]]
5322      output = [[['c0', 'd0']], [['a1', 'b1']]]
5323
5324      batch_dims = 1
5325      indices = [[[1, 0]], [[0, 1]]]
5326      params = [[['a0', 'b0'], ['c0', 'd0']],
5327                [['a1', 'b1'], ['c1', 'd1']]]
5328      output = [['c0'], ['b1']]
5329  ```
5330
5331  See also `tf.gather`.
5332
5333  Args:
5334    params: A `Tensor`. The tensor from which to gather values.
5335    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
5336      Index tensor.
5337    name: A name for the operation (optional).
5338    batch_dims: An integer or a scalar 'Tensor'. The number of batch dimensions.
5339
5340  Returns:
5341    A `Tensor`. Has the same type as `params`.
5342  """
5343  batch_dims_ = tensor_util.constant_value(batch_dims)
5344  if batch_dims_ is not None:
5345    batch_dims = int(batch_dims_)
5346  if batch_dims == 0:
5347    try:
5348      # TODO(apassos) find a less bad way of detecting resource variables
5349      # without introducing a circular dependency.
5350      return params.gather_nd(indices, name=name)
5351    except AttributeError:
5352      return gen_array_ops.gather_nd(params, indices, name=name)
5353  else:
5354    return batch_gather_nd(params, indices, batch_dims=batch_dims, name=name)
5355
5356
5357@tf_export("gather_nd", v1=[])
5358@dispatch.add_dispatch_support
5359def gather_nd_v2(params, indices, batch_dims=0, name=None):
5360  return gather_nd(params, indices, name=name, batch_dims=batch_dims)
5361
5362
5363gather_nd_v2.__doc__ = gather_nd.__doc__
5364
5365
5366def batch_gather_nd(params, indices, batch_dims, name=None):
5367  """gather_nd implementation with batch support."""
5368  with ops.name_scope(name, "BatchGatherND", [params, indices]):
5369    indices = ops.convert_to_tensor(indices, name="indices")
5370    params = ops.convert_to_tensor(params, name="params")
5371
5372    if not isinstance(batch_dims, int):
5373      raise TypeError("batch_dims must be an int; got %r" % (batch_dims,))
5374    if batch_dims < 0:
5375      raise ValueError("tf.gather_nd does not allow negative batch_dims.")
5376    params_ndims = params.shape.ndims
5377    indices_ndims = indices.shape.ndims
5378    if indices_ndims is not None and batch_dims >= indices_ndims:
5379      raise ValueError("batch_dims = %d must be less than rank(indices) = %d" %
5380                       (batch_dims, indices_ndims))
5381    if params_ndims is not None and batch_dims >= params_ndims:
5382      raise ValueError("batch_dims = %d must be less than rank(params) = %d" %
5383                       (batch_dims, params_ndims))
5384
5385    expand = batch_dims == 0
5386    if expand:
5387      # Normally gather_nd will be called when batch_dims == 0.
5388      # But if this function is called with batch_dims = 0, e.g. for testing
5389      # purposes, this adds a dummy batch dimension to make batch_dims = 1.
5390      params = expand_dims(params, axis=0)
5391      indices = expand_dims(indices, axis=0)
5392      batch_dims = 1
5393
5394    params_shape = shape(params)
5395    indices_shape = shape(indices)
5396    batch_shape = params_shape[:batch_dims]
5397    batch_size = gen_math_ops.prod(batch_shape, [0])
5398    index_internal_ndims = rank(indices) - batch_dims - 1
5399    indices_internal_shape = indices_shape[batch_dims:-1]
5400
5401    # Assuming a 'params' with shape [b1, ..., bM, g1, ..., gN] and an 'indices'
5402    # with shape [b1, ..., bM, i1, ..., iK, C], where C <= N, we need to modify
5403    # 'indices' s.t. it has shape [i1, ..., iK, D], where D <= M + N and slices
5404    # to the entire 'params' tensor.
5405    # Assuming we have a batch of shape [B1, B2], we use meshgrid to create a
5406    # grid of size B1 x B2.
5407    batch_dim_list = unstack(batch_shape, axis=0)
5408    dim_ranges = [
5409        gen_math_ops.cast(gen_math_ops._range(0, x, 1), indices.dtype)
5410        for x in batch_dim_list
5411    ]
5412    mesh_list = meshgrid(*dim_ranges, indexing="ij") if dim_ranges else []
5413    # Then we flatten and stack the tensors to form a (B1.B2) by 2 matrix.
5414    flat_list = [reshape(x, shape=(-1,)) for x in mesh_list]
5415    index_grid = transpose(stack(flat_list, axis=0))
5416    # We need to concatenate these batch coordinates with the internal indices.
5417    # concat -> index_grid [B1.B2, 2] with indices [i1, ..., iK, C]
5418    # So we reshape them both to [(B1.B2), i1, ..., iK, *]
5419    index_grid_shape = shape(index_grid)
5420    index_grid = reshape(
5421        index_grid,
5422        concat([
5423            index_grid_shape[:1],
5424            ones(index_internal_ndims, dtype=dtypes.int32), index_grid_shape[1:]
5425        ],
5426               axis=0))
5427    tile_shape = concat(((1,), indices_internal_shape, (1,)), axis=0)
5428    index_grid = tile(index_grid, multiples=tile_shape)
5429    # index_grid now has shape [(B1.B2), i1, ..., iK, 2]
5430    flat_shape = concat(([batch_size], indices_shape[batch_dims:]), axis=0)
5431    flat_indices = reshape(indices, shape=flat_shape)
5432    # flat_indices now has shape [(B1.B2), i1, ..., iK, C]
5433    indices = concat((index_grid, flat_indices), axis=-1)
5434    # indices has shape [(B1.B2), i1, ..., iK, 2+C]
5435    out = gen_array_ops.gather_nd(params, indices)
5436    # out has shape [(B1.B2), i1, ..., iK, N-C]. Now we reshape batch to
5437    # its original form.
5438    out_shape = shape(out)
5439    out = reshape(out, shape=concat((batch_shape, out_shape[1:]), axis=0))
5440    if expand:
5441      out = squeeze(out, axis=0)
5442  return out
5443
5444
5445@deprecation.deprecated_endpoints("tensor_scatter_update")
5446@tf_export(
5447    "tensor_scatter_nd_update",
5448    v1=["tensor_scatter_nd_update", "tensor_scatter_update"])
5449@dispatch.add_dispatch_support
5450def tensor_scatter_nd_update(tensor, indices, updates, name=None):
5451  """"Scatter `updates` into an existing tensor according to `indices`.
5452
5453  This operation creates a new tensor by applying sparse `updates` to the
5454  input `tensor`. This is similar to an index assignment.
5455
5456  ```
5457  # Not implemented: tensors cannot be updated inplace.
5458  tensor[indices] = updates
5459  ```
5460
5461  If an out of bound index is found on CPU, an error is returned.
5462
5463  > **WARNING**: There are some GPU specific semantics for this operation.
5464  >
5465  > - If an out of bound index is found, the index is ignored.
5466  > - The order in which updates are applied is nondeterministic, so the output
5467  >   will be nondeterministic if `indices` contains duplicates.
5468
5469  This operation is very similar to `tf.scatter_nd`, except that the updates are
5470  scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
5471  for the existing tensor cannot be re-used, a copy is made and updated.
5472
5473  In general:
5474
5475  * `indices` is an integer tensor - the indices to update in `tensor`.
5476  * `indices` has **at least two** axes, the last axis is the depth of the
5477    index vectors.
5478  * For each index vector in `indices` there is a corresponding entry in
5479    `updates`.
5480  * If the length of the index vectors matches the rank of the `tensor`, then
5481    the index vectors each point to scalars in `tensor` and each update is a
5482    scalar.
5483  * If the length of the index vectors is less than the rank of `tensor`, then
5484    the index vectors each point to slices of `tensor` and shape of the updates
5485    must match that slice.
5486
5487  Overall this leads to the following shape constraints:
5488
5489  ```
5490  assert tf.rank(indices) >= 2
5491  index_depth = indices.shape[-1]
5492  batch_shape = indices.shape[:-1]
5493  assert index_depth <= tf.rank(tensor)
5494  outer_shape = tensor.shape[:index_depth]
5495  inner_shape = tensor.shape[index_depth:]
5496  assert updates.shape == batch_shape + inner_shape
5497  ```
5498
5499  Typical usage is often much simpler than this general form, and it
5500  can be better understood starting with simple examples:
5501
5502  ### Scalar updates
5503
5504  The simplest usage inserts scalar elements into a tensor by index.
5505  In this case, the `index_depth` must equal the rank of the
5506  input `tensor`, slice each column of `indices` is an index into an axis of the
5507  input `tensor`.
5508
5509  In this simplest case the shape constraints are:
5510
5511  ```
5512  num_updates, index_depth = indices.shape.as_list()
5513  assert updates.shape == [num_updates]
5514  assert index_depth == tf.rank(tensor)`
5515  ```
5516
5517  For example, to insert 4 scattered elements in a rank-1 tensor with
5518  8 elements.
5519
5520  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
5521  <img style="width:100%"
5522    src="https://www.tensorflow.org/images/ScatterNd1.png">
5523  </div>
5524
5525  This scatter operation would look like this:
5526
5527  >>> tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
5528  >>> indices = [[1], [3], [4], [7]]       # num_updates == 4, index_depth == 1
5529  >>> updates = [9, 10, 11, 12]            # num_updates == 4
5530  >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5531  tf.Tensor([ 0 9  0 10  11  0  0 12], shape=(8,), dtype=int32)
5532
5533  The length (first axis) of `updates` must equal the length of the `indices`:
5534  `num_updates`. This is the number of updates being inserted. Each scalar
5535  update is inserted into `tensor` at the indexed location.
5536
5537  For a higher rank input `tensor` scalar updates can be inserted by using an
5538  `index_depth` that matches `tf.rank(tensor)`:
5539
5540  >>> tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
5541  >>> indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
5542  >>> updates = [5, 10]                    # num_updates == 2
5543  >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5544  tf.Tensor(
5545      [[ 1  5]
5546       [ 1  1]
5547       [10  1]], shape=(3, 2), dtype=int32)
5548
5549  ### Slice updates
5550
5551  When the input `tensor` has more than one axis scatter can be used to update
5552  entire slices.
5553
5554  In this case it's helpful to think of the input `tensor` as being a two level
5555  array-of-arrays. The shape of this two level array is split into the
5556  `outer_shape` and the `inner_shape`.
5557
5558  `indices` indexes into the outer level of the input tensor (`outer_shape`).
5559  and replaces the sub-array at that location with the corresponding item from
5560  the `updates` list. The shape of each update is `inner_shape`.
5561
5562  When updating a list of slices the shape constraints are:
5563
5564  ```
5565  num_updates, index_depth = indices.shape.as_list()
5566  inner_shape = tensor.shape[:index_depth]
5567  outer_shape = tensor.shape[index_depth:]
5568  assert updates.shape == [num_updates, inner_shape]
5569  ```
5570
5571  For example, to update rows of a `(6, 3)` `tensor`:
5572
5573  >>> tensor = tf.zeros([6, 3], dtype=tf.int32)
5574
5575  Use an index depth of one.
5576
5577  >>> indices = tf.constant([[2], [4]])     # num_updates == 2, index_depth == 1
5578  >>> num_updates, index_depth = indices.shape.as_list()
5579
5580  The `outer_shape` is `6`, the inner shape is `3`:
5581
5582  >>> outer_shape = tensor.shape[:index_depth]
5583  >>> inner_shape = tensor.shape[index_depth:]
5584
5585  2 rows are being indexed so 2 `updates` must be supplied.
5586  Each update must be shaped to match the `inner_shape`.
5587
5588  >>> # num_updates == 2, inner_shape==3
5589  >>> updates = tf.constant([[1, 2, 3],
5590  ...                        [4, 5, 6]])
5591
5592  Altogether this gives:
5593
5594  >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
5595  array([[0, 0, 0],
5596         [0, 0, 0],
5597         [1, 2, 3],
5598         [0, 0, 0],
5599         [4, 5, 6],
5600         [0, 0, 0]], dtype=int32)
5601
5602  #### More slice update examples
5603
5604  A tensor representing a batch of uniformly sized video clips naturally has 5
5605  axes: `[batch_size, time, width, height, channels]`.
5606
5607  For example:
5608
5609  >>> batch_size, time, width, height, channels = 13,11,7,5,3
5610  >>> video_batch = tf.zeros([batch_size, time, width, height, channels])
5611
5612  To replace a selection of video clips:
5613    * Use an `index_depth` of 1 (indexing the `outer_shape`: `[batch_size]`)
5614    * Provide updates each with a shape matching the `inner_shape`:
5615      `[time, width, height, channels]`.
5616
5617  To replace the first two clips with ones:
5618
5619  >>> indices = [[0],[1]]
5620  >>> new_clips = tf.ones([2, time, width, height, channels])
5621  >>> tf.tensor_scatter_nd_update(video_batch, indices, new_clips)
5622
5623  To replace a selection of frames in the videos:
5624
5625  * `indices` must have an `index_depth` of 2 for the `outer_shape`:
5626    `[batch_size, time]`.
5627  * `updates` must be shaped like a list of images.  Each update must have a
5628    shape, matching the `inner_shape`: `[width, height, channels]`.
5629
5630  To replace the first frame of the first three video clips:
5631
5632  >>> indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2
5633  >>> new_images = tf.ones([
5634  ...   # num_updates=3, inner_shape=(width, height, channels)
5635  ...   3, width, height, channels])
5636  >>> tf.tensor_scatter_nd_update(video_batch, indices, new_images)
5637
5638  ### Folded indices
5639
5640  In simple cases it's convenient to think of `indices` and `updates` as
5641  lists, but this is not a strict requirement. Instead of a flat `num_updates`,
5642  the `indices` and `updates` can be folded into a `batch_shape`. This
5643  `batch_shape` is all axes of the `indices`, except for the innermost
5644  `index_depth` axis.
5645
5646  ```
5647  index_depth = indices.shape[-1]
5648  batch_shape = indices.shape[:-1]
5649  ```
5650
5651  Note: The one exception is that the `batch_shape` cannot be `[]`. You can't
5652  update a single index by passing indices with shape `[index_depth]`.
5653
5654  `updates` must have a matching `batch_shape` (the axes before `inner_shape`).
5655
5656  ```
5657  assert updates.shape == batch_shape + inner_shape
5658  ```
5659
5660  Note: The result is equivalent to flattening the `batch_shape` axes of
5661  `indices` and `updates`. This generalization just avoids the need
5662  for reshapes when it is more natural to construct "folded" indices and
5663  updates.
5664
5665  With this generalization the full shape constraints are:
5666
5667  ```
5668  assert tf.rank(indices) >= 2
5669  index_depth = indices.shape[-1]
5670  batch_shape = indices.shape[:-1]
5671  assert index_depth <= tf.rank(tensor)
5672  outer_shape = tensor.shape[:index_depth]
5673  inner_shape = tensor.shape[index_depth:]
5674  assert updates.shape == batch_shape + inner_shape
5675  ```
5676
5677  For example, to draw an `X` on a `(5,5)` matrix start with these indices:
5678
5679  >>> tensor = tf.zeros([5,5])
5680  >>> indices = tf.constant([
5681  ...  [[0,0],
5682  ...   [1,1],
5683  ...   [2,2],
5684  ...   [3,3],
5685  ...   [4,4]],
5686  ...  [[0,4],
5687  ...   [1,3],
5688  ...   [2,2],
5689  ...   [3,1],
5690  ...   [4,0]],
5691  ... ])
5692  >>> indices.shape.as_list()  # batch_shape == [2, 5], index_depth == 2
5693  [2, 5, 2]
5694
5695  Here the `indices` do not have a shape of `[num_updates, index_depth]`, but a
5696  shape of `batch_shape+[index_depth]`.
5697
5698  Since the `index_depth` is equal to the rank of `tensor`:
5699
5700  * `outer_shape` is `(5,5)`
5701  * `inner_shape` is `()` - each update is scalar
5702  * `updates.shape` is `batch_shape + inner_shape == (5,2) + ()`
5703
5704  >>> updates = [
5705  ...   [1,1,1,1,1],
5706  ...   [1,1,1,1,1],
5707  ... ]
5708
5709  Putting this together gives:
5710
5711  >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
5712  array([[1., 0., 0., 0., 1.],
5713         [0., 1., 0., 1., 0.],
5714         [0., 0., 1., 0., 0.],
5715         [0., 1., 0., 1., 0.],
5716         [1., 0., 0., 0., 1.]], dtype=float32)
5717
5718  Args:
5719    tensor: Tensor to copy/update.
5720    indices: Indices to update.
5721    updates: Updates to apply at the indices.
5722    name: Optional name for the operation.
5723
5724  Returns:
5725    A new tensor with the given shape and updates applied according to the
5726    indices.
5727  """
5728  return gen_array_ops.tensor_scatter_update(
5729      tensor=tensor, indices=indices, updates=updates, name=name)
5730
5731
5732# Define quantize_v2 here in order to make name the second-to-last attribute,
5733# because round_mode was added later.
5734# (And also now because of 'axis' processing).
5735@tf_export(v1=["quantize_v2"])
5736@dispatch.add_dispatch_support
5737@deprecation.deprecated(
5738    "2017-10-25",
5739    "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
5740    "instead.")  # pylint: disable=missing-docstring
5741def quantize_v2(
5742    input,  # pylint: disable=redefined-builtin
5743    min_range,
5744    max_range,
5745    T,
5746    mode="MIN_COMBINED",
5747    name=None,
5748    round_mode="HALF_AWAY_FROM_ZERO",
5749    narrow_range=False,
5750    axis=None,
5751    ensure_minimum_range=0.01):
5752  if axis is None:
5753    axis = -1
5754  elif axis < 0:
5755    if input.shape.ndims is None:
5756      raise ValueError("input should have known rank to use negative axis.")
5757    axis %= input.shape.ndims
5758
5759  if ensure_minimum_range != 0.01:
5760    return gen_array_ops.quantize_v2(
5761        input,
5762        min_range,
5763        max_range,
5764        T=T,
5765        mode=mode,
5766        name=name,
5767        round_mode=round_mode,
5768        narrow_range=narrow_range,
5769        axis=axis,
5770        ensure_minimum_range=ensure_minimum_range)
5771  return gen_array_ops.quantize_v2(
5772      input,
5773      min_range,
5774      max_range,
5775      T=T,
5776      mode=mode,
5777      name=name,
5778      round_mode=round_mode,
5779      narrow_range=narrow_range,
5780      axis=axis)
5781
5782
5783quantize_v2.__doc__ = """Please use `tf.quantization.quantize` instead."""
5784
5785
5786# We want to expose tf.quantization.quantize instead of
5787# tf.quantization.quantize; we can deprecate tf.quantization.quantize in next
5788# version of TensorFlow.
5789@tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"])
5790@dispatch.add_dispatch_support
5791@deprecation.deprecated_endpoints("quantize")
5792def quantize(
5793    input,  # pylint: disable=redefined-builtin
5794    min_range,
5795    max_range,
5796    T,
5797    mode="MIN_COMBINED",
5798    round_mode="HALF_AWAY_FROM_ZERO",
5799    name=None,
5800    narrow_range=False,
5801    axis=None,
5802    ensure_minimum_range=0.01):
5803  """Quantize the input tensor."""
5804  if ensure_minimum_range != 0.01:
5805    return quantize_v2(
5806        input,
5807        min_range,
5808        max_range,
5809        T,
5810        mode=mode,
5811        round_mode=round_mode,
5812        name=name,
5813        narrow_range=narrow_range,
5814        axis=axis,
5815        ensure_minimum_range=ensure_minimum_range)
5816  return quantize_v2(
5817      input,
5818      min_range,
5819      max_range,
5820      T,
5821      mode=mode,
5822      round_mode=round_mode,
5823      name=name,
5824      narrow_range=narrow_range,
5825      axis=axis)
5826
5827
5828@tf_export("quantization.dequantize", v1=["quantization.dequantize",
5829                                          "dequantize"])
5830@dispatch.add_dispatch_support
5831@deprecation.deprecated_endpoints("dequantize")
5832def dequantize(  # pylint: disable=missing-docstring
5833    input,  # pylint: disable=redefined-builtin
5834    min_range,
5835    max_range,
5836    mode="MIN_COMBINED",
5837    name=None,
5838    axis=None,
5839    narrow_range=False,
5840    dtype=dtypes.float32):
5841  if axis is None:
5842    axis = -1
5843  elif axis < 0:
5844    if input.shape.ndims is None:
5845      raise ValueError("input should have known rank to use negative axis.")
5846    axis %= input.shape.ndims
5847
5848  if axis >= 0 or narrow_range:
5849    return gen_array_ops.dequantize(
5850        input,
5851        min_range,
5852        max_range,
5853        mode=mode,
5854        name=name,
5855        narrow_range=narrow_range,
5856        axis=axis,
5857        dtype=dtype)
5858  return gen_array_ops.dequantize(
5859      input, min_range, max_range, mode=mode, name=name, dtype=dtype)
5860
5861
5862dequantize.__doc__ = gen_array_ops.dequantize.__doc__
5863
5864
5865@tf_export("quantization.quantize_and_dequantize")
5866@dispatch.add_dispatch_support
5867@deprecation.deprecated(None,
5868                        "This Op has been deprecated, use" +
5869                        "`quantize_and_dequantize_v2` instead. To " +
5870                        "To simulate the V1 the behavior of " +
5871                        "tf.quantization.quantize_and_dequantize(...) use " +
5872                        "tf.grad_pass_through(" +
5873                        "tf.quantization.quantize_and_dequantize_v2)(...).")
5874def quantize_and_dequantize(
5875    input,  # pylint: disable=redefined-builtin
5876    input_min,
5877    input_max,
5878    signed_input=True,
5879    num_bits=8,
5880    range_given=False,
5881    round_mode="HALF_TO_EVEN",
5882    name=None,
5883    narrow_range=False,
5884    axis=None):
5885  """Quantizes then dequantizes a tensor.
5886
5887  Args:
5888    input: A `Tensor` to quantize and dequantize.
5889    input_min: If range_given=True, the minimum input value, that needs to be
5890      represented in the quantized representation. If axis is specified, this
5891      should be a vector of minimum values for each slice along axis.
5892    input_max: If range_given=True, the maximum input value that needs to be
5893      represented in the quantized representation. If axis is specified, this
5894      should be a vector of maximum values for each slice along axis.
5895    signed_input: True if the quantization is signed or unsigned.
5896    num_bits: The bitwidth of the quantization.
5897    range_given: If true use `input_min` and `input_max` for the range of the
5898      input, otherwise determine min and max from the input `Tensor`.
5899    round_mode: Rounding mode when rounding from float values to quantized ones.
5900      one of ['HALF_TO_EVEN', 'HALF_UP']
5901    name: Optional name for the operation.
5902    narrow_range: If true, then the absolute value of the quantized minimum
5903      value is the same as the quantized maximum value, instead of 1 greater.
5904      i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
5905    axis: Integer. If specified, refers to a dimension of the input tensor, such
5906      that quantization will be per slice along that dimension.
5907
5908  Returns:
5909    A `Tensor`. Each element is the result of quantizing and dequantizing the
5910    corresponding element of `input`.
5911  """
5912  if axis is None:
5913    axis = -1
5914  elif axis < 0:
5915    if input.shape.ndims is None:
5916      raise ValueError("input should have known rank to use negative axis.")
5917    axis %= input.shape.ndims
5918
5919  return gen_array_ops.quantize_and_dequantize_v2(
5920      input,
5921      input_min=input_min,
5922      input_max=input_max,
5923      signed_input=signed_input,
5924      num_bits=num_bits,
5925      range_given=range_given,
5926      round_mode=round_mode,
5927      narrow_range=narrow_range,
5928      axis=axis,
5929      name=name)
5930
5931
5932@tf_export("quantization.quantize_and_dequantize_v2")
5933@dispatch.add_dispatch_support
5934def quantize_and_dequantize_v2(
5935    input,  # pylint: disable=redefined-builtin
5936    input_min,
5937    input_max,
5938    signed_input=True,
5939    num_bits=8,
5940    range_given=False,
5941    round_mode="HALF_TO_EVEN",
5942    name=None,
5943    narrow_range=False,
5944    axis=None):
5945  """Quantizes then dequantizes a tensor.
5946
5947  Updates the gradient definition for quantization that is outside the range to
5948  be 0.To simulate the V1 the behavior of
5949  tf.quantization.quantize_and_dequantize(...) use
5950  tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
5951
5952  Example usage:
5953
5954  ```python
5955  def getQuantizeOp(input):
5956      input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
5957      net = tf.quantization.quantize_and_dequantize(input,
5958                                                    input_min=min_threshold,
5959                                                    input_max=max_threshold,
5960                                                    range_given=True)
5961
5962  To simulate v1 behavior:
5963
5964  def testDecomposeQuantizeDequantize(self):
5965      def f(input_tensor):
5966        return tf.quantization.quantize_and_dequantize_v2(input_tensor,
5967                                                          input_min = 5.0,
5968                                                          input_max= -10.0,
5969                                                          range_given=True)
5970      input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
5971      net = tf.grad_pass_through(f)(input_tensor)
5972  ```
5973
5974  Args:
5975    input: A `Tensor` to quantize and dequantize.
5976    input_min: If range_given=True, the minimum input value, that needs to be
5977      represented in the quantized representation. If axis is specified, this
5978      should be a vector of minimum values for each slice along axis.
5979    input_max: If range_given=True, the maximum input value that needs to be
5980      represented in the quantized representation. If axis is specified, this
5981      should be a vector of maximum values for each slice along axis.
5982    signed_input: True if the quantization is signed or unsigned.
5983    num_bits: The bitwidth of the quantization.
5984    range_given: If true use `input_min` and `input_max` for the range of the
5985      input, otherwise determine min and max from the input `Tensor`.
5986    round_mode: Rounding mode when rounding from float values to quantized ones.
5987      one of ['HALF_TO_EVEN', 'HALF_UP']
5988    name: Optional name for the operation.
5989    narrow_range: If true, then the absolute value of the quantized minimum
5990      value is the same as the quantized maximum value, instead of 1 greater.
5991      i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
5992    axis: Integer. If specified, refers to a dimension of the input tensor, such
5993      that quantization will be per slice along that dimension.
5994
5995  Returns:
5996    A `Tensor`. Each element is the result of quantizing and dequantizing the
5997    corresponding element of `input`.
5998  """
5999  if axis is None:
6000    axis = -1
6001  elif axis < 0:
6002    if input.shape.ndims is None:
6003      raise ValueError("input should have known rank to use negative axis.")
6004    axis %= input.shape.ndims
6005
6006  return gen_array_ops.quantize_and_dequantize_v4(
6007      input,
6008      input_min=input_min,
6009      input_max=input_max,
6010      signed_input=signed_input,
6011      num_bits=num_bits,
6012      range_given=range_given,
6013      round_mode=round_mode,
6014      narrow_range=narrow_range,
6015      axis=axis,
6016      name=name)
6017
6018
6019@tf_export("searchsorted")
6020@dispatch.add_dispatch_support
6021def searchsorted(sorted_sequence,
6022                 values,
6023                 side="left",
6024                 out_type=dtypes.int32,
6025                 name=None):
6026  """Searches for where a value would go in a sorted sequence.
6027
6028  This is not a method for checking containment (like python `in`).
6029
6030  The typical use case for this operation is "binning", "bucketing", or
6031  "discretizing". The `values` are assigned to bucket-indices based on the
6032  **edges** listed in `sorted_sequence`. This operation
6033  returns the bucket-index for each value.
6034
6035  >>> edges = [-1, 3.3, 9.1, 10.0]
6036  >>> values = [0.0, 4.1, 12.0]
6037  >>> tf.searchsorted(edges, values).numpy()
6038  array([1, 2, 4], dtype=int32)
6039
6040  The `side` argument controls which index is returned if a value lands exactly
6041  on an edge:
6042
6043  >>> seq = [0, 3, 9, 10, 10]
6044  >>> values = [0, 4, 10]
6045  >>> tf.searchsorted(seq, values).numpy()
6046  array([0, 2, 3], dtype=int32)
6047  >>> tf.searchsorted(seq, values, side="right").numpy()
6048  array([1, 2, 5], dtype=int32)
6049
6050  The `axis` is not settable for this operation. It always operates on the
6051  innermost dimension (`axis=-1`). The operation will accept any number of
6052  outer dimensions. Here it is applied to the rows of a matrix:
6053
6054  >>> sorted_sequence = [[0., 3., 8., 9., 10.],
6055  ...                    [1., 2., 3., 4., 5.]]
6056  >>> values = [[9.8, 2.1, 4.3],
6057  ...           [0.1, 6.6, 4.5, ]]
6058  >>> tf.searchsorted(sorted_sequence, values).numpy()
6059  array([[4, 1, 2],
6060         [0, 5, 4]], dtype=int32)
6061
6062  Note: This operation assumes that `sorted_sequence` **is sorted** along the
6063  innermost axis, maybe using `tf.sort(..., axis=-1)`. **If the sequence is not
6064  sorted no error is raised** and the content of the returned tensor is not well
6065  defined.
6066
6067  Args:
6068    sorted_sequence: N-D `Tensor` containing a sorted sequence.
6069    values: N-D `Tensor` containing the search values.
6070    side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
6071      upper_bound.
6072    out_type: The output type (`int32` or `int64`).  Default is `tf.int32`.
6073    name: Optional name for the operation.
6074
6075  Returns:
6076    An N-D `Tensor` the size of `values` containing the result of applying
6077    either lower_bound or upper_bound (depending on side) to each value.  The
6078    result is not a global index to the entire `Tensor`, but the index in the
6079    last dimension.
6080
6081  Raises:
6082    ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
6083                If the total size of `values` exceeds `2^31 - 1` elements.
6084                If the first `N-1` dimensions of the two tensors don't match.
6085  """
6086  sequence_size = shape_internal(sorted_sequence)[-1]
6087  values_size = shape_internal(values)[-1]
6088  sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
6089  values_2d = reshape(values, [-1, values_size])
6090  if side == "right":
6091    output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
6092                                       name)
6093  elif side == "left":
6094    output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
6095                                       name)
6096  else:
6097    raise ValueError("side must be either 'right' or 'left'.  Saw: %s." % side)
6098  return reshape(output, shape_internal(values))
6099
6100
6101quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
6102
6103
6104@tf_export("image.extract_patches")
6105@dispatch.add_dispatch_support
6106def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
6107  r"""Extract `patches` from `images`.
6108
6109  This op collects patches from the input image, as if applying a
6110  convolution. All extracted patches are stacked in the depth (last) dimension
6111  of the output.
6112
6113  Specifically, the op extracts patches of shape `sizes` which are `strides`
6114  apart in the input image. The output is subsampled using the `rates` argument,
6115  in the same manner as "atrous" or "dilated" convolutions.
6116
6117  The result is a 4D tensor which is indexed by batch, row, and column.
6118  `output[i, x, y]` contains a flattened patch of size `sizes[1], sizes[2]`
6119  which is taken from the input starting at
6120  `images[i, x*strides[1], y*strides[2]]`.
6121
6122  Each output patch can be reshaped to `sizes[1], sizes[2], depth`, where
6123  `depth` is `images.shape[3]`.
6124
6125  The output elements are taken from the input at intervals given by the `rate`
6126  argument, as in dilated convolutions.
6127
6128  The `padding` argument has no effect on the size of each patch, it determines
6129  how many patches are extracted. If `VALID`, only patches which are fully
6130  contained in the input image are included. If `SAME`, all patches whose
6131  starting point is inside the input are included, and areas outside the input
6132  default to zero.
6133
6134  Example:
6135
6136  ```
6137    n = 10
6138    # images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100
6139    images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]
6140
6141    # We generate two outputs as follows:
6142    # 1. 3x3 patches with stride length 5
6143    # 2. Same as above, but the rate is increased to 2
6144    tf.image.extract_patches(images=images,
6145                             sizes=[1, 3, 3, 1],
6146                             strides=[1, 5, 5, 1],
6147                             rates=[1, 1, 1, 1],
6148                             padding='VALID')
6149
6150    # Yields:
6151    [[[[ 1  2  3 11 12 13 21 22 23]
6152       [ 6  7  8 16 17 18 26 27 28]]
6153      [[51 52 53 61 62 63 71 72 73]
6154       [56 57 58 66 67 68 76 77 78]]]]
6155  ```
6156
6157  If we mark the pixels in the input image which are taken for the output with
6158  `*`, we see the pattern:
6159
6160  ```
6161     *  *  *  4  5  *  *  *  9 10
6162     *  *  * 14 15  *  *  * 19 20
6163     *  *  * 24 25  *  *  * 29 30
6164    31 32 33 34 35 36 37 38 39 40
6165    41 42 43 44 45 46 47 48 49 50
6166     *  *  * 54 55  *  *  * 59 60
6167     *  *  * 64 65  *  *  * 69 70
6168     *  *  * 74 75  *  *  * 79 80
6169    81 82 83 84 85 86 87 88 89 90
6170    91 92 93 94 95 96 97 98 99 100
6171  ```
6172
6173  ```
6174    tf.image.extract_patches(images=images,
6175                             sizes=[1, 3, 3, 1],
6176                             strides=[1, 5, 5, 1],
6177                             rates=[1, 2, 2, 1],
6178                             padding='VALID')
6179
6180    # Yields:
6181    [[[[  1   3   5  21  23  25  41  43  45]
6182       [  6   8  10  26  28  30  46  48  50]]
6183
6184      [[ 51  53  55  71  73  75  91  93  95]
6185       [ 56  58  60  76  78  80  96  98 100]]]]
6186  ```
6187
6188  We can again draw the effect, this time using the symbols `*`, `x`, `+` and
6189  `o` to distinguish the patches:
6190
6191  ```
6192     *  2  *  4  *  x  7  x  9  x
6193    11 12 13 14 15 16 17 18 19 20
6194     * 22  * 24  *  x 27  x 29  x
6195    31 32 33 34 35 36 37 38 39 40
6196     * 42  * 44  *  x 47  x 49  x
6197     + 52  + 54  +  o 57  o 59  o
6198    61 62 63 64 65 66 67 68 69 70
6199     + 72  + 74  +  o 77  o 79  o
6200    81 82 83 84 85 86 87 88 89 90
6201     + 92  + 94  +  o 97  o 99  o
6202  ```
6203
6204  Args:
6205    images: A 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`.
6206    sizes: The size of the extracted patches. Must be
6207      `[1, size_rows, size_cols, 1]`.
6208    strides: A 1-D Tensor of length 4. How far the centers of two consecutive
6209      patches are in the images. Must be: `[1, stride_rows, stride_cols, 1]`.
6210    rates: A 1-D Tensor of length 4. Must be: `[1, rate_rows, rate_cols, 1]`.
6211      This is the input stride, specifying how far two consecutive patch samples
6212      are in the input. Equivalent to extracting patches with `patch_sizes_eff =
6213      patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by subsampling
6214      them spatially by a factor of `rates`. This is equivalent to `rate` in
6215      dilated (a.k.a. Atrous) convolutions.
6216    padding: The type of padding algorithm to use.
6217    name: A name for the operation (optional).
6218
6219  Returns:
6220    A 4-D Tensor of the same type as the input.
6221  """
6222  return gen_array_ops.extract_image_patches(images, sizes, strides, rates,
6223                                             padding, name)
6224
6225
6226@tf_export(v1=["image.extract_image_patches", "extract_image_patches"])
6227@dispatch.add_dispatch_support
6228@deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead",
6229                             "ksizes")
6230def extract_image_patches(  # pylint: disable=missing-docstring
6231    images,
6232    ksizes=None,
6233    strides=None,
6234    rates=None,
6235    padding=None,
6236    name=None,
6237    sizes=None):
6238  """Extract patches from images and put them in the "depth" output dimension.
6239
6240  Args:
6241    `images`: A `Tensor`. Must be one of the following types: `float32`,
6242      `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`,
6243      `uint16`, `half`, `uint32`, `uint64`. 4-D Tensor with shape
6244    `[batch, in_rows, in_cols, depth]`. `ksizes`: A list of `ints` that has
6245      length `>= 4`. The size of the sliding window for each
6246    dimension of `images`. `strides`: A list of `ints` that has length `>= 4`.
6247      1-D of length 4. How far the centers of two consecutive
6248    patches are in the images. Must be:
6249    `[1, stride_rows, stride_cols, 1]`. `rates`: A list of `ints`
6250    that has length `>= 4`. 1-D of length 4. Must be: `[1, rate_rows, rate_cols,
6251      1]`. This is the input stride, specifying how far two consecutive patch
6252      samples are in the input. Equivalent to extracting patches with
6253      `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`,
6254      followed by subsampling them spatially by a factor of `rates`. This is
6255      equivalent to `rate` in dilated (a.k.a. Atrous) convolutions.
6256    `padding`: A `string` from: "SAME", "VALID". The type of padding algorithm
6257      to use.
6258    We specify the size-related attributes as:  ``` ksizes = [1, ksize_rows,
6259      ksize_cols, 1] strides = [1, strides_rows, strides_cols, 1] rates = [1,
6260      rates_rows, rates_cols, 1]
6261    name: A name for the operation (optional). ```
6262
6263  Returns:
6264    A Tensor. Has the same type as images.
6265  """
6266  ksizes = deprecation.deprecated_argument_lookup("sizes", sizes, "ksizes",
6267                                                  ksizes)
6268  return gen_array_ops.extract_image_patches(images, ksizes, strides, rates,
6269                                             padding, name)
6270
6271
6272extract_image_patches.__doc__ = gen_array_ops.extract_image_patches.__doc__
6273
6274
6275@tf_export("fingerprint")
6276@dispatch.add_dispatch_support
6277def fingerprint(data, method="farmhash64", name=None):
6278  r"""Generates fingerprint values.
6279
6280  Generates fingerprint values of `data`.
6281
6282  Fingerprint op considers the first dimension of `data` as the batch dimension,
6283  and `output[i]` contains the fingerprint value generated from contents in
6284  `data[i, ...]` for all `i`.
6285
6286  Fingerprint op writes fingerprint values as byte arrays. For example, the
6287  default method `farmhash64` generates a 64-bit fingerprint value at a time.
6288  This 8-byte value is written out as an `tf.uint8` array of size 8, in
6289  little-endian order.
6290
6291  For example, suppose that `data` has data type `tf.int32` and shape (2, 3, 4),
6292  and that the fingerprint method is `farmhash64`. In this case, the output
6293  shape is (2, 8), where 2 is the batch dimension size of `data`, and 8 is the
6294  size of each fingerprint value in bytes. `output[0, :]` is generated from
6295  12 integers in `data[0, :, :]` and similarly `output[1, :]` is generated from
6296  other 12 integers in `data[1, :, :]`.
6297
6298  Note that this op fingerprints the raw underlying buffer, and it does not
6299  fingerprint Tensor's metadata such as data type and/or shape. For example, the
6300  fingerprint values are invariant under reshapes and bitcasts as long as the
6301  batch dimension remain the same:
6302
6303  ```python
6304  tf.fingerprint(data) == tf.fingerprint(tf.reshape(data, ...))
6305  tf.fingerprint(data) == tf.fingerprint(tf.bitcast(data, ...))
6306  ```
6307
6308  For string data, one should expect `tf.fingerprint(data) !=
6309  tf.fingerprint(tf.string.reduce_join(data))` in general.
6310
6311  Args:
6312    data: A `Tensor`. Must have rank 1 or higher.
6313    method: A `Tensor` of type `tf.string`. Fingerprint method used by this op.
6314      Currently available method is `farmhash64`.
6315    name: A name for the operation (optional).
6316
6317  Returns:
6318    A two-dimensional `Tensor` of type `tf.uint8`. The first dimension equals to
6319    `data`'s first dimension, and the second dimension size depends on the
6320    fingerprint algorithm.
6321  """
6322  return gen_array_ops.fingerprint(data, method, name)
6323
6324
6325def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
6326  """Converts the given value to an integer Tensor."""
6327  tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
6328  if tensor.dtype.is_integer:
6329    tensor = gen_math_ops.cast(tensor, dtype)
6330  else:
6331    raise TypeError("%s must be an integer tensor; dtype=%s" %
6332                    (name, tensor.dtype))
6333  return tensor
6334
6335
6336def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"):
6337  """Validate an `axis` parameter, and normalize it to be positive.
6338
6339  If `ndims` is known (i.e., not `None`), then check that `axis` is in the
6340  range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
6341  `axis + ndims` (otherwise).
6342  If `ndims` is not known, and `axis` is positive, then return it as-is.
6343  If `ndims` is not known, and `axis` is negative, then report an error.
6344
6345  Args:
6346    axis: An integer constant
6347    ndims: An integer constant, or `None`
6348    axis_name: The name of `axis` (for error messages).
6349    ndims_name: The name of `ndims` (for error messages).
6350
6351  Returns:
6352    The normalized `axis` value.
6353
6354  Raises:
6355    ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
6356      `ndims is None`.
6357  """
6358  if not isinstance(axis, int):
6359    raise TypeError("%s must be an int; got %s" %
6360                    (axis_name, type(axis).__name__))
6361  if ndims is not None:
6362    if 0 <= axis < ndims:
6363      return axis
6364    elif -ndims <= axis < 0:
6365      return axis + ndims
6366    else:
6367      raise ValueError("%s=%s out of bounds: expected %s<=%s<%s" %
6368                       (axis_name, axis, -ndims, axis_name, ndims))
6369  elif axis < 0:
6370    raise ValueError("%s may only be negative if %s is statically known." %
6371                     (axis_name, ndims_name))
6372  return axis
6373
6374
6375# This op is intended to exactly match the semantics of numpy.repeat, with
6376# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
6377# when axis is not specified.  Rather than implement that special behavior, we
6378# simply make `axis` be a required argument.
6379#
6380# External (OSS) `tf.repeat` feature request:
6381# https://github.com/tensorflow/tensorflow/issues/8246
6382def repeat_with_axis(data, repeats, axis, name=None):
6383  """Repeats elements of `data`.
6384
6385  Args:
6386    data: An `N`-dimensional tensor.
6387    repeats: A 1-D integer tensor specifying how many times each element in
6388      `axis` should be repeated.  `len(repeats)` must equal `data.shape[axis]`.
6389      Supports broadcasting from a scalar value.
6390    axis: `int`.  The axis along which to repeat values.  Must be less than
6391      `max(N, 1)`.
6392    name: A name for the operation.
6393
6394  Returns:
6395    A tensor with `max(N, 1)` dimensions.  Has the same shape as `data`,
6396    except that dimension `axis` has size `sum(repeats)`.
6397
6398  Example usage:
6399
6400  >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6401  <tf.Tensor: shape=(5,), dtype=string,
6402  numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6403  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6404  <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6405  array([[1, 2],
6406         [1, 2],
6407         [3, 4],
6408         [3, 4],
6409         [3, 4]], dtype=int32)>
6410  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6411  <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6412  array([[1, 1, 2, 2, 2],
6413         [3, 3, 4, 4, 4]], dtype=int32)>
6414
6415  """
6416  if not isinstance(axis, int):
6417    raise TypeError("axis must be an int; got %s" % type(axis).__name__)
6418
6419  with ops.name_scope(name, "Repeat", [data, repeats]):
6420    data = ops.convert_to_tensor(data, name="data")
6421    repeats = convert_to_int_tensor(repeats, name="repeats")
6422    repeats.shape.with_rank_at_most(1)
6423
6424    # If `data` is a scalar, then upgrade it to a vector.
6425    data = _with_nonzero_rank(data)
6426    data_shape = shape(data)
6427
6428    # If `axis` is negative, then convert it to a positive value.
6429    axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
6430
6431    # If we know that `repeats` is a scalar, then we can just tile & reshape.
6432    if repeats.shape.num_elements() == 1:
6433      repeats = reshape(repeats, [])
6434      expanded = expand_dims(data, axis + 1)
6435      tiled = tile_one_dimension(expanded, axis + 1, repeats)
6436      result_shape = concat([
6437          data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:]
6438      ],
6439                            axis=0)
6440      return reshape(tiled, result_shape)
6441
6442
6443    # Check data Tensor shapes.
6444    if repeats.shape.ndims == 1:
6445      data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
6446
6447    repeats = broadcast_to(repeats, [data_shape[axis]])
6448    repeats_original = repeats
6449
6450    # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
6451    if repeats.shape.ndims != axis + 1:
6452      repeats_shape = shape(repeats)
6453      repeats_ndims = rank(repeats)
6454      broadcast_shape = concat(
6455          [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
6456      repeats = broadcast_to(repeats, broadcast_shape)
6457      repeats.set_shape([None] * (axis + 1))
6458
6459    # Create a "sequence mask" based on `repeats`, where slices across `axis`
6460    # contain one `True` value for each repetition.  E.g., if
6461    # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
6462    max_repeat = gen_math_ops.maximum(
6463        0, gen_math_ops._max(repeats, _all_dimensions(repeats)))
6464    mask = sequence_mask(repeats, max_repeat)
6465
6466    # Add a new dimension around each value that needs to be repeated, and
6467    # then tile that new dimension to match the maximum number of repetitions.
6468    expanded = expand_dims(data, axis + 1)
6469    tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
6470
6471    # Use `boolean_mask` to discard the extra repeated values.  This also
6472    # flattens all dimensions up through `axis`.
6473    masked = boolean_mask(tiled, mask)
6474
6475    # Reshape the output tensor to add the outer dimensions back.
6476    if axis == 0:
6477      result = masked
6478    else:
6479      repeated_dim_size = gen_math_ops._sum(
6480          repeats_original,
6481          axis=gen_math_ops._range(0, rank(repeats_original), 1))
6482      result_shape = concat(
6483          [data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]],
6484          axis=0)
6485      result = reshape(masked, result_shape)
6486
6487    # Preserve shape information.
6488    if data.shape.ndims is not None:
6489      new_axis_size = 0 if repeats.shape[0] == 0 else None
6490      result.set_shape(data.shape[:axis].concatenate(
6491          [new_axis_size]).concatenate(data.shape[axis + 1:]))
6492
6493    return result
6494
6495
6496def tile_one_dimension(data, axis, multiple):
6497  """Tiles a single dimension of a tensor."""
6498  # Assumes axis is a nonnegative int.
6499  if data.shape.ndims is not None:
6500    multiples = [1] * data.shape.ndims
6501    multiples[axis] = multiple
6502  else:
6503    ones_value = ones(rank(data), dtypes.int32)
6504    multiples = concat([ones_value[:axis], [multiple], ones_value[axis + 1:]],
6505                       axis=0)
6506  return tile(data, multiples)
6507
6508
6509def _with_nonzero_rank(data):
6510  """If `data` is scalar, then add a dimension; otherwise return as-is."""
6511  if data.shape.ndims is not None:
6512    if data.shape.ndims == 0:
6513      return stack([data])
6514    else:
6515      return data
6516  else:
6517    data_shape = shape(data)
6518    data_ndims = rank(data)
6519    return reshape(data, concat([[1], data_shape], axis=0)[-data_ndims:])
6520
6521
6522@tf_export("repeat")
6523@dispatch.add_dispatch_support
6524def repeat(input, repeats, axis=None, name=None):  # pylint: disable=redefined-builtin
6525  """Repeat elements of `input`.
6526
6527  See also `tf.concat`, `tf.stack`, `tf.tile`.
6528
6529  Args:
6530    input: An `N`-dimensional Tensor.
6531    repeats: An 1-D `int` Tensor. The number of repetitions for each element.
6532      repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
6533      must equal `input.shape[axis]` if axis is not None.
6534    axis: An int. The axis along which to repeat values. By default (axis=None),
6535      use the flattened input array, and return a flat output array.
6536    name: A name for the operation.
6537
6538  Returns:
6539    A Tensor which has the same shape as `input`, except along the given axis.
6540      If axis is None then the output array is flattened to match the flattened
6541      input array.
6542
6543  Example usage:
6544
6545  >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6546  <tf.Tensor: shape=(5,), dtype=string,
6547  numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6548
6549  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6550  <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6551  array([[1, 2],
6552         [1, 2],
6553         [3, 4],
6554         [3, 4],
6555         [3, 4]], dtype=int32)>
6556
6557  >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6558  <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6559  array([[1, 1, 2, 2, 2],
6560         [3, 3, 4, 4, 4]], dtype=int32)>
6561
6562  >>> repeat(3, repeats=4)
6563  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([3, 3, 3, 3], dtype=int32)>
6564
6565  >>> repeat([[1,2], [3,4]], repeats=2)
6566  <tf.Tensor: shape=(8,), dtype=int32,
6567  numpy=array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)>
6568
6569  """
6570  if axis is None:
6571    input = reshape(input, [-1])
6572    axis = 0
6573  return repeat_with_axis(input, repeats, axis, name)
6574