1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/svd.h"
17
18 #include <utility>
19
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array3d.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/tensor_float_32_utils.h"
37
38 namespace xla {
39
40 class SVDTest : public ClientLibraryTestBase {
41 protected:
SetUp()42 void SetUp() override {
43 ClientLibraryTestBase::SetUp();
44 batch_3d_4x5_ = Array3D<float>{
45 {
46 {4, 6, 8, 10, 1},
47 {6, 45, 54, 63, 1},
48 {8, 54, 146, 166, 1},
49 {10, 63, 166, 310, 1},
50 },
51 {
52 {16, 24, 8, 12, 6},
53 {24, 61, 82, 48, 5},
54 {8, 82, 100, 6, 4},
55 {12, 48, 6, 62, 3},
56 },
57 };
58
59 // Test fails with TensorFloat-32 enabled
60 tensorflow::enable_tensor_float_32_execution(false);
61 }
TearDown()62 void TearDown() override { ClientLibraryTestBase::TearDown(); }
63
GetUnitMatrix3D(int32_t batch_dim,int32_t mat_dim)64 Array3D<float> GetUnitMatrix3D(int32_t batch_dim, int32_t mat_dim) {
65 Array3D<float> result(batch_dim, mat_dim, mat_dim, 0.0);
66 for (int i = 0; i < batch_dim; ++i) {
67 for (int j = 0; j < mat_dim; ++j) {
68 result({i, j, j}) = 1.0;
69 }
70 }
71 return result;
72 }
73
ComputeMatmulUDVT(SVDResult result,XlaBuilder * builder)74 XlaOp ComputeMatmulUDVT(SVDResult result, XlaBuilder* builder) {
75 Shape u_shape = builder->GetShape(result.u).ValueOrDie();
76 Shape v_shape = builder->GetShape(result.v).ValueOrDie();
77
78 int64_t m = ShapeUtil::GetDimension(u_shape, -1);
79 int64_t n = ShapeUtil::GetDimension(v_shape, -1);
80
81 auto v = result.v;
82 auto u = result.u;
83 auto d = result.d;
84
85 if (m > n) {
86 u = SliceInMinorDims(u, {0, 0}, {m, n});
87 } else if (m < n) {
88 v = SliceInMinorDims(v, {0, 0}, {n, m});
89 }
90
91 int num_dims = u_shape.rank();
92 std::vector<int64_t> broadcast_dims(num_dims - 1);
93 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
94 broadcast_dims[num_dims - 2] = num_dims - 1;
95 return BatchDot(Mul(u, d, broadcast_dims), TransposeInMinorDims(v),
96 PrecisionConfig::HIGHEST);
97 }
98
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)99 XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
100 Shape shape = builder->GetShape(m1).ValueOrDie();
101 int64_t size = 1;
102 for (auto d : shape.dimensions()) {
103 size *= d;
104 }
105 return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
106 CreateScalarAddComputation(F32, builder)) /
107 ConstantR0WithType(builder, F32, size);
108 }
109
GenerateRandomMatrix(int xsize,int ysize)110 Array2D<float> GenerateRandomMatrix(int xsize, int ysize) {
111 Array2D<float> result{xsize, ysize, 0.0};
112 result.FillRandom(10 /* stddev */, 2 /* mean */);
113 return result;
114 }
115
116 Array3D<float> batch_3d_4x5_;
117 };
118
XLA_TEST_F(SVDTest,Simple2D)119 XLA_TEST_F(SVDTest, Simple2D) {
120 XlaBuilder builder(TestName());
121
122 Array2D<float> simple_2d_4x4_ = Array2D<float>{
123 {4, 6, 8, 10},
124 {6, 45, 54, 63},
125 {8, 54, 146, 166},
126 {10, 63, 166, 310},
127 };
128 XlaOp a;
129 auto a_data = CreateR2Parameter<float>(simple_2d_4x4_, 0, "a", &builder, &a);
130 auto result = SVD(a, 100, 1e-6);
131 ComputeMatmulUDVT(result, &builder);
132
133 ComputeAndCompareR2<float>(&builder, simple_2d_4x4_, {a_data.get()},
134 ErrorSpec(1e-3, 1e-3));
135 }
136
XLA_TEST_F(SVDTest,Test_VWVt_EQ_A_2x4x5)137 XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) {
138 XlaBuilder builder(TestName());
139
140 XlaOp a;
141 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
142 auto result = SVD(a, 100, 1e-8);
143 ComputeMatmulUDVT(result, &builder);
144
145 ComputeAndCompareR3<float>(&builder, batch_3d_4x5_, {a_data.get()},
146 ErrorSpec(1e-3, 1e-3));
147 }
148
XLA_TEST_F(SVDTest,Test_Orthogonality_U)149 XLA_TEST_F(SVDTest, Test_Orthogonality_U) {
150 XlaBuilder builder(TestName());
151
152 XlaOp a;
153 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
154 auto result = SVD(a, 100, 1e-8);
155 ComputeMatmulUDVT(result, &builder);
156 BatchDot(result.u, TransposeInMinorDims(result.u));
157
158 ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 4), {a_data.get()},
159 ErrorSpec(1e-2, 1e-2));
160 }
161
XLA_TEST_F(SVDTest,Test_Orthogonality_V)162 XLA_TEST_F(SVDTest, Test_Orthogonality_V) {
163 XlaBuilder builder(TestName());
164
165 XlaOp a;
166 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
167 auto result = SVD(a, 100, 1e-8);
168 BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
169
170 ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 5), {a_data.get()},
171 ErrorSpec(1e-3, 1e-3));
172 }
173
XLA_TEST_F(SVDTest,TestSingleValuesMatchNumpy)174 XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) {
175 XlaBuilder builder(TestName());
176
177 auto singular_values = Array2D<float>{
178 {431.05153007, 49.88334164, 20.94464584, 3.24845468},
179 {179.73128591, 68.05162245, 21.77679503, 13.94319712},
180 };
181
182 XlaOp a;
183 auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
184 auto result = SVD(a, 100, 1e-8);
185 Add(result.d, ZerosLike(result.d));
186
187 ComputeAndCompareR2<float>(&builder, singular_values, {a_data.get()},
188 ErrorSpec(1e-3, 1e-3));
189 }
190
191 // Too slow on the interpreter backend.
XLA_TEST_F(SVDTest,DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_512x128))192 XLA_TEST_F(SVDTest,
193 DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) {
194 XlaBuilder builder(TestName());
195 Array2D<float> a_val = GenerateRandomMatrix(512, 128);
196 XlaOp a;
197 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
198 auto result = SVD(a, 100, 1e-4);
199 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
200
201 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
202 ErrorSpec(1e-3, 1e-3));
203 }
204
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_128x256)205 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) {
206 XlaBuilder builder(TestName());
207 Array2D<float> a_val = GenerateRandomMatrix(128, 256);
208 XlaOp a;
209 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
210 auto result = SVD(a, 100, 1e-4);
211 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
212
213 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
214 ErrorSpec(1e-3, 1e-3));
215 }
216
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_256x128)217 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) {
218 XlaBuilder builder(TestName());
219 Array2D<float> a_val = GenerateRandomMatrix(256, 128);
220 XlaOp a;
221 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
222 auto result = SVD(a, 100, 1e-4);
223 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
224
225 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
226 ErrorSpec(1e-3, 1e-3));
227 }
228
229 // Too slow on the interpreter backend.
XLA_TEST_F(SVDTest,DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_128x512))230 XLA_TEST_F(SVDTest,
231 DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) {
232 XlaBuilder builder(TestName());
233 Array2D<float> a_val = GenerateRandomMatrix(128, 512);
234 XlaOp a;
235 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
236 auto result = SVD(a, 100, 1e-4);
237 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
238
239 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
240 ErrorSpec(1e-3, 1e-3));
241 }
242
243 // Too slow on the interpreter and CPU backends.
XLA_TEST_F(SVDTest,DISABLED_ON_CPU (DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_512x256)))244 XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
245 Various_Size_Random_Matrix_512x256))) {
246 XlaBuilder builder(TestName());
247 Array2D<float> a_val = GenerateRandomMatrix(512, 256);
248 XlaOp a;
249 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
250 auto result = SVD(a, 100, 1e-4);
251 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
252
253 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
254 ErrorSpec(1e-3, 1e-3));
255 }
256
257 // Too slow on the CPU, GPU and interpreter backends.
XLA_TEST_F(SVDTest,DISABLED_ON_GPU (DISABLED_ON_CPU (DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_512x512))))258 XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
259 Various_Size_Random_Matrix_512x512)))) {
260 XlaBuilder builder(TestName());
261 Array2D<float> a_val = GenerateRandomMatrix(512, 512);
262 XlaOp a;
263 auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
264 auto result = SVD(a, 100, 1e-4);
265 GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
266
267 ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
268 ErrorSpec(1e-3, 1e-3));
269 }
270
271 } // namespace xla
272