1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
17
18 #include "tensorflow/compiler/xla/array2d.h"
19 #include "tensorflow/compiler/xla/array3d.h"
20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32
33 namespace xla {
34
35 class SelfAdjointEigTest : public ClientLibraryTestBase {
36 protected:
SetUp()37 void SetUp() override {
38 ClientLibraryTestBase::SetUp();
39 batch_3d_4x4_ = Array3D<float>{
40 {
41 {4, 6, 8, 10},
42 {6, 45, 54, 63},
43 {8, 54, 146, 166},
44 {10, 63, 166, 310},
45 },
46 {
47 {16, 24, 8, 12},
48 {24, 61, 82, 48},
49 {8, 82, 100, 6},
50 {12, 48, 6, 62},
51 },
52 };
53 matrix2d_8x8_ = Array2D<float>{
54 {14., 123., 49., 112., 115., 173., 182., 125.},
55 {123., 14., 60., 118., 150., 130., 91., 72.},
56 {49., 60., 138., 111., 106., 101., 115., 142.},
57 {112., 118., 111., 142., 91., 130., 25., 61.},
58 {115., 150., 106., 91., 116., 121., 128., 85.},
59 {173., 130., 101., 130., 121., 70., 151., 132.},
60 {182., 91., 115., 25., 128., 151., 66., 92.},
61 {125., 72., 142., 61., 85., 132., 92., 156.},
62 };
63 low_rank_4x4_ = Array2D<float>{
64 // x = [[1, 2, 3, 4], [1, -1, 1, -1]]
65 // matmul(x.T, x)
66 {2, 1, 4, 3},
67 {1, 5, 5, 9},
68 {4, 5, 10, 11},
69 {3, 9, 11, 17},
70 };
71 }
TearDown()72 void TearDown() override { ClientLibraryTestBase::TearDown(); }
73
GetUnitMatrix3D(const Array3D<float> & matrix)74 Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
75 Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
76 for (int i = 0; i < matrix.n1(); ++i) {
77 for (int j = 0; j < matrix.n2(); ++j) {
78 result({i, j, j}) = 1.0;
79 }
80 }
81 return result;
82 }
83
ExtractTriangularMatrix(const Array3D<float> & matrix,bool lower)84 Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
85 bool lower) {
86 Array3D<float> result(matrix);
87 for (int i = 0; i < result.n1(); ++i) {
88 for (int j = 0; j < result.n2(); ++j) {
89 if (lower) {
90 for (int k = j + 1; k < result.n3(); ++k) {
91 result({i, j, k}) = 0.0;
92 }
93 } else {
94 for (int k = 0; k < j; ++k) {
95 result({i, j, k}) = 0.0;
96 }
97 }
98 }
99 }
100 return result;
101 }
102
ComputeMatmulVWVt(SelfAdjointEigResult result,XlaBuilder * builder)103 XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
104 Shape shape = builder->GetShape(result.v).ValueOrDie();
105 std::vector<int64> out_dims = shape.dimensions();
106 std::vector<int64> broadcast_dims(shape.rank() - 1);
107 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
108
109 broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
110 auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims));
111 return BatchDot(vw, TransposeInMinorDims(result.v),
112 PrecisionConfig::HIGHEST);
113 }
114
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)115 XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
116 Shape shape = builder->GetShape(m1).ValueOrDie();
117 int64 size = 1;
118 for (auto d : shape.dimensions()) {
119 size *= d;
120 }
121 return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
122 CreateScalarAddComputation(F32, builder)) /
123 ConstantR0WithType(builder, F32, size);
124 }
125
GenerateRandomSymmetricMatrix(int size)126 Array2D<float> GenerateRandomSymmetricMatrix(int size) {
127 Array2D<float> result{size, size, 0.0};
128 // TODO(b/128001705): This seed should not be needed but makes the test
129 // avoid inputs which trigger numerical instability.
130 result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */);
131 for (int i = 0; i < size; ++i) {
132 for (int j = 0; j < i; ++j) {
133 result({j, i}) = result({i, j});
134 }
135 }
136 return result;
137 }
138
139 Array3D<float> batch_3d_4x4_;
140 Array2D<float> matrix2d_8x8_;
141 Array2D<float> low_rank_4x4_;
142 Array2D<int> wrong_type_4x4_;
143 };
144
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_2x4x4)145 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
146 XlaBuilder builder(TestName());
147
148 XlaOp a;
149 auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
150 auto result = SelfAdjointEig(a);
151 ComputeMatmulVWVt(result, &builder);
152
153 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
154 ErrorSpec(1e-3, 1e-3));
155 }
156
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Lower_2x4x4)157 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
158 XlaBuilder builder(TestName());
159
160 XlaOp a;
161 auto a_data = CreateR3Parameter<float>(
162 ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
163 auto result = SelfAdjointEig(a);
164 ComputeMatmulVWVt(result, &builder);
165
166 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
167 ErrorSpec(1e-3, 1e-3));
168 }
169
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Upper_2x4x4)170 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) {
171 XlaBuilder builder(TestName());
172
173 XlaOp a;
174 auto a_data = CreateR3Parameter<float>(
175 ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
176 auto result = SelfAdjointEig(a, false);
177 ComputeMatmulVWVt(result, &builder);
178
179 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
180 ErrorSpec(1e-3, 1e-3));
181 }
182
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_2x4x4)183 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) {
184 XlaBuilder builder(TestName());
185
186 XlaOp a;
187 auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
188 auto result = SelfAdjointEig(a);
189 BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
190
191 ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_),
192 {a_data.get()}, ErrorSpec(1e-3, 1e-3));
193 }
194
XLA_TEST_F(SelfAdjointEigTest,Test_VtWV_EQ_A_Rank_Deficient_4x4)195 XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
196 XlaBuilder builder(TestName());
197
198 XlaOp a;
199 auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
200 auto result = SelfAdjointEig(a);
201 ComputeMatmulVWVt(result, &builder);
202
203 ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
204 ErrorSpec(1e-3, 1e-3));
205 }
206
XLA_TEST_F(SelfAdjointEigTest,Test_Eigen_8x8)207 XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) {
208 XlaBuilder builder(TestName());
209
210 // This is computed by numpy.linalg.eigh with float32.
211 std::vector<float> expected{-182.69205, -116.86245, -105.74489, -9.545369,
212 37.81711, 104.732285, 120.29153, 868.00385};
213
214 XlaOp a;
215 auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
216 auto result = SelfAdjointEig(a);
217 Add(result.w, ZerosLike(result.w));
218
219 ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
220 ErrorSpec(1e-3, 1e-3));
221 }
222
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_8x8)223 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) {
224 XlaBuilder builder(TestName());
225
226 float expected_vals = 1e-3;
227
228 XlaOp a;
229 auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
230 auto result = SelfAdjointEig(a);
231 // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
232 GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
233 BatchDot(TransposeInMinorDims(result.v), result.v),
234 &builder);
235
236 ComputeAndCompareR0<float>(&builder, expected_vals, {a_data.get()},
237 ErrorSpec(1e-3, 1e-3));
238 }
239
XLA_TEST_F(SelfAdjointEigTest,Wrong_Type_Int)240 XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) {
241 XlaBuilder builder(TestName());
242
243 XlaOp a;
244 auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
245 auto result = SelfAdjointEig(a);
246 EXPECT_FALSE(result.v.valid());
247 EXPECT_FALSE(result.w.valid());
248 }
249
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_8x8)250 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) {
251 XlaBuilder builder(TestName());
252 int size = 8;
253 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
254 XlaOp a;
255 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
256 auto result = SelfAdjointEig(a);
257 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
258
259 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
260 ErrorSpec(1e-3, 1e-3));
261 }
262
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_16x16)263 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) {
264 XlaBuilder builder(TestName());
265 int size = 16;
266 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
267 XlaOp a;
268 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
269 auto result = SelfAdjointEig(a);
270 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
271
272 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
273 ErrorSpec(1e-3, 1e-3));
274 }
275
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_32x32)276 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) {
277 XlaBuilder builder(TestName());
278 int size = 32;
279 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
280 XlaOp a;
281 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
282 auto result = SelfAdjointEig(a);
283 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
284
285 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
286 ErrorSpec(1e-3, 1e-3));
287 }
288
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_256x256)289 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) {
290 XlaBuilder builder(TestName());
291 int size = 256;
292 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
293 XlaOp a;
294 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
295 auto result = SelfAdjointEig(a);
296 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
297
298 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
299 ErrorSpec(1e-3, 1e-3));
300 }
301
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_512x512)302 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) {
303 XlaBuilder builder(TestName());
304 int size = 512;
305 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
306 XlaOp a;
307 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
308 auto result = SelfAdjointEig(a);
309 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
310
311 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
312 ErrorSpec(1e-3, 1e-3));
313 }
314
315 } // namespace xla
316