1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17
18 #include "absl/strings/string_view.h"
19 #include "tensorflow/compiler/xla/client/lib/slicing.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/compiler/xla/status.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/test_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27
28 namespace xla {
29 namespace {
30
31 class MatrixTest : public ClientLibraryTestBase {
32 protected:
33 template <typename T>
34 void TestMatrixDiagonal();
35 };
36
XLA_TEST_F(MatrixTest,Triangle)37 XLA_TEST_F(MatrixTest, Triangle) {
38 XlaBuilder builder(TestName());
39 Array3D<int32> input(2, 3, 4);
40 input.FillIota(0);
41
42 XlaOp a;
43 auto a_data = CreateR3Parameter<int32>(input, 0, "a", &builder, &a);
44 LowerTriangle(a);
45 Array3D<int32> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
46 {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
47
48 ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
49 }
50
51 template <typename T>
TestMatrixDiagonal()52 void MatrixTest::TestMatrixDiagonal() {
53 XlaBuilder builder("GetMatrixDiagonal");
54 Array3D<T> input(2, 3, 4);
55 input.FillIota(0);
56 std::map<int, Array2D<T>> k_and_expected = {
57 {0, {{0, 5, 10}, {12, 17, 22}}},
58 {1, {{1, 6, 11}, {13, 18, 23}}},
59 {2, {{2, 7}, {14, 19}}},
60 {3, {{3}, {15}}},
61 {4, {{}, {}}},
62 {-1, {{4, 9}, {16, 21}}},
63 {-2, {{8}, {20}}},
64 {-3, {{}, {}}},
65 {-4, {{}, {}}},
66 };
67 for (const auto& kv : k_and_expected) {
68 XlaOp a;
69 auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
70 GetMatrixDiagonal(a, kv.first);
71
72 ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
73 }
74 }
75
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)76 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
77
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)78 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
79
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)80 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
81
BatchedAValsFull()82 Array3D<float> BatchedAValsFull() {
83 return {{
84 {2, 0, 1, 2},
85 {3, 6, 0, 1},
86 {4, 7, 9, 0},
87 {5, 8, 10, 11},
88 },
89 {
90 {16, 24, 8, 12},
91 {24, 61, 82, 48},
92 {8, 82, 456, 106},
93 {12, 48, 106, 62},
94 }};
95 }
96
XLA_TEST_F(MatrixTest,RowBatchDot)97 XLA_TEST_F(MatrixTest, RowBatchDot) {
98 XlaBuilder builder(TestName());
99
100 int n = 4;
101
102 XlaOp a, row, index;
103 auto a_data =
104 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
105 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
106 "row", &builder, &row);
107 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
108 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
109
110 auto l_index = DynamicSliceInMinorDims(
111 a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
112 BatchDot(l_index, TransposeInMinorDims(row));
113
114 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
115 {a_data.get(), row_data.get(), index_data.get()});
116 }
117
XLA_TEST_F(MatrixTest,Einsum)118 XLA_TEST_F(MatrixTest, Einsum) {
119 XlaBuilder builder(TestName());
120
121 int n = 4;
122
123 XlaOp a, row, index;
124 auto a_data =
125 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
126 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
127 "row", &builder, &row);
128 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
129 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
130
131 auto l_index = DynamicSliceInMinorDims(
132 a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
133 Einsum(l_index, row, "abc,adc->abd");
134
135 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
136 {a_data.get(), row_data.get(), index_data.get()});
137 }
138
XLA_TEST_F(MatrixTest,ParseEinsumString)139 XLA_TEST_F(MatrixTest, ParseEinsumString) {
140 auto to_vec = [](absl::string_view s) {
141 std::vector<int64> v;
142 v.reserve(s.size());
143 for (auto c : s) {
144 v.push_back(int64{c});
145 }
146 return v;
147 };
148
149 auto to_string = [&](absl::string_view x, absl::string_view y,
150 absl::string_view o) {
151 return absl::StrCat(x, ",", y, "->", o);
152 };
153
154 std::vector<std::vector<string>> good_test_cases = {{"ab", "bc", "ac"},
155 {"Bab", "Bbc", "Bac"},
156 {"ab", "cd", "dcba"},
157 {"abc", "abd", "cbd"}};
158 for (auto test_case : good_test_cases) {
159 auto parse_result_or_status =
160 ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]));
161 EXPECT_TRUE(parse_result_or_status.status().ok());
162 auto parse_result = parse_result_or_status.ValueOrDie();
163 for (int i = 0; i < 3; ++i) {
164 EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
165 }
166 EXPECT_TRUE(ValidateEinsumNumericDimensions(
167 parse_result[0], parse_result[1], parse_result[2])
168 .ok());
169 }
170
171 std::vector<string> einsum_strings_that_fail_parsing = {
172 "", "a", "ab->ba", "ab,bc,cd->ad", "a...b,bc->a...c"};
173 for (auto test_case : einsum_strings_that_fail_parsing) {
174 auto parse_result_or_status = ParseEinsumString(test_case);
175 EXPECT_FALSE(parse_result_or_status.status().ok());
176 }
177
178 std::vector<string> einsum_strings_that_fail_numeric_validation = {
179 "a,b->c", "ab,bc->acd", "abz,bc->ac", "ab,bcz->ac"};
180 for (auto test_case : einsum_strings_that_fail_numeric_validation) {
181 auto parse_result_or_status = ParseEinsumString(test_case);
182 EXPECT_TRUE(parse_result_or_status.status().ok());
183 auto parse_result = parse_result_or_status.ValueOrDie();
184 EXPECT_FALSE(ValidateEinsumNumericDimensions(
185 parse_result[0], parse_result[1], parse_result[2])
186 .ok());
187 }
188 }
189
190 } // namespace
191 } // namespace xla
192