• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 
27 namespace xla {
28 namespace {
29 
30 using SlicingTest = xla::ClientLibraryTestBase;
31 
BValsRight()32 xla::Array2D<float> BValsRight() {
33   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
34 }
35 
BValsLeft()36 xla::Array2D<float> BValsLeft() {
37   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
38 }
39 
AValsFull()40 xla::Array2D<float> AValsFull() {
41   return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
42 }
43 
BatchedAValsFull()44 xla::Array3D<float> BatchedAValsFull() {
45   return {{
46               {2, 0, 1, 2},
47               {3, 6, 0, 1},
48               {4, 7, 9, 0},
49               {5, 8, 10, 11},
50           },
51           {
52               {16, 24, 8, 12},
53               {24, 61, 82, 48},
54               {8, 82, 456, 106},
55               {12, 48, 106, 62},
56           }};
57 }
58 
XLA_TEST_F(SlicingTest,Simple2dLookup)59 XLA_TEST_F(SlicingTest, Simple2dLookup) {
60   xla::XlaBuilder builder(TestName());
61 
62   xla::XlaOp a, x, y;
63   auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
64   auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
65   auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
66   DynamicSliceInMinorDims(a, {x, y}, {1, 1});
67 
68   ComputeAndCompareR2<float>(&builder, {{10}},
69                              {a_data.get(), x_data.get(), y_data.get()},
70                              xla::ErrorSpec(1e-2, 1e-2));
71 }
72 
XLA_TEST_F(SlicingTest,Simple3dLookup)73 XLA_TEST_F(SlicingTest, Simple3dLookup) {
74   xla::XlaBuilder builder(TestName());
75 
76   xla::XlaOp a, index;
77   auto a_data =
78       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
79   auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
80 
81   DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32_t>(&builder, 0)},
82                           {1, 4});
83 
84   ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
85                              {a_data.get(), index_data.get()});
86 }
87 
XLA_TEST_F(SlicingTest,NestedLookup)88 XLA_TEST_F(SlicingTest, NestedLookup) {
89   xla::XlaBuilder builder(TestName());
90 
91   xla::XlaOp a, index;
92   auto a_data =
93       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
94   auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
95 
96   auto slice = DynamicSliceInMinorDims(
97       a, {index, xla::ConstantR0<int32_t>(&builder, 0)}, {1, 4});
98   DynamicSliceInMinorDims(slice, {xla::ConstantR0<int32_t>(&builder, 0), index},
99                           {1, 1});
100 
101   ComputeAndCompareR3<float>(&builder, {{{6}}, {{61}}},
102                              {a_data.get(), index_data.get()});
103 }
104 
XLA_TEST_F(SlicingTest,SimpleSliceUpdate)105 XLA_TEST_F(SlicingTest, SimpleSliceUpdate) {
106   xla::XlaBuilder builder(TestName());
107 
108   xla::XlaOp a, b, x, y;
109   auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
110   auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
111   auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
112   auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
113 
114   DynamicUpdateSliceInMinorDims(a, b, {x, y});
115 
116   xla::Array2D<float> expected(
117       {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
118 
119   ComputeAndCompareR2<float>(
120       &builder, expected,
121       {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
122 }
123 
XLA_TEST_F(SlicingTest,NestedSliceUpdate)124 XLA_TEST_F(SlicingTest, NestedSliceUpdate) {
125   xla::XlaBuilder builder(TestName());
126 
127   xla::XlaOp a, b, x, y;
128   auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
129   auto b_data = CreateR2Parameter<float>({{1, -10}}, 1, "b", &builder, &b);
130   auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
131   auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
132 
133   auto z = xla::ConstantR0<int32_t>(&builder, 0);
134   auto slice = DynamicSliceInMinorDims(a, {x, z}, {1, 4});
135   auto inner = DynamicUpdateSliceInMinorDims(slice, b, {z, y});
136   DynamicUpdateSlice(a, inner, {x, z});
137 
138   xla::Array2D<float> expected(
139       {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 1, -10, 0}, {5, 8, 10, 11}}});
140 
141   ComputeAndCompareR2<float>(
142       &builder, expected,
143       {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
144 }
145 
XLA_TEST_F(SlicingTest,TorchGatherSparse)146 XLA_TEST_F(SlicingTest, TorchGatherSparse) {
147   xla::XlaBuilder builder(TestName());
148 
149   xla::XlaOp input, index;
150   auto input_data =
151       CreateR2Parameter<int>({{1, 2}, {3, 4}}, 0, "input", &builder, &input);
152   auto index_data =
153       CreateR2Parameter<int>({{0, 0}, {1, 0}}, 1, "index", &builder, &index);
154   TorchGather(input, index, 1);
155 
156   ComputeAndCompareR2<int>(&builder, {{1, 1}, {4, 3}},
157                            {input_data.get(), index_data.get()});
158 }
159 
XLA_TEST_F(SlicingTest,TorchGatherDense)160 XLA_TEST_F(SlicingTest, TorchGatherDense) {
161   xla::XlaBuilder builder(TestName());
162 
163   xla::XlaOp input, index;
164   auto input_data =
165       CreateR2Parameter<int>({{1, 2}, {3, 4}}, 0, "input", &builder, &input);
166   auto index_data =
167       CreateR2Parameter<int>({{0, 0}, {1, 0}}, 1, "index", &builder, &index);
168   TorchGather(input, index, 1, false);
169 
170   ComputeAndCompareR2<int>(&builder, {{1, 1}, {4, 3}},
171                            {input_data.get(), index_data.get()});
172 }
173 
XLA_TEST_F(SlicingTest,TorchScatterDense)174 XLA_TEST_F(SlicingTest, TorchScatterDense) {
175   xla::XlaBuilder builder(TestName());
176 
177   xla::XlaOp src, index, input;
178   auto input_data = CreateR2Parameter<int>({{0, 0, 0}, {0, 0, 0}}, 0, "input",
179                                            &builder, &input);
180   auto index_data =
181       CreateR2Parameter<int>({{1, 0}, {1, 2}}, 1, "index", &builder, &index);
182   auto src_data =
183       CreateR2Parameter<int>({{1, 2}, {3, 4}}, 2, "src", &builder, &src);
184   TorchScatterDense(input, index, src, 1,
185                     [](XlaOp l, XlaOp r) { return l + r; });
186 
187   ComputeAndCompareR2<int>(
188       &builder, {{2, 1, 0}, {0, 3, 4}},
189       {input_data.get(), index_data.get(), src_data.get()});
190 }
191 
XLA_TEST_F(SlicingTest,TorchIndexSelectOn0)192 XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
193   xla::XlaBuilder builder(TestName());
194 
195   xla::XlaOp input, index;
196   auto input_data =
197       CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
198                                 {-0.4664, 0.2647, -0.1228, -1.1068},
199                                 {-1.1734, -0.6571, 0.7230, -0.6004}},
200                                0, "input", &builder, &input);
201   auto index_data =
202       CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
203   TorchIndexSelect(input, index, 0);
204 
205   ComputeAndCompareR2<float>(
206       &builder,
207       {{0.1427, 0.0231, -0.5414, -1.0009}, {-1.1734, -0.6571, 0.7230, -0.6004}},
208       {input_data.get(), index_data.get()});
209 }
210 
XLA_TEST_F(SlicingTest,TorchIndexSelectOn0Size1)211 XLA_TEST_F(SlicingTest, TorchIndexSelectOn0Size1) {
212   xla::XlaBuilder builder(TestName());
213 
214   xla::XlaOp input, index;
215   auto input_data = CreateR2Parameter<float>(
216       {{-1.1734, -0.6571, 0.7230, -0.6004}}, 0, "input", &builder, &input);
217   auto index_data =
218       CreateR1Parameter<int>({0, 0, 0, 0, 0, 0}, 1, "index", &builder, &index);
219   TorchIndexSelect(input, index, 0);
220 
221   ComputeAndCompareR2<float>(&builder,
222                              {{-1.1734, -0.6571, 0.7230, -0.6004},
223                               {-1.1734, -0.6571, 0.7230, -0.6004},
224                               {-1.1734, -0.6571, 0.7230, -0.6004},
225                               {-1.1734, -0.6571, 0.7230, -0.6004},
226                               {-1.1734, -0.6571, 0.7230, -0.6004},
227                               {-1.1734, -0.6571, 0.7230, -0.6004}},
228                              {input_data.get(), index_data.get()});
229 }
230 
XLA_TEST_F(SlicingTest,TorchIndexSelectOn1)231 XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
232   xla::XlaBuilder builder(TestName());
233 
234   xla::XlaOp input, index;
235   auto input_data =
236       CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
237                                 {-0.4664, 0.2647, -0.1228, -1.1068},
238                                 {-1.1734, -0.6571, 0.7230, -0.6004}},
239                                0, "input", &builder, &input);
240   auto index_data =
241       CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
242 
243   TorchIndexSelect(input, index, 1);
244 
245   ComputeAndCompareR2<float>(
246       &builder, {{0.1427, -0.5414}, {-0.4664, -0.1228}, {-1.1734, 0.7230}},
247       {input_data.get(), index_data.get()});
248 }
249 
XLA_TEST_F(SlicingTest,EmptyIndexSelect)250 XLA_TEST_F(SlicingTest, EmptyIndexSelect) {
251   xla::XlaBuilder builder(TestName());
252 
253   xla::XlaOp input, index;
254   auto input_data =
255       CreateR2Parameter<float>({{0}, {0}, {0}}, 0, "input", &builder, &input);
256   auto index_data = CreateR1Parameter<int>({}, 1, "index", &builder, &index);
257   TorchIndexSelect(input, index, 1);
258   ComputeAndCompareR2<float>(&builder, {{}, {}, {}},
259                              {input_data.get(), index_data.get()});
260 }
261 
XLA_TEST_F(SlicingTest,DoubleEmptyIndexSelect)262 XLA_TEST_F(SlicingTest, DoubleEmptyIndexSelect) {
263   xla::XlaBuilder builder(TestName());
264 
265   xla::XlaOp input, index;
266   Literal l(ShapeUtil::MakeShape(F32, {0, 1, 2, 0}));
267   Literal i(ShapeUtil::MakeShape(S32, {0}));
268   TF_ASSERT_OK_AND_ASSIGN(
269       auto input_data,
270       CreateParameterAndTransferLiteral(0, l, "input", &builder, &input));
271   TF_ASSERT_OK_AND_ASSIGN(
272       auto index_data,
273       CreateParameterAndTransferLiteral(1, i, "index", &builder, &index));
274   TorchIndexSelect(input, index, 0);
275   ComputeAndCompareLiteral(&builder, l, {input_data.get(), index_data.get()});
276 }
277 
XLA_TEST_F(SlicingTest,EmptyIndexSelectNonZero)278 XLA_TEST_F(SlicingTest, EmptyIndexSelectNonZero) {
279   xla::XlaBuilder builder(TestName());
280 
281   xla::XlaOp input, index;
282   Literal l(ShapeUtil::MakeShape(F32, {0, 2}));
283   TF_ASSERT_OK_AND_ASSIGN(
284       auto input_data,
285       CreateParameterAndTransferLiteral(0, l, "input", &builder, &input));
286   auto index_data =
287       CreateR1Parameter<int>({0, 0, 0}, 1, "index", &builder, &index);
288   TorchIndexSelect(input, index, 0);
289   ComputeAndCompareR2<float>(&builder,
290                              {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}},
291                              {input_data.get(), index_data.get()});
292 }
293 
XLA_TEST_F(SlicingTest,BatchTorchIndexSelectOn0)294 XLA_TEST_F(SlicingTest, BatchTorchIndexSelectOn0) {
295   xla::XlaBuilder builder(TestName());
296 
297   xla::XlaOp input, index;
298   auto input_data =
299       CreateR3Parameter<int>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
300                               {{3, 2, 1, 0}, {7, 6, 5, 4}, {11, 10, 9, 8}}},
301                              0, "input", &builder, &input);
302   auto index_data =
303       CreateR2Parameter<int>({{0, 2}, {1, 2}}, 1, "index", &builder, &index);
304   TorchIndexSelect(input, index, 1, 1);
305 
306   ComputeAndCompareR3<int>(
307       &builder,
308       {{{0, 1, 2, 3}, {8, 9, 10, 11}}, {{7, 6, 5, 4}, {11, 10, 9, 8}}},
309       {input_data.get(), index_data.get()});
310 }
311 
312 }  // namespace
313 }  // namespace xla
314