• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for array_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import re
21import time
22import unittest
23
24import numpy as np
25
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.python.client import session
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import errors_impl
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import sparse_tensor
36from tensorflow.python.framework import tensor_shape
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import test_ops
39from tensorflow.python.framework import test_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import gen_array_ops
42from tensorflow.python.ops import gradients_impl
43from tensorflow.python.ops import init_ops
44from tensorflow.python.ops import map_fn
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.ops import state_ops
48from tensorflow.python.ops import variable_scope
49from tensorflow.python.ops import variables
50from tensorflow.python.platform import test as test_lib
51
52
53@test_util.run_all_in_graph_and_eager_modes
54class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
55
56  def testNonBatchMatrix(self):
57    matrix = [[1, 2, 3], [4, 5, 6]]  # Shape (2, 3)
58    expected_transposed = [[1, 4], [2, 5], [3, 6]]  # Shape (3, 2)
59    transposed = array_ops.matrix_transpose(matrix)
60    self.assertEqual((3, 2), transposed.get_shape())
61    self.assertAllEqual(expected_transposed, transposed)
62
63  def testConjugate(self):
64    m = [[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 6 + 6j]]
65    expected_transposed = [[1 - 1j, 4 - 4j], [2 - 2j, 5 - 5j], [3 - 3j, 6 - 6j]]
66    matrix = ops.convert_to_tensor(m)
67    transposed = array_ops.matrix_transpose(matrix, conjugate=True)
68    self.assertEqual((3, 2), transposed.get_shape())
69    self.assertAllEqual(expected_transposed, transposed)
70
71  def testBatchMatrix(self):
72    matrix_0 = [[1, 2, 3], [4, 5, 6]]
73    matrix_0_t = [[1, 4], [2, 5], [3, 6]]
74    matrix_1 = [[11, 22, 33], [44, 55, 66]]
75    matrix_1_t = [[11, 44], [22, 55], [33, 66]]
76    batch_matrix = [matrix_0, matrix_1]  # Shape (2, 2, 3)
77    expected_transposed = [matrix_0_t, matrix_1_t]  # Shape (2, 3, 2)
78    transposed = array_ops.matrix_transpose(batch_matrix)
79    self.assertEqual((2, 3, 2), transposed.get_shape())
80    self.assertAllEqual(expected_transposed, transposed)
81
82  def testNonBatchMatrixDynamicallyDefined(self):
83    # needs explicit `constant` because lists are not automatically
84    # converted to sensors when applying `transpose` below
85    matrix = constant_op.constant([[1, 2, 3], [4, 5, 6]])  # Shape (2, 3)
86    expected_transposed = [[1, 4], [2, 5], [3, 6]]  # Shape (3, 2)
87
88    @def_function.function(input_signature=[
89        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
90    ])
91    def transpose(matrix):
92      self.assertIs(matrix.shape.ndims, None)
93      return array_ops.matrix_transpose(matrix)
94
95    self.assertAllEqual(expected_transposed, transpose(matrix))
96
97  def testBatchMatrixDynamicallyDefined(self):
98    matrix_0 = [[1, 2, 3], [4, 5, 6]]
99    matrix_0_t = [[1, 4], [2, 5], [3, 6]]
100    matrix_1 = [[11, 22, 33], [44, 55, 66]]
101    matrix_1_t = [[11, 44], [22, 55], [33, 66]]
102    # needs explicit `constant` because lists are not automatically
103    # converted to sensors when applying `transpose` below
104    batch_matrix = constant_op.constant([matrix_0, matrix_1])  # Shape (2, 2, 3)
105    expected_transposed = [matrix_0_t, matrix_1_t]  # Shape (2, 3, 2)
106
107    @def_function.function(input_signature=[
108        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
109    ])
110    def transpose(matrix):
111      self.assertIs(matrix.shape.ndims, None)
112      return array_ops.matrix_transpose(matrix)
113
114    self.assertAllEqual(expected_transposed, transpose(batch_matrix))
115
116  def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self):
117    vector = [1, 2, 3]
118    with self.assertRaisesRegexp(ValueError, "should be a "):
119      array_ops.matrix_transpose(vector)
120
121
122class BooleanMaskTest(test_util.TensorFlowTestCase):
123
124  def setUp(self):
125    self.rng = np.random.RandomState(42)
126
127  def CheckVersusNumpy(self, ndims_mask, arr_shape, make_mask=None, axis=None):
128    """Check equivalence between boolean_mask and numpy masking."""
129    if make_mask is None:
130      make_mask = lambda shape: self.rng.randint(0, 2, size=shape).astype(bool)
131    arr = np.random.rand(*arr_shape)
132    mask = make_mask(arr_shape[:ndims_mask])
133    if axis is not None:
134      mask = make_mask(arr_shape[axis:ndims_mask + axis])
135    if axis is None or axis == 0:
136      masked_arr = arr[mask]
137    elif axis == 1:
138      masked_arr = arr[:, mask]
139    elif axis == 2:
140      masked_arr = arr[:, :, mask]
141    with self.cached_session():
142      masked_tensor = array_ops.boolean_mask(arr, mask, axis=axis)
143
144      # Leading dimension size of masked_tensor is always unknown until runtime
145      # since we don't how many elements will be kept.
146      leading = 1 if axis is None else axis + 1
147      self.assertAllEqual(masked_tensor.get_shape()[leading:],
148                          masked_arr.shape[leading:])
149
150      self.assertAllClose(masked_arr, masked_tensor.eval())
151
152  @test_util.run_deprecated_v1
153  def testMaskDim1ArrDim2Axis1(self):
154    ndims_mask = 1
155    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
156      self.CheckVersusNumpy(ndims_mask, arr_shape, axis=1)
157
158  @test_util.run_deprecated_v1
159  def testMaskDim2ArrDim2Axis1(self):
160    ndims_mask = 2
161    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
162      self.CheckVersusNumpy(ndims_mask, arr_shape, axis=1)
163
164  @test_util.run_deprecated_v1
165  def testMaskDim1ArrDim1(self):
166    ndims_mask = 1
167    for arr_shape in [(1,), (2,), (3,), (10,)]:
168      self.CheckVersusNumpy(ndims_mask, arr_shape)
169
170  @test_util.run_deprecated_v1
171  def testMaskDim1ArrDim2(self):
172    ndims_mask = 1
173    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
174      self.CheckVersusNumpy(ndims_mask, arr_shape)
175
176  @test_util.run_deprecated_v1
177  def testMaskDim2ArrDim2(self):
178    ndims_mask = 2
179    for arr_shape in [(1, 1), (2, 2), (2, 5)]:
180      self.CheckVersusNumpy(ndims_mask, arr_shape)
181
182  @test_util.run_deprecated_v1
183  def testMaskDim2ArrDim3(self):
184    ndims_mask = 2
185    for arr_shape in [(1, 1, 1), (1, 2, 2), (2, 2, 1)]:
186      self.CheckVersusNumpy(ndims_mask, arr_shape)
187
188  @test_util.run_deprecated_v1
189  def testEmptyInput2D(self):
190    mask = np.array([True, False])
191    arr = np.array([[], []]).astype(np.float32)
192    numpy_result = arr[mask]
193    tf_result = array_ops.boolean_mask(arr, mask)
194    self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
195    with self.cached_session():
196      self.assertAllClose(numpy_result, tf_result.eval())
197
198  @test_util.run_deprecated_v1
199  def testEmptyInput1D(self):
200    mask = np.array([]).astype(bool)
201    arr = np.array([]).astype(np.float32)
202    numpy_result = arr[mask]
203    tf_result = array_ops.boolean_mask(arr, mask)
204    self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
205    with self.cached_session():
206      self.assertAllClose(numpy_result, tf_result.eval())
207
208  @test_util.run_deprecated_v1
209  def testEmptyOutput(self):
210    make_mask = lambda shape: np.zeros(shape, dtype=bool)
211    for ndims_mask in range(1, 4):
212      for ndims_arr in range(ndims_mask, ndims_mask + 3):
213        for _ in range(3):
214          arr_shape = np.random.randint(1, 5, size=ndims_arr)
215          self.CheckVersusNumpy(ndims_mask, arr_shape, make_mask=make_mask)
216
217  @test_util.run_deprecated_v1
218  def testWorksWithDimensionsEqualToNoneDuringGraphBuild(self):
219    # The rank of the mask tensor must be specified. This is explained
220    # in the docstring as well.
221    with self.cached_session() as sess:
222      ph_tensor = array_ops.placeholder(dtypes.int32, shape=None)
223      ph_mask = array_ops.placeholder(dtypes.bool, shape=[None])
224
225      arr = np.array([[1, 2], [3, 4]])
226      mask = np.array([False, True])
227
228      masked_tensor = sess.run(
229          array_ops.boolean_mask(ph_tensor, ph_mask),
230          feed_dict={
231              ph_tensor: arr,
232              ph_mask: mask
233          })
234      np.testing.assert_allclose(masked_tensor, arr[mask])
235
236  @test_util.run_deprecated_v1
237  def testMaskDimensionsSetToNoneRaises(self):
238    # The rank of the mask tensor must be specified. This is explained
239    # in the docstring as well.
240    with self.cached_session():
241      tensor = array_ops.placeholder(dtypes.int32, shape=[None, 2])
242      mask = array_ops.placeholder(dtypes.bool, shape=None)
243      with self.assertRaisesRegexp(ValueError, "dimensions must be specified"):
244        array_ops.boolean_mask(tensor, mask)
245
246  def testMaskHasMoreDimsThanTensorRaises(self):
247    mask = [[True, True], [False, False]]
248    tensor = [1, 2, 3, 4]
249    with self.cached_session():
250      with self.assertRaisesRegexp(ValueError, "incompatible"):
251        array_ops.boolean_mask(tensor, mask).eval()
252
253  def testMaskIsScalarRaises(self):
254    mask = True
255    tensor = 1
256    with self.cached_session():
257      with self.assertRaisesRegexp(ValueError, "mask.*scalar"):
258        array_ops.boolean_mask(tensor, mask).eval()
259
260  def testMaskShapeDifferentThanFirstPartOfTensorShapeRaises(self):
261    mask = [True, True, True]
262    tensor = [[1, 2], [3, 4]]
263    with self.cached_session():
264      with self.assertRaisesRegexp(ValueError, "incompatible"):
265        array_ops.boolean_mask(tensor, mask).eval()
266
267  @test_util.run_deprecated_v1
268  def testStringMask(self):
269    # Reproduces b/111171330, where the optimized boolean_mask graph would
270    # be incorrectly placed on GPU.
271    with ops.Graph().as_default():
272      tile_placeholder = array_ops.placeholder(dtypes.int32, [2])
273      string_tensor = array_ops.tile([["hello"]], tile_placeholder)
274      bool_tensor = array_ops.tile([[True]], tile_placeholder)
275      masked_tensor = array_ops.boolean_mask(string_tensor, bool_tensor)
276      config = config_pb2.ConfigProto()
277      config.graph_options.rewrite_options.shape_optimization = 1
278      config.gpu_options.per_process_gpu_memory_fraction = 0.3
279      with session.Session(config=config) as sess:
280        result = sess.run(masked_tensor, feed_dict={tile_placeholder: [2, 2]})
281        self.assertAllEqual([b"hello", b"hello", b"hello", b"hello"], result)
282
283
284@test_util.run_all_in_graph_and_eager_modes
285class OperatorShapeTest(test_util.TensorFlowTestCase):
286
287  def testExpandScalar(self):
288    scalar = "hello"
289    scalar_expanded = array_ops.expand_dims(scalar, [0])
290    self.assertEqual(scalar_expanded.get_shape(), (1,))
291
292  def testSqueezeScalar(self):
293    scalar = "hello"
294    scalar_squeezed = array_ops.squeeze(scalar, ())
295    self.assertEqual(scalar_squeezed.get_shape(), ())
296
297  def testSqueezeMatrix(self):
298    matrix = [[1, 2, 3]]
299    matrix_squeezed = array_ops.squeeze(matrix, [0])
300    self.assertEqual(matrix_squeezed.get_shape(), (3))
301
302    with self.assertRaisesRegexp(
303        Exception, "Can not squeeze dim.1., expected a dimension of 1, got 3"):
304      matrix_squeezed = array_ops.squeeze(matrix, [1])
305
306  def testSqueezeScalarDim(self):
307    matrix = [[1, 2, 3]]
308    matrix_squeezed = array_ops.squeeze(matrix, 0)
309    self.assertEqual(matrix_squeezed.get_shape(), (3))
310
311  def testExpandDimsWithNonScalarDim(self):
312    with self.assertRaisesRegexp(Exception,
313                                 "must be a tensor with a single value"):
314      array_ops.expand_dims(1, axis=[0, 1])
315
316
317class ReverseV2Test(test_util.TensorFlowTestCase):
318
319  @test_util.run_deprecated_v1
320  def testReverse0DimAuto(self):
321    x_np = 4
322    for use_gpu in [False, True]:
323      with self.cached_session(use_gpu=use_gpu):
324        x_tf = array_ops.reverse_v2(x_np, []).eval()
325        self.assertAllEqual(x_tf, x_np)
326
327  def _reverse1DimAuto(self, np_dtype):
328    x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype)
329
330    for use_gpu in [False, True]:
331      for axis_dtype in [dtypes.int32, dtypes.int64]:
332        with self.cached_session(use_gpu=use_gpu):
333          x_tf = array_ops.reverse_v2(
334              x_np, constant_op.constant([0], dtype=axis_dtype)).eval()
335          self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
336
337  def _reverse2DimAuto(self, np_dtype):
338    x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype)
339
340    for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
341      for use_gpu in [False, True]:
342        for axis_dtype in [dtypes.int32, dtypes.int64]:
343          with self.cached_session(use_gpu=use_gpu):
344            x_tf_1 = reverse_f(x_np,
345                               constant_op.constant([0],
346                                                    dtype=axis_dtype)).eval()
347            x_tf_2 = reverse_f(x_np,
348                               constant_op.constant([-2],
349                                                    dtype=axis_dtype)).eval()
350            x_tf_3 = reverse_f(x_np,
351                               constant_op.constant([1],
352                                                    dtype=axis_dtype)).eval()
353            x_tf_4 = reverse_f(x_np,
354                               constant_op.constant([-1],
355                                                    dtype=axis_dtype)).eval()
356            x_tf_5 = reverse_f(x_np,
357                               constant_op.constant([1, 0],
358                                                    dtype=axis_dtype)).eval()
359            self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
360            self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
361            self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
362            self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1])
363            self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1])
364
365  # This test covers the axis validation in the shape function
366  # (no eval())
367  @test_util.run_deprecated_v1
368  def testInvalidAxis(self):
369    x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
370    with self.assertRaisesRegexp(ValueError, "is out of valid range"):
371      array_ops.reverse_v2(x_np, [-30])
372    with self.assertRaisesRegexp(ValueError, "is out of valid range"):
373      array_ops.reverse_v2(x_np, [2])
374    with self.assertRaisesRegexp(ValueError, "axis 0 specified more than once"):
375      array_ops.reverse_v2(x_np, [0, -2])
376
377  # This is the version of reverse that uses axis indices rather than
378  # bool tensors
379  # TODO(b/32254538): Change this test to use array_ops.reverse
380  #
381  # Note: this test passes placeholder as constant axis is validated
382  # in shape function (see testInvalidAxis)
383  @test_util.run_deprecated_v1
384  def testInvalid(self):
385    x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
386    axis = array_ops.placeholder(dtypes.int32)
387    with self.cached_session():
388      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
389                                   "is out of.*range"):
390        array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [-30]})
391      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
392                                   "is out of.*range"):
393        array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [2]})
394      with self.assertRaisesRegexp(
395          errors_impl.InvalidArgumentError,
396          "(axis 0 specified more than once|canonicalized axis 0 was repeated.)"
397      ):
398        array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [0, -2]})
399
400  @test_util.run_deprecated_v1
401  def testReverse1DimAuto(self):
402    for dtype in [
403        np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
404        np.float16, np.float32, np.float64, np.complex64, np.complex128,
405        np.array(b"").dtype.type
406    ]:
407      self._reverse1DimAuto(dtype)
408
409  @test_util.run_deprecated_v1
410  def testReverse2DimAuto(self):
411    for dtype in [
412        np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
413        np.float16, np.float32, np.float64, np.complex64, np.complex128,
414        np.array(b"").dtype.type
415    ]:
416      self._reverse2DimAuto(dtype)
417
418  @test_util.run_deprecated_v1
419  def testUnknownDims(self):
420    reverse_v2 = array_ops.reverse_v2
421    data_t = array_ops.placeholder(dtypes.float32)
422    axis_known_t = array_ops.placeholder(dtypes.int32, shape=[3])
423    reverse_known_t = reverse_v2(data_t, axis_known_t)
424    # Unlike V1 we cannot know this anymore
425    self.assertEqual(None, reverse_known_t.get_shape().ndims)
426
427    axis_unknown_t = array_ops.placeholder(dtypes.int32)
428    reverse_unknown_t = reverse_v2(data_t, axis_unknown_t)
429    self.assertIs(None, reverse_unknown_t.get_shape().ndims)
430
431    data_2d_t = array_ops.placeholder(dtypes.float32, shape=[None, None])
432    axis_2d_t = array_ops.placeholder(dtypes.int32, shape=[3])
433    reverse_2d_t = reverse_v2(data_2d_t, axis_2d_t)
434    self.assertEqual(2, reverse_2d_t.get_shape().ndims)
435
436  @test_util.run_deprecated_v1
437  def testReverseRowsOf3Channels(self):
438    """Tests optimized code for reversing rows with last dim size = 3."""
439    with self.session(use_gpu=True):
440      for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
441        for outer_size in (1, 2):
442          for middle_size in list(range(50)) + [100000]:
443            x_np = np.reshape(
444                np.arange(outer_size * middle_size * 3, dtype=np.float32),
445                newshape=(outer_size, middle_size, 3))
446            x_tf = reverse_f(x_np, [1]).eval()
447            np_answer = x_np[:, ::-1, :]
448            self.assertAllEqual(x_tf, np_answer)
449
450  @test_util.run_deprecated_v1
451  def testReverseRowsOf4Channels(self):
452    with self.session(use_gpu=True):
453      for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
454        for outer_size in (1, 2):
455          for middle_size in list(range(50)) + [100000]:
456            x_np = np.reshape(
457                np.arange(outer_size * middle_size * 4, dtype=np.float32),
458                newshape=(outer_size, middle_size, 4))
459            x_tf = reverse_f(x_np, [1]).eval()
460            np_answer = x_np[:, ::-1, :]
461            self.assertAllEqual(x_tf, np_answer)
462
463  @test_util.run_deprecated_v1
464  def testReverseColumnsOf3Channels(self):
465    with self.session(use_gpu=True):
466      for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
467        for outer_size in list(range(50)) + [100000]:
468          for middle_size in (1, 2):
469            x_np = np.reshape(
470                np.arange(outer_size * middle_size * 3, dtype=np.float32),
471                newshape=(outer_size, middle_size, 3))
472            x_tf = reverse_f(x_np, [0]).eval()
473            np_answer = x_np[::-1, :, :]
474            self.assertAllEqual(x_tf, np_answer)
475
476
477class MeshgridTest(test_util.TensorFlowTestCase):
478
479  def _compareDiff(self, x, y, use_gpu):
480    for index in ("ij", "xy"):
481      numpy_out = np.meshgrid(x, y, indexing=index)
482      tf_out = array_ops.meshgrid(x, y, indexing=index)
483      with self.cached_session(use_gpu=use_gpu):
484        for xx, yy in zip(numpy_out, tf_out):
485          self.assertAllEqual(xx, yy.eval())
486
487  def _compareDiffType(self, n, np_dtype, use_gpu):
488    inputs = []
489    for index in ("ij", "xy"):
490      for _ in range(n):
491        x = np.linspace(-10, 10, 5).astype(np_dtype)
492        if np_dtype in (np.complex64, np.complex128):
493          x += 1j
494        inputs.append(x)
495      numpy_out = np.meshgrid(*inputs, indexing=index)
496      with self.cached_session(use_gpu=use_gpu):
497        tf_out = array_ops.meshgrid(*inputs, indexing=index)
498        for x_np, x_tf in zip(numpy_out, tf_out):
499          self.assertAllEqual(x_np, x_tf.eval())
500
501  @test_util.run_deprecated_v1
502  def testCompare(self):
503    for t in (np.float16, np.float32, np.float64, np.int32, np.int64,
504              np.complex64, np.complex128):
505      self._compareDiffType(2, t, False)
506      self._compareDiffType(3, t, False)
507
508      x = [1, 2, 3]
509      y = [4, 5]
510
511      a = [[1, 1], [1, 1]]
512
513      self._compareDiff(x, y, False)
514      self._compareDiff(x, a, False)
515
516
517class StridedSliceChecker(object):
518  """Check a given tensor against the numpy result."""
519
520  REF_TENSOR = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
521  REF_TENSOR_ALIGNED = np.arange(1, 97, dtype=np.float32).reshape(3, 4, 8)
522
523  def __init__(self, test, x, tensor_type=dtypes.int32, check_type_infer=True):
524    self.x_np = np.array(x).astype(tensor_type.as_numpy_dtype)
525    if tensor_type.is_bool:
526      self.x_np = np.array(x % 3).astype(np.bool)
527    # Give the value a non-zero imaginary component for complex types.
528    if tensor_type.is_complex:
529      self.x_np -= 1j * self.x_np
530    self.test = test
531    self.x = constant_op.constant(self.x_np, dtype=tensor_type)
532    self.check_type_infer = check_type_infer
533
534  def __getitem__(self, spec):
535    op = self.x.__getitem__(spec)
536
537    def eval_if_tensor(x):
538      try:
539        return x.eval()
540      except AttributeError:
541        return x
542
543    if isinstance(spec, bool) or \
544      (isinstance(spec, ops.Tensor) and spec.dtype == dtypes.bool) or \
545      (isinstance(spec, np.ndarray) and spec.dtype == bool) or \
546      (isinstance(spec, (list, tuple)) and np.asarray(spec).dtype == bool):
547      tensor = op.eval()
548      np_spec = eval_if_tensor(spec)
549      self.test.assertAllEqual(self.x_np[np_spec], tensor)
550      return tensor
551
552    if not isinstance(spec, (list, tuple)):
553      spec = [spec]
554
555    tensor = op.eval()
556
557    # Make a numpy spec that pre-evals the tensors
558    np_specs = []
559
560    for s in spec:
561      if isinstance(s, slice):
562        start = eval_if_tensor(s.start)
563        stop = eval_if_tensor(s.stop)
564        step = eval_if_tensor(s.step)
565        np_specs.append(slice(start, stop, step))
566      else:
567        np_specs.append(eval_if_tensor(s))
568
569    self.test.assertAllEqual(self.x_np[tuple(np_specs)], tensor)
570    if self.check_type_infer:
571      self.test.assertAllEqual(tensor.shape, op.get_shape())
572    return tensor
573
574
575STRIDED_SLICE_TYPES = [
576    dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8, dtypes.float32,
577    dtypes.float64, dtypes.complex64, dtypes.complex128, dtypes.bool
578]
579
580
581class StridedSliceTest(test_util.TensorFlowTestCase):
582  """Test the strided slice operation with variants of slices."""
583
584  @test_util.run_deprecated_v1
585  def test_basic_slice(self):
586    for tensor_type in STRIDED_SLICE_TYPES:
587      with self.cached_session(use_gpu=True):
588        checker = StridedSliceChecker(
589            self, StridedSliceChecker.REF_TENSOR, tensor_type=tensor_type)
590        _ = checker[:, :, :]
591        # Various ways of representing identity slice
592        _ = checker[:, :, :]
593        _ = checker[::, ::, ::]
594        _ = checker[::1, ::1, ::1]
595        # Not zero slice
596        _ = checker[::1, ::5, ::2]
597        # Reverse in each dimension independently
598        _ = checker[::-1, :, :]
599        _ = checker[:, ::-1, :]
600        _ = checker[:, :, ::-1]
601        ## negative index tests i.e. n-2 in first component
602        _ = checker[-2::-1, :, ::1]
603        # negative index tests i.e. n-2 in first component, non-unit stride
604        _ = checker[-2::-1, :, ::2]
605
606        # Check rank-0 examples
607        checker2 = StridedSliceChecker(self, 5, tensor_type=tensor_type)
608        _ = checker2[None]
609        _ = checker2[...]
610        _ = checker2[tuple()]
611
612  def testInt64GPU(self):
613    if not test_util.is_gpu_available():
614      self.skipTest("No GPU available")
615
616    with test_util.force_gpu():
617      x = constant_op.constant([1., 2., 3.])
618      begin = constant_op.constant([2], dtype=dtypes.int64)
619      end = constant_op.constant([3], dtype=dtypes.int64)
620      strides = constant_op.constant([1], dtype=dtypes.int64)
621      s = array_ops.strided_slice(x, begin, end, strides)
622      self.assertAllEqual([3.], self.evaluate(s))
623
624  @test_util.assert_no_new_pyobjects_executing_eagerly
625  @test_util.assert_no_garbage_created
626  def testTensorSliceEagerMemory(self):
627    with context.eager_mode():
628      inputs = constant_op.constant([[[1], [2], [3], [4]]],
629                                    dtype=dtypes.float32)
630      # Tests that slicing an EagerTensor doesn't leak memory
631      inputs[0]  # pylint: disable=pointless-statement
632
633  @test_util.assert_no_new_pyobjects_executing_eagerly
634  @test_util.assert_no_garbage_created
635  def testVariableSliceEagerMemory(self):
636    with context.eager_mode():
637      v = variables.Variable([1., 2.])
638      v[0]  # pylint: disable=pointless-statement
639
640  @test_util.run_deprecated_v1
641  def testDegenerateSlices(self):
642    with self.session(use_gpu=True):
643      checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
644      # degenerate by offering a forward interval with a negative stride
645      _ = checker[0:-1:-1, :, :]
646      # degenerate with a reverse interval with a positive stride
647      _ = checker[-1:0, :, :]
648      # empty interval in every dimension
649      _ = checker[-1:0, 2:2, 2:3:-1]
650      # empty first dimension only (used to break for aligned tensors).
651      checker = StridedSliceChecker(self,
652                                    StridedSliceChecker.REF_TENSOR_ALIGNED)
653      _ = checker[1:0]
654
655  @test_util.run_deprecated_v1
656  def testSliceWithUndefinedDimension(self):
657    t = constant_op.constant([1, 2, 3])
658    d = tensor_shape.Dimension(None)
659    self.assertAllEqual(t[d:d:d], t)
660
661  @test_util.run_deprecated_v1
662  def testEllipsis(self):
663    with self.session(use_gpu=True):
664      raw = [[[[[1, 2], [3, 4], [5, 6]]], [[[7, 8], [9, 10], [11, 12]]]]]
665      checker = StridedSliceChecker(self, raw)
666
667      _ = checker[0:]
668      # implicit ellipsis
669      _ = checker[0:, ...]
670      # ellipsis alone
671      _ = checker[...]
672      # ellipsis at end
673      _ = checker[0:1, ...]
674      # ellipsis at begin
675      _ = checker[..., 0:1]
676      # ellipsis at middle
677      _ = checker[0:1, ..., 0:1]
678      # multiple ellipses not allowed
679      with self.assertRaisesRegexp(ValueError, "Multiple ellipses"):
680        _ = checker[..., :, ...].eval()
681
682  @test_util.run_deprecated_v1
683  def testShrink(self):
684    with self.session(use_gpu=True):
685      raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
686              [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
687      checker = StridedSliceChecker(self, raw)
688      _ = checker[:, :, :, :, 3]
689      _ = checker[..., 3]
690      _ = checker[:, 0]
691      _ = checker[:, :, 0]
692
693  @test_util.run_deprecated_v1
694  def testBothNewAxisAndShrink(self):
695    with self.session(use_gpu=True):
696      ones = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int16)
697      self.assertAllEqual(
698          ones[array_ops.newaxis, :,
699               0].eval(feed_dict={ones: [[1, 1], [1, 1]]}), [[1, 1]])
700
701  @test_util.run_deprecated_v1
702  def testTensorIndexing(self):
703    with self.session(use_gpu=True):
704      raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
705              [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
706      checker = StridedSliceChecker(self, raw, check_type_infer=False)
707      bar = constant_op.constant(2)
708      bar2 = constant_op.constant(3)
709      _ = checker[..., bar:bar2]
710      _ = checker[..., bar]
711      _ = checker[..., 3]
712      _ = checker[..., 2**64 // 2**63]  # Test longs in Python 2
713
714  def testTensorIndexingTypeError(self):
715    with self.session(use_gpu=True):
716      checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
717      expected = re.escape(array_ops._SLICE_TYPE_ERROR)
718      with self.assertRaisesRegexp(TypeError, expected):
719        _ = checker["foo"]
720      with self.assertRaisesRegexp(TypeError, expected):
721        _ = checker[constant_op.constant("foo")]
722      with self.assertRaisesRegexp(TypeError, expected):
723        _ = checker[0.0]
724      with self.assertRaisesRegexp(TypeError, expected):
725        _ = checker[constant_op.constant(0.0)]
726      with self.assertRaisesRegexp(TypeError, expected):
727        _ = checker[constant_op.constant([1, 2, 3])]
728      with self.assertRaisesRegexp(TypeError, expected):
729        _ = checker[[2.1, -0.7, 1.5]]
730
731  @test_util.run_deprecated_v1
732  def testExpand(self):
733    with self.session(use_gpu=True):
734      raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
735              [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]]
736      checker = StridedSliceChecker(self, raw)
737      # new axis (followed by implicit ellipsis)
738      _ = checker[np.newaxis]
739      # newaxis after ellipsis
740      _ = checker[..., np.newaxis]
741      # newaxis in between ellipsis and explicit range
742      _ = checker[..., np.newaxis, :]
743      _ = checker[:, ..., np.newaxis, :, :]
744      # Reverse final dimension with new axis
745      _ = checker[:, :, np.newaxis, :, 2::-1]
746      # Ellipsis in middle of two newaxis
747      _ = checker[np.newaxis, ..., np.newaxis]
748
749  @test_util.run_deprecated_v1
750  def testExpandVariable(self):
751    with self.session(use_gpu=True):
752      x = variables.Variable(7, dtype=dtypes.int32)
753      x.initializer.run()
754      y = x[None].eval()
755      self.assertEqual(y.shape, (1,))
756      self.assertAllEqual(y, (7,))
757
758  @test_util.run_deprecated_v1
759  def testOptimizedCases(self):
760    with self.session(use_gpu=True):
761      checker = StridedSliceChecker(self,
762                                    StridedSliceChecker.REF_TENSOR_ALIGNED)
763      # Identity
764      _ = checker[:]
765      # Identity
766      _ = checker[...]
767      # Identity
768      _ = checker[np.newaxis, ..., np.newaxis]
769      # First axis slice
770      _ = checker[1:]
771      # First axis slice
772      _ = checker[np.newaxis, 1:]
773
774  @test_util.run_v1_only("currently failing on v2")
775  def testMasks(self):
776    with self.session(use_gpu=True):
777      scalar = np.array(0)
778      # Test tensor type mask
779      checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
780      _ = checker[checker.x > 2]
781      _ = checker[checker.x <= 5]
782      _ = checker[ops.convert_to_tensor(scalar)]
783
784      # Test numpy array type mask
785      raw = np.array([[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
786                       [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23,
787                                                              24]]]]])
788      checker1 = StridedSliceChecker(self, raw)
789      _ = checker1[raw >= 4]
790      _ = checker1[raw < 19]
791      _ = checker1[scalar]
792
793      # Test boolean and non boolean cases
794      mask = np.array([True, False, True])
795      raw1 = np.array([[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]])
796      checker2 = StridedSliceChecker(self, raw1)
797      _ = checker2[mask]
798      _ = checker2[ops.convert_to_tensor(mask)]
799
800
801class StridedSliceShapeChecker(object):
802
803  def __init__(self, x):
804    self.x = x
805
806  def __getitem__(self, spec):
807    op = self.x.__getitem__(spec)
808    return op.get_shape()
809
810
811class StridedSliceShapeTest(test_util.TensorFlowTestCase):
812  """Test the shape inference of StridedSliceShapes."""
813
814  @test_util.run_deprecated_v1
815  def testUnknown(self):
816    with self.session(use_gpu=True):
817      uncertain_tensor = array_ops.placeholder(dtypes.float32)
818      a = StridedSliceShapeChecker(uncertain_tensor)
819      a_slice_shape = a[...]
820      self.assertAllEqual(a_slice_shape.ndims, None)
821
822  def tensorShapeEqual(self, x, y):
823    self.assertTrue(x is not None and y is not None or x is None and y is None)
824    self.assertEqual(x.as_list(), y.as_list())
825
826  @test_util.run_deprecated_v1
827  def testTensorShapeUncertain(self):
828    with self.session(use_gpu=True):
829      uncertain_tensor = array_ops.placeholder(
830          dtypes.float32, shape=(5, None, 7))
831      a = StridedSliceShapeChecker(uncertain_tensor)
832      self.tensorShapeEqual(a[3:5], tensor_shape.TensorShape([2, None, 7]))
833      self.tensorShapeEqual(a[3:5, :, 4], tensor_shape.TensorShape([2, None]))
834      self.tensorShapeEqual(a[3:5, 3:4, 4], tensor_shape.TensorShape([2, None]))
835      self.tensorShapeEqual(a[3:5, :, 5:10],
836                            tensor_shape.TensorShape([2, None, 2]))
837      self.tensorShapeEqual(a[3:5, :, 50:3],
838                            tensor_shape.TensorShape([2, None, 0]))
839      self.tensorShapeEqual(a[3:5, :, array_ops.newaxis, 50:3,],
840                            tensor_shape.TensorShape([2, None, 1, 0]))
841      self.tensorShapeEqual(a[1:5:2, :, array_ops.newaxis, 50:3,],
842                            tensor_shape.TensorShape([2, None, 1, 0]))
843      self.tensorShapeEqual(a[:5:3, :, array_ops.newaxis, 50:3,],
844                            tensor_shape.TensorShape([2, None, 1, 0]))
845      self.tensorShapeEqual(a[:2:3, :, array_ops.newaxis, 50:3,],
846                            tensor_shape.TensorShape([1, None, 1, 0]))
847      self.tensorShapeEqual(a[::-1, :, array_ops.newaxis, ::-2],
848                            tensor_shape.TensorShape([5, None, 1, 4]))
849
850  @test_util.run_deprecated_v1
851  def testTensorValuedIndexShape(self):
852    with self.session(use_gpu=True):
853      defined_shape_tensor = array_ops.placeholder(
854          dtypes.float32, shape=(5, 3, 7))
855      index_value = array_ops.placeholder(dtypes.int32, shape=())
856      a = StridedSliceShapeChecker(defined_shape_tensor)
857      self.tensorShapeEqual(a[index_value], tensor_shape.TensorShape([3, 7]))
858      self.tensorShapeEqual(a[index_value, ::-1],
859                            tensor_shape.TensorShape([3, 7]))
860      self.tensorShapeEqual(a[index_value, ::-2],
861                            tensor_shape.TensorShape([2, 7]))
862      other_scalar = array_ops.placeholder(dtypes.int32, shape=())
863      self.tensorShapeEqual(a[index_value, other_scalar:2],
864                            tensor_shape.TensorShape([None, 7]))
865
866
867class GradSliceChecker(object):
868  """Tests that we can compute a gradient for var^2."""
869
870  def __init__(self, test, sess, var, varnp):
871    self.test = test
872    self.sess = sess
873    self.val = var * var
874    self.var = var
875    self.varnp = varnp
876
877  def __getitem__(self, spec):
878    slice_var = self.var[spec]
879    slice_val = self.val[spec]
880
881    # compute analytic 2nd derivative
882    analytic_grad2 = 2 * slice_val
883
884    dy = variables.Variable(
885        array_ops.ones_like(slice_var, dtype=dtypes.float32))
886    assign = dy.assign(slice_var)
887    slice_val_grad, = gradients_impl.gradients(slice_val, self.var, grad_ys=dy)
888    slice_val_grad2, = gradients_impl.gradients(
889        slice_val_grad, dy, grad_ys=self.var)
890    self.sess.run(assign)
891    slice_val_grad_evaled, slice_val_grad2_evaled = (
892        self.sess.run([slice_val_grad, slice_val_grad2]))
893    analytic_grad2_evaled = analytic_grad2.eval()
894    self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled)
895
896    # compute analytic gradient for slice
897    np_val_grad = (2 * self.varnp * self.varnp)
898    np_sliceval_grad = np.zeros(self.var.get_shape())
899    if isinstance(spec, ops.Tensor):
900      spec = self.sess.run([spec])
901    np_sliceval_grad[spec] = np_val_grad[spec]
902    # verify gradient
903    self.test.assertAllEqual(slice_val_grad_evaled, np_sliceval_grad)
904
905
906class StridedSliceGradTest(test_util.TensorFlowTestCase):
907  """Test that strided slice's custom gradient produces correct gradients."""
908
909  @test_util.run_v1_only("b/120545219")
910  def testGradient(self):
911    with self.session(use_gpu=True) as sess:
912      var = variables.Variable(
913          array_ops.reshape(
914              math_ops.range(1, 97, 1, dtype=dtypes.float32), shape=(6, 4, 4)))
915      init = variables.global_variables_initializer()
916      sess.run(init)
917
918      raw = np.array(range(1, 97, 1)).reshape((6, 4, 4))
919      grad = GradSliceChecker(self, sess, var, raw)
920      _ = grad[2:6:2, 1:3, 1:3]
921      _ = grad[3:0:-2, 1:3, 1:3]
922      _ = grad[3:0:-2, array_ops.newaxis, 1:3, 2, array_ops.newaxis]
923      _ = grad[3:0:-2, 1:3, 2]
924      _ = grad[:, -1, :]
925      _ = grad[:, -2, :]
926      with self.assertRaisesRegexp(ValueError, "out of bounds"):
927        _ = grad[:, -200, :]
928      with self.assertRaisesRegexp(ValueError, "out of bounds"):
929        _ = grad[:, 200, :]
930
931      # Test numpy array type mask
932      _ = grad[raw > 51]
933      # Test tensor type mask
934      _ = grad[ops.convert_to_tensor(raw) <= 76]
935
936  @test_util.run_v1_only("b/120545219")
937  def testGradientZero(self):
938    with self.session(use_gpu=True) as sess:
939      var = variables.Variable(8.)
940      init = variables.global_variables_initializer()
941      sess.run(init)
942      grad = GradSliceChecker(self, sess, var, np.array(8))
943      _ = grad[tuple()]
944
945  @test_util.run_deprecated_v1
946  def testInt64Indices(self):
947    with self.session(use_gpu=True) as sess:
948      a = math_ops.range(3, dtype=dtypes.float32)
949      index = constant_op.constant(1, dtype=dtypes.int64)
950      b = 2. * a[index]
951      grad, = gradients_impl.gradients(b, a)
952      self.assertAllEqual(self.evaluate(grad), [0., 2., 0.])
953
954
955class StridedSliceGradTypeTest(test_util.TensorFlowTestCase):
956  """Test varied index types and host located memory."""
957
958  @test_util.run_deprecated_v1
959  def testHostVsDevice(self):
960    with self.session(use_gpu=True) as sess:
961      var2 = variables.Variable(
962          array_ops.reshape(
963              math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
964              shape=(4, 1, 1)))
965      varshape = variables.Variable([6, 4, 4], dtype=dtypes.int32)
966      self.evaluate(variables.global_variables_initializer())
967      begin = constant_op.constant([0, 0, 0])
968      end = constant_op.constant([4, 1, 1])
969      strides = constant_op.constant([1, 1, 1])
970      foo = array_ops.strided_slice_grad(varshape, begin, end, strides, var2)
971      sess.run(foo)
972
973  @test_util.run_deprecated_v1
974  def testInt64Shape(self):
975    with self.session(use_gpu=True) as sess:
976      original_dy = array_ops.reshape(
977          math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
978          shape=(4, 1, 1))
979      original_shape = constant_op.constant([6, 4, 4], dtype=dtypes.int64)
980      self.evaluate(variables.global_variables_initializer())
981      begin = constant_op.constant([0, 0, 0], dtype=dtypes.int64)
982      end = constant_op.constant([4, 1, 1], dtype=dtypes.int64)
983      strides = constant_op.constant([1, 1, 1], dtype=dtypes.int64)
984      dx = array_ops.strided_slice_grad(original_shape, begin, end, strides,
985                                        original_dy)
986      sess.run(dx)
987
988  @test_util.run_deprecated_v1
989  def testMixedIndexTypes(self):
990    with self.session(use_gpu=True) as sess:
991      original_dy = array_ops.reshape(
992          math_ops.cast(math_ops.range(1, 5, 1), dtypes.float32),
993          shape=(4, 1, 1))
994      original_shape = constant_op.constant([6, 4, 4], dtype=dtypes.int64)
995      self.evaluate(variables.global_variables_initializer())
996      begin = constant_op.constant([0, 0, 0], dtype=dtypes.int32)
997      end = constant_op.constant([4, 1, 1], dtype=dtypes.int64)
998      strides = constant_op.constant([1, 1, 1], dtype=dtypes.int64)
999      with self.assertRaisesRegexp(
1000          TypeError, "Input 'begin' of 'StridedSliceGrad' Op has type int32"
1001          " that does not match type int64 of argument 'shape'"):
1002        dx = array_ops.strided_slice_grad(original_shape, begin, end, strides,
1003                                          original_dy)
1004        sess.run(dx)
1005
1006
1007class BenchmarkSlice(object):
1008
1009  def __init__(self, tensor):
1010    self.tensor = tensor
1011
1012  def __getitem__(self, x):
1013    return self.tensor[x]
1014
1015
1016class StridedSliceBenchmark(test_lib.Benchmark):
1017  """Benchmark new strided slice operation on non-trivial case."""
1018
1019  def run_and_time(self, slice_op):
1020    variables.global_variables_initializer().run()
1021    for _ in range(10):
1022      _ = slice_op.eval()
1023    iters = 1000
1024    t0 = time.time()
1025    for _ in range(iters):
1026      slice_op.eval()
1027    t1 = time.time()
1028    self.report_benchmark(iters=iters, wall_time=(t1 - t0) / 1000.0)
1029
1030  def make_variable(self):
1031    n = 256
1032    shape = (n, n, n)
1033    items = n**3
1034    var = variables.Variable(
1035        array_ops.reshape(math_ops.linspace(1., float(items), items), shape),
1036        dtype=dtypes.float32)
1037    return var
1038
1039  def benchmark_strided_slice_skip(self):
1040    with session.Session():
1041      var = self.make_variable()
1042      helper = BenchmarkSlice(var)
1043      slice_op = helper[::2, ::1, ::2]
1044      self.run_and_time(slice_op)
1045
1046  def benchmark_strided_slice_easy(self):
1047    with session.Session():
1048      var = self.make_variable()
1049      helper = BenchmarkSlice(var)
1050      slice_op = helper[3::1, 3::1, 3::1]
1051      self.run_and_time(slice_op)
1052
1053  def benchmark_slice_easy(self):
1054    with session.Session():
1055      var = self.make_variable()
1056      slice_op = var[3::1, 3::1, 3::1]
1057      self.run_and_time(slice_op)
1058
1059
1060class StridedSliceAssignChecker(object):
1061
1062  def __init__(self, test, x, tensor_type=dtypes.float32, use_resource=False):
1063    self.tensor_type = tensor_type
1064    self.test = test
1065    self._use_resource = use_resource
1066
1067    self.x_np = np.array(x).astype(tensor_type.as_numpy_dtype)
1068    # Give the value a non-zero imaginary component for complex types.
1069    if tensor_type.is_complex:
1070      self.x_np -= 1j * self.x_np
1071    self.x = constant_op.constant(self.x_np, dtype=tensor_type)
1072
1073  def __setitem__(self, index, value):
1074    value = np.array(value).astype(self.tensor_type.as_numpy_dtype)
1075    # Give the value a non-zero imaginary component for complex types.
1076    if self.tensor_type.is_complex:
1077      value -= 1j * value
1078
1079    with self.test.test_session(use_gpu=True) as sess:
1080      if self._use_resource:
1081        var = resource_variable_ops.ResourceVariable(self.x)
1082      else:
1083        var = variables.Variable(self.x)
1084      sess.run(variables.variables_initializer([var]))
1085      val = sess.run(var[index].assign(value))
1086      # val_copy is used to check that tf.compat.v1.assign works equivalently
1087      # to the assign method above.
1088      val_copy = sess.run(state_ops.assign(var[index], value))
1089      valnp = np.copy(self.x_np)
1090      valnp[index] = np.array(value)
1091      self.test.assertAllEqual(val, valnp)
1092      self.test.assertAllEqual(val_copy, valnp)
1093
1094
1095class SliceAssignTest(test_util.TensorFlowTestCase):
1096
1097  @test_util.run_deprecated_v1
1098  def testInvalidSlice(self):
1099    with self.cached_session() as sess:
1100      foo = constant_op.constant([1, 2, 3])
1101      with self.assertRaisesRegexp(
1102          ValueError, "Sliced assignment"
1103          " is only supported for variables"):
1104        bar = foo[:2].assign(constant_op.constant([1, 2]))
1105        sess.run(bar)
1106
1107  def doTestSliceAssign(self, use_resource):
1108    for dtype in STRIDED_SLICE_TYPES:
1109      checker = StridedSliceAssignChecker(
1110          self, [[1, 2, 3], [4, 5, 6]],
1111          use_resource=use_resource,
1112          tensor_type=dtype)
1113      # Check if equal
1114      checker[:] = [[10, 20, 30], [40, 50, 60]]
1115      # Check trivial (1,1) shape tensor
1116      checker[1:2, 1:2] = [[66]]
1117      # shrinks shape changes
1118      checker[1:2, 1] = [66]
1119      checker[1, 1:2] = [66]
1120      checker[1, 1] = 66
1121      # newaxis shape changes
1122      checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
1123      # shrink and newaxis
1124      checker[None, None, 0, 0:1] = [[[99]]]
1125      # Non unit strides
1126      checker[::1, ::-2] = [[3, 33], [4, 44]]
1127      # degenerate interval
1128      checker[8:10, 0] = []
1129      checker[8:10, 8:10] = [[]]
1130    # Assign vector to scalar (rank-0) using newaxis
1131    checker2 = StridedSliceAssignChecker(self, 222)
1132    checker2[()] = 6  # no indices
1133    checker2[...] = 6  # ellipsis
1134    checker2[None] = [6]  # new axis
1135
1136  @test_util.run_deprecated_v1
1137  @test_util.disable_xla("b/123559667")
1138  def testSliceAssign(self):
1139    self.doTestSliceAssign(use_resource=False)
1140
1141  @test_util.run_deprecated_v1
1142  @test_util.disable_xla("b/123559667")
1143  def testSliceAssignResource(self):
1144    self.doTestSliceAssign(use_resource=True)
1145
1146  @test_util.run_v1_only("b/120545219")
1147  def testUninitialized(self):
1148    with self.assertRaisesRegexp(
1149        errors.FailedPreconditionError,
1150        "Attempting to use uninitialized value Variable"):
1151      with self.cached_session() as sess:
1152        v = variables.VariableV1([1, 2])
1153        sess.run(v[:].assign([1, 2]))
1154
1155  @test_util.run_v1_only("b/120545219")
1156  def testTypeError(self):
1157    init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
1158    too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
1159    too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
1160    v = variables.VariableV1(init_val)
1161    with self.assertRaises(TypeError):
1162      v[:].assign(too_small_val)
1163    with self.assertRaises(TypeError):
1164      v[:].assign(too_large_val)
1165
1166  @test_util.run_deprecated_v1
1167  def testTypeErrorResource(self):
1168    init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
1169    too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
1170    too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
1171    v = resource_variable_ops.ResourceVariable(init_val)
1172    with self.cached_session() as sess:
1173      self.evaluate(v.initializer)
1174      with self.assertRaises(ValueError):
1175        sess.run(v[:].assign(too_large_val))
1176      with self.assertRaises(ValueError):
1177        sess.run(v[:].assign(too_small_val))
1178
1179
1180class ShapeSizeRankTest(test_util.TensorFlowTestCase):
1181
1182  @test_util.run_in_graph_and_eager_modes
1183  def testDenseShape(self):
1184    t_value = [[0, 42], [24, 0]]
1185    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value)))
1186    self.assertEqual(4, self.evaluate(array_ops.size(t_value)))
1187    self.assertEqual(2, self.evaluate(array_ops.rank(t_value)))
1188
1189    t = constant_op.constant(t_value)
1190    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t)))
1191    self.assertEqual(4, self.evaluate(array_ops.size(t)))
1192    self.assertEqual(2, self.evaluate(array_ops.rank(t)))
1193
1194  @test_util.run_in_graph_and_eager_modes
1195  def testSparseShape(self):
1196    sp_value = sparse_tensor.SparseTensorValue(
1197        indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2))
1198    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp_value)))
1199    self.assertEqual(4, self.evaluate(array_ops.size(sp_value)))
1200    self.assertEqual(2, self.evaluate(array_ops.rank(sp_value)))
1201
1202    sp = sparse_tensor.SparseTensor.from_value(sp_value)
1203    self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp)))
1204    self.assertEqual(4, self.evaluate(array_ops.size(sp)))
1205    self.assertEqual(2, self.evaluate(array_ops.rank(sp)))
1206
1207  @test_util.run_in_graph_and_eager_modes
1208  def testSizeDtype(self):
1209    tensor = [1]
1210    self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype)
1211    self.assertEqual(
1212        dtypes.int64,
1213        self.evaluate(array_ops.size(tensor, out_type=dtypes.int64)).dtype)
1214
1215
1216class SequenceMaskTest(test_util.TensorFlowTestCase):
1217
1218  def testExceptions(self):
1219    with self.cached_session():
1220      with self.assertRaisesRegexp(ValueError, "maxlen must be scalar"):
1221        array_ops.sequence_mask([10, 20], [10, 20])
1222
1223  @test_util.run_deprecated_v1
1224  def testOneDimensionalWithMaxlen(self):
1225    with self.cached_session():
1226      res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5)
1227      self.assertAllEqual(res.get_shape(), [3, 5])
1228      self.assertAllEqual(
1229          res.eval(),
1230          [[True, False, False, False, False], [True, True, True, False, False],
1231           [True, True, False, False, False]])
1232
1233  @test_util.run_deprecated_v1
1234  def testOneDimensionalDtypeWithoutMaxlen(self):
1235    with self.cached_session():
1236      # test dtype and default maxlen:
1237      res = array_ops.sequence_mask(
1238          constant_op.constant([0, 1, 4]), dtype=dtypes.float32)
1239      self.assertAllEqual(res.get_shape().as_list(), [3, 4])
1240      self.assertAllEqual(
1241          res.eval(),
1242          [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]])
1243
1244  @test_util.run_deprecated_v1
1245  def testOneDimensionalWithoutMaxlen(self):
1246    with self.cached_session():
1247      res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]))
1248      self.assertAllEqual(res.get_shape().as_list(), [3, 4])
1249      self.assertAllEqual(
1250          res.eval(), [[False, False, False, False],
1251                       [True, False, False, False], [True, True, True, True]])
1252
1253  @test_util.run_deprecated_v1
1254  def testTwoDimensional(self):
1255    with self.cached_session():
1256      res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5)
1257      self.assertAllEqual(res.get_shape(), [1, 3, 5])
1258      self.assertAllEqual(res.eval(), [[[True, False, False, False, False],
1259                                        [True, True, True, False, False],
1260                                        [True, True, False, False, False]]])
1261
1262      # test dtype and default maxlen:
1263      res = array_ops.sequence_mask(
1264          constant_op.constant([[0, 1, 4], [1, 2, 3]]), dtype=dtypes.float32)
1265      self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4])
1266      self.assertAllEqual(
1267          res.eval(),
1268          [[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
1269           [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0]]])
1270
1271  @test_util.run_deprecated_v1
1272  def testUnknownShape(self):
1273    lengths = array_ops.placeholder(dtype=dtypes.int32)
1274    res = array_ops.sequence_mask(lengths)
1275    self.assertEqual(res.shape, None)
1276
1277  @test_util.run_deprecated_v1
1278  def testDtypes(self):
1279
1280    def check_dtypes(lengths_dtype, maxlen_dtype):
1281      res = array_ops.sequence_mask(
1282          constant_op.constant([1, 3, 2], dtype=lengths_dtype),
1283          constant_op.constant(5, dtype=maxlen_dtype))
1284      self.assertAllEqual(res.get_shape(), [3, 5])
1285      self.assertAllEqual(
1286          res.eval(),
1287          [[True, False, False, False, False], [True, True, True, False, False],
1288           [True, True, False, False, False]])
1289
1290    with self.cached_session():
1291      check_dtypes(dtypes.int32, dtypes.int32)
1292      check_dtypes(dtypes.int32, dtypes.int64)
1293      check_dtypes(dtypes.int64, dtypes.int32)
1294      check_dtypes(dtypes.int64, dtypes.int64)
1295
1296
1297class ConcatSliceResourceTest(test_util.TensorFlowTestCase):
1298
1299  @test_util.run_in_graph_and_eager_modes
1300  @test_util.run_deprecated_v1
1301  def testConcatSlice(self):
1302    r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b")
1303    r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c")
1304    c = array_ops.stack([r1, r2])
1305    s = array_ops.strided_slice(c, [1], [2])
1306    self.evaluate(test_ops.resource_create_op(s))
1307    with self.assertRaises(errors.AlreadyExistsError):
1308      self.evaluate(test_ops.resource_create_op(r2))
1309
1310
1311class IdentityTest(test_util.TensorFlowTestCase):
1312
1313  @test_util.run_gpu_only
1314  def testEagerIdentity(self):
1315    with context.eager_mode():
1316
1317      def _test(x, y, device):
1318        self.assertAllEqual(x.numpy(), y.numpy())
1319        self.assertTrue(device in y.device.lower())
1320
1321      with test_util.force_gpu():
1322        a = constant_op.constant([[2], [3]], dtype=dtypes.float32)
1323      with test_util.force_gpu():
1324        b = array_ops.identity(a)
1325        _test(a, b, "gpu")
1326      with test_util.force_cpu():
1327        c = array_ops.identity(b)
1328        _test(b, c, "cpu")
1329      with test_util.force_cpu():
1330        d = array_ops.identity(c)
1331        _test(c, d, "cpu")
1332      with test_util.force_gpu():
1333        e = array_ops.identity(d)
1334        _test(d, e, "gpu")
1335
1336
1337class PadTest(test_util.TensorFlowTestCase):
1338
1339  def testEager(self):
1340    with context.eager_mode():
1341      t = constant_op.constant([[1, 2, 3], [4, 5, 6]])
1342      paddings = constant_op.constant([[
1343          1,
1344          1,
1345      ], [2, 2]])
1346      padded = array_ops.pad(t, paddings, "CONSTANT")
1347      self.assertAllEqual(padded.numpy(),
1348                          [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0, 0],
1349                           [0, 0, 4, 5, 6, 0, 0], [0, 0, 0, 0, 0, 0, 0]])
1350
1351
1352class InvertPermutationTest(test_util.TensorFlowTestCase):
1353
1354  @test_util.run_deprecated_v1
1355  def testInvertPermutation(self):
1356    for dtype in [dtypes.int32, dtypes.int64]:
1357      with self.cached_session(use_gpu=True):
1358        x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype)
1359        y = array_ops.invert_permutation(x)
1360        self.assertAllEqual(y.get_shape(), [5])
1361        self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1])
1362
1363
1364class UnravelIndexTest(test_util.TensorFlowTestCase):
1365
1366  # TODO(b/73086570): Reenable test.
1367  @unittest.skip("Test does not pass internally.")
1368  def testUnravelIndex(self):
1369    with self.cached_session():
1370      for dtype in [dtypes.int32, dtypes.int64]:
1371        indices_1 = constant_op.constant(1621, dtype=dtype)
1372        dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
1373        out_1 = array_ops.unravel_index(indices_1, dims_1)
1374        self.assertAllEqual(out_1.eval(), [3, 1, 4, 1])
1375
1376        indices_2 = constant_op.constant([1621], dtype=dtype)
1377        dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
1378        out_2 = array_ops.unravel_index(indices_2, dims_2)
1379        self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]])
1380
1381        indices_3 = constant_op.constant([22, 41, 37], dtype=dtype)
1382        dims_3 = constant_op.constant([7, 6], dtype=dtype)
1383        out_3 = array_ops.unravel_index(indices_3, dims_3)
1384        self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
1385
1386
1387class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
1388
1389  @test_util.run_deprecated_v1
1390  def testSimple(self):
1391    with self.cached_session():
1392      a = array_ops.constant(10)
1393      guarantee_a = array_ops.guarantee_const(a)
1394      self.assertEqual(10, guarantee_a.eval())
1395
1396  @test_util.run_deprecated_v1
1397  def testVariables(self):
1398    with self.cached_session() as sess:
1399      for use_resource in [False, True]:
1400        a = variable_scope.get_variable(
1401            "var_{}".format(use_resource), [],
1402            initializer=init_ops.constant_initializer(10.0),
1403            use_resource=use_resource)
1404        guarantee_a = array_ops.guarantee_const(a)
1405        self.evaluate(variables.global_variables_initializer())
1406        self.assertEqual(10.0, guarantee_a.eval())
1407
1408  @test_util.run_deprecated_v1
1409  def testResourceRejection(self):
1410    with self.cached_session() as sess:
1411      a = variable_scope.get_variable(
1412          "resource_var", [],
1413          initializer=init_ops.constant_initializer(10.0),
1414          use_resource=True)
1415      guarantee_a = array_ops.guarantee_const(a.handle)
1416      self.evaluate(variables.global_variables_initializer())
1417      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
1418                                               "cannot be a resource variable"):
1419        guarantee_a.eval()
1420
1421
1422class SnapshotOpTest(test_util.TensorFlowTestCase):
1423
1424  @test_util.run_deprecated_v1
1425  def testInvertPermutation(self):
1426    for dtype in [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]:
1427      with self.cached_session(use_gpu=True):
1428        x = constant_op.constant([0, 1, 2, 3], dtype=dtype)
1429        y = gen_array_ops.snapshot(x)
1430        self.assertAllEqual(y.eval(), [0, 1, 2, 3])
1431
1432
1433@test_util.run_all_in_graph_and_eager_modes
1434class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
1435
1436  # Generates a tensor of the specified `shape` using values from `values`
1437  # scaled by (slice_idx + 1) along `axis` dimension.
1438  def _scale_per_slice(self, shape, axis, values):
1439    # Note: repeats the values if the shape is larger than values.
1440    out = np.take(values, np.remainder(np.arange(np.prod(shape)),
1441                                       len(values))).reshape(shape)
1442    if axis is not None:
1443      scale_shape = [1] * len(shape)
1444      scale_shape[axis] = shape[axis]
1445      out *= np.arange(1, shape[axis] + 1).reshape(scale_shape)
1446    return out
1447
1448  def testAxis(self):
1449    shape = np.array([2, 3, 4, 5])
1450    values = np.array([-1, -0.5, 0, 0.3, 0.8, 0.555, 0.5], dtype=np.float32)
1451    quant_values = np.array(
1452        [-1, -0.5, 0, 38.0 / 128, 102.0 / 128, 71.0 / 128, 0.5],
1453        dtype=np.float32)
1454    for axis in [None, 0, 1, 2, 3]:
1455      inputs = constant_op.constant(self._scale_per_slice(shape, axis, values))
1456      expected = self._scale_per_slice(shape, axis, quant_values)
1457      unused_minmax_value = 0 if axis is None else [0] * shape[axis]
1458      fake_quantized = self.evaluate(
1459          array_ops.quantize_and_dequantize(
1460              inputs,
1461              unused_minmax_value,
1462              unused_minmax_value,
1463              range_given=False,
1464              round_mode="HALF_UP",
1465              axis=axis))
1466      self.assertAllEqual(fake_quantized, expected)
1467      if axis is not None:
1468        fake_quantized = self.evaluate(
1469            array_ops.quantize_and_dequantize(
1470                inputs,
1471                unused_minmax_value,
1472                unused_minmax_value,
1473                range_given=False,
1474                axis=(axis - 4)))
1475        self.assertAllClose(fake_quantized, expected)
1476
1477
1478@test_util.run_all_in_graph_and_eager_modes
1479class SortedSearchTest(test_util.TensorFlowTestCase):
1480
1481  def testUpperBoundFloatHandCoded(self):
1482    cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
1483    arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
1484                   dtype=np.float32)
1485    result = np.searchsorted(cdf, arr, side="right")
1486    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1487    self.assertAllEqual(result, tf_result)
1488
1489  def testUpperBoundFloatRandomNd(self):
1490    dim_size = 7
1491    for d in range(1, 5):
1492      shape = [dim_size] * d
1493      cdf = np.cumsum(
1494          np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
1495      arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
1496
1497      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1498
1499      cdf = cdf.reshape([-1, dim_size])
1500      arr = arr.reshape([-1, dim_size])
1501      result = np.zeros(arr.shape, dtype=np.int32)
1502      for i in range(dim_size**(d - 1)):
1503        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1504
1505      result = result.reshape(shape)
1506
1507      self.assertAllEqual(result, tf_result)
1508
1509  def testUpperBoundFloatUneven(self):
1510    batch_size = 7
1511    size_search_array = 1000
1512    size_values = 47
1513    cdf = np.cumsum(
1514        np.random.uniform(size=[batch_size, size_search_array]).astype(
1515            np.float32),
1516        axis=1)
1517    arr = np.random.uniform(size=[batch_size, size_values]).astype(
1518        np.float32) * size_search_array
1519
1520    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1521
1522    result = np.zeros(arr.shape, dtype=np.int32)
1523    for i in range(batch_size):
1524      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1525
1526    self.assertAllEqual(result, tf_result)
1527
1528  def testLowerBoundFloatHandCoded(self):
1529    cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
1530    arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
1531                   dtype=np.float32)
1532    result = np.searchsorted(cdf, arr, side="left")
1533    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1534    self.assertAllEqual(result, tf_result)
1535
1536  def testLowerBoundFloatRandomNd(self):
1537    dim_size = 7
1538    for d in range(1, 5):
1539      shape = [dim_size] * d
1540      cdf = np.cumsum(
1541          np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
1542      arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
1543
1544      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1545
1546      cdf = cdf.reshape([-1, dim_size])
1547      arr = arr.reshape([-1, dim_size])
1548      result = np.zeros(arr.shape, dtype=np.int32)
1549      for i in range(dim_size**(d - 1)):
1550        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1551
1552      result = result.reshape(shape)
1553
1554      self.assertAllEqual(result, tf_result)
1555
1556  def testLowerBoundFloatUneven(self):
1557    batch_size = 7
1558    size_search_array = 1000
1559    size_values = 47
1560    cdf = np.cumsum(
1561        np.random.uniform(size=[batch_size, size_search_array]).astype(
1562            np.float32),
1563        axis=1)
1564    arr = np.random.uniform(size=[batch_size, size_values]).astype(
1565        np.float32) * size_search_array
1566
1567    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1568
1569    result = np.zeros(arr.shape, dtype=np.int32)
1570    for i in range(batch_size):
1571      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1572
1573    self.assertAllEqual(result, tf_result)
1574
1575  def testUpperBoundIntHandCoded(self):
1576    cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
1577    arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
1578    result = np.searchsorted(cdf, arr, side="right")
1579    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1580    self.assertAllEqual(result, tf_result)
1581
1582  def testUpperBoundIntRandomNd(self):
1583    dim_size = 7
1584    for d in range(1, 5):
1585      shape = [dim_size] * d
1586      cdf = np.cumsum(
1587          np.random.randint(low=0, high=10, size=shape).astype(np.int64),
1588          axis=(d - 1))
1589      arr = np.random.randint(
1590          low=0, high=10 * dim_size, size=shape).astype(np.int64)
1591
1592      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1593
1594      cdf = cdf.reshape([-1, dim_size])
1595      arr = arr.reshape([-1, dim_size])
1596      result = np.zeros(arr.shape, dtype=np.int32)
1597      for i in range(dim_size**(d - 1)):
1598        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1599
1600      result = result.reshape(shape)
1601
1602      self.assertAllEqual(result, tf_result)
1603
1604  def testUpperBoundIntUneven(self):
1605    batch_size = 7
1606    size_search_array = 1000
1607    size_values = 47
1608    cdf = np.cumsum(
1609        np.random.randint(low=0, high=10,
1610                          size=[batch_size,
1611                                size_search_array]).astype(np.int64),
1612        axis=1)
1613    arr = np.random.randint(
1614        low=0, high=10 * size_search_array, size=[batch_size,
1615                                                  size_values]).astype(np.int64)
1616
1617    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
1618
1619    result = np.zeros(arr.shape, dtype=np.int32)
1620    for i in range(batch_size):
1621      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
1622
1623    self.assertAllEqual(result, tf_result)
1624
1625  def testLowerBoundIntHandCoded(self):
1626    cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
1627    arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
1628    result = np.searchsorted(cdf, arr, side="left")
1629    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1630    self.assertAllEqual(result, tf_result)
1631
1632  def testLowerBoundIntRandomNd(self):
1633    dim_size = 7
1634    for d in range(1, 5):
1635      shape = [dim_size] * d
1636      cdf = np.cumsum(
1637          np.random.randint(low=0, high=10, size=shape).astype(np.int64),
1638          axis=(d - 1))
1639      arr = np.random.randint(
1640          low=0, high=10 * dim_size, size=shape).astype(np.int64)
1641
1642      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1643
1644      cdf = cdf.reshape([-1, dim_size])
1645      arr = arr.reshape([-1, dim_size])
1646      result = np.zeros(arr.shape, dtype=np.int32)
1647      for i in range(dim_size**(d - 1)):
1648        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1649
1650      result = result.reshape(shape)
1651
1652      self.assertAllEqual(result, tf_result)
1653
1654  def testLowerBoundIntUneven(self):
1655    batch_size = 7
1656    size_search_array = 1000
1657    size_values = 47
1658    cdf = np.cumsum(
1659        np.random.randint(low=0, high=10,
1660                          size=[batch_size,
1661                                size_search_array]).astype(np.int64),
1662        axis=1)
1663    arr = np.random.randint(
1664        low=0, high=10 * size_search_array, size=[batch_size,
1665                                                  size_values]).astype(np.int64)
1666
1667    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
1668
1669    result = np.zeros(arr.shape, dtype=np.int32)
1670    for i in range(batch_size):
1671      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
1672
1673    self.assertAllEqual(result, tf_result)
1674
1675
1676class BatchGatherNdTest(test_util.TensorFlowTestCase):
1677
1678  def testShapesMatch(self):
1679    """Tests for various different shape combinations."""
1680    shapes = []
1681    # params_shape, indices_shape, batch_dims
1682    shapes.append(((2, 2, 2), (2, 1), 1),)
1683    shapes.append(((2, 2, 2), (2, 2), 1),)
1684    shapes.append(((2, 2, 2), (2, 3), 0),)
1685    shapes.append(((2, 2, 2), (3,), 0),)
1686    shapes.append(((2, 2, 2), (1,), 0),)
1687    shapes.append(((2, 2, 3, 2), (2, 3), 1),)
1688    shapes.append(((2, 2, 3, 2), (2, 2), 1),)
1689    shapes.append(((2, 2, 3, 2), (2, 1), 1),)
1690    shapes.append(((2, 2, 3, 2), (2, 1, 3), 1),)
1691    shapes.append(((2, 2, 3, 2), (2, 2, 2), 1),)
1692    shapes.append(((2, 2, 3, 2), (2, 3, 1), 1),)
1693    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3), 2),)
1694    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2), 2),)
1695    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1), 2),)
1696    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1, 3), 2),)
1697    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2, 2), 2),)
1698    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
1699
1700    for params_shape, indices_shape, batch_dims in shapes:
1701      params = constant_op.constant(1.0, shape=(params_shape))
1702      indices = constant_op.constant(
1703          1, shape=(indices_shape), dtype=dtypes.int32)
1704      out = array_ops.batch_gather_nd(
1705          params=params, indices=indices, batch_dims=batch_dims)
1706      ndims_params = len(params_shape) - batch_dims
1707      ndims_rows = ndims_params - indices_shape[-1]
1708      expected_out_shape = indices_shape[:-1]
1709      if ndims_rows > 0:
1710        expected_out_shape += params_shape[-ndims_rows:]
1711      self.assertSequenceEqual(out.shape, expected_out_shape)
1712
1713  def testReducesToGatherNDWhenBatchDimIsZero(self):
1714    """Confirms setting batch_dims to zero reduces to tf.gather_nd."""
1715    params = constant_op.constant(np.random.uniform(0.0, 1.0, size=(7, 8, 9)))
1716    indices_shapes = []
1717    indices_shapes.append((1,))
1718    indices_shapes.append((3, 1))
1719    indices_shapes.append((3, 3, 1))
1720    indices_shapes.append((2,))
1721    indices_shapes.append((3, 2))
1722    indices_shapes.append((3, 3, 2))
1723    indices_shapes.append((3,))
1724    indices_shapes.append((3, 3))
1725    indices_shapes.append((3, 3, 3))
1726
1727    for indices_shape in indices_shapes:
1728      indices = np.random.randint(0, 7, size=indices_shape)
1729      gather_nd_result = gen_array_ops.gather_nd(params, indices)
1730      batch_gather_nd_result = array_ops.batch_gather_nd(
1731          params=params, indices=indices, batch_dims=0)
1732      self.assertAllEqual(gather_nd_result, batch_gather_nd_result)
1733
1734  def testSameResultAsMapFn(self):
1735    """Compares results with gather_nd called on every element with map_fn."""
1736    shapes = []
1737    # params_shape, indices_shape, batch_dims
1738    shapes.append(((2, 2, 2), (2, 1), 1),)
1739    shapes.append(((2, 2, 2), (2, 2), 1),)
1740    shapes.append(((2, 2, 3, 2), (2, 3), 1),)
1741    shapes.append(((2, 2, 3, 2), (2, 2), 1),)
1742    shapes.append(((2, 2, 3, 2), (2, 1), 1),)
1743    shapes.append(((2, 2, 3, 2), (2, 1, 3), 1),)
1744    shapes.append(((2, 2, 3, 2), (2, 2, 2), 1),)
1745    shapes.append(((2, 2, 3, 2), (2, 3, 1), 1),)
1746    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3), 2),)
1747    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2), 2),)
1748    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1), 2),)
1749    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1, 3), 2),)
1750    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2, 2), 2),)
1751    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
1752
1753    for params_shape, indices_shape, batch_dims in shapes:
1754      params = constant_op.constant(
1755          np.random.uniform(0.0, 1.0, size=(params_shape)))
1756      indices = np.random.randint(0, 2, size=indices_shape)
1757      batch_gather_nd_result = array_ops.batch_gather_nd(
1758          params=params, indices=indices, batch_dims=batch_dims)
1759
1760      if batch_dims > 1:
1761        params = array_ops.reshape(
1762            params, shape=[-1] + list(params_shape[batch_dims:]))
1763        indices = array_ops.reshape(
1764            indices, shape=[-1] + list(indices_shape[batch_dims:]))
1765
1766      map_fn_gather_nd_result = map_fn.map_fn(
1767          fn=self._map_fn_body, elems=(params, indices), dtype=dtypes.float64)
1768
1769      if batch_dims > 1:
1770        out_shape = map_fn_gather_nd_result.shape.as_list()
1771        out_shape = list(params_shape[:batch_dims]) + out_shape[1:]
1772        map_fn_gather_nd_result = array_ops.reshape(
1773            map_fn_gather_nd_result, shape=out_shape)
1774
1775      self.assertAllEqual(map_fn_gather_nd_result, batch_gather_nd_result)
1776
1777  def _map_fn_body(self, elems):
1778    return gen_array_ops.gather_nd(elems[0], elems[1])
1779
1780  def testBatchDimsAsTensor(self):
1781    """Tests Tensor batch_dims as input works as intended."""
1782    shapes = []
1783    # params_shape, indices_shape, batch_dims
1784    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 0),)
1785    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 1),)
1786    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
1787
1788    for params_shape, indices_shape, batch_dims in shapes:
1789      params = constant_op.constant(
1790          np.random.uniform(0.0, 1.0, size=(params_shape)))
1791      indices = np.random.randint(0, 2, size=indices_shape)
1792      batch_gather_nd_result = array_ops.gather_nd(
1793          params=params, indices=indices, batch_dims=batch_dims)
1794      batch_dims_tensor = constant_op.constant([batch_dims])
1795      batch_gather_nd_tensor_batch_dims_result = array_ops.gather_nd(
1796          params=params, indices=indices, batch_dims=batch_dims_tensor)
1797
1798      self.assertAllEqual(batch_gather_nd_tensor_batch_dims_result,
1799                          batch_gather_nd_result)
1800
1801  def testInvalidBatchDimsRaisesException(self):
1802    """Tests whether invalid batch_dims raise expected exceptions."""
1803    params = constant_op.constant(
1804        np.random.uniform(0.0, 1.0, size=(3, 2, 2, 3, 4)))
1805    indices = np.random.randint(0, 2, size=(3, 2, 3))
1806
1807    with self.assertRaises(TypeError):
1808      array_ops.batch_gather_nd(
1809          params=params,
1810          indices=indices,
1811          batch_dims=constant_op.constant((0, 1)))
1812
1813    with self.assertRaises(ValueError):
1814      array_ops.batch_gather_nd(params=params, indices=indices, batch_dims=-1)
1815
1816    with self.assertRaises(ValueError):
1817      array_ops.batch_gather_nd(params=params, indices=indices, batch_dims=4)
1818
1819  @test_util.run_deprecated_v1
1820  def testNoneBatchDimensions(self):
1821    """Tests gather_nd works with None dimensions."""
1822    shapes = []
1823    # params_shape, indices_shape, batch_dims
1824    shapes.append(((2, 2, 2), (2, 1), 1),)
1825    shapes.append(((2, 2, 2), (2, 2), 1),)
1826    shapes.append(((2, 2, 3, 2), (2, 3), 1),)
1827    shapes.append(((2, 2, 3, 2), (2, 2), 1),)
1828    shapes.append(((2, 2, 3, 2), (2, 1), 1),)
1829    shapes.append(((2, 2, 3, 2), (2, 1, 3), 1),)
1830    shapes.append(((2, 2, 3, 2), (2, 2, 2), 1),)
1831    shapes.append(((2, 2, 3, 2), (2, 3, 1), 1),)
1832    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3), 2),)
1833    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2), 2),)
1834    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1), 2),)
1835    shapes.append(((3, 2, 2, 3, 4), (3, 2, 1, 3), 2),)
1836    shapes.append(((3, 2, 2, 3, 4), (3, 2, 2, 2), 2),)
1837    shapes.append(((3, 2, 2, 3, 4), (3, 2, 3, 1), 2),)
1838
1839    for params_shape, indices_shape, batch_dims in shapes:
1840      params_ph_shape = list(params_shape)
1841      indices_ph_shape = list(indices_shape)
1842      for i in range(batch_dims):
1843        params_ph_shape[i] = None
1844        indices_ph_shape[i] = None
1845
1846      params = array_ops.placeholder(dtypes.float32, shape=params_ph_shape)
1847      indices = array_ops.placeholder(dtypes.int32, shape=indices_ph_shape)
1848      out = array_ops.batch_gather_nd(
1849          params=params, indices=indices, batch_dims=batch_dims)
1850
1851      with self.cached_session() as sess:
1852        params_val = np.ones(dtype=np.float32, shape=params_shape)
1853        indices_val = np.ones(dtype=np.int32, shape=indices_shape)
1854        res = sess.run(
1855            out, feed_dict={
1856                params: params_val,
1857                indices: indices_val
1858            })
1859      row_ndims = len(params_shape) - batch_dims - indices_shape[-1]
1860      expected_out_shape = indices_shape[:-1]
1861      if row_ndims > 0:
1862        expected_out_shape += params_shape[-row_ndims:]
1863
1864      self.assertSequenceEqual(res.shape, expected_out_shape)
1865
1866  @test_util.run_deprecated_v1
1867  def testUnknownIndices(self):
1868    """Tests whether indices with unknown rank works correctly."""
1869    params = constant_op.constant(((0, 1, 2),))
1870    indices = array_ops.placeholder(dtypes.int32)
1871    gather_nd_t = array_ops.gather_nd(params, indices, batch_dims=1)
1872    shape = gather_nd_t.get_shape()
1873    self.assertEqual(None, shape.ndims)
1874    self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
1875
1876
1877class RepeatTest(test_util.TensorFlowTestCase):
1878
1879  @test_util.run_deprecated_v1
1880  def testRepeatScalar(self):
1881    with self.test_session():
1882      v_tf = array_ops.repeat(constant_op.constant(3), 4)
1883      v_np = np.repeat(3, 4)
1884      self.assertAllEqual(v_tf.eval(), v_np)
1885
1886  @test_util.run_deprecated_v1
1887  def testRepeatMatrix(self):
1888    with self.test_session():
1889      x = np.array([[1, 2], [3, 4]], dtype=np.int32)
1890      v_tf = array_ops.repeat(constant_op.constant(x), 2)
1891      v_np = np.repeat(x, 2)
1892      self.assertAllEqual(v_tf.eval(), v_np)
1893
1894  @test_util.run_deprecated_v1
1895  def testRepeatMatrixAxis0(self):
1896    with self.test_session():
1897      x = np.array([[1, 2], [3, 4]], dtype=np.int32)
1898      v_tf = array_ops.repeat(
1899          constant_op.constant(x), constant_op.constant([1, 2]), axis=0)
1900      v_np = np.repeat(x, [1, 2], axis=0)
1901      self.assertAllEqual(v_tf.eval(), v_np)
1902
1903  @test_util.run_deprecated_v1
1904  def testRepeatMatrixAxis1(self):
1905    with self.test_session():
1906      x = np.array([[1, 2], [3, 4]], dtype=np.int32)
1907      v_tf = array_ops.repeat(
1908          constant_op.constant(x), constant_op.constant(3), axis=1)
1909      v_np = np.repeat(x, 3, axis=1)
1910      self.assertAllEqual(v_tf.eval(), v_np)
1911
1912  @test_util.run_deprecated_v1
1913  def testRepeatMatrixRepeatArray(self):
1914    with self.test_session():
1915      x = np.array([[1, 2], [3, 4]], dtype=np.int32)
1916      v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4])
1917      v_np = np.repeat(x, [1, 2, 3, 4])
1918      self.assertAllEqual(v_tf.eval(), v_np)
1919
1920
1921if __name__ == "__main__":
1922  test_lib.main()
1923