1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26
27 namespace xla {
28 namespace {
29
30 using SlicingTest = xla::ClientLibraryTestBase;
31
BValsRight()32 xla::Array2D<float> BValsRight() {
33 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
34 }
35
BValsLeft()36 xla::Array2D<float> BValsLeft() {
37 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
38 }
39
AValsFull()40 xla::Array2D<float> AValsFull() {
41 return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
42 }
43
BatchedAValsFull()44 xla::Array3D<float> BatchedAValsFull() {
45 return {{
46 {2, 0, 1, 2},
47 {3, 6, 0, 1},
48 {4, 7, 9, 0},
49 {5, 8, 10, 11},
50 },
51 {
52 {16, 24, 8, 12},
53 {24, 61, 82, 48},
54 {8, 82, 456, 106},
55 {12, 48, 106, 62},
56 }};
57 }
58
XLA_TEST_F(SlicingTest,Simple2dLookup)59 XLA_TEST_F(SlicingTest, Simple2dLookup) {
60 xla::XlaBuilder builder(TestName());
61
62 xla::XlaOp a, x, y;
63 auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
64 auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
65 auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
66 DynamicSliceInMinorDims(a, {x, y}, {1, 1});
67
68 ComputeAndCompareR2<float>(&builder, {{10}},
69 {a_data.get(), x_data.get(), y_data.get()},
70 xla::ErrorSpec(1e-2, 1e-2));
71 }
72
XLA_TEST_F(SlicingTest,Simple3dLookup)73 XLA_TEST_F(SlicingTest, Simple3dLookup) {
74 xla::XlaBuilder builder(TestName());
75
76 xla::XlaOp a, index;
77 auto a_data =
78 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
79 auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
80
81 DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32_t>(&builder, 0)},
82 {1, 4});
83
84 ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
85 {a_data.get(), index_data.get()});
86 }
87
XLA_TEST_F(SlicingTest,NestedLookup)88 XLA_TEST_F(SlicingTest, NestedLookup) {
89 xla::XlaBuilder builder(TestName());
90
91 xla::XlaOp a, index;
92 auto a_data =
93 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
94 auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
95
96 auto slice = DynamicSliceInMinorDims(
97 a, {index, xla::ConstantR0<int32_t>(&builder, 0)}, {1, 4});
98 DynamicSliceInMinorDims(slice, {xla::ConstantR0<int32_t>(&builder, 0), index},
99 {1, 1});
100
101 ComputeAndCompareR3<float>(&builder, {{{6}}, {{61}}},
102 {a_data.get(), index_data.get()});
103 }
104
XLA_TEST_F(SlicingTest,SimpleSliceUpdate)105 XLA_TEST_F(SlicingTest, SimpleSliceUpdate) {
106 xla::XlaBuilder builder(TestName());
107
108 xla::XlaOp a, b, x, y;
109 auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
110 auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
111 auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
112 auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
113
114 DynamicUpdateSliceInMinorDims(a, b, {x, y});
115
116 xla::Array2D<float> expected(
117 {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
118
119 ComputeAndCompareR2<float>(
120 &builder, expected,
121 {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
122 }
123
XLA_TEST_F(SlicingTest,NestedSliceUpdate)124 XLA_TEST_F(SlicingTest, NestedSliceUpdate) {
125 xla::XlaBuilder builder(TestName());
126
127 xla::XlaOp a, b, x, y;
128 auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
129 auto b_data = CreateR2Parameter<float>({{1, -10}}, 1, "b", &builder, &b);
130 auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
131 auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
132
133 auto z = xla::ConstantR0<int32_t>(&builder, 0);
134 auto slice = DynamicSliceInMinorDims(a, {x, z}, {1, 4});
135 auto inner = DynamicUpdateSliceInMinorDims(slice, b, {z, y});
136 DynamicUpdateSlice(a, inner, {x, z});
137
138 xla::Array2D<float> expected(
139 {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 1, -10, 0}, {5, 8, 10, 11}}});
140
141 ComputeAndCompareR2<float>(
142 &builder, expected,
143 {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
144 }
145
XLA_TEST_F(SlicingTest,TorchGatherSparse)146 XLA_TEST_F(SlicingTest, TorchGatherSparse) {
147 xla::XlaBuilder builder(TestName());
148
149 xla::XlaOp input, index;
150 auto input_data =
151 CreateR2Parameter<int>({{1, 2}, {3, 4}}, 0, "input", &builder, &input);
152 auto index_data =
153 CreateR2Parameter<int>({{0, 0}, {1, 0}}, 1, "index", &builder, &index);
154 TorchGather(input, index, 1);
155
156 ComputeAndCompareR2<int>(&builder, {{1, 1}, {4, 3}},
157 {input_data.get(), index_data.get()});
158 }
159
XLA_TEST_F(SlicingTest,TorchGatherDense)160 XLA_TEST_F(SlicingTest, TorchGatherDense) {
161 xla::XlaBuilder builder(TestName());
162
163 xla::XlaOp input, index;
164 auto input_data =
165 CreateR2Parameter<int>({{1, 2}, {3, 4}}, 0, "input", &builder, &input);
166 auto index_data =
167 CreateR2Parameter<int>({{0, 0}, {1, 0}}, 1, "index", &builder, &index);
168 TorchGather(input, index, 1, false);
169
170 ComputeAndCompareR2<int>(&builder, {{1, 1}, {4, 3}},
171 {input_data.get(), index_data.get()});
172 }
173
XLA_TEST_F(SlicingTest,TorchScatterDense)174 XLA_TEST_F(SlicingTest, TorchScatterDense) {
175 xla::XlaBuilder builder(TestName());
176
177 xla::XlaOp src, index, input;
178 auto input_data = CreateR2Parameter<int>({{0, 0, 0}, {0, 0, 0}}, 0, "input",
179 &builder, &input);
180 auto index_data =
181 CreateR2Parameter<int>({{1, 0}, {1, 2}}, 1, "index", &builder, &index);
182 auto src_data =
183 CreateR2Parameter<int>({{1, 2}, {3, 4}}, 2, "src", &builder, &src);
184 TorchScatterDense(input, index, src, 1,
185 [](XlaOp l, XlaOp r) { return l + r; });
186
187 ComputeAndCompareR2<int>(
188 &builder, {{2, 1, 0}, {0, 3, 4}},
189 {input_data.get(), index_data.get(), src_data.get()});
190 }
191
XLA_TEST_F(SlicingTest,TorchIndexSelectOn0)192 XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
193 xla::XlaBuilder builder(TestName());
194
195 xla::XlaOp input, index;
196 auto input_data =
197 CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
198 {-0.4664, 0.2647, -0.1228, -1.1068},
199 {-1.1734, -0.6571, 0.7230, -0.6004}},
200 0, "input", &builder, &input);
201 auto index_data =
202 CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
203 TorchIndexSelect(input, index, 0);
204
205 ComputeAndCompareR2<float>(
206 &builder,
207 {{0.1427, 0.0231, -0.5414, -1.0009}, {-1.1734, -0.6571, 0.7230, -0.6004}},
208 {input_data.get(), index_data.get()});
209 }
210
XLA_TEST_F(SlicingTest,TorchIndexSelectOn0Size1)211 XLA_TEST_F(SlicingTest, TorchIndexSelectOn0Size1) {
212 xla::XlaBuilder builder(TestName());
213
214 xla::XlaOp input, index;
215 auto input_data = CreateR2Parameter<float>(
216 {{-1.1734, -0.6571, 0.7230, -0.6004}}, 0, "input", &builder, &input);
217 auto index_data =
218 CreateR1Parameter<int>({0, 0, 0, 0, 0, 0}, 1, "index", &builder, &index);
219 TorchIndexSelect(input, index, 0);
220
221 ComputeAndCompareR2<float>(&builder,
222 {{-1.1734, -0.6571, 0.7230, -0.6004},
223 {-1.1734, -0.6571, 0.7230, -0.6004},
224 {-1.1734, -0.6571, 0.7230, -0.6004},
225 {-1.1734, -0.6571, 0.7230, -0.6004},
226 {-1.1734, -0.6571, 0.7230, -0.6004},
227 {-1.1734, -0.6571, 0.7230, -0.6004}},
228 {input_data.get(), index_data.get()});
229 }
230
XLA_TEST_F(SlicingTest,TorchIndexSelectOn1)231 XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
232 xla::XlaBuilder builder(TestName());
233
234 xla::XlaOp input, index;
235 auto input_data =
236 CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
237 {-0.4664, 0.2647, -0.1228, -1.1068},
238 {-1.1734, -0.6571, 0.7230, -0.6004}},
239 0, "input", &builder, &input);
240 auto index_data =
241 CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
242
243 TorchIndexSelect(input, index, 1);
244
245 ComputeAndCompareR2<float>(
246 &builder, {{0.1427, -0.5414}, {-0.4664, -0.1228}, {-1.1734, 0.7230}},
247 {input_data.get(), index_data.get()});
248 }
249
XLA_TEST_F(SlicingTest,EmptyIndexSelect)250 XLA_TEST_F(SlicingTest, EmptyIndexSelect) {
251 xla::XlaBuilder builder(TestName());
252
253 xla::XlaOp input, index;
254 auto input_data =
255 CreateR2Parameter<float>({{0}, {0}, {0}}, 0, "input", &builder, &input);
256 auto index_data = CreateR1Parameter<int>({}, 1, "index", &builder, &index);
257 TorchIndexSelect(input, index, 1);
258 ComputeAndCompareR2<float>(&builder, {{}, {}, {}},
259 {input_data.get(), index_data.get()});
260 }
261
XLA_TEST_F(SlicingTest,DoubleEmptyIndexSelect)262 XLA_TEST_F(SlicingTest, DoubleEmptyIndexSelect) {
263 xla::XlaBuilder builder(TestName());
264
265 xla::XlaOp input, index;
266 Literal l(ShapeUtil::MakeShape(F32, {0, 1, 2, 0}));
267 Literal i(ShapeUtil::MakeShape(S32, {0}));
268 TF_ASSERT_OK_AND_ASSIGN(
269 auto input_data,
270 CreateParameterAndTransferLiteral(0, l, "input", &builder, &input));
271 TF_ASSERT_OK_AND_ASSIGN(
272 auto index_data,
273 CreateParameterAndTransferLiteral(1, i, "index", &builder, &index));
274 TorchIndexSelect(input, index, 0);
275 ComputeAndCompareLiteral(&builder, l, {input_data.get(), index_data.get()});
276 }
277
XLA_TEST_F(SlicingTest,EmptyIndexSelectNonZero)278 XLA_TEST_F(SlicingTest, EmptyIndexSelectNonZero) {
279 xla::XlaBuilder builder(TestName());
280
281 xla::XlaOp input, index;
282 Literal l(ShapeUtil::MakeShape(F32, {0, 2}));
283 TF_ASSERT_OK_AND_ASSIGN(
284 auto input_data,
285 CreateParameterAndTransferLiteral(0, l, "input", &builder, &input));
286 auto index_data =
287 CreateR1Parameter<int>({0, 0, 0}, 1, "index", &builder, &index);
288 TorchIndexSelect(input, index, 0);
289 ComputeAndCompareR2<float>(&builder,
290 {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}},
291 {input_data.get(), index_data.get()});
292 }
293
XLA_TEST_F(SlicingTest,BatchTorchIndexSelectOn0)294 XLA_TEST_F(SlicingTest, BatchTorchIndexSelectOn0) {
295 xla::XlaBuilder builder(TestName());
296
297 xla::XlaOp input, index;
298 auto input_data =
299 CreateR3Parameter<int>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
300 {{3, 2, 1, 0}, {7, 6, 5, 4}, {11, 10, 9, 8}}},
301 0, "input", &builder, &input);
302 auto index_data =
303 CreateR2Parameter<int>({{0, 2}, {1, 2}}, 1, "index", &builder, &index);
304 TorchIndexSelect(input, index, 1, 1);
305
306 ComputeAndCompareR3<int>(
307 &builder,
308 {{{0, 1, 2, 3}, {8, 9, 10, 11}}, {{7, 6, 5, 4}, {11, 10, 9, 8}}},
309 {input_data.get(), index_data.get()});
310 }
311
312 } // namespace
313 } // namespace xla
314