1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17
18 #include <limits>
19 #include <map>
20 #include <string>
21 #include <vector>
22
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33
34 namespace xla {
35 namespace {
36
37 class MatrixTest : public ClientLibraryTestBase {
38 protected:
39 template <typename T>
40 void TestMatrixDiagonal();
41 template <typename T>
42 void TestMatrixDiagonal4D();
43 template <typename T>
44 void TestSetMatrixDiagonal();
45
46 template <typename T>
k_and_expected() const47 std::map<int, Array2D<T>> k_and_expected() const {
48 return std::map<int, Array2D<T>>{
49 {0, {{0, 5, 10}, {12, 17, 22}}},
50 {1, {{1, 6, 11}, {13, 18, 23}}},
51 {2, {{2, 7}, {14, 19}}},
52 {3, {{3}, {15}}},
53 {4, {{}, {}}},
54 {-1, {{4, 9}, {16, 21}}},
55 {-2, {{8}, {20}}},
56 {-3, {{}, {}}},
57 {-4, {{}, {}}},
58 };
59 }
60 };
61
XLA_TEST_F(MatrixTest,Triangle)62 XLA_TEST_F(MatrixTest, Triangle) {
63 XlaBuilder builder(TestName());
64 Array3D<int32> input(2, 3, 4);
65 input.FillIota(0);
66
67 XlaOp a;
68 auto a_data = CreateR3Parameter<int32>(input, 0, "a", &builder, &a);
69 LowerTriangle(a);
70 Array3D<int32> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
71 {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
72
73 ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
74 }
75
XLA_TEST_F(MatrixTest,Symmetrize)76 XLA_TEST_F(MatrixTest, Symmetrize) {
77 for (bool lower : {false, true}) {
78 XlaBuilder builder(TestName());
79 float nan = std::numeric_limits<float>::quiet_NaN();
80 Array<float> input = {
81 {1, nan, nan},
82 {2, 3, nan},
83 {4, 5, 6},
84 };
85
86 XlaOp a;
87 auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
88 Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
89
90 Array<float> expected = {
91 {1, 2, 4},
92 {2, 3, 5},
93 {4, 5, 6},
94 };
95
96 ComputeAndCompare<float>(&builder, expected, {a_data.get()});
97 }
98 }
99
XLA_TEST_F(MatrixTest,SymmetrizeComplex)100 XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
101 for (bool lower : {false, true}) {
102 XlaBuilder builder(TestName());
103 float nan = std::numeric_limits<float>::quiet_NaN();
104 Array<complex64> input = {
105 {complex64{1, nan}, nan, nan},
106 {complex64{2, 7}, complex64{3, nan}, nan},
107 {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
108 };
109
110 XlaOp a;
111 auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
112 Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
113
114 Array<complex64> expected = {
115 {1, complex64{2, -7}, complex64{4, -8}},
116 {complex64{2, 7}, 3, complex64{5, -9}},
117 {complex64{4, 8}, complex64{5, 9}, 6},
118 };
119
120 ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
121 }
122 }
123
124 template <typename T>
TestMatrixDiagonal()125 void MatrixTest::TestMatrixDiagonal() {
126 XlaBuilder builder("SetMatrixDiagonal");
127 Array3D<T> input(2, 3, 4);
128 input.FillIota(0);
129 for (const auto& kv : k_and_expected<T>()) {
130 XlaOp a;
131 auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
132 GetMatrixDiagonal(a, kv.first);
133
134 ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
135 }
136 }
137
138 template <typename T>
TestSetMatrixDiagonal()139 void MatrixTest::TestSetMatrixDiagonal() {
140 XlaBuilder builder("GetMatrixDiagonal");
141 Array3D<T> input(2, 3, 4);
142 input.FillIota(0);
143 for (const auto& kv : k_and_expected<T>()) {
144 XlaOp a;
145 XlaOp b;
146 auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
147 auto new_diag =
148 CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
149
150 GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
151 kv.first) -
152 ScalarLike(b, 1);
153
154 ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
155 }
156 }
157
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S32)158 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
159 TestSetMatrixDiagonal<int32>();
160 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S64)161 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
162 TestSetMatrixDiagonal<int64>();
163 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_F32)164 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
165 TestSetMatrixDiagonal<float>();
166 }
167
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)168 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
169
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)170 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
171
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)172 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
173
174 template <typename T>
TestMatrixDiagonal4D()175 void MatrixTest::TestMatrixDiagonal4D() {
176 XlaBuilder builder("GetMatrixDiagonal");
177 Array4D<T> input(2, 2, 4, 3);
178 input.FillIota(0);
179 std::map<int, Array3D<T>> k_and_expected = {
180 {0, {{{0, 4, 8}, {12, 16, 20}}, {{24, 28, 32}, {36, 40, 44}}}},
181 {1, {{{1, 5}, {13, 17}}, {{25, 29}, {37, 41}}}},
182 {2, {{{2}, {14}}, {{26}, {38}}}},
183 {3, {{{}, {}}, {{}, {}}}},
184 {4, {{{}, {}}, {{}, {}}}},
185 {-1, {{{3, 7, 11}, {15, 19, 23}}, {{27, 31, 35}, {39, 43, 47}}}},
186 {-2, {{{6, 10}, {18, 22}}, {{30, 34}, {42, 46}}}},
187 {-3, {{{9}, {21}}, {{33}, {45}}}},
188 {-4, {{{}, {}}, {{}, {}}}},
189 };
190 for (const auto& kv : k_and_expected) {
191 XlaOp a;
192 auto a_data = CreateR4Parameter<T>(input, 0, "a", &builder, &a);
193 GetMatrixDiagonal(a, kv.first);
194
195 ComputeAndCompareR3<T>(&builder, kv.second, {a_data.get()});
196 }
197 }
198
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S32)199 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) {
200 TestMatrixDiagonal4D<int32>();
201 }
202
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S64)203 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) {
204 TestMatrixDiagonal4D<int64>();
205 }
206
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_F32)207 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) {
208 TestMatrixDiagonal4D<float>();
209 }
210
BatchedAValsFull()211 Array3D<float> BatchedAValsFull() {
212 return {{
213 {2, 0, 1, 2},
214 {3, 6, 0, 1},
215 {4, 7, 9, 0},
216 {5, 8, 10, 11},
217 },
218 {
219 {16, 24, 8, 12},
220 {24, 61, 82, 48},
221 {8, 82, 456, 106},
222 {12, 48, 106, 62},
223 }};
224 }
225
XLA_TEST_F(MatrixTest,RowBatchDot)226 XLA_TEST_F(MatrixTest, RowBatchDot) {
227 XlaBuilder builder(TestName());
228 int n = 4;
229
230 XlaOp a, row, index;
231 auto a_data =
232 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
233 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
234 "row", &builder, &row);
235 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
236 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
237
238 auto l_index = DynamicSliceInMinorDims(
239 a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
240 BatchDot(l_index, TransposeInMinorDims(row));
241
242 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
243 {a_data.get(), row_data.get(), index_data.get()});
244 }
245
XLA_TEST_F(MatrixTest,Einsum)246 XLA_TEST_F(MatrixTest, Einsum) {
247 XlaBuilder builder(TestName());
248
249 int n = 4;
250
251 XlaOp a, row, index;
252 auto a_data =
253 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
254 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
255 "row", &builder, &row);
256 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
257 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
258
259 auto l_index = DynamicSliceInMinorDims(
260 a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
261 Einsum(l_index, row, "abc,adc->abd");
262
263 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
264 {a_data.get(), row_data.get(), index_data.get()});
265 }
266
XLA_TEST_F(MatrixTest,ParseEinsumString)267 XLA_TEST_F(MatrixTest, ParseEinsumString) {
268 auto to_vec = [](absl::string_view s) {
269 std::vector<int64> v;
270 v.reserve(s.size());
271 int e = -3;
272 for (auto c : s) {
273 v.push_back(c == '.' ? e++ : int64{c});
274 }
275 return v;
276 };
277
278 auto to_string = [&](absl::string_view x, absl::string_view y,
279 absl::string_view o) {
280 return absl::StrCat(x, ",", y, "->", o);
281 };
282
283 std::vector<std::vector<string>> good_test_cases = {
284 {"ab", "bc", "ac"},
285 {"Bab", "Bbc", "Bac"},
286 {"ab", "cd", "dcba"},
287 {"abc", "abd", "cbd"},
288 {"...ab", "...bc", "...ac"},
289 {"a...bc", "...abd", "cbd..."},
290 {"...ab", "...bc", "ac"},
291 {"...b", "...bc", "...c"},
292 {"...abz", "...bc", "...ac"},
293 {"...ab", "...bcz", "...ac"},
294 {"abz", "bc", "ac"},
295 {"ab", "bcz", "ac"},
296
297 {"a", "b", "c"},
298 {"...a", "...b", "...c"},
299 {"abb", "bcc", "ac"},
300 {"ab", "bc", "ad"},
301 };
302 for (auto test_case : good_test_cases) {
303 auto parse_result_or_status =
304 ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
305 test_case[0].size(), test_case[1].size());
306 EXPECT_TRUE(parse_result_or_status.status().ok());
307 auto parse_result = parse_result_or_status.ValueOrDie();
308 for (int i = 0; i < 3; ++i) {
309 EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
310 }
311 }
312
313 std::vector<string> einsum_strings_that_fail_parsing = {
314 "", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
315 };
316 for (auto test_case : einsum_strings_that_fail_parsing) {
317 auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
318 EXPECT_FALSE(parse_result_or_status.status().ok());
319 }
320 }
321
XLA_TEST_F(MatrixTest,NormalizeEinsumString)322 XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
323 EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
324 EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
325 EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
326 EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
327 EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
328 }
329
330 } // namespace
331 } // namespace xla
332