• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
17 
18 #include "tensorflow/compiler/xla/array2d.h"
19 #include "tensorflow/compiler/xla/array3d.h"
20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 
33 namespace xla {
34 
35 class SelfAdjointEigTest : public ClientLibraryTestBase {
36  protected:
SetUp()37   void SetUp() override {
38     ClientLibraryTestBase::SetUp();
39     batch_3d_4x4_ = Array3D<float>{
40         {
41             {4, 6, 8, 10},
42             {6, 45, 54, 63},
43             {8, 54, 146, 166},
44             {10, 63, 166, 310},
45         },
46         {
47             {16, 24, 8, 12},
48             {24, 61, 82, 48},
49             {8, 82, 100, 6},
50             {12, 48, 6, 62},
51         },
52     };
53     matrix2d_8x8_ = Array2D<float>{
54         {14., 123., 49., 112., 115., 173., 182., 125.},
55         {123., 14., 60., 118., 150., 130., 91., 72.},
56         {49., 60., 138., 111., 106., 101., 115., 142.},
57         {112., 118., 111., 142., 91., 130., 25., 61.},
58         {115., 150., 106., 91., 116., 121., 128., 85.},
59         {173., 130., 101., 130., 121., 70., 151., 132.},
60         {182., 91., 115., 25., 128., 151., 66., 92.},
61         {125., 72., 142., 61., 85., 132., 92., 156.},
62     };
63     low_rank_4x4_ = Array2D<float>{
64         // x = [[1, 2, 3, 4], [1, -1, 1, -1]]
65         // matmul(x.T, x)
66         {2, 1, 4, 3},
67         {1, 5, 5, 9},
68         {4, 5, 10, 11},
69         {3, 9, 11, 17},
70     };
71   }
TearDown()72   void TearDown() override { ClientLibraryTestBase::TearDown(); }
73 
GetUnitMatrix3D(const Array3D<float> & matrix)74   Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
75     Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
76     for (int i = 0; i < matrix.n1(); ++i) {
77       for (int j = 0; j < matrix.n2(); ++j) {
78         result({i, j, j}) = 1.0;
79       }
80     }
81     return result;
82   }
83 
ExtractTriangularMatrix(const Array3D<float> & matrix,bool lower)84   Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
85                                          bool lower) {
86     Array3D<float> result(matrix);
87     for (int i = 0; i < result.n1(); ++i) {
88       for (int j = 0; j < result.n2(); ++j) {
89         if (lower) {
90           for (int k = j + 1; k < result.n3(); ++k) {
91             result({i, j, k}) = 0.0;
92           }
93         } else {
94           for (int k = 0; k < j; ++k) {
95             result({i, j, k}) = 0.0;
96           }
97         }
98       }
99     }
100     return result;
101   }
102 
ComputeMatmulVWVt(SelfAdjointEigResult result,XlaBuilder * builder)103   XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
104     Shape shape = builder->GetShape(result.v).ValueOrDie();
105     std::vector<int64> out_dims = shape.dimensions();
106     std::vector<int64> broadcast_dims(shape.rank() - 1);
107     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
108 
109     broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
110     auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims));
111     return BatchDot(vw, TransposeInMinorDims(result.v),
112                     PrecisionConfig::HIGHEST);
113   }
114 
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)115   XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
116     Shape shape = builder->GetShape(m1).ValueOrDie();
117     int64 size = 1;
118     for (auto d : shape.dimensions()) {
119       size *= d;
120     }
121     return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
122                      CreateScalarAddComputation(F32, builder)) /
123            ConstantR0WithType(builder, F32, size);
124   }
125 
GenerateRandomSymmetricMatrix(int size)126   Array2D<float> GenerateRandomSymmetricMatrix(int size) {
127     Array2D<float> result{size, size, 0.0};
128     // TODO(b/128001705): This seed should not be needed but makes the test
129     // avoid inputs which trigger numerical instability.
130     result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */);
131     for (int i = 0; i < size; ++i) {
132       for (int j = 0; j < i; ++j) {
133         result({j, i}) = result({i, j});
134       }
135     }
136     return result;
137   }
138 
139   Array3D<float> batch_3d_4x4_;
140   Array2D<float> matrix2d_8x8_;
141   Array2D<float> low_rank_4x4_;
142   Array2D<int> wrong_type_4x4_;
143 };
144 
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_2x4x4)145 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
146   XlaBuilder builder(TestName());
147 
148   XlaOp a;
149   auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
150   auto result = SelfAdjointEig(a);
151   ComputeMatmulVWVt(result, &builder);
152 
153   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
154                              ErrorSpec(1e-3, 1e-3));
155 }
156 
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Lower_2x4x4)157 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
158   XlaBuilder builder(TestName());
159 
160   XlaOp a;
161   auto a_data = CreateR3Parameter<float>(
162       ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
163   auto result = SelfAdjointEig(a);
164   ComputeMatmulVWVt(result, &builder);
165 
166   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
167                              ErrorSpec(1e-3, 1e-3));
168 }
169 
XLA_TEST_F(SelfAdjointEigTest,Test_VWVt_EQ_A_Upper_2x4x4)170 XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) {
171   XlaBuilder builder(TestName());
172 
173   XlaOp a;
174   auto a_data = CreateR3Parameter<float>(
175       ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
176   auto result = SelfAdjointEig(a, false);
177   ComputeMatmulVWVt(result, &builder);
178 
179   ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
180                              ErrorSpec(1e-3, 1e-3));
181 }
182 
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_2x4x4)183 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) {
184   XlaBuilder builder(TestName());
185 
186   XlaOp a;
187   auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
188   auto result = SelfAdjointEig(a);
189   BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
190 
191   ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_),
192                              {a_data.get()}, ErrorSpec(1e-3, 1e-3));
193 }
194 
XLA_TEST_F(SelfAdjointEigTest,Test_VtWV_EQ_A_Rank_Deficient_4x4)195 XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
196   XlaBuilder builder(TestName());
197 
198   XlaOp a;
199   auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
200   auto result = SelfAdjointEig(a);
201   ComputeMatmulVWVt(result, &builder);
202 
203   ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
204                              ErrorSpec(1e-3, 1e-3));
205 }
206 
XLA_TEST_F(SelfAdjointEigTest,Test_Eigen_8x8)207 XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) {
208   XlaBuilder builder(TestName());
209 
210   // This is computed by numpy.linalg.eigh with float32.
211   std::vector<float> expected{-182.69205, -116.86245, -105.74489, -9.545369,
212                               37.81711,   104.732285, 120.29153,  868.00385};
213 
214   XlaOp a;
215   auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
216   auto result = SelfAdjointEig(a);
217   Add(result.w, ZerosLike(result.w));
218 
219   ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
220                              ErrorSpec(1e-3, 1e-3));
221 }
222 
XLA_TEST_F(SelfAdjointEigTest,Test_Orthogonality_8x8)223 XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) {
224   XlaBuilder builder(TestName());
225 
226   float expected_vals = 1e-3;
227 
228   XlaOp a;
229   auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
230   auto result = SelfAdjointEig(a);
231   // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
232   GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
233                           BatchDot(TransposeInMinorDims(result.v), result.v),
234                           &builder);
235 
236   ComputeAndCompareR0<float>(&builder, expected_vals, {a_data.get()},
237                              ErrorSpec(1e-3, 1e-3));
238 }
239 
XLA_TEST_F(SelfAdjointEigTest,Wrong_Type_Int)240 XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) {
241   XlaBuilder builder(TestName());
242 
243   XlaOp a;
244   auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
245   auto result = SelfAdjointEig(a);
246   EXPECT_FALSE(result.v.valid());
247   EXPECT_FALSE(result.w.valid());
248 }
249 
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_8x8)250 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) {
251   XlaBuilder builder(TestName());
252   int size = 8;
253   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
254   XlaOp a;
255   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
256   auto result = SelfAdjointEig(a);
257   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
258 
259   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
260                              ErrorSpec(1e-3, 1e-3));
261 }
262 
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_16x16)263 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) {
264   XlaBuilder builder(TestName());
265   int size = 16;
266   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
267   XlaOp a;
268   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
269   auto result = SelfAdjointEig(a);
270   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
271 
272   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
273                              ErrorSpec(1e-3, 1e-3));
274 }
275 
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_32x32)276 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) {
277   XlaBuilder builder(TestName());
278   int size = 32;
279   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
280   XlaOp a;
281   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
282   auto result = SelfAdjointEig(a);
283   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
284 
285   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
286                              ErrorSpec(1e-3, 1e-3));
287 }
288 
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_256x256)289 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) {
290   XlaBuilder builder(TestName());
291   int size = 256;
292   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
293   XlaOp a;
294   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
295   auto result = SelfAdjointEig(a);
296   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
297 
298   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
299                              ErrorSpec(1e-3, 1e-3));
300 }
301 
XLA_TEST_F(SelfAdjointEigTest,Various_Size_Random_Matrix_512x512)302 XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) {
303   XlaBuilder builder(TestName());
304   int size = 512;
305   Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
306   XlaOp a;
307   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
308   auto result = SelfAdjointEig(a);
309   GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
310 
311   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
312                              ErrorSpec(1e-3, 1e-3));
313 }
314 
315 }  // namespace xla
316