• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.framework import constant_op
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import errors
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops.linalg import linalg as linalg_lib
22from tensorflow.python.ops.linalg import linear_operator_test_util
23from tensorflow.python.ops.linalg import slicing
24from tensorflow.python.platform import test
25
26linalg = linalg_lib
27
28
29class _MakeSlices(object):
30
31  def __getitem__(self, slices):
32    return slices if isinstance(slices, tuple) else (slices,)
33
34
35make_slices = _MakeSlices()
36
37
38@test_util.run_all_in_graph_and_eager_modes
39class SlicingTest(test.TestCase):
40  """Tests for slicing LinearOperators."""
41
42  def test_single_param_slice_withstep_broadcastdim(self):
43    event_dim = 3
44    sliced = slicing._slice_single_param(
45        array_ops.zeros([1, 1, event_dim]),
46        param_ndims_to_matrix_ndims=1,
47        slices=make_slices[44:-52:-3, -94::],
48        batch_shape=constant_op.constant([2, 7], dtype=dtypes.int32))
49    self.assertAllEqual((1, 1, event_dim), self.evaluate(sliced).shape)
50
51  def test_single_param_slice_stop_leadingdim(self):
52    sliced = slicing._slice_single_param(
53        array_ops.zeros([7, 6, 5, 4, 3]),
54        param_ndims_to_matrix_ndims=2,
55        slices=make_slices[:2],
56        batch_shape=constant_op.constant([7, 6, 5], dtype=dtypes.int32))
57    self.assertAllEqual((2, 6, 5, 4, 3), self.evaluate(sliced).shape)
58
59  def test_single_param_slice_stop_trailingdim(self):
60    sliced = slicing._slice_single_param(
61        array_ops.zeros([7, 6, 5, 4, 3]),
62        param_ndims_to_matrix_ndims=2,
63        slices=make_slices[..., :2],
64        batch_shape=constant_op.constant([7, 6, 5]))
65    self.assertAllEqual((7, 6, 2, 4, 3), self.evaluate(sliced).shape)
66
67  def test_single_param_slice_stop_broadcastdim(self):
68    sliced = slicing._slice_single_param(
69        array_ops.zeros([7, 1, 5, 4, 3]),
70        param_ndims_to_matrix_ndims=2,
71        slices=make_slices[:, :2],
72        batch_shape=constant_op.constant([7, 6, 5]))
73    self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape)
74
75  def test_single_param_slice_newaxis_leading(self):
76    sliced = slicing._slice_single_param(
77        array_ops.zeros([7, 6, 5, 4, 3]),
78        param_ndims_to_matrix_ndims=2,
79        slices=make_slices[:, array_ops.newaxis],
80        batch_shape=constant_op.constant([7, 6, 5]))
81    self.assertAllEqual((7, 1, 6, 5, 4, 3), self.evaluate(sliced).shape)
82
83  def test_single_param_slice_newaxis_trailing(self):
84    sliced = slicing._slice_single_param(
85        array_ops.zeros([7, 6, 5, 4, 3]),
86        param_ndims_to_matrix_ndims=2,
87        slices=make_slices[..., array_ops.newaxis, :],
88        batch_shape=constant_op.constant([7, 6, 5]))
89    self.assertAllEqual((7, 6, 1, 5, 4, 3), self.evaluate(sliced).shape)
90
91  def test_single_param_slice_start(self):
92    sliced = slicing._slice_single_param(
93        array_ops.zeros([7, 6, 5, 4, 3]),
94        param_ndims_to_matrix_ndims=2,
95        slices=make_slices[:, 2:],
96        batch_shape=constant_op.constant([7, 6, 5]))
97    self.assertAllEqual((7, 4, 5, 4, 3), self.evaluate(sliced).shape)
98
99  def test_single_param_slice_start_broadcastdim(self):
100    sliced = slicing._slice_single_param(
101        array_ops.zeros([7, 1, 5, 4, 3]),
102        param_ndims_to_matrix_ndims=2,
103        slices=make_slices[:, 2:],
104        batch_shape=constant_op.constant([7, 6, 5]))
105    self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape)
106
107  def test_single_param_slice_int(self):
108    sliced = slicing._slice_single_param(
109        array_ops.zeros([7, 6, 5, 4, 3]),
110        param_ndims_to_matrix_ndims=2,
111        slices=make_slices[:, 2],
112        batch_shape=constant_op.constant([7, 6, 5]))
113    self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)
114
115  def test_single_param_slice_int_broadcastdim(self):
116    sliced = slicing._slice_single_param(
117        array_ops.zeros([7, 1, 5, 4, 3]),
118        param_ndims_to_matrix_ndims=2,
119        slices=make_slices[:, 2],
120        batch_shape=constant_op.constant([7, 6, 5]))
121    self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)
122
123  def test_single_param_slice_tensor(self):
124    param = array_ops.placeholder_with_default(
125        array_ops.zeros([7, 6, 5, 4, 3]), shape=None)
126    idx = array_ops.placeholder_with_default(
127        constant_op.constant(2, dtype=dtypes.int32), shape=[])
128    sliced = slicing._slice_single_param(
129        param,
130        param_ndims_to_matrix_ndims=2,
131        slices=make_slices[:, idx],
132        batch_shape=constant_op.constant([7, 6, 5]))
133    self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)
134
135  def test_single_param_slice_tensor_broadcastdim(self):
136    param = array_ops.placeholder_with_default(
137        array_ops.zeros([7, 1, 5, 4, 3]), shape=None)
138    idx = array_ops.placeholder_with_default(
139        constant_op.constant(2, dtype=dtypes.int32), shape=[])
140    sliced = slicing._slice_single_param(
141        param,
142        param_ndims_to_matrix_ndims=2,
143        slices=make_slices[:, idx],
144        batch_shape=constant_op.constant([7, 6, 5]))
145    self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)
146
147  def test_single_param_slice_broadcast_batch(self):
148    sliced = slicing._slice_single_param(
149        array_ops.zeros([4, 3, 1]),  # batch = [4, 3], event = [1]
150        param_ndims_to_matrix_ndims=1,
151        slices=make_slices[..., array_ops.newaxis, 2:, array_ops.newaxis],
152        batch_shape=constant_op.constant([7, 4, 3]))
153    self.assertAllEqual(
154        list(array_ops.zeros([1, 4, 3])[
155            ..., array_ops.newaxis, 2:, array_ops.newaxis].shape) + [1],
156        self.evaluate(sliced).shape)
157
158  def test_single_param_slice_broadcast_batch_leading_newaxis(self):
159    sliced = slicing._slice_single_param(
160        array_ops.zeros([4, 3, 1]),  # batch = [4, 3], event = [1]
161        param_ndims_to_matrix_ndims=1,
162        slices=make_slices[
163            array_ops.newaxis, ..., array_ops.newaxis, 2:, array_ops.newaxis],
164        batch_shape=constant_op.constant([7, 4, 3]))
165    expected = array_ops.zeros(
166        [1, 4, 3])[
167            array_ops.newaxis, ..., array_ops.newaxis,
168            2:, array_ops.newaxis].shape + [1]
169    self.assertAllEqual(expected, self.evaluate(sliced).shape)
170
171  def test_single_param_multi_ellipsis(self):
172    with self.assertRaisesRegexp(ValueError, 'Found multiple `...`'):
173      slicing._slice_single_param(
174          array_ops.zeros([7, 6, 5, 4, 3]),
175          param_ndims_to_matrix_ndims=2,
176          slices=make_slices[:, ..., 2, ...],
177          batch_shape=constant_op.constant([7, 6, 5]))
178
179  def test_single_param_too_many_slices(self):
180    with self.assertRaises(
181        (IndexError, ValueError, errors.InvalidArgumentError)):
182      slicing._slice_single_param(
183          array_ops.zeros([7, 6, 5, 4, 3]),
184          param_ndims_to_matrix_ndims=2,
185          slices=make_slices[:, :3, ..., -2:, :],
186          batch_shape=constant_op.constant([7, 6, 5]))
187
188  def test_slice_single_param_operator(self):
189    matrix = linear_operator_test_util.random_normal(
190        shape=[1, 4, 3, 2, 2], dtype=dtypes.float32)
191    operator = linalg.LinearOperatorFullMatrix(matrix, is_square=True)
192    sliced = operator[..., array_ops.newaxis, 2:, array_ops.newaxis]
193
194    self.assertAllEqual(
195        list(array_ops.zeros([1, 4, 3])[
196            ..., array_ops.newaxis, 2:, array_ops.newaxis].shape),
197        sliced.batch_shape_tensor())
198
199  def test_slice_nested_operator(self):
200    linop = linalg.LinearOperatorKronecker([
201        linalg.LinearOperatorBlockDiag([
202            linalg.LinearOperatorDiag(array_ops.ones([1, 2, 2])),
203            linalg.LinearOperatorDiag(array_ops.ones([3, 5, 2, 2]))]),
204        linalg.LinearOperatorFullMatrix(
205            linear_operator_test_util.random_normal(
206                shape=[4, 1, 1, 1, 3, 3], dtype=dtypes.float32))])
207    self.assertAllEqual(linop[0, ...].batch_shape_tensor(), [3, 5, 2])
208    self.assertAllEqual(linop[
209        0, ..., array_ops.newaxis].batch_shape_tensor(), [3, 5, 2, 1])
210    self.assertAllEqual(linop[
211        ..., array_ops.newaxis].batch_shape_tensor(), [4, 3, 5, 2, 1])
212
213
214if __name__ == '__main__':
215  test.main()
216