1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.framework import constant_op 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import errors 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops.linalg import linalg as linalg_lib 22from tensorflow.python.ops.linalg import linear_operator_test_util 23from tensorflow.python.ops.linalg import slicing 24from tensorflow.python.platform import test 25 26linalg = linalg_lib 27 28 29class _MakeSlices(object): 30 31 def __getitem__(self, slices): 32 return slices if isinstance(slices, tuple) else (slices,) 33 34 35make_slices = _MakeSlices() 36 37 38@test_util.run_all_in_graph_and_eager_modes 39class SlicingTest(test.TestCase): 40 """Tests for slicing LinearOperators.""" 41 42 def test_single_param_slice_withstep_broadcastdim(self): 43 event_dim = 3 44 sliced = slicing._slice_single_param( 45 array_ops.zeros([1, 1, event_dim]), 46 param_ndims_to_matrix_ndims=1, 47 slices=make_slices[44:-52:-3, -94::], 48 batch_shape=constant_op.constant([2, 7], dtype=dtypes.int32)) 49 self.assertAllEqual((1, 1, event_dim), self.evaluate(sliced).shape) 50 51 def test_single_param_slice_stop_leadingdim(self): 52 sliced = slicing._slice_single_param( 53 array_ops.zeros([7, 6, 5, 4, 3]), 54 param_ndims_to_matrix_ndims=2, 55 slices=make_slices[:2], 56 batch_shape=constant_op.constant([7, 6, 5], dtype=dtypes.int32)) 57 self.assertAllEqual((2, 6, 5, 4, 3), self.evaluate(sliced).shape) 58 59 def test_single_param_slice_stop_trailingdim(self): 60 sliced = slicing._slice_single_param( 61 array_ops.zeros([7, 6, 5, 4, 3]), 62 param_ndims_to_matrix_ndims=2, 63 slices=make_slices[..., :2], 64 batch_shape=constant_op.constant([7, 6, 5])) 65 self.assertAllEqual((7, 6, 2, 4, 3), self.evaluate(sliced).shape) 66 67 def test_single_param_slice_stop_broadcastdim(self): 68 sliced = slicing._slice_single_param( 69 array_ops.zeros([7, 1, 5, 4, 3]), 70 param_ndims_to_matrix_ndims=2, 71 slices=make_slices[:, :2], 72 batch_shape=constant_op.constant([7, 6, 5])) 73 self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape) 74 75 def test_single_param_slice_newaxis_leading(self): 76 sliced = slicing._slice_single_param( 77 array_ops.zeros([7, 6, 5, 4, 3]), 78 param_ndims_to_matrix_ndims=2, 79 slices=make_slices[:, array_ops.newaxis], 80 batch_shape=constant_op.constant([7, 6, 5])) 81 self.assertAllEqual((7, 1, 6, 5, 4, 3), self.evaluate(sliced).shape) 82 83 def test_single_param_slice_newaxis_trailing(self): 84 sliced = slicing._slice_single_param( 85 array_ops.zeros([7, 6, 5, 4, 3]), 86 param_ndims_to_matrix_ndims=2, 87 slices=make_slices[..., array_ops.newaxis, :], 88 batch_shape=constant_op.constant([7, 6, 5])) 89 self.assertAllEqual((7, 6, 1, 5, 4, 3), self.evaluate(sliced).shape) 90 91 def test_single_param_slice_start(self): 92 sliced = slicing._slice_single_param( 93 array_ops.zeros([7, 6, 5, 4, 3]), 94 param_ndims_to_matrix_ndims=2, 95 slices=make_slices[:, 2:], 96 batch_shape=constant_op.constant([7, 6, 5])) 97 self.assertAllEqual((7, 4, 5, 4, 3), self.evaluate(sliced).shape) 98 99 def test_single_param_slice_start_broadcastdim(self): 100 sliced = slicing._slice_single_param( 101 array_ops.zeros([7, 1, 5, 4, 3]), 102 param_ndims_to_matrix_ndims=2, 103 slices=make_slices[:, 2:], 104 batch_shape=constant_op.constant([7, 6, 5])) 105 self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape) 106 107 def test_single_param_slice_int(self): 108 sliced = slicing._slice_single_param( 109 array_ops.zeros([7, 6, 5, 4, 3]), 110 param_ndims_to_matrix_ndims=2, 111 slices=make_slices[:, 2], 112 batch_shape=constant_op.constant([7, 6, 5])) 113 self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape) 114 115 def test_single_param_slice_int_broadcastdim(self): 116 sliced = slicing._slice_single_param( 117 array_ops.zeros([7, 1, 5, 4, 3]), 118 param_ndims_to_matrix_ndims=2, 119 slices=make_slices[:, 2], 120 batch_shape=constant_op.constant([7, 6, 5])) 121 self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape) 122 123 def test_single_param_slice_tensor(self): 124 param = array_ops.placeholder_with_default( 125 array_ops.zeros([7, 6, 5, 4, 3]), shape=None) 126 idx = array_ops.placeholder_with_default( 127 constant_op.constant(2, dtype=dtypes.int32), shape=[]) 128 sliced = slicing._slice_single_param( 129 param, 130 param_ndims_to_matrix_ndims=2, 131 slices=make_slices[:, idx], 132 batch_shape=constant_op.constant([7, 6, 5])) 133 self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape) 134 135 def test_single_param_slice_tensor_broadcastdim(self): 136 param = array_ops.placeholder_with_default( 137 array_ops.zeros([7, 1, 5, 4, 3]), shape=None) 138 idx = array_ops.placeholder_with_default( 139 constant_op.constant(2, dtype=dtypes.int32), shape=[]) 140 sliced = slicing._slice_single_param( 141 param, 142 param_ndims_to_matrix_ndims=2, 143 slices=make_slices[:, idx], 144 batch_shape=constant_op.constant([7, 6, 5])) 145 self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape) 146 147 def test_single_param_slice_broadcast_batch(self): 148 sliced = slicing._slice_single_param( 149 array_ops.zeros([4, 3, 1]), # batch = [4, 3], event = [1] 150 param_ndims_to_matrix_ndims=1, 151 slices=make_slices[..., array_ops.newaxis, 2:, array_ops.newaxis], 152 batch_shape=constant_op.constant([7, 4, 3])) 153 self.assertAllEqual( 154 list(array_ops.zeros([1, 4, 3])[ 155 ..., array_ops.newaxis, 2:, array_ops.newaxis].shape) + [1], 156 self.evaluate(sliced).shape) 157 158 def test_single_param_slice_broadcast_batch_leading_newaxis(self): 159 sliced = slicing._slice_single_param( 160 array_ops.zeros([4, 3, 1]), # batch = [4, 3], event = [1] 161 param_ndims_to_matrix_ndims=1, 162 slices=make_slices[ 163 array_ops.newaxis, ..., array_ops.newaxis, 2:, array_ops.newaxis], 164 batch_shape=constant_op.constant([7, 4, 3])) 165 expected = array_ops.zeros( 166 [1, 4, 3])[ 167 array_ops.newaxis, ..., array_ops.newaxis, 168 2:, array_ops.newaxis].shape + [1] 169 self.assertAllEqual(expected, self.evaluate(sliced).shape) 170 171 def test_single_param_multi_ellipsis(self): 172 with self.assertRaisesRegexp(ValueError, 'Found multiple `...`'): 173 slicing._slice_single_param( 174 array_ops.zeros([7, 6, 5, 4, 3]), 175 param_ndims_to_matrix_ndims=2, 176 slices=make_slices[:, ..., 2, ...], 177 batch_shape=constant_op.constant([7, 6, 5])) 178 179 def test_single_param_too_many_slices(self): 180 with self.assertRaises( 181 (IndexError, ValueError, errors.InvalidArgumentError)): 182 slicing._slice_single_param( 183 array_ops.zeros([7, 6, 5, 4, 3]), 184 param_ndims_to_matrix_ndims=2, 185 slices=make_slices[:, :3, ..., -2:, :], 186 batch_shape=constant_op.constant([7, 6, 5])) 187 188 def test_slice_single_param_operator(self): 189 matrix = linear_operator_test_util.random_normal( 190 shape=[1, 4, 3, 2, 2], dtype=dtypes.float32) 191 operator = linalg.LinearOperatorFullMatrix(matrix, is_square=True) 192 sliced = operator[..., array_ops.newaxis, 2:, array_ops.newaxis] 193 194 self.assertAllEqual( 195 list(array_ops.zeros([1, 4, 3])[ 196 ..., array_ops.newaxis, 2:, array_ops.newaxis].shape), 197 sliced.batch_shape_tensor()) 198 199 def test_slice_nested_operator(self): 200 linop = linalg.LinearOperatorKronecker([ 201 linalg.LinearOperatorBlockDiag([ 202 linalg.LinearOperatorDiag(array_ops.ones([1, 2, 2])), 203 linalg.LinearOperatorDiag(array_ops.ones([3, 5, 2, 2]))]), 204 linalg.LinearOperatorFullMatrix( 205 linear_operator_test_util.random_normal( 206 shape=[4, 1, 1, 1, 3, 3], dtype=dtypes.float32))]) 207 self.assertAllEqual(linop[0, ...].batch_shape_tensor(), [3, 5, 2]) 208 self.assertAllEqual(linop[ 209 0, ..., array_ops.newaxis].batch_shape_tensor(), [3, 5, 2, 1]) 210 self.assertAllEqual(linop[ 211 ..., array_ops.newaxis].batch_shape_tensor(), [4, 3, 5, 2, 1]) 212 213 214if __name__ == '__main__': 215 test.main() 216