1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/svd.h"
17 #include <utility>
18
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/matrix.h"
24 #include "tensorflow/compiler/xla/client/lib/slicing.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace xla {
37
38 class SVDTest : public ClientLibraryTestBase {
39 protected:
SetUp()40 void SetUp() override {
41 ClientLibraryTestBase::SetUp();
42 batch_3d_4x5_ = Array3D<float>{
43 {
44 {4, 6, 8, 10, 1},
45 {6, 45, 54, 63, 1},
46 {8, 54, 146, 166, 1},
47 {10, 63, 166, 310, 1},
48 },
49 {
50 {16, 24, 8, 12, 6},
51 {24, 61, 82, 48, 5},
52 {8, 82, 100, 6, 4},
53 {12, 48, 6, 62, 3},
54 },
55 };
56 }
TearDown()57 void TearDown() override { ClientLibraryTestBase::TearDown(); }
58
GetUnitMatrix3D(int32 batch_dim,int32 mat_dim)59 Array3D<float> GetUnitMatrix3D(int32 batch_dim, int32 mat_dim) {
60 Array3D<float> result(batch_dim, mat_dim, mat_dim, 0.0);
61 for (int i = 0; i < batch_dim; ++i) {
62 for (int j = 0; j < mat_dim; ++j) {
63 result({i, j, j}) = 1.0;
64 }
65 }
66 return result;
67 }
68
ComputeMatmulUDVT(SVDResult result,XlaBuilder * builder)69 XlaOp ComputeMatmulUDVT(SVDResult result, XlaBuilder* builder) {
70 Shape u_shape = builder->GetShape(result.u).ValueOrDie();
71 Shape v_shape = builder->GetShape(result.v).ValueOrDie();
72
73 int64 m = ShapeUtil::GetDimension(u_shape, -1);
74 int64 n = ShapeUtil::GetDimension(v_shape, -1);
75
76 auto v = result.v;
77 auto u = result.u;
78 auto d = result.d;
79
80 auto zero = Zero(builder, S32);
81 if (m > n) {
82 u = DynamicSliceInMinorDims(u, {zero, zero}, {m, n});
83 } else if (m < n) {
84 v = DynamicSliceInMinorDims(v, {zero, zero}, {n, m});
85 }
86
87 int num_dims = u_shape.rank();
88 std::vector<int64> broadcast_dims(num_dims - 1);
89 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
90 broadcast_dims[num_dims - 2] = num_dims - 1;
91 return BatchDot(Mul(u, d, broadcast_dims), TransposeInMinorDims(v),
92 PrecisionConfig::HIGHEST);
93 }
94
ExtractTriangularMatrix(const Array3D<float> & matrix,bool lower)95 Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
96 bool lower) {
97 Array3D<float> result(matrix);
98 for (int i = 0; i < result.n1(); ++i) {
99 for (int j = 0; j < result.n2(); ++j) {
100 if (lower) {
101 for (int k = j + 1; k < result.n3(); ++k) {
102 result({i, j, k}) = 0.0;
103 }
104 } else {
105 for (int k = 0; k < j; ++k) {
106 result({i, j, k}) = 0.0;
107 }
108 }
109 }
110 }
111 return result;
112 }
113
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)114 XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
115 Shape shape = builder->GetShape(m1).ValueOrDie();
116 int64 size = 1;
117 for (auto d : shape.dimensions()) {
118 size *= d;
119 }
120 return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
121 CreateScalarAddComputation(F32, builder)) /
122 ConstantR0WithType(builder, F32, size);
123 }
124
GenerateRandomMatrix(int xsize,int ysize)125 Array2D<float> GenerateRandomMatrix(int xsize, int ysize) {
126 Array2D<float> result{xsize, ysize, 0.0};
127 result.FillRandom(10 /* stddev */, 2 /* mean */);
128 return result;
129 }
130
131 Array3D<float> batch_3d_4x5_;
132 };
133
XLA_TEST_F(SVDTest,Simple2D)134 XLA_TEST_F(SVDTest, Simple2D) {
135 XlaBuilder builder(TestName());
136
137 Array2D<float> simple_2d_4x4_ = Array2D<float>{
138 {4, 6, 8, 10},
139 {6, 45, 54, 63},
140 {8, 54, 146, 166},
141 {10, 63, 166, 310},
142 };
143 XlaOp a;
144 auto a_data = CreateR2Parameter<float>(simple_2d_4x4_, 0, "a", &builder, &a);
145 auto result = SVD(a, 100, 1e-6);
146 ComputeMatmulUDVT(result, &builder);
147
148 ComputeAndCompareR2<float>(&builder, simple_2d_4x4_, {a_data.get()},
149 ErrorSpec(1e-3, 1e-3));
150 }
151
XLA_TEST_F(SVDTest,Test_VWVt_EQ_A_2x4x5)152 XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) {
153 XlaBuilder builder(TestName());
154
155 XlaOp a;
156 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
157 auto result = SVD(a, 100, 1e-8);
158 ComputeMatmulUDVT(result, &builder);
159
160 ComputeAndCompareR3<float>(&builder, batch_3d_4x5_, {a_data.get()},
161 ErrorSpec(1e-3, 1e-3));
162 }
163
XLA_TEST_F(SVDTest,Test_Orthogonality_U)164 XLA_TEST_F(SVDTest, Test_Orthogonality_U) {
165 XlaBuilder builder(TestName());
166
167 XlaOp a;
168 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
169 auto result = SVD(a, 100, 1e-8);
170 ComputeMatmulUDVT(result, &builder);
171 BatchDot(result.u, TransposeInMinorDims(result.u));
172
173 ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 4), {a_data.get()},
174 ErrorSpec(1e-2, 1e-2));
175 }
176
XLA_TEST_F(SVDTest,Test_Orthogonality_V)177 XLA_TEST_F(SVDTest, Test_Orthogonality_V) {
178 XlaBuilder builder(TestName());
179
180 XlaOp a;
181 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
182 auto result = SVD(a, 100, 1e-8);
183 BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
184
185 ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 5), {a_data.get()},
186 ErrorSpec(1e-3, 1e-3));
187 }
188
XLA_TEST_F(SVDTest,TestSingleValuesMatchNumpy)189 XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) {
190 XlaBuilder builder(TestName());
191
192 auto singular_values = Array2D<float>{
193 {431.05153007, 49.88334164, 20.94464584, 3.24845468},
194 {179.73128591, 68.05162245, 21.77679503, 13.94319712},
195 };
196
197 XlaOp a;
198 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
199 auto result = SVD(a, 100, 1e-8);
200 Add(result.d, ZerosLike(result.d));
201
202 ComputeAndCompareR2<float>(&builder, singular_values, {a_data.get()},
203 ErrorSpec(1e-3, 1e-3));
204 }
205
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_512x128)206 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x128) {
207 XlaBuilder builder(TestName());
208 Array2D<float> a_val = GenerateRandomMatrix(512, 128);
209 XlaOp a;
210 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
211 auto result = SVD(a, 100, 1e-6);
212 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
213
214 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
215 ErrorSpec(1e-3, 1e-3));
216 }
217
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_128x256)218 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) {
219 XlaBuilder builder(TestName());
220 Array2D<float> a_val = GenerateRandomMatrix(128, 256);
221 XlaOp a;
222 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
223 auto result = SVD(a, 100, 1e-6);
224 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
225
226 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
227 ErrorSpec(1e-3, 1e-3));
228 }
229
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_256x128)230 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) {
231 XlaBuilder builder(TestName());
232 Array2D<float> a_val = GenerateRandomMatrix(256, 128);
233 XlaOp a;
234 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
235 auto result = SVD(a, 100, 1e-6);
236 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
237
238 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
239 ErrorSpec(1e-3, 1e-3));
240 }
241
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_128x512)242 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x512) {
243 XlaBuilder builder(TestName());
244 Array2D<float> a_val = GenerateRandomMatrix(128, 512);
245 XlaOp a;
246 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
247 auto result = SVD(a, 100, 1e-6);
248 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
249
250 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
251 ErrorSpec(1e-3, 1e-3));
252 }
253
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_512x256)254 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x256) {
255 XlaBuilder builder(TestName());
256 Array2D<float> a_val = GenerateRandomMatrix(512, 256);
257 XlaOp a;
258 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
259 auto result = SVD(a, 100, 1e-6);
260 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
261
262 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
263 ErrorSpec(1e-3, 1e-3));
264 }
265
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_512x512)266 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) {
267 XlaBuilder builder(TestName());
268 Array2D<float> a_val = GenerateRandomMatrix(512, 512);
269 XlaOp a;
270 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
271 auto result = SVD(a, 100, 1e-6);
272 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
273
274 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
275 ErrorSpec(1e-3, 1e-3));
276 }
277
278 } // namespace xla
279