• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for array_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import re
21import time
22import unittest
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python.client import session
29from tensorflow.python.eager import backprop
30from tensorflow.python.eager import context
31from tensorflow.python.eager import def_function
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import errors_impl
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import sparse_tensor
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.framework import test_ops
41from tensorflow.python.framework import test_util
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import gen_array_ops
44from tensorflow.python.ops import gradient_checker_v2
45from tensorflow.python.ops import gradients_impl
46from tensorflow.python.ops import init_ops
47from tensorflow.python.ops import list_ops
48from tensorflow.python.ops import map_fn
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import resource_variable_ops
51from tensorflow.python.ops import state_ops
52from tensorflow.python.ops import variable_scope
53from tensorflow.python.ops import variables
54from tensorflow.python.platform import test as test_lib
55
56
57@test_util.run_all_in_graph_and_eager_modes
58class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
59
60  def testNonBatchMatrix(self):
61    matrix = [[1, 2, 3], [4, 5, 6]]  # Shape (2, 3)
62    expected_transposed = [[1, 4], [2, 5], [3, 6]]  # Shape (3, 2)
63    transposed = array_ops.matrix_transpose(matrix)
64    self.assertEqual((3, 2), transposed.get_shape())
65    self.assertAllEqual(expected_transposed, transposed)
66
67  def testConjugate(self):
68    m = [[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 6 + 6j]]
69    expected_transposed = [[1 - 1j, 4 - 4j], [2 - 2j, 5 - 5j], [3 - 3j, 6 - 6j]]
70    matrix = ops.convert_to_tensor(m)
71    transposed = array_ops.matrix_transpose(matrix, conjugate=True)
72    self.assertEqual((3, 2), transposed.get_shape())
73    self.assertAllEqual(expected_transposed, transposed)
74
75  def testBatchMatrix(self):
76    matrix_0 = [[1, 2, 3], [4, 5, 6]]
77    matrix_0_t = [[1, 4], [2, 5], [3, 6]]
78    matrix_1 = [[11, 22, 33], [44, 55, 66]]
79    matrix_1_t = [[11, 44], [22, 55], [33, 66]]
80    batch_matrix = [matrix_0, matrix_1]  # Shape (2, 2, 3)
81    expected_transposed = [matrix_0_t, matrix_1_t]  # Shape (2, 3, 2)
82    transposed = array_ops.matrix_transpose(batch_matrix)
83    self.assertEqual((2, 3, 2), transposed.get_shape())
84    self.assertAllEqual(expected_transposed, transposed)
85
86  def testNonBatchMatrixDynamicallyDefined(self):
87    # needs explicit `constant` because lists are not automatically
88    # converted to sensors when applying `transpose` below
89    matrix = constant_op.constant([[1, 2, 3], [4, 5, 6]])  # Shape (2, 3)
90    expected_transposed = [[1, 4], [2, 5], [3, 6]]  # Shape (3, 2)
91
92    @def_function.function(input_signature=[
93        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
94    ])
95    def transpose(matrix):
96      self.assertIs(matrix.shape.ndims, None)
97      return array_ops.matrix_transpose(matrix)
98
99    self.assertAllEqual(expected_transposed, transpose(matrix))
100
101  def testBatchMatrixDynamicallyDefined(self):
102    matrix_0 = [[1, 2, 3], [4, 5, 6]]
103    matrix_0_t = [[1, 4], [2, 5], [3, 6]]
104    matrix_1 = [[11, 22, 33], [44, 55, 66]]
105    matrix_1_t = [[11, 44], [22, 55], [33, 66]]
106    # needs explicit `constant` because lists are not automatically
107    # converted to sensors when applying `transpose` below
108    batch_matrix = constant_op.constant([matrix_0, matrix_1])  # Shape (2, 2, 3)
109    expected_transposed = [matrix_0_t, matrix_1_t]  # Shape (2, 3, 2)
110
111    @def_function.function(input_signature=[
112        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
113    ])
114    def transpose(matrix):
115      self.assertIs(matrix.shape.ndims, None)
116      return array_ops.matrix_transpose(matrix)
117
118    self.assertAllEqual(expected_transposed, transpose(batch_matrix))
119
120  def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self):
121    vector = [1, 2, 3]
122    with self.assertRaisesRegex(ValueError, "should be a "):
123      array_ops.matrix_transpose(vector)
124
125
126class BooleanMaskTest(test_util.TensorFlowTestCase):
127
128  def setUp(self):
129    self.rng = np.random.RandomState(42)
130
131  def CheckVersusNumpy(self, ndims_mask, arr_shape, make_mask=None, axis=None):
132    """Check equivalence between boolean_mask and numpy masking."""
133    if make_mask is None:
134      make_mask = lambda shape: self.rng.randint(0, 2, size=shape).astype(bool)
135    arr = np.random.rand(*arr_shape)
136    mask = make_mask(arr_shape[:ndims_mask])
137    if axis is not None:
138      mask = make_mask(arr_shape[axis:ndims_mask + axis])
139    if axis is None or axis == 0:
140      masked_arr = arr[mask]
141    elif axis == 1:
142      masked_arr = arr[:, mask]
143    elif axis == 2:
144      masked_arr = arr[:, :, mask]
145    with self.cached_session():
146      masked_tensor = array_ops.boolean_mask(arr, mask, axis=axis)
147
148      # Leading dimension size of masked_tensor is always unknown until runtime
149      # since we don't how many elements will be kept.
150      leading = 1 if axis is None else axis + 1
151      self.assertAllEqual(masked_tensor.get_shape()[leading:],
152                          masked_arr.shape[leading:])
153
154      self.assertAllClose(masked_arr, masked_tensor)
155
156  @test_util.run_deprecated_v1
157  def testMaskDim1ArrDim2Axis1(self):
158    ndims_mask = 1
159    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
160      with self.subTest(arr_shape=arr_shape):
161        self.CheckVersusNumpy(ndims_mask, arr_shape, axis=1)
162
163  @test_util.run_deprecated_v1
164  def testMaskDim2ArrDim2Axis1(self):
165    ndims_mask = 2
166    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
167      with self.subTest(arr_shape=arr_shape):
168        self.CheckVersusNumpy(ndims_mask, arr_shape, axis=1)
169
170  @test_util.run_deprecated_v1
171  def testMaskDim1ArrDim1(self):
172    ndims_mask = 1
173    for arr_shape in [(1,), (2,), (3,), (10,)]:
174      with self.subTest(arr_shape=arr_shape):
175        self.CheckVersusNumpy(ndims_mask, arr_shape)
176
177  @test_util.run_deprecated_v1
178  def testMaskDim1ArrDim2(self):
179    ndims_mask = 1
180    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
181      with self.subTest(arr_shape=arr_shape):
182        self.CheckVersusNumpy(ndims_mask, arr_shape)
183
184  @test_util.run_deprecated_v1
185  def testMaskDim2ArrDim2(self):
186    ndims_mask = 2
187    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
188      with self.subTest(arr_shape=arr_shape):
189        self.CheckVersusNumpy(ndims_mask, arr_shape)
190
191  @test_util.run_deprecated_v1
192  def testMaskDim2ArrDim3(self):
193    ndims_mask = 2
194    for arr_shape in [(1, 1, 1), (1, 2, 2), (2, 2, 1)]:
195      with self.subTest(arr_shape=arr_shape):
196        self.CheckVersusNumpy(ndims_mask, arr_shape)
197
198  @test_util.run_deprecated_v1
199  def testEmptyInput2D(self):
200    mask = np.array([True, False])
201    arr = np.array([[], []]).astype(np.float32)
202    numpy_result = arr[mask]
203    tf_result = array_ops.boolean_mask(arr, mask)
204    self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
205    with self.cached_session():
206      self.assertAllClose(numpy_result, tf_result)
207
208  @test_util.run_deprecated_v1
209  def testEmptyInput1D(self):
210    mask = np.array([]).astype(bool)
211    arr = np.array([]).astype(np.float32)
212    numpy_result = arr[mask]
213    tf_result = array_ops.boolean_mask(arr, mask)
214    self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
215    with self.cached_session():
216      self.assertAllClose(numpy_result, tf_result)
217
218  @test_util.run_deprecated_v1
219  def testEmptyOutput(self):
220    make_mask = lambda shape: np.zeros(shape, dtype=bool)
221    for ndims_mask in range(1, 4):
222      for ndims_arr in range(ndims_mask, ndims_mask + 3):
223        for _ in range(3):
224          with self.subTest(ndims_mask=ndims_mask, ndims_arr=ndims_arr, _=_):
225            arr_shape = np.random.randint(1, 5, size=ndims_arr)
226            self.CheckVersusNumpy(ndims_mask, arr_shape, make_mask=make_mask)
227
228  @test_util.run_deprecated_v1
229  def testWorksWithDimensionsEqualToNoneDuringGraphBuild(self):
230    # The rank of the mask tensor must be specified. This is explained
231    # in the docstring as well.
232    with self.cached_session() as sess:
233      ph_tensor = array_ops.placeholder(dtypes.int32, shape=None)
234      ph_mask = array_ops.placeholder(dtypes.bool, shape=[None])
235
236      arr = np.array([[1, 2], [3, 4]])
237      mask = np.array([False, True])
238
239      masked_tensor = sess.run(
240          array_ops.boolean_mask(ph_tensor, ph_mask),
241          feed_dict={
242              ph_tensor: arr,
243              ph_mask: mask
244          })
245      np.testing.assert_allclose(masked_tensor, arr[mask])
246
247  @test_util.run_deprecated_v1
248  def testMaskDimensionsSetToNoneRaises(self):
249    # The rank of the mask tensor must be specified. This is explained
250    # in the docstring as well.
251    with self.cached_session():
252      tensor = array_ops.placeholder(dtypes.int32, shape=[None, 2])
253      mask = array_ops.placeholder(dtypes.bool, shape=None)
254      with self.assertRaisesRegex(ValueError, "dimensions must be specified"):
255        array_ops.boolean_mask(tensor, mask)
256
257  def testMaskHasMoreDimsThanTensorRaises(self):
258    mask = [[True, True], [False, False]]
259    tensor = [1, 2, 3, 4]
260    with self.cached_session():
261      with self.assertRaisesRegex(ValueError, "incompatible"):
262        array_ops.boolean_mask(tensor, mask).eval()
263
264  def testMaskIsScalarRaises(self):
265    mask = True
266    tensor = 1
267    with self.cached_session():
268      with self.assertRaisesRegex(ValueError, "mask.*scalar"):
269        array_ops.boolean_mask(tensor, mask).eval()
270
271  def testMaskShapeDifferentThanFirstPartOfTensorShapeRaises(self):
272    mask = [True, True, True]
273    tensor = [[1, 2], [3, 4]]
274    with self.cached_session():
275      with self.assertRaisesRegex(ValueError, "incompatible"):
276        array_ops.boolean_mask(tensor, mask).eval()
277
278  @test_util.run_deprecated_v1
279  def testStringMask(self):
280    # Reproduces b/111171330, where the optimized boolean_mask graph would
281    # be incorrectly placed on GPU.
282    with ops.Graph().as_default():
283      tile_placeholder = array_ops.placeholder(dtypes.int32, [2])
284      string_tensor = array_ops.tile([["hello"]], tile_placeholder)
285      bool_tensor = array_ops.tile([[True]], tile_placeholder)
286      masked_tensor = array_ops.boolean_mask(string_tensor, bool_tensor)
287      config = config_pb2.ConfigProto()
288      config.graph_options.rewrite_options.shape_optimization = 1
289      config.gpu_options.per_process_gpu_memory_fraction = 0.3
290      with session.Session(config=config) as sess:
291        result = sess.run(masked_tensor, feed_dict={tile_placeholder: [2, 2]})
292        self.assertAllEqual([b"hello", b"hello", b"hello", b"hello"], result)
293
294  def testMaskWithAxisTensor(self):
295
296    @def_function.function(autograph=False)
297    def f():
298      return array_ops.boolean_mask([1, 2, 3], [True, False, True],
299                                    axis=constant_op.constant(
300                                        0, dtype=dtypes.int32))
301
302    self.assertAllEqual(self.evaluate(f()), [1, 3])
303
304  def testMaskWithAxisNonConstTensor(self):
305
306    @def_function.function(
307        autograph=False,
308        input_signature=[
309            tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
310        ])
311    def f(axis):
312      return array_ops.boolean_mask([1, 2, 3], [True, False, True], axis=axis)
313
314    self.assertAllEqual(
315        self.evaluate(f(constant_op.constant(0, dtype=dtypes.int32))), [1, 3])
316
317
318@test_util.run_all_in_graph_and_eager_modes
319class OperatorShapeTest(test_util.TensorFlowTestCase):
320
321  def testExpandScalar(self):
322    scalar = "hello"
323    scalar_expanded = array_ops.expand_dims(scalar, [0])
324    self.assertEqual(scalar_expanded.get_shape(), (1,))
325
326  def testSqueezeScalar(self):
327    scalar = "hello"
328    scalar_squeezed = array_ops.squeeze(scalar, ())
329    self.assertEqual(scalar_squeezed.get_shape(), ())
330
331  def testSqueezeMatrix(self):
332    matrix = [[1, 2, 3]]
333    matrix_squeezed = array_ops.squeeze(matrix, [0])
334    self.assertEqual(matrix_squeezed.get_shape(), (3))
335
336    with self.assertRaisesRegex(
337        Exception, "Can not squeeze dim.1., expected a dimension of 1, got 3"):
338      matrix_squeezed = array_ops.squeeze(matrix, [1])
339
340  def testSqueezeScalarDim(self):
341    matrix = [[1, 2, 3]]
342    matrix_squeezed = array_ops.squeeze(matrix, 0)
343    self.assertEqual(matrix_squeezed.get_shape(), (3))
344
345  def testExpandDimsWithNonScalarDim(self):
346    with self.assertRaisesRegex(Exception,
347                                "must be a tensor with a single value"):
348      array_ops.expand_dims(1, axis=[0, 1])
349
350
351class ReverseV2Test(test_util.TensorFlowTestCase):
352
353  @test_util.run_deprecated_v1
354  def testReverse0DimAuto(self):
355    x_np = 4
356    for use_gpu in [False, True]:
357      with self.subTest(use_gpu=use_gpu):
358        with self.cached_session(use_gpu=use_gpu):
359          x_tf = array_ops.reverse_v2(x_np, []).eval()
360          self.assertAllEqual(x_tf, x_np)
361
362  def _reverse1DimAuto(self, np_dtype):
363    x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype)
364
365    for use_gpu in [False, True]:
366      for axis_dtype in [dtypes.int32, dtypes.int64]:
367        with self.subTest(use_gpu=use_gpu, axis_dtype=axis_dtype):
368          with self.cached_session(use_gpu=use_gpu):
369            x_tf = array_ops.reverse_v2(
370                x_np, constant_op.constant([0], dtype=axis_dtype)).eval()
371            self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
372
373  def _reverse2DimAuto(self, np_dtype):
374    x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype)
375
376    for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
377      for use_gpu in [False, True]:
378        for axis_dtype in [dtypes.int32, dtypes.int64]:
379          with self.subTest(
380              reverse_f=reverse_f, use_gpu=use_gpu, axis_dtype=axis_dtype):
381            with self.cached_session(use_gpu=use_gpu):
382              x_tf_1 = reverse_f(x_np,
383                                 constant_op.constant([0],
384                                                      dtype=axis_dtype)).eval()
385              x_tf_2 = reverse_f(x_np,
386                                 constant_op.constant([-2],
387                                                      dtype=axis_dtype)).eval()
388              x_tf_3 = reverse_f(x_np,
389                                 constant_op.constant([1],
390                                                      dtype=axis_dtype)).eval()
391              x_tf_4 = reverse_f(x_np,
392                                 constant_op.constant([-1],
393                                                      dtype=axis_dtype)).eval()
394              x_tf_5 = reverse_f(x_np,
395                                 constant_op.constant([1, 0],
396                                                      dtype=axis_dtype)).eval()
397              self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
398              self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
399              self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
400              self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1])
401              self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1])
402
403  # This test covers the axis validation in the shape function
404  # (no eval())
405  @test_util.run_deprecated_v1
406  def testInvalidAxis(self):
407    x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
408    with self.assertRaisesRegex(ValueError, "is out of valid range"):
409      array_ops.reverse_v2(x_np, [-30])
410    with self.assertRaisesRegex(ValueError, "is out of valid range"):
411      array_ops.reverse_v2(x_np, [2])
412    with self.assertRaisesRegex(ValueError, "axis 0 specified more than once"):
413      array_ops.reverse_v2(x_np, [0, -2])
414
415  # This is the version of reverse that uses axis indices rather than
416  # bool tensors
417  # TODO(b/32254538): Change this test to use array_ops.reverse
418  #
419  # Note: this test passes placeholder as constant axis is validated
420  # in shape function (see testInvalidAxis)
421  @test_util.run_deprecated_v1
422  def testInvalid(self):
423    x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
424    axis = array_ops.placeholder(dtypes.int32)
425    with self.cached_session():
426      with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
427                                  "is out of.*range"):
428        array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [-30]})
429      with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
430                                  "is out of.*range"):
431        array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [2]})
432      with self.assertRaisesRegex(
433          errors_impl.InvalidArgumentError,
434          "(axis 0 specified more than once|canonicalized axis 0 was repeated.)"
435      ):
436        array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [0, -2]})
437
438  @test_util.run_deprecated_v1
439  def testReverse1DimAuto(self):
440    for dtype in [
441        np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
442        np.float16, np.float32, np.float64, np.complex64, np.complex128,
443        np.array(b"").dtype.type
444    ]:
445      self._reverse1DimAuto(dtype)
446
447  @test_util.run_deprecated_v1
448  def testReverse2DimAuto(self):
449    for dtype in [
450        np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
451        np.float16, np.float32, np.float64, np.complex64, np.complex128,
452        np.array(b"").dtype.type
453    ]:
454      self._reverse2DimAuto(dtype)
455
456  @test_util.run_deprecated_v1
457  def testUnknownDims(self):
458    reverse_v2 = array_ops.reverse_v2
459    data_t = array_ops.placeholder(dtypes.float32)
460    axis_known_t = array_ops.placeholder(dtypes.int32, shape=[3])
461    reverse_known_t = reverse_v2(data_t, axis_known_t)
462    # Unlike V1 we cannot know this anymore
463    self.assertEqual(None, reverse_known_t.get_shape().ndims)
464
465    axis_unknown_t = array_ops.placeholder(dtypes.int32)
466    reverse_unknown_t = reverse_v2(data_t, axis_unknown_t)
467    self.assertIs(None, reverse_unknown_t.get_shape().ndims)
468
469    data_2d_t = array_ops.placeholder(dtypes.float32, shape=[None, None])
470    axis_2d_t = array_ops.placeholder(dtypes.int32, shape=[3])
471    reverse_2d_t = reverse_v2(data_2d_t, axis_2d_t)
472    self.assertEqual(2, reverse_2d_t.get_shape().ndims)
473
474  @test_util.run_deprecated_v1
475  def testReverseRowsOf3Channels(self):
476    """Tests optimized code for reversing rows with last dim size = 3."""
477    with self.session():
478      for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
479        for outer_size in (1, 2):
480          for middle_size in list(range(50)) + [100000]:
481            with self.subTest(
482                reverse_f=reverse_f,
483                outer_size=outer_size,
484                middle_size=middle_size):
485              x_np = np.reshape(
486                  np.arange(outer_size * middle_size * 3, dtype=np.float32),
487                  newshape=(outer_size, middle_size, 3))
488              x_tf = reverse_f(x_np, [1]).eval()
489              np_answer = x_np[:, ::-1, :]
490              self.assertAllEqual(x_tf, np_answer)
491
492  @test_util.run_deprecated_v1
493  def testReverseRowsOf4Channels(self):
494    with self.session():
495      for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
496        for outer_size in (1, 2):
497          for middle_size in list(range(50)) + [100000]:
498            with self.subTest(
499                reverse_f=reverse_f,
500                outer_size=outer_size,
501                middle_size=middle_size):
502              x_np = np.reshape(
503                  np.arange(outer_size * middle_size * 4, dtype=np.float32),
504                  newshape=(outer_size, middle_size, 4))
505              x_tf = reverse_f(x_np, [1]).eval()
506              np_answer = x_np[:, ::-1, :]
507              self.assertAllEqual(x_tf, np_answer)
508
509  @test_util.run_deprecated_v1
510  def testReverseColumnsOf3Channels(self):
511    with self.session():
512      for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
513        for outer_size in list(range(50)) + [100000]:
514          for middle_size in (1, 2):
515            with self.subTest(
516                reverse_f=reverse_f,
517                outer_size=outer_size,
518                middle_size=middle_size):
519              x_np = np.reshape(
520                  np.arange(outer_size * middle_size * 3, dtype=np.float32),
521                  newshape=(outer_size, middle_size, 3))
522              x_tf = reverse_f(x_np, [0]).eval()
523              np_answer = x_np[::-1, :, :]
524              self.assertAllEqual(x_tf, np_answer)
525
526  def testReverseInvalidShape(self):
527    x = np.ndarray(shape=[0, 1, 1])
528    v = array_ops.reverse_v2(x, axis=[1])
529    self.assertAllEqual(self.evaluate(v), v)
530
531
532class MeshgridTest(test_util.TensorFlowTestCase):
533
534  def _compareDiff(self, x, y, use_gpu):
535    for index in ("ij", "xy"):
536      numpy_out = np.meshgrid(x, y, indexing=index)
537      tf_out = array_ops.meshgrid(x, y, indexing=index)
538      with self.cached_session(use_gpu=use_gpu):
539        for xx, yy in zip(numpy_out, tf_out):
540          self.assertAllEqual(xx, yy)
541
542  def _compareDiffType(self, n, np_dtype, use_gpu):
543    inputs = []
544    for index in ("ij", "xy"):
545      for _ in range(n):
546        x = np.linspace(-10, 10, 5).astype(np_dtype)
547        if np_dtype in (np.complex64, np.complex128):
548          x += 1j
549        inputs.append(x)
550      numpy_out = np.meshgrid(*inputs, indexing=index)
551      with self.cached_session(use_gpu=use_gpu):
552        tf_out = array_ops.meshgrid(*inputs, indexing=index)
553        for x_np, x_tf in zip(numpy_out, tf_out):
554          self.assertAllEqual(x_np, x_tf)
555
556  @test_util.run_deprecated_v1
557  def testCompare(self):
558    for t in (np.float16, np.float32, np.float64, np.int32, np.int64,
559              np.complex64, np.complex128):
560      with self.subTest(t=t):
561        self._compareDiffType(2, t, False)
562        self._compareDiffType(3, t, False)
563
564        x = [1, 2, 3]
565        y = [4, 5]
566
567        a = [[1, 1], [1, 1]]
568
569        self._compareDiff(x, y, False)
570        self._compareDiff(x, a, False)
571
572
573class StridedSliceChecker(object):
574  """Check a given tensor against the numpy result."""
575
576  REF_TENSOR = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
577  REF_TENSOR_ALIGNED = np.arange(1, 97, dtype=np.float32).reshape(3, 4, 8)
578
579  def __init__(self, test, x, tensor_type=dtypes.int32, check_type_infer=True):
580    self.x_np = np.array(x).astype(tensor_type.as_numpy_dtype)
581    if tensor_type.is_bool:
582      self.x_np = np.array(x % 3).astype(np.bool)
583    # Give the value a non-zero imaginary component for complex types.
584    if tensor_type.is_complex:
585      self.x_np -= 1j * self.x_np
586    self.test = test
587    self.x = constant_op.constant(self.x_np, dtype=tensor_type)
588    self.check_type_infer = check_type_infer
589
590  def __getitem__(self, spec):
591    op = self.x.__getitem__(spec)
592
593    def eval_if_tensor(x):
594      try:
595        return x.eval()
596      except AttributeError:
597        return x
598
599    if isinstance(spec, bool) or \
600      (isinstance(spec, ops.Tensor) and spec.dtype == dtypes.bool) or \
601      (isinstance(spec, np.ndarray) and spec.dtype == bool) or \
602      (isinstance(spec, (list, tuple)) and np.asarray(spec).dtype == bool):
603      tensor = op.eval()
604      np_spec = eval_if_tensor(spec)
605      self.test.assertAllEqual(self.x_np[np_spec], tensor)
606      return tensor
607
608    if not isinstance(spec, (list, tuple)):
609      spec = [spec]
610
611    tensor = op.eval()
612
613    # Make a numpy spec that pre-evals the tensors
614    np_specs = []
615
616    for s in spec:
617      if isinstance(s, slice):
618        start = eval_if_tensor(s.start)
619        stop = eval_if_tensor(s.stop)
620        step = eval_if_tensor(s.step)
621        np_specs.append(slice(start, stop, step))
622      else:
623        np_specs.append(eval_if_tensor(s))
624
625    self.test.assertAllEqual(self.x_np[tuple(np_specs)], tensor)
626    if self.check_type_infer:
627      self.test.assertAllEqual(tensor.shape, op.get_shape())
628    return tensor
629
630
631STRIDED_SLICE_TYPES = [
632    dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8, dtypes.float32,
633    dtypes.float64, dtypes.complex64, dtypes.complex128, dtypes.bool
634]
635
636
637class StridedSliceTest(test_util.TensorFlowTestCase):
638  """Test the strided slice operation with variants of slices."""
639
640  @test_util.run_deprecated_v1
641  def test_basic_slice(self):
642    for tensor_type in STRIDED_SLICE_TYPES:
643      with self.subTest(tensor_type=tensor_type):
644        with self.cached_session():
645          checker = StridedSliceChecker(
646              self, StridedSliceChecker.REF_TENSOR, tensor_type=tensor_type)
647          _ = checker[:, :, :]
648          # Various ways of representing identity slice
649          _ = checker[:, :, :]
650          _ = checker[::, ::, ::]
651          _ = checker[::1, ::1, ::1]
652          # Not zero slice
653          _ = checker[::1, ::5, ::2]
654          # Reverse in each dimension independently
655          _ = checker[::-1, :, :]
656          _ = checker[:, ::-1, :]
657          _ = checker[:, :, ::-1]
658          ## negative index tests i.e. n-2 in first component
659          _ = checker[-2::-1, :, ::1]
660          # negative index tests i.e. n-2 in first component, non-unit stride
661          _ = checker[-2::-1, :, ::2]
662
663          # Check rank-0 examples
664          checker2 = StridedSliceChecker(self, 5, tensor_type=tensor_type)
665          _ = checker2[None]
666          _ = checker2[...]
667          _ = checker2[tuple()]
668
669  def testInt64GPU(self):
670    if not test_util.is_gpu_available():
671      self.skipTest("No GPU available")
672
673    with test_util.force_gpu():
674      x = constant_op.constant([1., 2., 3.])
675      begin = constant_op.constant([2], dtype=dtypes.int64)
676      end = constant_op.constant([3], dtype=dtypes.int64)
677      strides = constant_op.constant([1], dtype=dtypes.int64)
678      s = array_ops.strided_slice(x, begin, end, strides)
679      self.assertAllEqual([3.], self.evaluate(s))
680
681  @test_util.assert_no_new_pyobjects_executing_eagerly
682  @test_util.assert_no_garbage_created
683  def testTensorSliceEagerMemory(self):
684    with context.eager_mode():
685      inputs = constant_op.constant([[[1], [2], [3], [4]]],
686                                    dtype=dtypes.float32)
687      # Tests that slicing an EagerTensor doesn't leak memory
688      inputs[0]  # pylint: disable=pointless-statement
689
690  @test_util.assert_no_new_pyobjects_executing_eagerly
691  @test_util.assert_no_garbage_created
692  def testVariableSliceEagerMemory(self):
693    with context.eager_mode():
694      v = variables.Variable([1., 2.])
695      v[0]  # pylint: disable=pointless-statement
696
697  @test_util.run_deprecated_v1
698  def testDegenerateSlices(self):
699    with self.session():
700      checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
701      # degenerate by offering a forward interval with a negative stride
702      _ = checker[0:-1:-1, :, :]
703      # degenerate with a reverse interval with a positive stride
704      _ = checker[-1:0, :, :]
705      # empty interval in every dimension
706      _ = checker[-1:0, 2:2, 2:3:-1]
707      # empty first dimension only (used to break for aligned tensors).
708      checker = StridedSliceChecker(self,
709                                    StridedSliceChecker.REF_TENSOR_ALIGNED)
710      _ = checker[1:0]
711
712  @test_util.run_deprecated_v1
713  def testSliceWithUndefinedDimension(self):
714    t = constant_op.constant([1, 2, 3])
715    d = tensor_shape.Dimension(None)
716    self.assertAllEqual(t[d:d:d], t)
717
718  @test_util.run_deprecated_v1
719  def testEllipsis(self):
720    with self.session():
721      raw = [[[[[1, 2], [3, 4], [5, 6]]], [[[7, 8], [9, 10], [11, 12]]]]]
722      checker = StridedSliceChecker(self, raw)
723
724      _ = checker[0:]
725      # implicit ellipsis
726      _ = checker[0:, ...]
727      # ellipsis alone
728      _ = checker[...]
729      # ellipsis at end
730      _ = checker[0:1, ...]
731      # ellipsis at begin
732      _ = checker[..., 0:1]
733      # ellipsis at middle
734      _ = checker[0:1, ..., 0:1]
735      # multiple ellipses not allowed
736      with self.assertRaisesRegex(ValueError, "Multiple ellipses"):
737        _ = checker[..., :, ...].eval()
738
739  @test_util.run_deprecated_v1
740  def testShrink(self):
741    with self.session():
742      raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
743              [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
744      checker = StridedSliceChecker(self, raw)
745      _ = checker[:, :, :, :, 3]
746      _ = checker[..., 3]
747      _ = checker[:, 0]
748      _ = checker[:, :, 0]
749
750  @test_util.run_deprecated_v1
751  def testBothNewAxisAndShrink(self):
752    with self.session():
753      ones = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int16)
754      self.assertAllEqual(
755          ones[array_ops.newaxis, :,
756               0].eval(feed_dict={ones: [[1, 1], [1, 1]]}), [[1, 1]])
757
758  @test_util.run_deprecated_v1
759  def testTensorIndexing(self):
760    with self.session():
761      raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
762              [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
763      checker = StridedSliceChecker(self, raw, check_type_infer=False)
764      bar = constant_op.constant(2)
765      bar2 = constant_op.constant(3)
766      _ = checker[..., bar:bar2]
767      _ = checker[..., bar]
768      _ = checker[..., 3]
769      _ = checker[..., 2**64 // 2**63]  # Test longs in Python 2
770
771  def testTensorIndexingTypeError(self):
772    with self.session():
773      checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
774      expected = re.escape(array_ops._SLICE_TYPE_ERROR)
775      with self.assertRaisesRegex(TypeError, expected):
776        _ = checker["foo"]
777      with self.assertRaisesRegex(TypeError, expected):
778        _ = checker[constant_op.constant("foo")]
779      with self.assertRaisesRegex(TypeError, expected):
780        _ = checker[0.0]
781      with self.assertRaisesRegex(TypeError, expected):
782        _ = checker[constant_op.constant(0.0)]
783      with self.assertRaisesRegex(TypeError, expected):
784        _ = checker[constant_op.constant([1, 2, 3])]
785      with self.assertRaisesRegex(TypeError, expected):
786        _ = checker[[2.1, -0.7, 1.5]]
787
788  @test_util.run_deprecated_v1
789  def testExpand(self):
790    with self.session():
791      raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
792              [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
793      checker = StridedSliceChecker(self, raw)
794      # new axis (followed by implicit ellipsis)
795      _ = checker[np.newaxis]
796      # newaxis after ellipsis
797      _ = checker[..., np.newaxis]
798      # newaxis in between ellipsis and explicit range
799      _ = checker[..., np.newaxis, :]
800      _ = checker[:, ..., np.newaxis, :, :]
801      # Reverse final dimension with new axis
802      _ = checker[:, :, np.newaxis, :, 2::-1]
803      # Ellipsis in middle of two newaxis
804      _ = checker[np.newaxis, ..., np.newaxis]
805
806  @test_util.run_deprecated_v1
807  def testExpandVariable(self):
808    with self.session():
809      x = variables.Variable(7, dtype=dtypes.int32)
810      self.evaluate(x.initializer)
811      y = x[None].eval()
812      self.assertEqual(y.shape, (1,))
813      self.assertAllEqual(y, (7,))
814
815  @test_util.run_deprecated_v1
816  def testOptimizedCases(self):
817    with self.session():
818      checker = StridedSliceChecker(self,
819                                    StridedSliceChecker.REF_TENSOR_ALIGNED)
820      # Identity
821      _ = checker[:]
822      # Identity
823      _ = checker[...]
824      # Identity
825      _ = checker[np.newaxis, ..., np.newaxis]
826      # First axis slice
827      _ = checker[1:]
828      # First axis slice
829      _ = checker[np.newaxis, 1:]
830
831  @test_util.run_v1_only("currently failing on v2")
832  def testMasks(self):
833    with self.session():
834      scalar = np.array(0)
835      # Test tensor type mask
836      checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
837      _ = checker[checker.x > 2]
838      _ = checker[checker.x <= 5]
839      _ = checker[ops.convert_to_tensor(scalar)]
840
841      # Test numpy array type mask
842      raw = np.array([[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
843                       [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23,
844                                                              24]]]]])
845      checker1 = StridedSliceChecker(self, raw)
846      _ = checker1[raw >= 4]
847      _ = checker1[raw < 19]
848      _ = checker1[scalar]
849
850      # Test boolean and non boolean cases
851      mask = np.array([True, False, True])
852      raw1 = np.array([[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]])
853      checker2 = StridedSliceChecker(self, raw1)
854      _ = checker2[mask]
855      _ = checker2[ops.convert_to_tensor(mask)]
856
857
858class StridedSliceShapeChecker(object):
859
860  def __init__(self, x):
861    self.x = x
862
863  def __getitem__(self, spec):
864    op = self.x.__getitem__(spec)
865    return op.get_shape()
866
867
868class StridedSliceShapeTest(test_util.TensorFlowTestCase):
869  """Test the shape inference of StridedSliceShapes."""
870
871  @test_util.run_deprecated_v1
872  def testUnknown(self):
873    with self.session():
874      uncertain_tensor = array_ops.placeholder(dtypes.float32)
875      a = StridedSliceShapeChecker(uncertain_tensor)
876      a_slice_shape = a[...]
877      self.assertAllEqual(a_slice_shape.ndims, None)
878
879  def tensorShapeEqual(self, x, y):
880    self.assertTrue(x is not None and y is not None or x is None and y is None)
881    self.assertEqual(x.as_list(), y.as_list())
882
883  @test_util.run_deprecated_v1
884  def testTensorShapeUncertain(self):
885    with self.session():
886      uncertain_tensor = array_ops.placeholder(
887          dtypes.float32, shape=(5, None, 7))
888      a = StridedSliceShapeChecker(uncertain_tensor)
889      self.tensorShapeEqual(a[3:5], tensor_shape.TensorShape([2, None, 7]))
890      self.tensorShapeEqual(a[3:5, :, 4], tensor_shape.TensorShape([2, None]))
891      self.tensorShapeEqual(a[3:5, 3:4, 4], tensor_shape.TensorShape([2, None]))
892      self.tensorShapeEqual(a[3:5, :, 5:10],
893                            tensor_shape.TensorShape([2, None, 2]))
894      self.tensorShapeEqual(a[3:5, :, 50:3],
895                            tensor_shape.TensorShape([2, None, 0]))
896      self.tensorShapeEqual(a[3:5, :, array_ops.newaxis, 50:3,],
897                            tensor_shape.TensorShape([2, None, 1, 0]))
898      self.tensorShapeEqual(a[1:5:2, :, array_ops.newaxis, 50:3,],
899                            tensor_shape.TensorShape([2, None, 1, 0]))
900      self.tensorShapeEqual(a[:5:3, :, array_ops.newaxis, 50:3,],
901                            tensor_shape.TensorShape([2, None, 1, 0]))
902      self.tensorShapeEqual(a[:2:3, :, array_ops.newaxis, 50:3,],
903                            tensor_shape.TensorShape([1, None, 1, 0]))
904      self.tensorShapeEqual(a[::-1, :, array_ops.newaxis, ::-2],
905                            tensor_shape.TensorShape([5, None, 1, 4]))
906
907  @test_util.run_deprecated_v1
908  def testTensorValuedIndexShape(self):
909    with self.session():
910      defined_shape_tensor = array_ops.placeholder(
911          dtypes.float32, shape=(5, 3, 7))
912      index_value = array_ops.placeholder(dtypes.int32, shape=())
913      a = StridedSliceShapeChecker(defined_shape_tensor)
914      self.tensorShapeEqual(a[index_value], tensor_shape.TensorShape([3, 7]))
915      self.tensorShapeEqual(a[index_value, ::-1],
916                            tensor_shape.TensorShape([3, 7]))
917      self.tensorShapeEqual(a[index_value, ::-2],
918                            tensor_shape.TensorShape([2, 7]))
919      other_scalar = array_ops.placeholder(dtypes.int32, shape=())
920      self.tensorShapeEqual(a[index_value, other_scalar:2],
921                            tensor_shape.TensorShape([None, 7]))
922
923
924class GradSliceChecker(object):
925  """Tests that we can compute a gradient for var^2."""
926
927  def __init__(self, test, sess, var, varnp):
928    self.test = test
929    self.sess = sess
930    self.val = var * var
931    self.var = var
932    self.varnp = varnp
933
934  def __getitem__(self, spec):
935    slice_var = self.var[spec]
936    slice_val = self.val[spec]
937
938    # compute analytic 2nd derivative
939    analytic_grad2 = 2 * slice_val
940
941    dy = variables.Variable(
942        array_ops.ones_like(slice_var, dtype=dtypes.float32))
943    assign = dy.assign(slice_var)
944    slice_val_grad, = gradients_impl.gradients(slice_val, self.var, grad_ys=dy)
945    slice_val_grad2, = gradients_impl.gradients(
946        slice_val_grad, dy, grad_ys=self.var)
947    self.sess.run(assign)
948    slice_val_grad_evaled, slice_val_grad2_evaled = (
949        self.sess.run([slice_val_grad, slice_val_grad2]))
950    analytic_grad2_evaled = analytic_grad2.eval()
951    self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled)
952
953    # compute analytic gradient for slice
954    np_val_grad = (2 * self.varnp * self.varnp)
955    np_sliceval_grad = np.zeros(self.var.get_shape())
956    if isinstance(spec, ops.Tensor):
957      spec = self.sess.run([spec])
958    np_sliceval_grad[spec] = np_val_grad[spec]
959    # verify gradient
960    self.test.assertAllEqual(slice_val_grad_evaled, np_sliceval_grad)
961
962
963class StridedSliceGradTest(test_util.TensorFlowTestCase):
964  """Test that strided slice's custom gradient produces correct gradients."""
965
966  @test_util.run_v1_only("b/120545219")
967  def testGradient(self):
968    with self.session() as sess:
969      var = variables.Variable(
970          array_ops.reshape(
971              math_ops.range(1, 97, 1, dtype=dtypes.float32), shape=(6, 4, 4)))
972      init = variables.global_variables_initializer()
973      sess.run(init)
974
975      raw = np.array(range(1, 97, 1)).reshape((6, 4, 4))
976      grad = GradSliceChecker(self, sess, var, raw)
977      _ = grad[2:6:2, 1:3, 1:3]
978      _ = grad[3:0:-2, 1:3, 1:3]
979      _ = grad[3:0:-2, array_ops.newaxis, 1:3, 2, array_ops.newaxis]
980      _ = grad[3:0:-2, 1:3, 2]
981      _ = grad[:, -1, :]
982      _ = grad[:, -2, :]
983      with self.assertRaisesRegex(ValueError, "out of bounds"):
984        _ = grad[:, -200, :]
985      with self.assertRaisesRegex(ValueError, "out of bounds"):
986        _ = grad[:, 200, :]
987
988      # Test numpy array type mask
989      _ = grad[raw > 51]
990      # Test tensor type mask
991      _ = grad[ops.convert_to_tensor(raw) <= 76]
992
993  @test_util.run_v1_only("b/120545219")
994  def testGradientZero(self):
995    with self.session() as sess:
996      var = variables.Variable(8.)
997      init = variables.global_variables_initializer()
998      sess.run(init)
999      grad = GradSliceChecker(self, sess, var, np.array(8))
1000      _ = grad[tuple()]
1001
1002  @test_util.run_deprecated_v1
1003  def testInt64Indices(self):
1004    with self.session():
1005      a = math_ops.range(3, dtype=dtypes.float32)
1006      index = constant_op.constant(1, dtype=dtypes.int64)
1007      b = 2. * a[index]
1008      grad, = gradients_impl.gradients(b, a)
1009      self.assertAllEqual(self.evaluate(grad), [0., 2., 0.])
1010
1011
1012class StridedSliceGradTypeTest(test_util.TensorFlowTestCase):
1013  """Test varied index types and host located memory."""
1014
1015  @test_util.run_deprecated_v1
1016  def testHostVsDevice(self):
1017    with self.session() as sess:
1018      var2 = variables.Variable(
1019          array_ops.reshape(
1020              math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
1021              shape=(4, 1, 1)))
1022      varshape = variables.Variable([6, 4, 4], dtype=dtypes.int32)
1023      self.evaluate(variables.global_variables_initializer())
1024      begin = constant_op.constant([0, 0, 0])
1025      end = constant_op.constant([4, 1, 1])
1026      strides = constant_op.constant([1, 1, 1])
1027      foo = array_ops.strided_slice_grad(varshape, begin, end, strides, var2)
1028      sess.run(foo)
1029
1030  @test_util.run_deprecated_v1
1031  def testInt64Shape(self):
1032    with self.session() as sess:
1033      original_dy = array_ops.reshape(
1034          math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
1035          shape=(4, 1, 1))
1036      original_shape = constant_op.constant([6, 4, 4], dtype=dtypes.int64)
1037      self.evaluate(variables.global_variables_initializer())
1038      begin = constant_op.constant([0, 0, 0], dtype=dtypes.int64)
1039      end = constant_op.constant([4, 1, 1], dtype=dtypes.int64)
1040      strides = constant_op.constant([1, 1, 1], dtype=dtypes.int64)
1041      dx = array_ops.strided_slice_grad(original_shape, begin, end, strides,
1042                                        original_dy)
1043      sess.run(dx)
1044
1045  @test_util.run_deprecated_v1
1046  def testMixedIndexTypes(self):
1047    with self.session() as sess:
1048      original_dy = array_ops.reshape(
1049          math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
1050          shape=(4, 1, 1))
1051      original_shape = constant_op.constant([6, 4, 4], dtype=dtypes.int64)
1052      self.evaluate(variables.global_variables_initializer())
1053      begin = constant_op.constant([0, 0, 0], dtype=dtypes.int32)
1054      end = constant_op.constant([4, 1, 1], dtype=dtypes.int64)
1055      strides = constant_op.constant([1, 1, 1], dtype=dtypes.int64)
1056      with self.assertRaisesRegex(
1057          TypeError, "Input 'begin' of 'StridedSliceGrad' Op has type int32"
1058          " that does not match type int64 of argument 'shape'"):
1059        dx = array_ops.strided_slice_grad(original_shape, begin, end, strides,
1060                                          original_dy)
1061        sess.run(dx)
1062
1063
1064class BenchmarkSlice(object):
1065
1066  def __init__(self, tensor):
1067    self.tensor = tensor
1068
1069  def __getitem__(self, x):
1070    return self.tensor[x]
1071
1072
1073class StridedSliceBenchmark(test_lib.Benchmark):
1074  """Benchmark new strided slice operation on non-trivial case."""
1075
1076  def run_and_time(self, slice_op):
1077    self.evaluate(variables.global_variables_initializer())
1078    for _ in range(10):
1079      _ = self.evaluate(slice_op)
1080    iters = 1000
1081    t0 = time.time()
1082    for _ in range(iters):
1083      self.evaluate(slice_op)
1084    t1 = time.time()
1085    self.report_benchmark(iters=iters, wall_time=(t1 - t0) / 1000.0)
1086
1087  def make_variable(self):
1088    n = 256
1089    shape = (n, n, n)
1090    items = n**3
1091    var = variables.Variable(
1092        array_ops.reshape(math_ops.linspace(1., float(items), items), shape),
1093        dtype=dtypes.float32)
1094    return var
1095
1096  def benchmark_strided_slice_skip(self):
1097    with session.Session():
1098      var = self.make_variable()
1099      helper = BenchmarkSlice(var)
1100      slice_op = helper[::2, ::1, ::2]
1101      self.run_and_time(slice_op)
1102
1103  def benchmark_strided_slice_easy(self):
1104    with session.Session():
1105      var = self.make_variable()
1106      helper = BenchmarkSlice(var)
1107      slice_op = helper[3::1, 3::1, 3::1]
1108      self.run_and_time(slice_op)
1109
1110  def benchmark_slice_easy(self):
1111    with session.Session():
1112      var = self.make_variable()
1113      slice_op = var[3::1, 3::1, 3::1]
1114      self.run_and_time(slice_op)
1115
1116
1117class StridedSliceAssignChecker(object):
1118
1119  def __init__(self, test, x, tensor_type=dtypes.float32, use_resource=False):
1120    self.tensor_type = tensor_type
1121    self.test = test
1122    self._use_resource = use_resource
1123
1124    self.x_np = np.array(x).astype(tensor_type.as_numpy_dtype)
1125    # Give the value a non-zero imaginary component for complex types.
1126    if tensor_type.is_complex:
1127      self.x_np -= 1j * self.x_np
1128    self.x = constant_op.constant(self.x_np, dtype=tensor_type)
1129
1130  def __setitem__(self, index, value):
1131    value = np.array(value).astype(self.tensor_type.as_numpy_dtype)
1132    # Give the value a non-zero imaginary component for complex types.
1133    if self.tensor_type.is_complex:
1134      value -= 1j * value
1135
1136    with self.test.test_session() as sess:
1137      if self._use_resource:
1138        var = resource_variable_ops.ResourceVariable(self.x)
1139      else:
1140        var = variables.Variable(self.x)
1141      sess.run(variables.variables_initializer([var]))
1142      val = sess.run(var[index].assign(value))
1143      # val_copy is used to check that tf.compat.v1.assign works equivalently
1144      # to the assign method above.
1145      val_copy = sess.run(state_ops.assign(var[index], value))
1146      valnp = np.copy(self.x_np)
1147      valnp[index] = np.array(value)
1148      self.test.assertAllEqual(val, valnp)
1149      self.test.assertAllEqual(val_copy, valnp)
1150
1151
1152class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
1153
1154  def testInvalidSlice(self):
1155    foo = constant_op.constant([1, 2, 3])
1156    with self.assertRaisesRegex(AttributeError, "no attribute 'assign'"):
1157      bar = foo[:2].assign(constant_op.constant([1, 2]))
1158      self.evaluate(bar)
1159
1160  def doTestSliceAssign(self, use_resource):
1161    for dtype in STRIDED_SLICE_TYPES:
1162      with self.subTest(dtype=dtype):
1163        checker = StridedSliceAssignChecker(
1164            self, [[1, 2, 3], [4, 5, 6]],
1165            use_resource=use_resource,
1166            tensor_type=dtype)
1167        # Check if equal
1168        checker[:] = [[10, 20, 30], [40, 50, 60]]
1169        # Check trivial (1,1) shape tensor
1170        checker[1:2, 1:2] = [[66]]
1171        # shrinks shape changes
1172        checker[1:2, 1] = [66]
1173        checker[1, 1:2] = [66]
1174        checker[1, 1] = 66
1175        # newaxis shape changes
1176        checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
1177        # shrink and newaxis
1178        checker[None, None, 0, 0:1] = [[[99]]]
1179        # Non unit strides
1180        checker[::1, ::-2] = [[3, 33], [4, 44]]
1181        # degenerate interval
1182        checker[8:10, 0] = []
1183        checker[8:10, 8:10] = [[]]
1184    # Assign vector to scalar (rank-0) using newaxis
1185    checker2 = StridedSliceAssignChecker(self, 222)
1186    checker2[()] = 6  # no indices
1187    checker2[...] = 6  # ellipsis
1188    checker2[None] = [6]  # new axis
1189
1190  @test_util.run_deprecated_v1
1191  @test_util.disable_xla("b/123559667")
1192  def testSliceAssign(self):
1193    self.doTestSliceAssign(use_resource=False)
1194
1195  @test_util.run_deprecated_v1
1196  @test_util.disable_xla("b/123559667")
1197  def testSliceAssignResource(self):
1198    self.doTestSliceAssign(use_resource=True)
1199
1200  @test_util.run_v1_only("b/120545219")
1201  def testUninitialized(self):
1202    with self.assertRaisesRegex(
1203        errors.FailedPreconditionError,
1204        "Attempting to use uninitialized value Variable"):
1205      with self.cached_session() as sess:
1206        v = variables.VariableV1([1, 2])
1207        sess.run(v[:].assign([1, 2]))
1208
1209  @test_util.run_v1_only("b/120545219")
1210  def testTypeError(self):
1211    init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
1212    too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
1213    too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
1214    v = variables.VariableV1(init_val)
1215    with self.assertRaises(TypeError):
1216      v[:].assign(too_small_val)
1217    with self.assertRaises(TypeError):
1218      v[:].assign(too_large_val)
1219
1220  @test_util.run_deprecated_v1
1221  def testTypeErrorResource(self):
1222    init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
1223    too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
1224    too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
1225    v = resource_variable_ops.ResourceVariable(init_val)
1226    with self.cached_session() as sess:
1227      self.evaluate(v.initializer)
1228      with self.assertRaises(ValueError):
1229        sess.run(v[:].assign(too_large_val))
1230      with self.assertRaises(ValueError):
1231        sess.run(v[:].assign(too_small_val))
1232
1233  @test_util.disable_xla("b/123559667")
1234  @test_util.run_in_graph_and_eager_modes
1235  def testTensorStridedSliceUpdateWithInputForward(self):
1236    """Tests tensor_strided_slice_update with input-forwarding taking effect."""
1237    @def_function.function
1238    def assign(x):
1239      y = x + 1
1240      return gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0])
1241    self.assertAllEqual([0, 1], self.evaluate(assign(array_ops.zeros([2]))))
1242
1243  @test_util.disable_xla("b/123559667")
1244  @test_util.run_in_graph_and_eager_modes
1245  def testTensorStridedSliceUpdateNoInputForward(self):
1246    """Tests tensor_strided_slice_update with no input-forwarding."""
1247    x = constant_op.constant([0.2, 0.3])
1248    y = x + 1
1249    # y's buffer won't be forwarded to z because y and z will be alive at the
1250    # same time later.
1251    z = gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0.4])
1252    ans = y + z
1253    self.assertAllClose([1.6, 2.6], self.evaluate(ans))
1254
1255  @test_util.disable_xla("b/123559667")
1256  def testTensorStridedSliceUpdateGradSimple(self):
1257    original = constant_op.constant([0.2, 0.3])
1258    updates = constant_op.constant([0.4])
1259    with backprop.GradientTape() as tape:
1260      tape.watch([original, updates])
1261      updated = gen_array_ops.tensor_strided_slice_update(
1262          original, [0], [1], [1], updates)
1263    d1, d2 = tape.gradient(updated, [original, updates],
1264                           output_gradients=constant_op.constant([2.0, 3.0]))
1265    self.assertAllClose([0.0, 3.0], d1)
1266    self.assertAllClose([2.0], d2)
1267
1268  @parameterized.named_parameters(
1269      ("_%s" % i, *args) for i, args in enumerate([  # pylint:disable=g-complex-comprehension
1270          ([2, 5], [0, 1], [1, 0], [1, 2], [2], 0, 2, 0, 0, 1),
1271          ([4], [5], [3], [1], [3], 1, 0, 0, 0, 0),
1272          ([2, 2, 3, 2], [0, 0, 1], [1, 0, 2], [1, 0, 1], [2, 3], 0, 0, 2, 0, 5)
1273      ]))
1274  @test_util.disable_xla("b/123559667")
1275  def testTensorStridedSliceUpdateGrad(
1276      self, shape, begin, end, strides, updates_shape, *args):
1277    with self.cached_session():
1278      def f(a, b):
1279        return gen_array_ops.tensor_strided_slice_update(
1280            a, begin, end, strides, b, *args)
1281      theoretical, numerical = gradient_checker_v2.compute_gradient(
1282          f, [array_ops.zeros(shape), array_ops.ones(updates_shape)], delta=1.0)
1283      self.assertAllClose(theoretical, numerical)
1284
1285
1286class ShapeSizeRankTest(test_util.TensorFlowTestCase):
1287
1288  @test_util.run_in_graph_and_eager_modes
1289  def testDenseShape(self):
1290    t_value = [[0, 42], [24, 0]]
1291    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value)))
1292    self.assertEqual(4, self.evaluate(array_ops.size(t_value)))
1293    self.assertEqual(2, self.evaluate(array_ops.rank(t_value)))
1294
1295    t = constant_op.constant(t_value)
1296    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t)))
1297    self.assertEqual(4, self.evaluate(array_ops.size(t)))
1298    self.assertEqual(2, self.evaluate(array_ops.rank(t)))
1299
1300  @test_util.run_in_graph_and_eager_modes
1301  def testSparseShape(self):
1302    sp_value = sparse_tensor.SparseTensorValue(
1303        indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2))
1304    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp_value)))
1305    self.assertEqual(4, self.evaluate(array_ops.size(sp_value)))
1306    self.assertEqual(2, self.evaluate(array_ops.rank(sp_value)))
1307
1308    sp = sparse_tensor.SparseTensor.from_value(sp_value)
1309    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp)))
1310    self.assertEqual(4, self.evaluate(array_ops.size(sp)))
1311    self.assertEqual(2, self.evaluate(array_ops.rank(sp)))
1312
1313  @test_util.run_in_graph_and_eager_modes
1314  def testSizeDtype(self):
1315    tensor = [1]
1316    self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype)
1317    self.assertEqual(
1318        dtypes.int64,
1319        self.evaluate(array_ops.size(tensor, out_type=dtypes.int64)).dtype)
1320
1321
1322class SequenceMaskTest(test_util.TensorFlowTestCase):
1323
1324  def testExceptions(self):
1325    with self.cached_session():
1326      with self.assertRaisesRegex(ValueError, "maxlen must be scalar"):
1327        array_ops.sequence_mask([10, 20], [10, 20])
1328
1329  @test_util.run_deprecated_v1
1330  def testOneDimensionalWithMaxlen(self):
1331    with self.cached_session():
1332      res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5)
1333      self.assertAllEqual(res.get_shape(), [3, 5])
1334      self.assertAllEqual(
1335          res,
1336          [[True, False, False, False, False], [True, True, True, False, False],
1337           [True, True, False, False, False]])
1338
1339  @test_util.run_deprecated_v1
1340  def testOneDimensionalDtypeWithoutMaxlen(self):
1341    with self.cached_session():
1342      # test dtype and default maxlen:
1343      res = array_ops.sequence_mask(
1344          constant_op.constant([0, 1, 4]), dtype=dtypes.float32)
1345      self.assertAllEqual(res.get_shape().as_list(), [3, 4])
1346      self.assertAllEqual(
1347          res,
1348          [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]])
1349
1350  @test_util.run_deprecated_v1
1351  def testOneDimensionalWithoutMaxlen(self):
1352    with self.cached_session():
1353      res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]))
1354      self.assertAllEqual(res.get_shape().as_list(), [3, 4])
1355      self.assertAllEqual(
1356          res, [[False, False, False, False], [True, False, False, False],
1357                [True, True, True, True]])
1358
1359  @test_util.run_deprecated_v1
1360  def testTwoDimensional(self):
1361    with self.cached_session():
1362      res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5)
1363      self.assertAllEqual(res.get_shape(), [1, 3, 5])
1364      self.assertAllEqual(res, [[[True, False, False, False, False],
1365                                 [True, True, True, False, False],
1366                                 [True, True, False, False, False]]])
1367
1368      # test dtype and default maxlen:
1369      res = array_ops.sequence_mask(
1370          constant_op.constant([[0, 1, 4], [1, 2, 3]]), dtype=dtypes.float32)
1371      self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4])
1372      self.assertAllEqual(
1373          res,
1374          [[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
1375           [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0]]])
1376
1377  @test_util.run_deprecated_v1
1378  def testUnknownShape(self):
1379    lengths = array_ops.placeholder(dtype=dtypes.int32)
1380    res = array_ops.sequence_mask(lengths)
1381    self.assertEqual(res.shape, None)
1382
1383  @test_util.run_deprecated_v1
1384  def testDtypes(self):
1385
1386    def check_dtypes(lengths_dtype, maxlen_dtype):
1387      res = array_ops.sequence_mask(
1388          constant_op.constant([1, 3, 2], dtype=lengths_dtype),
1389          constant_op.constant(5, dtype=maxlen_dtype))
1390      self.assertAllEqual(res.get_shape(), [3, 5])
1391      self.assertAllEqual(
1392          res,
1393          [[True, False, False, False, False], [True, True, True, False, False],
1394           [True, True, False, False, False]])
1395
1396    with self.cached_session():
1397      check_dtypes(dtypes.int32, dtypes.int32)
1398      check_dtypes(dtypes.int32, dtypes.int64)
1399      check_dtypes(dtypes.int64, dtypes.int32)
1400      check_dtypes(dtypes.int64, dtypes.int64)
1401
1402  def testOutputDtype(self):
1403
1404    def check_output_dtype(output_dtype):
1405      res = self.evaluate(
1406          array_ops.sequence_mask(
1407              constant_op.constant([1, 3, 2], dtype=dtypes.int32),
1408              constant_op.constant(5, dtype=dtypes.int32),
1409              dtype=output_dtype))
1410      self.assertAllEqual(
1411          res,
1412          self.evaluate(
1413              math_ops.cast([[True, False, False, False, False],
1414                             [True, True, True, False, False],
1415                             [True, True, False, False, False]], output_dtype)))
1416
1417    check_output_dtype(dtypes.bool)
1418    check_output_dtype("bool")
1419    check_output_dtype(np.bool)
1420    check_output_dtype(dtypes.int32)
1421    check_output_dtype("int32")
1422    check_output_dtype(np.int32)
1423    check_output_dtype(dtypes.float32)
1424    check_output_dtype("float32")
1425    check_output_dtype(np.float32)
1426    check_output_dtype(dtypes.int64)
1427    check_output_dtype("float64")
1428    check_output_dtype(np.float64)
1429
1430
1431class ConcatSliceResourceTest(test_util.TensorFlowTestCase):
1432
1433  @test_util.run_in_graph_and_eager_modes
1434  @test_util.run_deprecated_v1
1435  def testConcatSlice(self):
1436    r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b")
1437    r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c")
1438    c = array_ops.stack([r1, r2])
1439    s = array_ops.strided_slice(c, [1], [2])
1440    self.evaluate(test_ops.resource_create_op(s))
1441    with self.assertRaises(errors.AlreadyExistsError):
1442      self.evaluate(test_ops.resource_create_op(r2))
1443
1444
1445class IdentityTest(test_util.TensorFlowTestCase):
1446
1447  @test_util.run_gpu_only
1448  def testEagerIdentity(self):
1449    with context.eager_mode():
1450
1451      def _test(x, y, device):
1452        self.assertAllEqual(x.numpy(), y.numpy())
1453        self.assertTrue(device in y.device.lower())
1454
1455      with test_util.force_gpu():
1456        a = constant_op.constant([[2], [3]], dtype=dtypes.float32)
1457      with test_util.force_gpu():
1458        b = array_ops.identity(a)
1459        _test(a, b, "gpu")
1460      with test_util.force_cpu():
1461        c = array_ops.identity(b)
1462        _test(b, c, "cpu")
1463      with test_util.force_cpu():
1464        d = array_ops.identity(c)
1465        _test(c, d, "cpu")
1466      with test_util.force_gpu():
1467        e = array_ops.identity(d)
1468        _test(d, e, "gpu")
1469
1470
1471class PadTest(test_util.TensorFlowTestCase):
1472
1473  def testEager(self):
1474    with context.eager_mode():
1475      t = constant_op.constant([[1, 2, 3], [4, 5, 6]])
1476      paddings = constant_op.constant([[
1477          1,
1478          1,
1479      ], [2, 2]])
1480      padded = array_ops.pad(t, paddings, "CONSTANT")
1481      self.assertAllEqual(padded.numpy(),
1482                          [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0, 0],
1483                           [0, 0, 4, 5, 6, 0, 0], [0, 0, 0, 0, 0, 0, 0]])
1484
1485  def testSymmetricMirrorPadGrad(self):
1486    t = np.broadcast_to(np.arange(0, 7), (3, 2, 1, 7))
1487    paddings = constant_op.constant([
1488        [1, 1],
1489        [0, 0],
1490        [0, 0],
1491        [2, 2],
1492    ])
1493    expected = np.broadcast_to(np.array([9, 27, 27]), (1, 2, 1, 3))
1494    result = gen_array_ops.mirror_pad_grad(t, paddings, "SYMMETRIC")
1495    self.assertAllEqual(result, expected)
1496
1497  def testReflectMirrorPadGrad(self):
1498    t = np.broadcast_to(np.reshape(np.arange(0, 7), (7, 1)), (1, 4, 7, 1))
1499    paddings = constant_op.constant([
1500        [0, 0],
1501        [1, 1],
1502        [2, 2],
1503        [0, 0],
1504    ])
1505    expected = np.broadcast_to(
1506        np.reshape(np.array([16, 18, 8]), (3, 1)), (1, 2, 3, 1))
1507    result = gen_array_ops.mirror_pad_grad(t, paddings, "REFLECT")
1508    self.assertAllEqual(result, expected)
1509
1510
1511class InvertPermutationTest(test_util.TensorFlowTestCase):
1512
1513  @test_util.run_deprecated_v1
1514  def testInvertPermutation(self):
1515    for dtype in [dtypes.int32, dtypes.int64]:
1516      with self.subTest(dtype=dtype):
1517        with self.cached_session():
1518          x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype)
1519          y = array_ops.invert_permutation(x)
1520          self.assertAllEqual(y.get_shape(), [5])
1521          self.assertAllEqual(y, [2, 4, 3, 0, 1])
1522
1523
1524class UnravelIndexTest(test_util.TensorFlowTestCase):
1525
1526  # TODO(b/73086570): Reenable test.
1527  @unittest.skip("Test does not pass internally.")
1528  def testUnravelIndex(self):
1529    with self.cached_session():
1530      for dtype in [dtypes.int32, dtypes.int64]:
1531        with self.subTest(dtype=dtype):
1532          indices_1 = constant_op.constant(1621, dtype=dtype)
1533          dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
1534          out_1 = array_ops.unravel_index(indices_1, dims_1)
1535          self.assertAllEqual(out_1, [3, 1, 4, 1])
1536
1537          indices_2 = constant_op.constant([1621], dtype=dtype)
1538          dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
1539          out_2 = array_ops.unravel_index(indices_2, dims_2)
1540          self.assertAllEqual(out_2, [[3], [1], [4], [1]])
1541
1542          indices_3 = constant_op.constant([22, 41, 37], dtype=dtype)
1543          dims_3 = constant_op.constant([7, 6], dtype=dtype)
1544          out_3 = array_ops.unravel_index(indices_3, dims_3)
1545          self.assertAllEqual(out_3, [[3, 6, 6], [4, 5, 1]])
1546
1547  # Test case for GitHub issue 40204.
1548  def testUnravelIndexZeroDim(self):
1549    with self.cached_session():
1550      for dtype in [dtypes.int32, dtypes.int64]:
1551        with self.assertRaisesRegex(errors.InvalidArgumentError,
1552                                    "index is out of bound as with dims"):
1553          indices = constant_op.constant([2, 5, 7], dtype=dtype)
1554          dims = constant_op.constant([3, 0], dtype=dtype)
1555          self.evaluate(array_ops.unravel_index(indices=indices, dims=dims))
1556
1557
1558class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
1559
1560  @test_util.run_deprecated_v1
1561  def testSimple(self):
1562    with self.cached_session():
1563      a = array_ops.constant(10)
1564      guarantee_a = array_ops.guarantee_const(a)
1565      self.assertEqual(10, self.evaluate(guarantee_a))
1566
1567  @test_util.run_deprecated_v1
1568  def testVariables(self):
1569    with self.cached_session() as sess:
1570      for use_resource in [False, True]:
1571        with self.subTest(use_resource=use_resource):
1572          a = variable_scope.get_variable(
1573              "var_{}".format(use_resource), [],
1574              initializer=init_ops.constant_initializer(10.0),
1575              use_resource=use_resource)
1576          guarantee_a = array_ops.guarantee_const(a)
1577          self.evaluate(variables.global_variables_initializer())
1578          self.assertEqual(10.0, self.evaluate(guarantee_a))
1579
1580  @test_util.run_deprecated_v1
1581  def testResourceRejection(self):
1582    with self.cached_session() as sess:
1583      a = variable_scope.get_variable(
1584          "resource_var", [],
1585          initializer=init_ops.constant_initializer(10.0),
1586          use_resource=True)
1587      guarantee_a = array_ops.guarantee_const(a.handle)
1588      self.evaluate(variables.global_variables_initializer())
1589      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
1590                                               "cannot be a resource variable"):
1591        self.evaluate(guarantee_a)
1592
1593
1594class SnapshotOpTest(test_util.TensorFlowTestCase):
1595
1596  @test_util.run_deprecated_v1
1597  def testInvertPermutation(self):
1598    for dtype in [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]:
1599      with self.subTest(dtype=dtype):
1600        with self.cached_session():
1601          x = constant_op.constant([0, 1, 2, 3], dtype=dtype)
1602          y = gen_array_ops.snapshot(x)
1603          self.assertAllEqual(y, [0, 1, 2, 3])
1604
1605
1606@test_util.run_all_in_graph_and_eager_modes
1607class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
1608
1609  # Generates a tensor of the specified `shape` using values from `values`
1610  # scaled by (slice_idx + 1) along `axis` dimension.
1611  def _scale_per_slice(self, shape, axis, values):
1612    # Note: repeats the values if the shape is larger than values.
1613    out = np.take(values, np.remainder(np.arange(np.prod(shape)),
1614                                       len(values))).reshape(shape)
1615    if axis is not None:
1616      scale_shape = [1] * len(shape)
1617      scale_shape[axis] = shape[axis]
1618      out *= np.arange(1, shape[axis] + 1).reshape(scale_shape)
1619    return out
1620
1621  def testAxis(self):
1622    shape = np.array([2, 3, 4, 5])
1623    values = np.array([-1, -0.5, 0, 0.3, 0.8, 0.555, 0.5], dtype=np.float32)
1624    quant_values = np.array(
1625        [-1, -0.5, 0, 38.0 / 128, 102.0 / 128, 71.0 / 128, 0.5],
1626        dtype=np.float32)
1627    for axis in [None, 0, 1, 2, 3]:
1628      with self.subTest(axis=axis):
1629        inputs = constant_op.constant(
1630            self._scale_per_slice(shape, axis, values))
1631        expected = self._scale_per_slice(shape, axis, quant_values)
1632        unused_minmax_value = 0 if axis is None else [0] * shape[axis]
1633        fake_quantized = self.evaluate(
1634            array_ops.quantize_and_dequantize_v2(
1635                inputs,
1636                unused_minmax_value,
1637                unused_minmax_value,
1638                range_given=False,
1639                round_mode="HALF_UP",
1640                axis=axis))
1641        self.assertAllEqual(fake_quantized, expected)
1642        if axis is not None:
1643          fake_quantized = self.evaluate(
1644              array_ops.quantize_and_dequantize_v2(
1645                  inputs,
1646                  unused_minmax_value,
1647                  unused_minmax_value,
1648                  range_given=False,
1649                  axis=(axis - 4)))
1650          self.assertAllClose(fake_quantized, expected)
1651
1652  def testBadAxis(self):
1653    input_tensor = [2.5, 2.5]
1654    input_min = [0, 0]
1655    input_max = [1, 1]
1656    error_message_pattern = "Shape must be at least rank 11 but is rank 1"
1657    # TODO(b/171260356): Eager mode and graph mode throw different error types
1658    error = errors.InvalidArgumentError if context.executing_eagerly(
1659    ) else ValueError
1660    with self.assertRaisesRegex(error, error_message_pattern):
1661      self.evaluate(
1662          array_ops.quantize_and_dequantize_v2(
1663              input=input_tensor,
1664              input_min=input_min,
1665              input_max=input_max,
1666              axis=10))
1667
1668  def testQuantizeDequantizeGrad(self):
1669    shape = (2, 2)
1670    max_threshold = 0
1671    min_threshold = -10
1672    input_value = np.random.rand(2, 2) * 40.0 - 20.0
1673    input_tensor = constant_op.constant(input_value, shape=shape,
1674                                        name="input_tensor")
1675    with self.cached_session():
1676      def f(a):
1677        return array_ops.quantize_and_dequantize_v2(
1678            a,
1679            input_min=min_threshold,
1680            input_max=max_threshold,
1681            range_given=True)
1682      output_grad = gradient_checker_v2.compute_gradient(f, [input_tensor])
1683      self.assertAllClose(output_grad[0], np.zeros([1, 4, 4]))
1684
1685
1686@test_util.run_all_in_graph_and_eager_modes
1687class SortedSearchTest(test_util.TensorFlowTestCase):
1688
1689  def testUpperBoundFloatHandCoded(self):
1690    cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
1691    arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
1692                   dtype=np.float32)
1693    result = np.searchsorted(cdf, arr, side="right")
1694    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1695    self.assertAllEqual(result, tf_result)
1696
1697  def testUpperBoundFloatRandomNd(self):
1698    dim_size = 7
1699    for d in range(1, 5):
1700      shape = [dim_size] * d
1701      cdf = np.cumsum(
1702          np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
1703      arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
1704
1705      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1706
1707      cdf = cdf.reshape([-1, dim_size])
1708      arr = arr.reshape([-1, dim_size])
1709      result = np.zeros(arr.shape, dtype=np.int32)
1710      for i in range(dim_size**(d - 1)):
1711        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1712
1713      result = result.reshape(shape)
1714
1715      self.assertAllEqual(result, tf_result)
1716
1717  def testUpperBoundFloatUneven(self):
1718    batch_size = 7
1719    size_search_array = 1000
1720    size_values = 47
1721    cdf = np.cumsum(
1722        np.random.uniform(size=[batch_size, size_search_array]).astype(
1723            np.float32),
1724        axis=1)
1725    arr = np.random.uniform(size=[batch_size, size_values]).astype(
1726        np.float32) * size_search_array
1727
1728    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1729
1730    result = np.zeros(arr.shape, dtype=np.int32)
1731    for i in range(batch_size):
1732      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1733
1734    self.assertAllEqual(result, tf_result)
1735
1736  def testLowerBoundFloatHandCoded(self):
1737    cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
1738    arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
1739                   dtype=np.float32)
1740    result = np.searchsorted(cdf, arr, side="left")
1741    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1742    self.assertAllEqual(result, tf_result)
1743
1744  def testLowerBoundFloatRandomNd(self):
1745    dim_size = 7
1746    for d in range(1, 5):
1747      shape = [dim_size] * d
1748      cdf = np.cumsum(
1749          np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
1750      arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
1751
1752      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1753
1754      cdf = cdf.reshape([-1, dim_size])
1755      arr = arr.reshape([-1, dim_size])
1756      result = np.zeros(arr.shape, dtype=np.int32)
1757      for i in range(dim_size**(d - 1)):
1758        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1759
1760      result = result.reshape(shape)
1761
1762      self.assertAllEqual(result, tf_result)
1763
1764  def testLowerBoundFloatUneven(self):
1765    batch_size = 7
1766    size_search_array = 1000
1767    size_values = 47
1768    cdf = np.cumsum(
1769        np.random.uniform(size=[batch_size, size_search_array]).astype(
1770            np.float32),
1771        axis=1)
1772    arr = np.random.uniform(size=[batch_size, size_values]).astype(
1773        np.float32) * size_search_array
1774
1775    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1776
1777    result = np.zeros(arr.shape, dtype=np.int32)
1778    for i in range(batch_size):
1779      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1780
1781    self.assertAllEqual(result, tf_result)
1782
1783  def testUpperBoundIntHandCoded(self):
1784    cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
1785    arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
1786    result = np.searchsorted(cdf, arr, side="right")
1787    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1788    self.assertAllEqual(result, tf_result)
1789
1790  def testUpperBoundIntRandomNd(self):
1791    dim_size = 7
1792    for d in range(1, 5):
1793      shape = [dim_size] * d
1794      cdf = np.cumsum(
1795          np.random.randint(low=0, high=10, size=shape).astype(np.int64),
1796          axis=(d - 1))
1797      arr = np.random.randint(
1798          low=0, high=10 * dim_size, size=shape).astype(np.int64)
1799
1800      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1801
1802      cdf = cdf.reshape([-1, dim_size])
1803      arr = arr.reshape([-1, dim_size])
1804      result = np.zeros(arr.shape, dtype=np.int32)
1805      for i in range(dim_size**(d - 1)):
1806        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1807
1808      result = result.reshape(shape)
1809
1810      self.assertAllEqual(result, tf_result)
1811
1812  def testUpperBoundIntUneven(self):
1813    batch_size = 7
1814    size_search_array = 1000
1815    size_values = 47
1816    cdf = np.cumsum(
1817        np.random.randint(low=0, high=10,
1818                          size=[batch_size,
1819                                size_search_array]).astype(np.int64),
1820        axis=1)
1821    arr = np.random.randint(
1822        low=0, high=10 * size_search_array, size=[batch_size,
1823                                                  size_values]).astype(np.int64)
1824
1825    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1826
1827    result = np.zeros(arr.shape, dtype=np.int32)
1828    for i in range(batch_size):
1829      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1830
1831    self.assertAllEqual(result, tf_result)
1832
1833  def testLowerBoundIntHandCoded(self):
1834    cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
1835    arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
1836    result = np.searchsorted(cdf, arr, side="left")
1837    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1838    self.assertAllEqual(result, tf_result)
1839
1840  def testLowerBoundIntRandomNd(self):
1841    dim_size = 7
1842    for d in range(1, 5):
1843      shape = [dim_size] * d
1844      cdf = np.cumsum(
1845          np.random.randint(low=0, high=10, size=shape).astype(np.int64),
1846          axis=(d - 1))
1847      arr = np.random.randint(
1848          low=0, high=10 * dim_size, size=shape).astype(np.int64)
1849
1850      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1851
1852      cdf = cdf.reshape([-1, dim_size])
1853      arr = arr.reshape([-1, dim_size])
1854      result = np.zeros(arr.shape, dtype=np.int32)
1855      for i in range(dim_size**(d - 1)):
1856        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1857
1858      result = result.reshape(shape)
1859
1860      self.assertAllEqual(result, tf_result)
1861
1862  def testLowerBoundIntUneven(self):
1863    batch_size = 7
1864    size_search_array = 1000
1865    size_values = 47
1866    cdf = np.cumsum(
1867        np.random.randint(low=0, high=10,
1868                          size=[batch_size,
1869                                size_search_array]).astype(np.int64),
1870        axis=1)
1871    arr = np.random.randint(
1872        low=0, high=10 * size_search_array, size=[batch_size,
1873                                                  size_values]).astype(np.int64)
1874
1875    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1876
1877    result = np.zeros(arr.shape, dtype=np.int32)
1878    for i in range(batch_size):
1879      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1880
1881    self.assertAllEqual(result, tf_result)
1882
1883  def testZeroSequenceSize(self):
1884    dtype = dtypes.int32
1885    for side in ("left", "right"):
1886      with self.subTest(side=side):
1887        self.assertAllEqual(
1888            array_ops.searchsorted(
1889                array_ops.ones([2, 0]),
1890                array_ops.ones([2, 3]),
1891                side=side,
1892                out_type=dtype), array_ops.zeros([2, 3], dtype))
1893
1894  def testZeroValueSize(self):
1895    dtype = dtypes.int32
1896    for side in ("left", "right"):
1897      with self.subTest(side=side):
1898        self.assertAllEqual(
1899            array_ops.searchsorted(
1900                array_ops.ones([2, 3]),
1901                array_ops.ones([2, 0]),
1902                side=side,
1903                out_type=dtype), array_ops.zeros([2, 0], dtype))
1904
1905
1906class BatchGatherNdTest(test_util.TensorFlowTestCase):
1907
1908  def testShapesMatch(self):
1909    """Tests for various different shape combinations."""
1910    shapes = []
1911    # params_shape, indices_shape, batch_dims
1912    shapes.append(((2, 2, 2), (2, 1), 1),)
1913    shapes.append(((2, 2, 2), (2, 2), 1),)
1914    shapes.append(((2, 2, 2), (2, 3), 0),)
1915    shapes.append(((2, 2, 2), (3,), 0),)
1916    shapes.append(((2, 2, 2), (1,), 0),)
1917    shapes.append(((2, 2, 3, 2), (2, 3), 1),)
1918    shapes.append(((2, 2, 3, 2), (2, 2), 1),)
1919    shapes.append(((2, 2, 3, 2), (2, 1), 1),)
1920    shapes.append(((2, 2, 3, 2), (2, 1, 3), 1),)
1921    shapes.append(((2, 2, 3, 2), (2, 2, 2), 1),)
1922    shapes.append(((2, 2, 3, 2), (2, 3, 1), 1),)
1923    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3), 2),)
1924    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2), 2),)
1925    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1), 2),)
1926    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1, 3), 2),)
1927    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2, 2), 2),)
1928    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
1929
1930    for params_shape, indices_shape, batch_dims in shapes:
1931      with self.subTest(
1932          params_shape=params_shape,
1933          indices_shape=indices_shape,
1934          batch_dims=batch_dims):
1935        params = constant_op.constant(1.0, shape=(params_shape))
1936        indices = constant_op.constant(
1937            1, shape=(indices_shape), dtype=dtypes.int32)
1938        out = array_ops.batch_gather_nd(
1939            params=params, indices=indices, batch_dims=batch_dims)
1940        ndims_params = len(params_shape) - batch_dims
1941        ndims_rows = ndims_params - indices_shape[-1]
1942        expected_out_shape = indices_shape[:-1]
1943        if ndims_rows > 0:
1944          expected_out_shape += params_shape[-ndims_rows:]
1945        self.assertSequenceEqual(out.shape, expected_out_shape)
1946
1947  def testReducesToGatherNDWhenBatchDimIsZero(self):
1948    """Confirms setting batch_dims to zero reduces to tf.gather_nd."""
1949    params = constant_op.constant(np.random.uniform(0.0, 1.0, size=(7, 8, 9)))
1950    indices_shapes = []
1951    indices_shapes.append((1,))
1952    indices_shapes.append((3, 1))
1953    indices_shapes.append((3, 3, 1))
1954    indices_shapes.append((2,))
1955    indices_shapes.append((3, 2))
1956    indices_shapes.append((3, 3, 2))
1957    indices_shapes.append((3,))
1958    indices_shapes.append((3, 3))
1959    indices_shapes.append((3, 3, 3))
1960
1961    for indices_shape in indices_shapes:
1962      with self.subTest(indices_shape=indices_shape):
1963        indices = np.random.randint(0, 7, size=indices_shape)
1964        gather_nd_result = gen_array_ops.gather_nd(params, indices)
1965        batch_gather_nd_result = array_ops.batch_gather_nd(
1966            params=params, indices=indices, batch_dims=0)
1967        self.assertAllEqual(gather_nd_result, batch_gather_nd_result)
1968
1969  def testSameResultAsMapFn(self):
1970    """Compares results with gather_nd called on every element with map_fn."""
1971    shapes = []
1972    # params_shape, indices_shape, batch_dims
1973    shapes.append(((2, 2, 2), (2, 1), 1),)
1974    shapes.append(((2, 2, 2), (2, 2), 1),)
1975    shapes.append(((2, 2, 3, 2), (2, 3), 1),)
1976    shapes.append(((2, 2, 3, 2), (2, 2), 1),)
1977    shapes.append(((2, 2, 3, 2), (2, 1), 1),)
1978    shapes.append(((2, 2, 3, 2), (2, 1, 3), 1),)
1979    shapes.append(((2, 2, 3, 2), (2, 2, 2), 1),)
1980    shapes.append(((2, 2, 3, 2), (2, 3, 1), 1),)
1981    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3), 2),)
1982    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2), 2),)
1983    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1), 2),)
1984    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1, 3), 2),)
1985    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2, 2), 2),)
1986    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
1987
1988    for params_shape, indices_shape, batch_dims in shapes:
1989      with self.subTest(
1990          params_shape=params_shape,
1991          indices_shape=indices_shape,
1992          batch_dims=batch_dims):
1993        params = constant_op.constant(
1994            np.random.uniform(0.0, 1.0, size=(params_shape)))
1995        indices = np.random.randint(0, 2, size=indices_shape)
1996        batch_gather_nd_result = array_ops.batch_gather_nd(
1997            params=params, indices=indices, batch_dims=batch_dims)
1998
1999        if batch_dims > 1:
2000          params = array_ops.reshape(
2001              params, shape=[-1] + list(params_shape[batch_dims:]))
2002          indices = array_ops.reshape(
2003              indices, shape=[-1] + list(indices_shape[batch_dims:]))
2004
2005        map_fn_gather_nd_result = map_fn.map_fn(
2006            fn=self._map_fn_body, elems=(params, indices), dtype=dtypes.float64)
2007
2008        if batch_dims > 1:
2009          out_shape = map_fn_gather_nd_result.shape.as_list()
2010          out_shape = list(params_shape[:batch_dims]) + out_shape[1:]
2011          map_fn_gather_nd_result = array_ops.reshape(
2012              map_fn_gather_nd_result, shape=out_shape)
2013
2014        self.assertAllEqual(map_fn_gather_nd_result, batch_gather_nd_result)
2015
2016  def _map_fn_body(self, elems):
2017    return gen_array_ops.gather_nd(elems[0], elems[1])
2018
2019  def testBatchDimsAsTensor(self):
2020    """Tests Tensor batch_dims as input works as intended."""
2021    shapes = []
2022    # params_shape, indices_shape, batch_dims
2023    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 0),)
2024    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 1),)
2025    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
2026
2027    for params_shape, indices_shape, batch_dims in shapes:
2028      with self.subTest(
2029          params_shape=params_shape,
2030          indices_shape=indices_shape,
2031          batch_dims=batch_dims):
2032        params = constant_op.constant(
2033            np.random.uniform(0.0, 1.0, size=(params_shape)))
2034        indices = np.random.randint(0, 2, size=indices_shape)
2035        batch_gather_nd_result = array_ops.gather_nd(
2036            params=params, indices=indices, batch_dims=batch_dims)
2037        batch_dims_tensor = constant_op.constant([batch_dims])
2038        batch_gather_nd_tensor_batch_dims_result = array_ops.gather_nd(
2039            params=params, indices=indices, batch_dims=batch_dims_tensor)
2040
2041        self.assertAllEqual(batch_gather_nd_tensor_batch_dims_result,
2042                            batch_gather_nd_result)
2043
2044  def testInvalidBatchDimsRaisesException(self):
2045    """Tests whether invalid batch_dims raise expected exceptions."""
2046    params = constant_op.constant(
2047        np.random.uniform(0.0, 1.0, size=(3, 2, 2, 3, 4)))
2048    indices = np.random.randint(0, 2, size=(3, 2, 3))
2049
2050    with self.assertRaises(TypeError):
2051      array_ops.batch_gather_nd(
2052          params=params,
2053          indices=indices,
2054          batch_dims=constant_op.constant((0, 1)))
2055
2056    with self.assertRaises(ValueError):
2057      array_ops.batch_gather_nd(params=params, indices=indices, batch_dims=-1)
2058
2059    with self.assertRaises(ValueError):
2060      array_ops.batch_gather_nd(params=params, indices=indices, batch_dims=4)
2061
2062  @test_util.run_deprecated_v1
2063  def testNoneBatchDimensions(self):
2064    """Tests gather_nd works with None dimensions."""
2065    shapes = []
2066    # params_shape, indices_shape, batch_dims
2067    shapes.append(((2, 2, 2), (2, 1), 1),)
2068    shapes.append(((2, 2, 2), (2, 2), 1),)
2069    shapes.append(((2, 2, 3, 2), (2, 3), 1),)
2070    shapes.append(((2, 2, 3, 2), (2, 2), 1),)
2071    shapes.append(((2, 2, 3, 2), (2, 1), 1),)
2072    shapes.append(((2, 2, 3, 2), (2, 1, 3), 1),)
2073    shapes.append(((2, 2, 3, 2), (2, 2, 2), 1),)
2074    shapes.append(((2, 2, 3, 2), (2, 3, 1), 1),)
2075    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3), 2),)
2076    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2), 2),)
2077    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1), 2),)
2078    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1, 3), 2),)
2079    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2, 2), 2),)
2080    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
2081
2082    for params_shape, indices_shape, batch_dims in shapes:
2083      params_ph_shape = list(params_shape)
2084      indices_ph_shape = list(indices_shape)
2085      for i in range(batch_dims):
2086        params_ph_shape[i] = None
2087        indices_ph_shape[i] = None
2088
2089      params = array_ops.placeholder(dtypes.float32, shape=params_ph_shape)
2090      indices = array_ops.placeholder(dtypes.int32, shape=indices_ph_shape)
2091      out = array_ops.batch_gather_nd(
2092          params=params, indices=indices, batch_dims=batch_dims)
2093
2094      with self.cached_session() as sess:
2095        params_val = np.ones(dtype=np.float32, shape=params_shape)
2096        indices_val = np.ones(dtype=np.int32, shape=indices_shape)
2097        res = sess.run(
2098            out, feed_dict={
2099                params: params_val,
2100                indices: indices_val
2101            })
2102      row_ndims = len(params_shape) - batch_dims - indices_shape[-1]
2103      expected_out_shape = indices_shape[:-1]
2104      if row_ndims > 0:
2105        expected_out_shape += params_shape[-row_ndims:]
2106
2107      self.assertSequenceEqual(res.shape, expected_out_shape)
2108
2109  @test_util.run_deprecated_v1
2110  def testUnknownIndices(self):
2111    """Tests whether indices with unknown rank works correctly."""
2112    params = constant_op.constant(((0, 1, 2),))
2113    indices = array_ops.placeholder(dtypes.int32)
2114    gather_nd_t = array_ops.gather_nd(params, indices, batch_dims=1)
2115    shape = gather_nd_t.get_shape()
2116    self.assertEqual(None, shape.ndims)
2117    self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
2118
2119
2120@test_util.run_all_in_graph_and_eager_modes
2121class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase):
2122
2123  @parameterized.parameters(
2124      (3, 4, None),
2125      ([[1, 2], [3, 4]], 2, None),
2126      ([[1, 2], [3, 4]], [1, 2], 0),
2127      ([[1, 2], [3, 4]], [1, 2], 1),
2128      ([[1, 2], [3, 4]], 3, 1),
2129      ([[1, 2], [3, 4]], [1, 2, 3, 4], None),
2130      (np.ones([0, 4]), 0, 1),
2131      (np.ones([1, 2]), [2], None),
2132  )
2133  def testRepeat(self, array, repeats, axis):
2134    array = np.array(array)
2135
2136    @def_function.function(
2137        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2)
2138    def repeat_fn(array, repeats):
2139      return array_ops.repeat(array, repeats, axis)
2140
2141    v_tf = array_ops.repeat(constant_op.constant(array), repeats, axis)
2142    v_tf_fn = repeat_fn(
2143        constant_op.constant(array, dtype=dtypes.int32), repeats)
2144    v_np = np.repeat(array, repeats, axis)
2145    self.assertAllEqual(v_tf, v_np)
2146    self.assertAllEqual(v_tf_fn, v_np)
2147
2148
2149@test_util.run_all_in_graph_and_eager_modes
2150class TileVariantTest(test_util.TensorFlowTestCase):
2151
2152  def test_tile_tensor_list(self):
2153    t = constant_op.constant(np.random.uniform(size=[2, 3, 4]))
2154    handle = list_ops.tensor_list_from_tensor(t, element_shape=None)
2155    with ops.device("CPU:0"):
2156      tiled_handles = array_ops.tile(array_ops.reshape(handle, [1]), [2])
2157    tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
2158                                                [3, 4])
2159    tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
2160                                                [3, 4])
2161    self.assertAllEqual(t, tiled_tensor_0)
2162    self.assertAllEqual(t, tiled_tensor_1)
2163    # Now mutate some of the lists and make sure the changes are not reflected
2164    # in the tiled handles.
2165    with ops.control_dependencies([
2166        list_ops.tensor_list_scatter([t[0] + 1], [0], input_handle=handle),
2167        list_ops.tensor_list_set_item(tiled_handles[0], 0, t[0] + 2)]):
2168      tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
2169                                                  [3, 4])
2170      tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
2171                                                  [3, 4])
2172    self.assertAllEqual(t, tiled_tensor_0)
2173    self.assertAllEqual(t, tiled_tensor_1)
2174
2175
2176if __name__ == "__main__":
2177  test_lib.main()
2178