1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
17
18 #include "tensorflow/compiler/xla/array.h"
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34
35 namespace xla {
36
37 class SelfAdjointEigTest : public ClientLibraryTestBase {
38 protected:
SetUp()39 void SetUp() override {
40 ClientLibraryTestBase::SetUp();
41 batch_3d_4x4_ = Array3D<float>{
42 {
43 {4, 6, 8, 10},
44 {6, 45, 54, 63},
45 {8, 54, 146, 166},
46 {10, 63, 166, 310},
47 },
48 {
49 {16, 24, 8, 12},
50 {24, 61, 82, 48},
51 {8, 82, 100, 6},
52 {12, 48, 6, 62},
53 },
54 };
55 matrix2d_8x8_ = Array2D<float>{
56 {14., 123., 49., 112., 115., 173., 182., 125.},
57 {123., 14., 60., 118., 150., 130., 91., 72.},
58 {49., 60., 138., 111., 106., 101., 115., 142.},
59 {112., 118., 111., 142., 91., 130., 25., 61.},
60 {115., 150., 106., 91., 116., 121., 128., 85.},
61 {173., 130., 101., 130., 121., 70., 151., 132.},
62 {182., 91., 115., 25., 128., 151., 66., 92.},
63 {125., 72., 142., 61., 85., 132., 92., 156.},
64 };
65 low_rank_4x4_ = Array2D<float>{
66 // x = [[1, 2, 3, 4], [1, -1, 1, -1]]
67 // matmul(x.T, x)
68 {2, 1, 4, 3},
69 {1, 5, 5, 9},
70 {4, 5, 10, 11},
71 {3, 9, 11, 17},
72 };
73 }
TearDown()74 void TearDown() override { ClientLibraryTestBase::TearDown(); }
75
GetUnitMatrix3D(const Array3D<float> & matrix)76 Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
77 Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
78 for (int i = 0; i < matrix.n1(); ++i) {
79 for (int j = 0; j < matrix.n2(); ++j) {
80 result({i, j, j}) = 1.0;
81 }
82 }
83 return result;
84 }
85
ExtractTriangularMatrix(const Array3D<float> & matrix,bool lower)86 Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
87 bool lower) {
88 Array3D<float> result(matrix);
89 for (int i = 0; i < result.n1(); ++i) {
90 for (int j = 0; j < result.n2(); ++j) {
91 if (lower) {
92 for (int k = j + 1; k < result.n3(); ++k) {
93 result({i, j, k}) = 0.0;
94 }
95 } else {
96 for (int k = 0; k < j; ++k) {
97 result({i, j, k}) = 0.0;
98 }
99 }
100 }
101 }
102 return result;
103 }
104
105 Array3D<float> batch_3d_4x4_;
106 Array2D<float> matrix2d_8x8_;
107 Array2D<float> low_rank_4x4_;
108 Array2D<int> wrong_type_4x4_;
109 };
110
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)111 XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
112 Shape shape = builder->GetShape(m1).ValueOrDie();
113 int64_t size = ShapeUtil::ElementsIn(shape);
114 return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
115 CreateScalarAddComputation(F32, builder)) /
116 ConstantR0WithType(builder, F32, std::max<int64_t>(1, size));
117 }
118
ComputeMatmulVWVt(SelfAdjointEigResult result,XlaBuilder * builder)119 XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
120 Shape shape = builder->GetShape(result.v).ValueOrDie();
121 absl::Span<const int64_t> out_dims = shape.dimensions();
122 std::vector<int64_t> broadcast_dims(shape.rank() - 1);
123 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
124
125 broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
126 auto vw =
127 Mul(result.v,
128 BroadcastInDim(ConvertElementType(result.w, shape.element_type()),
129 out_dims, broadcast_dims));
130 return BatchDot(vw, MaybeConjugate(TransposeInMinorDims(result.v), true),
131 PrecisionConfig::HIGHEST);
132 }
133
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_2x4x4)134 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
135 for (bool sort_eigenvalues : {false, true}) {
136 XlaBuilder builder(TestName());
137
138 XlaOp a;
139 auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
140 auto result = SelfAdjointEig(a, /*lower=*/true, /*max_iter=*/15,
141 /*tol=*/1e-5, sort_eigenvalues);
142 ComputeMatmulVWVt(result, &builder);
143
144 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
145 ErrorSpec(1e-3, 1e-3));
146 }
147 }
148
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_3x3_Complex)149 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_3x3_Complex) {
150 XlaBuilder builder(TestName());
151 Array<complex64> input = {
152 {1, complex64{2, -7}, complex64{4, -8}},
153 {complex64{2, 7}, 3, complex64{5, -9}},
154 {complex64{4, 8}, complex64{5, 9}, 6},
155 };
156 XlaOp a;
157 auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
158 auto result = SelfAdjointEig(a);
159 ComputeMatmulVWVt(result, &builder);
160
161 ComputeAndCompare<complex64>(&builder, input, {a_data.get()},
162 ErrorSpec(1e-3, 1e-3));
163 }
164
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Lower_2x4x4)165 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
166 XlaBuilder builder(TestName());
167
168 XlaOp a;
169 auto a_data = CreateR3Parameter<float>(
170 ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
171 auto result = SelfAdjointEig(a);
172 ComputeMatmulVWVt(result, &builder);
173
174 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
175 ErrorSpec(1e-3, 1e-3));
176 }
177
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Upper_2x4x4)178 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) {
179 XlaBuilder builder(TestName());
180
181 XlaOp a;
182 auto a_data = CreateR3Parameter<float>(
183 ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
184 auto result = SelfAdjointEig(a, false);
185 ComputeMatmulVWVt(result, &builder);
186
187 ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
188 ErrorSpec(1e-3, 1e-3));
189 }
190
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_2x4x4)191 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) {
192 XlaBuilder builder(TestName());
193
194 XlaOp a;
195 auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
196 auto result = SelfAdjointEig(a);
197 BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
198
199 ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_),
200 {a_data.get()}, ErrorSpec(1e-3, 1e-3));
201 }
202
XLA_TEST_F(SelfAdjointEigTest,Test_VtWV_EQ_A_Rank_Deficient_4x4)203 XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
204 XlaBuilder builder(TestName());
205
206 XlaOp a;
207 auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
208 auto result = SelfAdjointEig(a);
209 ComputeMatmulVWVt(result, &builder);
210
211 ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
212 ErrorSpec(1e-3, 1e-3));
213 }
214
XLA_TEST_F(SelfAdjointEigTest,Test_Eigen_8x8)215 XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) {
216 XlaBuilder builder(TestName());
217
218 // This is computed by numpy.linalg.eigh with float32.
219 std::vector<float> expected{-182.69205, -116.86245, -105.74489, -9.545369,
220 37.81711, 104.732285, 120.29153, 868.00385};
221
222 XlaOp a;
223 auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
224 auto result = SelfAdjointEig(a);
225 Add(result.w, ZerosLike(result.w));
226
227 ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
228 ErrorSpec(1e-3, 1e-3));
229 }
230
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_8x8)231 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) {
232 XlaBuilder builder(TestName());
233
234 float expected_vals = 1e-3;
235
236 XlaOp a;
237 auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
238 auto result = SelfAdjointEig(a);
239 // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
240 GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
241 BatchDot(TransposeInMinorDims(result.v), result.v),
242 &builder);
243
244 ComputeAndCompareR0<float>(&builder, expected_vals, {a_data.get()},
245 ErrorSpec(1e-3, 1e-3));
246 }
247
XLA_TEST_F(SelfAdjointEigTest,Wrong_Type_Int)248 XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) {
249 XlaBuilder builder(TestName());
250
251 XlaOp a;
252 auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
253 auto result = SelfAdjointEig(a);
254 EXPECT_FALSE(result.v.valid());
255 EXPECT_FALSE(result.w.valid());
256 }
257
GenerateRandomSymmetricMatrix(int size)258 Array2D<float> GenerateRandomSymmetricMatrix(int size) {
259 Array2D<float> result{size, size, 0.0};
260 // TODO(b/128001705): This seed should not be needed but makes the test
261 // avoid inputs which trigger numerical instability.
262 result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */);
263 for (int i = 0; i < size; ++i) {
264 for (int j = 0; j < i; ++j) {
265 result({j, i}) = result({i, j});
266 }
267 }
268 return result;
269 }
270
271 using EighTestCase = int64_t;
272 class RandomEighTest : public ClientLibraryTestBase,
273 public ::testing::WithParamInterface<EighTestCase> {};
274
XLA_TEST_P(RandomEighTest,Random)275 XLA_TEST_P(RandomEighTest, Random) {
276 XlaBuilder builder(TestName());
277 int64_t size = GetParam();
278 Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
279 XlaOp a;
280 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
281 auto result = SelfAdjointEig(a);
282 GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
283
284 // TODO(phawkins): this would be better expressed as <= 6e-3.
285 ComputeAndCompareR0<float>(&builder, 3e-3, {a_data.get()},
286 ErrorSpec(3e-3, 0));
287 }
288
289 #ifndef XLA_TEST_BACKEND_CPU
290 INSTANTIATE_TEST_SUITE_P(
291 RandomEighTestInstantiation, RandomEighTest,
292 ::testing::Values(0, 1, 2, 3, 8, 16, 32, 77, 129, 203, 256, 257, 493, 511,
293 512,
294 // Large tests are slow on CPU.
295 513, 1000),
__anon4b17a7960102(const ::testing::TestParamInfo<EighTestCase>& info) 296 [](const ::testing::TestParamInfo<EighTestCase>& info) {
297 const int64_t size = info.param;
298 return absl::StrCat(size);
299 });
300 #else
301 INSTANTIATE_TEST_SUITE_P(
302 RandomEighTestInstantiation, RandomEighTest,
303 ::testing::Values(0, 1, 2, 3, 8, 16, 32, 77, 129, 203, 256, 257, 493, 511,
304 512),
__anon4b17a7960202(const ::testing::TestParamInfo<EighTestCase>& info) 305 [](const ::testing::TestParamInfo<EighTestCase>& info) {
306 const int64_t size = info.param;
307 return absl::StrCat(size);
308 });
309 #endif // XLA_TEST_BACKEND_CPU
310
311 } // namespace xla
312