• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17 
18 #include "absl/strings/string_view.h"
19 #include "tensorflow/compiler/xla/client/lib/slicing.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/compiler/xla/status.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/test_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 
28 namespace xla {
29 namespace {
30 
31 class MatrixTest : public ClientLibraryTestBase {
32  protected:
33   template <typename T>
34   void TestMatrixDiagonal();
35 };
36 
XLA_TEST_F(MatrixTest,Triangle)37 XLA_TEST_F(MatrixTest, Triangle) {
38   XlaBuilder builder(TestName());
39   Array3D<int32> input(2, 3, 4);
40   input.FillIota(0);
41 
42   XlaOp a;
43   auto a_data = CreateR3Parameter<int32>(input, 0, "a", &builder, &a);
44   LowerTriangle(a);
45   Array3D<int32> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
46                            {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
47 
48   ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
49 }
50 
51 template <typename T>
TestMatrixDiagonal()52 void MatrixTest::TestMatrixDiagonal() {
53   XlaBuilder builder("GetMatrixDiagonal");
54   Array3D<T> input(2, 3, 4);
55   input.FillIota(0);
56   std::map<int, Array2D<T>> k_and_expected = {
57       {0, {{0, 5, 10}, {12, 17, 22}}},
58       {1, {{1, 6, 11}, {13, 18, 23}}},
59       {2, {{2, 7}, {14, 19}}},
60       {3, {{3}, {15}}},
61       {4, {{}, {}}},
62       {-1, {{4, 9}, {16, 21}}},
63       {-2, {{8}, {20}}},
64       {-3, {{}, {}}},
65       {-4, {{}, {}}},
66   };
67   for (const auto& kv : k_and_expected) {
68     XlaOp a;
69     auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
70     GetMatrixDiagonal(a, kv.first);
71 
72     ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
73   }
74 }
75 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)76 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
77 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)78 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
79 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)80 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
81 
BatchedAValsFull()82 Array3D<float> BatchedAValsFull() {
83   return {{
84               {2, 0, 1, 2},
85               {3, 6, 0, 1},
86               {4, 7, 9, 0},
87               {5, 8, 10, 11},
88           },
89           {
90               {16, 24, 8, 12},
91               {24, 61, 82, 48},
92               {8, 82, 456, 106},
93               {12, 48, 106, 62},
94           }};
95 }
96 
XLA_TEST_F(MatrixTest,RowBatchDot)97 XLA_TEST_F(MatrixTest, RowBatchDot) {
98   XlaBuilder builder(TestName());
99 
100   int n = 4;
101 
102   XlaOp a, row, index;
103   auto a_data =
104       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
105   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
106                                            "row", &builder, &row);
107   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
108   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
109 
110   auto l_index = DynamicSliceInMinorDims(
111       a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
112   BatchDot(l_index, TransposeInMinorDims(row));
113 
114   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
115                              {a_data.get(), row_data.get(), index_data.get()});
116 }
117 
XLA_TEST_F(MatrixTest,Einsum)118 XLA_TEST_F(MatrixTest, Einsum) {
119   XlaBuilder builder(TestName());
120 
121   int n = 4;
122 
123   XlaOp a, row, index;
124   auto a_data =
125       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
126   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
127                                            "row", &builder, &row);
128   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
129   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
130 
131   auto l_index = DynamicSliceInMinorDims(
132       a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
133   Einsum(l_index, row, "abc,adc->abd");
134 
135   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
136                              {a_data.get(), row_data.get(), index_data.get()});
137 }
138 
XLA_TEST_F(MatrixTest,ParseEinsumString)139 XLA_TEST_F(MatrixTest, ParseEinsumString) {
140   auto to_vec = [](absl::string_view s) {
141     std::vector<int64> v;
142     v.reserve(s.size());
143     for (auto c : s) {
144       v.push_back(int64{c});
145     }
146     return v;
147   };
148 
149   auto to_string = [&](absl::string_view x, absl::string_view y,
150                        absl::string_view o) {
151     return absl::StrCat(x, ",", y, "->", o);
152   };
153 
154   std::vector<std::vector<string>> good_test_cases = {{"ab", "bc", "ac"},
155                                                       {"Bab", "Bbc", "Bac"},
156                                                       {"ab", "cd", "dcba"},
157                                                       {"abc", "abd", "cbd"}};
158   for (auto test_case : good_test_cases) {
159     auto parse_result_or_status =
160         ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]));
161     EXPECT_TRUE(parse_result_or_status.status().ok());
162     auto parse_result = parse_result_or_status.ValueOrDie();
163     for (int i = 0; i < 3; ++i) {
164       EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
165     }
166     EXPECT_TRUE(ValidateEinsumNumericDimensions(
167                     parse_result[0], parse_result[1], parse_result[2])
168                     .ok());
169   }
170 
171   std::vector<string> einsum_strings_that_fail_parsing = {
172       "", "a", "ab->ba", "ab,bc,cd->ad", "a...b,bc->a...c"};
173   for (auto test_case : einsum_strings_that_fail_parsing) {
174     auto parse_result_or_status = ParseEinsumString(test_case);
175     EXPECT_FALSE(parse_result_or_status.status().ok());
176   }
177 
178   std::vector<string> einsum_strings_that_fail_numeric_validation = {
179       "a,b->c", "ab,bc->acd", "abz,bc->ac", "ab,bcz->ac"};
180   for (auto test_case : einsum_strings_that_fail_numeric_validation) {
181     auto parse_result_or_status = ParseEinsumString(test_case);
182     EXPECT_TRUE(parse_result_or_status.status().ok());
183     auto parse_result = parse_result_or_status.ValueOrDie();
184     EXPECT_FALSE(ValidateEinsumNumericDimensions(
185                      parse_result[0], parse_result[1], parse_result[2])
186                      .ok());
187   }
188 }
189 
190 }  // namespace
191 }  // namespace xla
192