1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
22 #include "tensorflow/compiler/xla/tests/test_macros.h"
23 #include "tensorflow/compiler/xla/types.h"
24
25 namespace xla {
26 namespace {
27
28 using SlicingTest = xla::ClientLibraryTestBase;
29
BValsRight()30 xla::Array2D<float> BValsRight() {
31 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
32 }
33
BValsLeft()34 xla::Array2D<float> BValsLeft() {
35 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
36 }
37
AValsFull()38 xla::Array2D<float> AValsFull() {
39 return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
40 }
41
BatchedAValsFull()42 xla::Array3D<float> BatchedAValsFull() {
43 return {{
44 {2, 0, 1, 2},
45 {3, 6, 0, 1},
46 {4, 7, 9, 0},
47 {5, 8, 10, 11},
48 },
49 {
50 {16, 24, 8, 12},
51 {24, 61, 82, 48},
52 {8, 82, 456, 106},
53 {12, 48, 106, 62},
54 }};
55 }
56
XLA_TEST_F(SlicingTest,Simple2dLookup)57 XLA_TEST_F(SlicingTest, Simple2dLookup) {
58 xla::XlaBuilder builder(TestName());
59
60 xla::XlaOp a, x, y;
61 auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
62 auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
63 auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
64 DynamicSliceInMinorDims(a, {x, y}, {1, 1});
65
66 ComputeAndCompareR2<float>(&builder, {{10}},
67 {a_data.get(), x_data.get(), y_data.get()},
68 xla::ErrorSpec(1e-2, 1e-2));
69 }
70
XLA_TEST_F(SlicingTest,Simple3dLookup)71 XLA_TEST_F(SlicingTest, Simple3dLookup) {
72 xla::XlaBuilder builder(TestName());
73
74 xla::XlaOp a, index;
75 auto a_data =
76 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
77 auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
78
79 DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32>(&builder, 0)},
80 {1, 4});
81
82 ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
83 {a_data.get(), index_data.get()});
84 }
85
XLA_TEST_F(SlicingTest,SimpleSliceUpdate)86 XLA_TEST_F(SlicingTest, SimpleSliceUpdate) {
87 xla::XlaBuilder builder(TestName());
88
89 xla::XlaOp a, b, x, y;
90 auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
91 auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
92 auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
93 auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
94
95 DynamicUpdateSliceInMinorDims(a, b, {x, y});
96
97 xla::Array2D<float> expected(
98 {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
99
100 ComputeAndCompareR2<float>(
101 &builder, expected,
102 {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
103 }
104
XLA_TEST_F(SlicingTest,TorchGather)105 XLA_TEST_F(SlicingTest, TorchGather) {
106 xla::XlaBuilder builder(TestName());
107
108 xla::XlaOp input, index;
109 auto input_data =
110 CreateR2Parameter<int>({{1, 2}, {3, 4}}, 0, "input", &builder, &input);
111 auto index_data =
112 CreateR2Parameter<int>({{0, 0}, {1, 0}}, 1, "index", &builder, &index);
113 TorchGather(input, index, 1);
114
115 ComputeAndCompareR2<int>(&builder, {{1, 1}, {4, 3}},
116 {input_data.get(), index_data.get()});
117 }
118
XLA_TEST_F(SlicingTest,TorchIndexSelectOn0)119 XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
120 xla::XlaBuilder builder(TestName());
121
122 xla::XlaOp input, index;
123 auto input_data =
124 CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
125 {-0.4664, 0.2647, -0.1228, -1.1068},
126 {-1.1734, -0.6571, 0.7230, -0.6004}},
127 0, "input", &builder, &input);
128 auto index_data =
129 CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
130 TorchIndexSelect(input, index, 0);
131
132 ComputeAndCompareR2<float>(
133 &builder,
134 {{0.1427, 0.0231, -0.5414, -1.0009}, {-1.1734, -0.6571, 0.7230, -0.6004}},
135 {input_data.get(), index_data.get()});
136 }
137
XLA_TEST_F(SlicingTest,TorchIndexSelectOn1)138 XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
139 xla::XlaBuilder builder(TestName());
140
141 xla::XlaOp input, index;
142 auto input_data =
143 CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
144 {-0.4664, 0.2647, -0.1228, -1.1068},
145 {-1.1734, -0.6571, 0.7230, -0.6004}},
146 0, "input", &builder, &input);
147 auto index_data =
148 CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
149 TorchIndexSelect(input, index, 1);
150
151 ComputeAndCompareR2<float>(
152 &builder, {{0.1427, -0.5414}, {-0.4664, -0.1228}, {-1.1734, 0.7230}},
153 {input_data.get(), index_data.get()});
154 }
155
156 } // namespace
157 } // namespace xla
158