• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
17 
18 #include "tensorflow/compiler/xla/array.h"
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 
35 namespace xla {
36 
37 class SelfAdjointEigTest : public ClientLibraryTestBase {
38  protected:
SetUp()39   void SetUp() override {
40     ClientLibraryTestBase::SetUp();
41     batch_3d_4x4_ = Array3D<float>{
42         {
43             {4, 6, 8, 10},
44             {6, 45, 54, 63},
45             {8, 54, 146, 166},
46             {10, 63, 166, 310},
47         },
48         {
49             {16, 24, 8, 12},
50             {24, 61, 82, 48},
51             {8, 82, 100, 6},
52             {12, 48, 6, 62},
53         },
54     };
55     matrix2d_8x8_ = Array2D<float>{
56         {14., 123., 49., 112., 115., 173., 182., 125.},
57         {123., 14., 60., 118., 150., 130., 91., 72.},
58         {49., 60., 138., 111., 106., 101., 115., 142.},
59         {112., 118., 111., 142., 91., 130., 25., 61.},
60         {115., 150., 106., 91., 116., 121., 128., 85.},
61         {173., 130., 101., 130., 121., 70., 151., 132.},
62         {182., 91., 115., 25., 128., 151., 66., 92.},
63         {125., 72., 142., 61., 85., 132., 92., 156.},
64     };
65     low_rank_4x4_ = Array2D<float>{
66         // x = [[1, 2, 3, 4], [1, -1, 1, -1]]
67         // matmul(x.T, x)
68         {2, 1, 4, 3},
69         {1, 5, 5, 9},
70         {4, 5, 10, 11},
71         {3, 9, 11, 17},
72     };
73   }
TearDown()74   void TearDown() override { ClientLibraryTestBase::TearDown(); }
75 
GetUnitMatrix3D(const Array3D<float> & matrix)76   Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
77     Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
78     for (int i = 0; i < matrix.n1(); ++i) {
79       for (int j = 0; j < matrix.n2(); ++j) {
80         result({i, j, j}) = 1.0;
81       }
82     }
83     return result;
84   }
85 
ExtractTriangularMatrix(const Array3D<float> & matrix,bool lower)86   Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
87                                          bool lower) {
88     Array3D<float> result(matrix);
89     for (int i = 0; i < result.n1(); ++i) {
90       for (int j = 0; j < result.n2(); ++j) {
91         if (lower) {
92           for (int k = j + 1; k < result.n3(); ++k) {
93             result({i, j, k}) = 0.0;
94           }
95         } else {
96           for (int k = 0; k < j; ++k) {
97             result({i, j, k}) = 0.0;
98           }
99         }
100       }
101     }
102     return result;
103   }
104 
105   Array3D<float> batch_3d_4x4_;
106   Array2D<float> matrix2d_8x8_;
107   Array2D<float> low_rank_4x4_;
108   Array2D<int> wrong_type_4x4_;
109 };
110 
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)111 XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
112   Shape shape = builder->GetShape(m1).ValueOrDie();
113   int64_t size = ShapeUtil::ElementsIn(shape);
114   return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
115                    CreateScalarAddComputation(F32, builder)) /
116          ConstantR0WithType(builder, F32, std::max<int64_t>(1, size));
117 }
118 
ComputeMatmulVWVt(SelfAdjointEigResult result,XlaBuilder * builder)119 XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
120   Shape shape = builder->GetShape(result.v).ValueOrDie();
121   absl::Span<const int64_t> out_dims = shape.dimensions();
122   std::vector<int64_t> broadcast_dims(shape.rank() - 1);
123   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
124 
125   broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
126   auto vw =
127       Mul(result.v,
128           BroadcastInDim(ConvertElementType(result.w, shape.element_type()),
129                          out_dims, broadcast_dims));
130   return BatchDot(vw, MaybeConjugate(TransposeInMinorDims(result.v), true),
131                   PrecisionConfig::HIGHEST);
132 }
133 
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_2x4x4)134 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
135   for (bool sort_eigenvalues : {false, true}) {
136     XlaBuilder builder(TestName());
137 
138     XlaOp a;
139     auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
140     auto result = SelfAdjointEig(a, /*lower=*/true, /*max_iter=*/15,
141                                  /*tol=*/1e-5, sort_eigenvalues);
142     ComputeMatmulVWVt(result, &builder);
143 
144     ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
145                                ErrorSpec(1e-3, 1e-3));
146   }
147 }
148 
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_3x3_Complex)149 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_3x3_Complex) {
150   XlaBuilder builder(TestName());
151   Array<complex64> input = {
152       {1, complex64{2, -7}, complex64{4, -8}},
153       {complex64{2, 7}, 3, complex64{5, -9}},
154       {complex64{4, 8}, complex64{5, 9}, 6},
155   };
156   XlaOp a;
157   auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
158   auto result = SelfAdjointEig(a);
159   ComputeMatmulVWVt(result, &builder);
160 
161   ComputeAndCompare<complex64>(&builder, input, {a_data.get()},
162                                ErrorSpec(1e-3, 1e-3));
163 }
164 
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Lower_2x4x4)165 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
166   XlaBuilder builder(TestName());
167 
168   XlaOp a;
169   auto a_data = CreateR3Parameter<float>(
170       ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
171   auto result = SelfAdjointEig(a);
172   ComputeMatmulVWVt(result, &builder);
173 
174   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
175                              ErrorSpec(1e-3, 1e-3));
176 }
177 
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Upper_2x4x4)178 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) {
179   XlaBuilder builder(TestName());
180 
181   XlaOp a;
182   auto a_data = CreateR3Parameter<float>(
183       ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
184   auto result = SelfAdjointEig(a, false);
185   ComputeMatmulVWVt(result, &builder);
186 
187   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
188                              ErrorSpec(1e-3, 1e-3));
189 }
190 
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_2x4x4)191 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) {
192   XlaBuilder builder(TestName());
193 
194   XlaOp a;
195   auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
196   auto result = SelfAdjointEig(a);
197   BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
198 
199   ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_),
200                              {a_data.get()}, ErrorSpec(1e-3, 1e-3));
201 }
202 
XLA_TEST_F(SelfAdjointEigTest,Test_VtWV_EQ_A_Rank_Deficient_4x4)203 XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
204   XlaBuilder builder(TestName());
205 
206   XlaOp a;
207   auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
208   auto result = SelfAdjointEig(a);
209   ComputeMatmulVWVt(result, &builder);
210 
211   ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
212                              ErrorSpec(1e-3, 1e-3));
213 }
214 
XLA_TEST_F(SelfAdjointEigTest,Test_Eigen_8x8)215 XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) {
216   XlaBuilder builder(TestName());
217 
218   // This is computed by numpy.linalg.eigh with float32.
219   std::vector<float> expected{-182.69205, -116.86245, -105.74489, -9.545369,
220                               37.81711,   104.732285, 120.29153,  868.00385};
221 
222   XlaOp a;
223   auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
224   auto result = SelfAdjointEig(a);
225   Add(result.w, ZerosLike(result.w));
226 
227   ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
228                              ErrorSpec(1e-3, 1e-3));
229 }
230 
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_8x8)231 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) {
232   XlaBuilder builder(TestName());
233 
234   float expected_vals = 1e-3;
235 
236   XlaOp a;
237   auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
238   auto result = SelfAdjointEig(a);
239   // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
240   GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
241                           BatchDot(TransposeInMinorDims(result.v), result.v),
242                           &builder);
243 
244   ComputeAndCompareR0<float>(&builder, expected_vals, {a_data.get()},
245                              ErrorSpec(1e-3, 1e-3));
246 }
247 
XLA_TEST_F(SelfAdjointEigTest,Wrong_Type_Int)248 XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) {
249   XlaBuilder builder(TestName());
250 
251   XlaOp a;
252   auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
253   auto result = SelfAdjointEig(a);
254   EXPECT_FALSE(result.v.valid());
255   EXPECT_FALSE(result.w.valid());
256 }
257 
GenerateRandomSymmetricMatrix(int size)258 Array2D<float> GenerateRandomSymmetricMatrix(int size) {
259   Array2D<float> result{size, size, 0.0};
260   // TODO(b/128001705): This seed should not be needed but makes the test
261   // avoid inputs which trigger numerical instability.
262   result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */);
263   for (int i = 0; i < size; ++i) {
264     for (int j = 0; j < i; ++j) {
265       result({j, i}) = result({i, j});
266     }
267   }
268   return result;
269 }
270 
271 using EighTestCase = int64_t;
272 class RandomEighTest : public ClientLibraryTestBase,
273                        public ::testing::WithParamInterface<EighTestCase> {};
274 
XLA_TEST_P(RandomEighTest,Random)275 XLA_TEST_P(RandomEighTest, Random) {
276   XlaBuilder builder(TestName());
277   int64_t size = GetParam();
278   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
279   XlaOp a;
280   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
281   auto result = SelfAdjointEig(a);
282   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
283 
284   // TODO(phawkins): this would be better expressed as <= 6e-3.
285   ComputeAndCompareR0<float>(&builder, 3e-3, {a_data.get()},
286                              ErrorSpec(3e-3, 0));
287 }
288 
289 #ifndef XLA_TEST_BACKEND_CPU
290 INSTANTIATE_TEST_SUITE_P(
291     RandomEighTestInstantiation, RandomEighTest,
292     ::testing::Values(0, 1, 2, 3, 8, 16, 32, 77, 129, 203, 256, 257, 493, 511,
293                       512,
294                       // Large tests are slow on CPU.
295                       513, 1000),
__anon4b17a7960102(const ::testing::TestParamInfo<EighTestCase>& info) 296     [](const ::testing::TestParamInfo<EighTestCase>& info) {
297       const int64_t size = info.param;
298       return absl::StrCat(size);
299     });
300 #else
301 INSTANTIATE_TEST_SUITE_P(
302     RandomEighTestInstantiation, RandomEighTest,
303     ::testing::Values(0, 1, 2, 3, 8, 16, 32, 77, 129, 203, 256, 257, 493, 511,
304                       512),
__anon4b17a7960202(const ::testing::TestParamInfo<EighTestCase>& info) 305     [](const ::testing::TestParamInfo<EighTestCase>& info) {
306       const int64_t size = info.param;
307       return absl::StrCat(size);
308     });
309 #endif  // XLA_TEST_BACKEND_CPU
310 
311 }  // namespace xla
312