• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/svd.h"
17 #include <utility>
18 
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/matrix.h"
24 #include "tensorflow/compiler/xla/client/lib/slicing.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 
38 class SVDTest : public ClientLibraryTestBase {
39  protected:
SetUp()40   void SetUp() override {
41     ClientLibraryTestBase::SetUp();
42     batch_3d_4x5_ = Array3D<float>{
43         {
44             {4, 6, 8, 10, 1},
45             {6, 45, 54, 63, 1},
46             {8, 54, 146, 166, 1},
47             {10, 63, 166, 310, 1},
48         },
49         {
50             {16, 24, 8, 12, 6},
51             {24, 61, 82, 48, 5},
52             {8, 82, 100, 6, 4},
53             {12, 48, 6, 62, 3},
54         },
55     };
56   }
TearDown()57   void TearDown() override { ClientLibraryTestBase::TearDown(); }
58 
GetUnitMatrix3D(int32 batch_dim,int32 mat_dim)59   Array3D<float> GetUnitMatrix3D(int32 batch_dim, int32 mat_dim) {
60     Array3D<float> result(batch_dim, mat_dim, mat_dim, 0.0);
61     for (int i = 0; i < batch_dim; ++i) {
62       for (int j = 0; j < mat_dim; ++j) {
63         result({i, j, j}) = 1.0;
64       }
65     }
66     return result;
67   }
68 
ComputeMatmulUDVT(SVDResult result,XlaBuilder * builder)69   XlaOp ComputeMatmulUDVT(SVDResult result, XlaBuilder* builder) {
70     Shape u_shape = builder->GetShape(result.u).ValueOrDie();
71     Shape v_shape = builder->GetShape(result.v).ValueOrDie();
72 
73     int64 m = ShapeUtil::GetDimension(u_shape, -1);
74     int64 n = ShapeUtil::GetDimension(v_shape, -1);
75 
76     auto v = result.v;
77     auto u = result.u;
78     auto d = result.d;
79 
80     auto zero = Zero(builder, S32);
81     if (m > n) {
82       u = DynamicSliceInMinorDims(u, {zero, zero}, {m, n});
83     } else if (m < n) {
84       v = DynamicSliceInMinorDims(v, {zero, zero}, {n, m});
85     }
86 
87     int num_dims = u_shape.rank();
88     std::vector<int64> broadcast_dims(num_dims - 1);
89     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
90     broadcast_dims[num_dims - 2] = num_dims - 1;
91     return BatchDot(Mul(u, d, broadcast_dims), TransposeInMinorDims(v),
92                     PrecisionConfig::HIGHEST);
93   }
94 
ExtractTriangularMatrix(const Array3D<float> & matrix,bool lower)95   Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
96                                          bool lower) {
97     Array3D<float> result(matrix);
98     for (int i = 0; i < result.n1(); ++i) {
99       for (int j = 0; j < result.n2(); ++j) {
100         if (lower) {
101           for (int k = j + 1; k < result.n3(); ++k) {
102             result({i, j, k}) = 0.0;
103           }
104         } else {
105           for (int k = 0; k < j; ++k) {
106             result({i, j, k}) = 0.0;
107           }
108         }
109       }
110     }
111     return result;
112   }
113 
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)114   XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
115     Shape shape = builder->GetShape(m1).ValueOrDie();
116     int64 size = 1;
117     for (auto d : shape.dimensions()) {
118       size *= d;
119     }
120     return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
121                      CreateScalarAddComputation(F32, builder)) /
122            ConstantR0WithType(builder, F32, size);
123   }
124 
GenerateRandomMatrix(int xsize,int ysize)125   Array2D<float> GenerateRandomMatrix(int xsize, int ysize) {
126     Array2D<float> result{xsize, ysize, 0.0};
127     result.FillRandom(10 /* stddev */, 2 /* mean */);
128     return result;
129   }
130 
131   Array3D<float> batch_3d_4x5_;
132 };
133 
XLA_TEST_F(SVDTest,Simple2D)134 XLA_TEST_F(SVDTest, Simple2D) {
135   XlaBuilder builder(TestName());
136 
137   Array2D<float> simple_2d_4x4_ = Array2D<float>{
138       {4, 6, 8, 10},
139       {6, 45, 54, 63},
140       {8, 54, 146, 166},
141       {10, 63, 166, 310},
142   };
143   XlaOp a;
144   auto a_data = CreateR2Parameter<float>(simple_2d_4x4_, 0, "a", &builder, &a);
145   auto result = SVD(a, 100, 1e-6);
146   ComputeMatmulUDVT(result, &builder);
147 
148   ComputeAndCompareR2<float>(&builder, simple_2d_4x4_, {a_data.get()},
149                              ErrorSpec(1e-3, 1e-3));
150 }
151 
XLA_TEST_F(SVDTest,Test_VWVt_EQ_A_2x4x5)152 XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) {
153   XlaBuilder builder(TestName());
154 
155   XlaOp a;
156   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
157   auto result = SVD(a, 100, 1e-8);
158   ComputeMatmulUDVT(result, &builder);
159 
160   ComputeAndCompareR3<float>(&builder, batch_3d_4x5_, {a_data.get()},
161                              ErrorSpec(1e-3, 1e-3));
162 }
163 
XLA_TEST_F(SVDTest,Test_Orthogonality_U)164 XLA_TEST_F(SVDTest, Test_Orthogonality_U) {
165   XlaBuilder builder(TestName());
166 
167   XlaOp a;
168   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
169   auto result = SVD(a, 100, 1e-8);
170   ComputeMatmulUDVT(result, &builder);
171   BatchDot(result.u, TransposeInMinorDims(result.u));
172 
173   ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 4), {a_data.get()},
174                              ErrorSpec(1e-2, 1e-2));
175 }
176 
XLA_TEST_F(SVDTest,Test_Orthogonality_V)177 XLA_TEST_F(SVDTest, Test_Orthogonality_V) {
178   XlaBuilder builder(TestName());
179 
180   XlaOp a;
181   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
182   auto result = SVD(a, 100, 1e-8);
183   BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
184 
185   ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 5), {a_data.get()},
186                              ErrorSpec(1e-3, 1e-3));
187 }
188 
XLA_TEST_F(SVDTest,TestSingleValuesMatchNumpy)189 XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) {
190   XlaBuilder builder(TestName());
191 
192   auto singular_values = Array2D<float>{
193       {431.05153007, 49.88334164, 20.94464584, 3.24845468},
194       {179.73128591, 68.05162245, 21.77679503, 13.94319712},
195   };
196 
197   XlaOp a;
198   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
199   auto result = SVD(a, 100, 1e-8);
200   Add(result.d, ZerosLike(result.d));
201 
202   ComputeAndCompareR2<float>(&builder, singular_values, {a_data.get()},
203                              ErrorSpec(1e-3, 1e-3));
204 }
205 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_512x128)206 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x128) {
207   XlaBuilder builder(TestName());
208   Array2D<float> a_val = GenerateRandomMatrix(512, 128);
209   XlaOp a;
210   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
211   auto result = SVD(a, 100, 1e-6);
212   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
213 
214   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
215                              ErrorSpec(1e-3, 1e-3));
216 }
217 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_128x256)218 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) {
219   XlaBuilder builder(TestName());
220   Array2D<float> a_val = GenerateRandomMatrix(128, 256);
221   XlaOp a;
222   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
223   auto result = SVD(a, 100, 1e-6);
224   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
225 
226   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
227                              ErrorSpec(1e-3, 1e-3));
228 }
229 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_256x128)230 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) {
231   XlaBuilder builder(TestName());
232   Array2D<float> a_val = GenerateRandomMatrix(256, 128);
233   XlaOp a;
234   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
235   auto result = SVD(a, 100, 1e-6);
236   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
237 
238   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
239                              ErrorSpec(1e-3, 1e-3));
240 }
241 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_128x512)242 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x512) {
243   XlaBuilder builder(TestName());
244   Array2D<float> a_val = GenerateRandomMatrix(128, 512);
245   XlaOp a;
246   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
247   auto result = SVD(a, 100, 1e-6);
248   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
249 
250   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
251                              ErrorSpec(1e-3, 1e-3));
252 }
253 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_512x256)254 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x256) {
255   XlaBuilder builder(TestName());
256   Array2D<float> a_val = GenerateRandomMatrix(512, 256);
257   XlaOp a;
258   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
259   auto result = SVD(a, 100, 1e-6);
260   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
261 
262   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
263                              ErrorSpec(1e-3, 1e-3));
264 }
265 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_512x512)266 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) {
267   XlaBuilder builder(TestName());
268   Array2D<float> a_val = GenerateRandomMatrix(512, 512);
269   XlaOp a;
270   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
271   auto result = SVD(a, 100, 1e-6);
272   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
273 
274   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
275                              ErrorSpec(1e-3, 1e-3));
276 }
277 
278 }  // namespace xla
279