• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17 
18 #include <limits>
19 #include <map>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 
34 namespace xla {
35 namespace {
36 
37 class MatrixTest : public ClientLibraryTestBase {
38  protected:
39   template <typename T>
40   void TestMatrixDiagonal();
41   template <typename T>
42   void TestMatrixDiagonal4D();
43   template <typename T>
44   void TestSetMatrixDiagonal();
45 
46   template <typename T>
k_and_expected() const47   std::map<int, Array2D<T>> k_and_expected() const {
48     return std::map<int, Array2D<T>>{
49         {0, {{0, 5, 10}, {12, 17, 22}}},
50         {1, {{1, 6, 11}, {13, 18, 23}}},
51         {2, {{2, 7}, {14, 19}}},
52         {3, {{3}, {15}}},
53         {4, {{}, {}}},
54         {-1, {{4, 9}, {16, 21}}},
55         {-2, {{8}, {20}}},
56         {-3, {{}, {}}},
57         {-4, {{}, {}}},
58     };
59   }
60 };
61 
XLA_TEST_F(MatrixTest,Triangle)62 XLA_TEST_F(MatrixTest, Triangle) {
63   XlaBuilder builder(TestName());
64   Array3D<int32> input(2, 3, 4);
65   input.FillIota(0);
66 
67   XlaOp a;
68   auto a_data = CreateR3Parameter<int32>(input, 0, "a", &builder, &a);
69   LowerTriangle(a);
70   Array3D<int32> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
71                            {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
72 
73   ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
74 }
75 
XLA_TEST_F(MatrixTest,Symmetrize)76 XLA_TEST_F(MatrixTest, Symmetrize) {
77   for (bool lower : {false, true}) {
78     XlaBuilder builder(TestName());
79     float nan = std::numeric_limits<float>::quiet_NaN();
80     Array<float> input = {
81         {1, nan, nan},
82         {2, 3, nan},
83         {4, 5, 6},
84     };
85 
86     XlaOp a;
87     auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
88     Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
89 
90     Array<float> expected = {
91         {1, 2, 4},
92         {2, 3, 5},
93         {4, 5, 6},
94     };
95 
96     ComputeAndCompare<float>(&builder, expected, {a_data.get()});
97   }
98 }
99 
XLA_TEST_F(MatrixTest,SymmetrizeComplex)100 XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
101   for (bool lower : {false, true}) {
102     XlaBuilder builder(TestName());
103     float nan = std::numeric_limits<float>::quiet_NaN();
104     Array<complex64> input = {
105         {complex64{1, nan}, nan, nan},
106         {complex64{2, 7}, complex64{3, nan}, nan},
107         {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
108     };
109 
110     XlaOp a;
111     auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
112     Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
113 
114     Array<complex64> expected = {
115         {1, complex64{2, -7}, complex64{4, -8}},
116         {complex64{2, 7}, 3, complex64{5, -9}},
117         {complex64{4, 8}, complex64{5, 9}, 6},
118     };
119 
120     ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
121   }
122 }
123 
124 template <typename T>
TestMatrixDiagonal()125 void MatrixTest::TestMatrixDiagonal() {
126   XlaBuilder builder("SetMatrixDiagonal");
127   Array3D<T> input(2, 3, 4);
128   input.FillIota(0);
129   for (const auto& kv : k_and_expected<T>()) {
130     XlaOp a;
131     auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
132     GetMatrixDiagonal(a, kv.first);
133 
134     ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
135   }
136 }
137 
138 template <typename T>
TestSetMatrixDiagonal()139 void MatrixTest::TestSetMatrixDiagonal() {
140   XlaBuilder builder("GetMatrixDiagonal");
141   Array3D<T> input(2, 3, 4);
142   input.FillIota(0);
143   for (const auto& kv : k_and_expected<T>()) {
144     XlaOp a;
145     XlaOp b;
146     auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
147     auto new_diag =
148         CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
149 
150     GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
151                       kv.first) -
152         ScalarLike(b, 1);
153 
154     ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
155   }
156 }
157 
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S32)158 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
159   TestSetMatrixDiagonal<int32>();
160 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S64)161 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
162   TestSetMatrixDiagonal<int64>();
163 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_F32)164 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
165   TestSetMatrixDiagonal<float>();
166 }
167 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)168 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
169 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)170 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
171 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)172 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
173 
174 template <typename T>
TestMatrixDiagonal4D()175 void MatrixTest::TestMatrixDiagonal4D() {
176   XlaBuilder builder("GetMatrixDiagonal");
177   Array4D<T> input(2, 2, 4, 3);
178   input.FillIota(0);
179   std::map<int, Array3D<T>> k_and_expected = {
180       {0, {{{0, 4, 8}, {12, 16, 20}}, {{24, 28, 32}, {36, 40, 44}}}},
181       {1, {{{1, 5}, {13, 17}}, {{25, 29}, {37, 41}}}},
182       {2, {{{2}, {14}}, {{26}, {38}}}},
183       {3, {{{}, {}}, {{}, {}}}},
184       {4, {{{}, {}}, {{}, {}}}},
185       {-1, {{{3, 7, 11}, {15, 19, 23}}, {{27, 31, 35}, {39, 43, 47}}}},
186       {-2, {{{6, 10}, {18, 22}}, {{30, 34}, {42, 46}}}},
187       {-3, {{{9}, {21}}, {{33}, {45}}}},
188       {-4, {{{}, {}}, {{}, {}}}},
189   };
190   for (const auto& kv : k_and_expected) {
191     XlaOp a;
192     auto a_data = CreateR4Parameter<T>(input, 0, "a", &builder, &a);
193     GetMatrixDiagonal(a, kv.first);
194 
195     ComputeAndCompareR3<T>(&builder, kv.second, {a_data.get()});
196   }
197 }
198 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S32)199 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) {
200   TestMatrixDiagonal4D<int32>();
201 }
202 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S64)203 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) {
204   TestMatrixDiagonal4D<int64>();
205 }
206 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_F32)207 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) {
208   TestMatrixDiagonal4D<float>();
209 }
210 
BatchedAValsFull()211 Array3D<float> BatchedAValsFull() {
212   return {{
213               {2, 0, 1, 2},
214               {3, 6, 0, 1},
215               {4, 7, 9, 0},
216               {5, 8, 10, 11},
217           },
218           {
219               {16, 24, 8, 12},
220               {24, 61, 82, 48},
221               {8, 82, 456, 106},
222               {12, 48, 106, 62},
223           }};
224 }
225 
XLA_TEST_F(MatrixTest,RowBatchDot)226 XLA_TEST_F(MatrixTest, RowBatchDot) {
227   XlaBuilder builder(TestName());
228   int n = 4;
229 
230   XlaOp a, row, index;
231   auto a_data =
232       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
233   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
234                                            "row", &builder, &row);
235   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
236   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
237 
238   auto l_index = DynamicSliceInMinorDims(
239       a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
240   BatchDot(l_index, TransposeInMinorDims(row));
241 
242   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
243                              {a_data.get(), row_data.get(), index_data.get()});
244 }
245 
XLA_TEST_F(MatrixTest,Einsum)246 XLA_TEST_F(MatrixTest, Einsum) {
247   XlaBuilder builder(TestName());
248 
249   int n = 4;
250 
251   XlaOp a, row, index;
252   auto a_data =
253       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
254   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
255                                            "row", &builder, &row);
256   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
257   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
258 
259   auto l_index = DynamicSliceInMinorDims(
260       a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
261   Einsum(l_index, row, "abc,adc->abd");
262 
263   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
264                              {a_data.get(), row_data.get(), index_data.get()});
265 }
266 
XLA_TEST_F(MatrixTest,ParseEinsumString)267 XLA_TEST_F(MatrixTest, ParseEinsumString) {
268   auto to_vec = [](absl::string_view s) {
269     std::vector<int64> v;
270     v.reserve(s.size());
271     int e = -3;
272     for (auto c : s) {
273       v.push_back(c == '.' ? e++ : int64{c});
274     }
275     return v;
276   };
277 
278   auto to_string = [&](absl::string_view x, absl::string_view y,
279                        absl::string_view o) {
280     return absl::StrCat(x, ",", y, "->", o);
281   };
282 
283   std::vector<std::vector<string>> good_test_cases = {
284       {"ab", "bc", "ac"},
285       {"Bab", "Bbc", "Bac"},
286       {"ab", "cd", "dcba"},
287       {"abc", "abd", "cbd"},
288       {"...ab", "...bc", "...ac"},
289       {"a...bc", "...abd", "cbd..."},
290       {"...ab", "...bc", "ac"},
291       {"...b", "...bc", "...c"},
292       {"...abz", "...bc", "...ac"},
293       {"...ab", "...bcz", "...ac"},
294       {"abz", "bc", "ac"},
295       {"ab", "bcz", "ac"},
296 
297       {"a", "b", "c"},
298       {"...a", "...b", "...c"},
299       {"abb", "bcc", "ac"},
300       {"ab", "bc", "ad"},
301   };
302   for (auto test_case : good_test_cases) {
303     auto parse_result_or_status =
304         ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
305                           test_case[0].size(), test_case[1].size());
306     EXPECT_TRUE(parse_result_or_status.status().ok());
307     auto parse_result = parse_result_or_status.ValueOrDie();
308     for (int i = 0; i < 3; ++i) {
309       EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
310     }
311   }
312 
313   std::vector<string> einsum_strings_that_fail_parsing = {
314       "", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
315   };
316   for (auto test_case : einsum_strings_that_fail_parsing) {
317     auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
318     EXPECT_FALSE(parse_result_or_status.status().ok());
319   }
320 }
321 
XLA_TEST_F(MatrixTest,NormalizeEinsumString)322 XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
323   EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
324   EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
325   EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
326   EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
327   EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
328 }
329 
330 }  // namespace
331 }  // namespace xla
332