1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17
18 #include <limits>
19 #include <map>
20 #include <string>
21 #include <vector>
22
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33
34 namespace xla {
35 namespace {
36
37 class MatrixTest : public ClientLibraryTestBase {
38 protected:
39 template <typename T>
40 void TestMatrixDiagonal();
41 template <typename T>
42 void TestMatrixDiagonal4D();
43 template <typename T>
44 void TestSetMatrixDiagonal();
45
46 template <typename T>
k_and_expected() const47 std::map<int, Array2D<T>> k_and_expected() const {
48 return std::map<int, Array2D<T>>{
49 {0, {{0, 5, 10}, {12, 17, 22}}},
50 {1, {{1, 6, 11}, {13, 18, 23}}},
51 {2, {{2, 7}, {14, 19}}},
52 {3, {{3}, {15}}},
53 {4, {{}, {}}},
54 {-1, {{4, 9}, {16, 21}}},
55 {-2, {{8}, {20}}},
56 {-3, {{}, {}}},
57 {-4, {{}, {}}},
58 };
59 }
60 };
61
XLA_TEST_F(MatrixTest,Triangle)62 XLA_TEST_F(MatrixTest, Triangle) {
63 XlaBuilder builder(TestName());
64 Array3D<int32> input(2, 3, 4);
65 input.FillIota(0);
66
67 XlaOp a;
68 auto a_data = CreateR3Parameter<int32>(input, 0, "a", &builder, &a);
69 LowerTriangle(a);
70 Array3D<int32> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
71 {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
72
73 ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
74 }
75
76 template <typename T>
TestMatrixDiagonal()77 void MatrixTest::TestMatrixDiagonal() {
78 XlaBuilder builder("SetMatrixDiagonal");
79 Array3D<T> input(2, 3, 4);
80 input.FillIota(0);
81 for (const auto& kv : k_and_expected<T>()) {
82 XlaOp a;
83 auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
84 GetMatrixDiagonal(a, kv.first);
85
86 ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
87 }
88 }
89
90 template <typename T>
TestSetMatrixDiagonal()91 void MatrixTest::TestSetMatrixDiagonal() {
92 XlaBuilder builder("GetMatrixDiagonal");
93 Array3D<T> input(2, 3, 4);
94 input.FillIota(0);
95 for (const auto& kv : k_and_expected<T>()) {
96 XlaOp a;
97 XlaOp b;
98 auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
99 auto new_diag =
100 CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
101
102 GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
103 kv.first) -
104 ScalarLike(b, 1);
105
106 ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
107 }
108 }
109
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S32)110 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
111 TestSetMatrixDiagonal<int32>();
112 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S64)113 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
114 TestSetMatrixDiagonal<int64>();
115 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_F32)116 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
117 TestSetMatrixDiagonal<float>();
118 }
119
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)120 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
121
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)122 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
123
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)124 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
125
126 template <typename T>
TestMatrixDiagonal4D()127 void MatrixTest::TestMatrixDiagonal4D() {
128 XlaBuilder builder("GetMatrixDiagonal");
129 Array4D<T> input(2, 2, 4, 3);
130 input.FillIota(0);
131 std::map<int, Array3D<T>> k_and_expected = {
132 {0, {{{0, 4, 8}, {12, 16, 20}}, {{24, 28, 32}, {36, 40, 44}}}},
133 {1, {{{1, 5}, {13, 17}}, {{25, 29}, {37, 41}}}},
134 {2, {{{2}, {14}}, {{26}, {38}}}},
135 {3, {{{}, {}}, {{}, {}}}},
136 {4, {{{}, {}}, {{}, {}}}},
137 {-1, {{{3, 7, 11}, {15, 19, 23}}, {{27, 31, 35}, {39, 43, 47}}}},
138 {-2, {{{6, 10}, {18, 22}}, {{30, 34}, {42, 46}}}},
139 {-3, {{{9}, {21}}, {{33}, {45}}}},
140 {-4, {{{}, {}}, {{}, {}}}},
141 };
142 for (const auto& kv : k_and_expected) {
143 XlaOp a;
144 auto a_data = CreateR4Parameter<T>(input, 0, "a", &builder, &a);
145 GetMatrixDiagonal(a, kv.first);
146
147 ComputeAndCompareR3<T>(&builder, kv.second, {a_data.get()});
148 }
149 }
150
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S32)151 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) {
152 TestMatrixDiagonal4D<int32>();
153 }
154
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S64)155 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) {
156 TestMatrixDiagonal4D<int64>();
157 }
158
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_F32)159 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) {
160 TestMatrixDiagonal4D<float>();
161 }
162
BatchedAValsFull()163 Array3D<float> BatchedAValsFull() {
164 return {{
165 {2, 0, 1, 2},
166 {3, 6, 0, 1},
167 {4, 7, 9, 0},
168 {5, 8, 10, 11},
169 },
170 {
171 {16, 24, 8, 12},
172 {24, 61, 82, 48},
173 {8, 82, 456, 106},
174 {12, 48, 106, 62},
175 }};
176 }
177
XLA_TEST_F(MatrixTest,RowBatchDot)178 XLA_TEST_F(MatrixTest, RowBatchDot) {
179 XlaBuilder builder(TestName());
180 int n = 4;
181
182 XlaOp a, row, index;
183 auto a_data =
184 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
185 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
186 "row", &builder, &row);
187 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
188 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
189
190 auto l_index = DynamicSliceInMinorDims(
191 a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
192 BatchDot(l_index, TransposeInMinorDims(row));
193
194 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
195 {a_data.get(), row_data.get(), index_data.get()});
196 }
197
XLA_TEST_F(MatrixTest,Einsum)198 XLA_TEST_F(MatrixTest, Einsum) {
199 XlaBuilder builder(TestName());
200
201 int n = 4;
202
203 XlaOp a, row, index;
204 auto a_data =
205 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
206 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
207 "row", &builder, &row);
208 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
209 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
210
211 auto l_index = DynamicSliceInMinorDims(
212 a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
213 Einsum(l_index, row, "abc,adc->abd");
214
215 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
216 {a_data.get(), row_data.get(), index_data.get()});
217 }
218
XLA_TEST_F(MatrixTest,ParseEinsumString)219 XLA_TEST_F(MatrixTest, ParseEinsumString) {
220 auto to_vec = [](absl::string_view s) {
221 std::vector<int64> v;
222 v.reserve(s.size());
223 int e = -3;
224 for (auto c : s) {
225 v.push_back(c == '.' ? e++ : int64{c});
226 }
227 return v;
228 };
229
230 auto to_string = [&](absl::string_view x, absl::string_view y,
231 absl::string_view o) {
232 return absl::StrCat(x, ",", y, "->", o);
233 };
234
235 std::vector<std::vector<string>> good_test_cases = {
236 {"ab", "bc", "ac"},
237 {"Bab", "Bbc", "Bac"},
238 {"ab", "cd", "dcba"},
239 {"abc", "abd", "cbd"},
240 {"...ab", "...bc", "...ac"},
241 {"a...bc", "...abd", "cbd..."},
242 {"...ab", "...bc", "ac"},
243 {"...b", "...bc", "...c"},
244 {"...abz", "...bc", "...ac"},
245 {"...ab", "...bcz", "...ac"},
246 {"abz", "bc", "ac"},
247 {"ab", "bcz", "ac"},
248
249 {"a", "b", "c"},
250 {"...a", "...b", "...c"},
251 {"abb", "bcc", "ac"},
252 {"ab", "bc", "ad"},
253 };
254 for (auto test_case : good_test_cases) {
255 auto parse_result_or_status =
256 ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
257 test_case[0].size(), test_case[1].size());
258 EXPECT_TRUE(parse_result_or_status.status().ok());
259 auto parse_result = parse_result_or_status.ValueOrDie();
260 for (int i = 0; i < 3; ++i) {
261 EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
262 }
263 }
264
265 std::vector<string> einsum_strings_that_fail_parsing = {
266 "", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
267 };
268 for (auto test_case : einsum_strings_that_fail_parsing) {
269 auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
270 EXPECT_FALSE(parse_result_or_status.status().ok());
271 }
272 }
273
XLA_TEST_F(MatrixTest,NormalizeEinsumString)274 XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
275 EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
276 EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
277 EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
278 EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
279 EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
280 }
281
282 } // namespace
283 } // namespace xla
284