• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/svd.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array3d.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/tensor_float_32_utils.h"
37 
38 namespace xla {
39 
40 class SVDTest : public ClientLibraryTestBase {
41  protected:
SetUp()42   void SetUp() override {
43     ClientLibraryTestBase::SetUp();
44     batch_3d_4x5_ = Array3D<float>{
45         {
46             {4, 6, 8, 10, 1},
47             {6, 45, 54, 63, 1},
48             {8, 54, 146, 166, 1},
49             {10, 63, 166, 310, 1},
50         },
51         {
52             {16, 24, 8, 12, 6},
53             {24, 61, 82, 48, 5},
54             {8, 82, 100, 6, 4},
55             {12, 48, 6, 62, 3},
56         },
57     };
58 
59     // Test fails with TensorFloat-32 enabled
60     tensorflow::enable_tensor_float_32_execution(false);
61   }
TearDown()62   void TearDown() override { ClientLibraryTestBase::TearDown(); }
63 
GetUnitMatrix3D(int32_t batch_dim,int32_t mat_dim)64   Array3D<float> GetUnitMatrix3D(int32_t batch_dim, int32_t mat_dim) {
65     Array3D<float> result(batch_dim, mat_dim, mat_dim, 0.0);
66     for (int i = 0; i < batch_dim; ++i) {
67       for (int j = 0; j < mat_dim; ++j) {
68         result({i, j, j}) = 1.0;
69       }
70     }
71     return result;
72   }
73 
ComputeMatmulUDVT(SVDResult result,XlaBuilder * builder)74   XlaOp ComputeMatmulUDVT(SVDResult result, XlaBuilder* builder) {
75     Shape u_shape = builder->GetShape(result.u).ValueOrDie();
76     Shape v_shape = builder->GetShape(result.v).ValueOrDie();
77 
78     int64_t m = ShapeUtil::GetDimension(u_shape, -1);
79     int64_t n = ShapeUtil::GetDimension(v_shape, -1);
80 
81     auto v = result.v;
82     auto u = result.u;
83     auto d = result.d;
84 
85     if (m > n) {
86       u = SliceInMinorDims(u, {0, 0}, {m, n});
87     } else if (m < n) {
88       v = SliceInMinorDims(v, {0, 0}, {n, m});
89     }
90 
91     int num_dims = u_shape.rank();
92     std::vector<int64_t> broadcast_dims(num_dims - 1);
93     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
94     broadcast_dims[num_dims - 2] = num_dims - 1;
95     return BatchDot(Mul(u, d, broadcast_dims), TransposeInMinorDims(v),
96                     PrecisionConfig::HIGHEST);
97   }
98 
GetAverageAbsoluteError(XlaOp m1,XlaOp m2,XlaBuilder * builder)99   XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
100     Shape shape = builder->GetShape(m1).ValueOrDie();
101     int64_t size = 1;
102     for (auto d : shape.dimensions()) {
103       size *= d;
104     }
105     return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
106                      CreateScalarAddComputation(F32, builder)) /
107            ConstantR0WithType(builder, F32, size);
108   }
109 
GenerateRandomMatrix(int xsize,int ysize)110   Array2D<float> GenerateRandomMatrix(int xsize, int ysize) {
111     Array2D<float> result{xsize, ysize, 0.0};
112     result.FillRandom(10 /* stddev */, 2 /* mean */);
113     return result;
114   }
115 
116   Array3D<float> batch_3d_4x5_;
117 };
118 
XLA_TEST_F(SVDTest,Simple2D)119 XLA_TEST_F(SVDTest, Simple2D) {
120   XlaBuilder builder(TestName());
121 
122   Array2D<float> simple_2d_4x4_ = Array2D<float>{
123       {4, 6, 8, 10},
124       {6, 45, 54, 63},
125       {8, 54, 146, 166},
126       {10, 63, 166, 310},
127   };
128   XlaOp a;
129   auto a_data = CreateR2Parameter<float>(simple_2d_4x4_, 0, "a", &builder, &a);
130   auto result = SVD(a, 100, 1e-6);
131   ComputeMatmulUDVT(result, &builder);
132 
133   ComputeAndCompareR2<float>(&builder, simple_2d_4x4_, {a_data.get()},
134                              ErrorSpec(1e-3, 1e-3));
135 }
136 
XLA_TEST_F(SVDTest,Test_VWVt_EQ_A_2x4x5)137 XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) {
138   XlaBuilder builder(TestName());
139 
140   XlaOp a;
141   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
142   auto result = SVD(a, 100, 1e-8);
143   ComputeMatmulUDVT(result, &builder);
144 
145   ComputeAndCompareR3<float>(&builder, batch_3d_4x5_, {a_data.get()},
146                              ErrorSpec(1e-3, 1e-3));
147 }
148 
XLA_TEST_F(SVDTest,Test_Orthogonality_U)149 XLA_TEST_F(SVDTest, Test_Orthogonality_U) {
150   XlaBuilder builder(TestName());
151 
152   XlaOp a;
153   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
154   auto result = SVD(a, 100, 1e-8);
155   ComputeMatmulUDVT(result, &builder);
156   BatchDot(result.u, TransposeInMinorDims(result.u));
157 
158   ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 4), {a_data.get()},
159                              ErrorSpec(1e-2, 1e-2));
160 }
161 
XLA_TEST_F(SVDTest,Test_Orthogonality_V)162 XLA_TEST_F(SVDTest, Test_Orthogonality_V) {
163   XlaBuilder builder(TestName());
164 
165   XlaOp a;
166   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
167   auto result = SVD(a, 100, 1e-8);
168   BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
169 
170   ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(2, 5), {a_data.get()},
171                              ErrorSpec(1e-3, 1e-3));
172 }
173 
XLA_TEST_F(SVDTest,TestSingleValuesMatchNumpy)174 XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) {
175   XlaBuilder builder(TestName());
176 
177   auto singular_values = Array2D<float>{
178       {431.05153007, 49.88334164, 20.94464584, 3.24845468},
179       {179.73128591, 68.05162245, 21.77679503, 13.94319712},
180   };
181 
182   XlaOp a;
183   auto a_data = CreateR3Parameter<float>(batch_3d_4x5_, 0, "a", &builder, &a);
184   auto result = SVD(a, 100, 1e-8);
185   Add(result.d, ZerosLike(result.d));
186 
187   ComputeAndCompareR2<float>(&builder, singular_values, {a_data.get()},
188                              ErrorSpec(1e-3, 1e-3));
189 }
190 
191 // Too slow on the interpreter backend.
XLA_TEST_F(SVDTest,DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_512x128))192 XLA_TEST_F(SVDTest,
193            DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) {
194   XlaBuilder builder(TestName());
195   Array2D<float> a_val = GenerateRandomMatrix(512, 128);
196   XlaOp a;
197   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
198   auto result = SVD(a, 100, 1e-4);
199   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
200 
201   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
202                              ErrorSpec(1e-3, 1e-3));
203 }
204 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_128x256)205 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) {
206   XlaBuilder builder(TestName());
207   Array2D<float> a_val = GenerateRandomMatrix(128, 256);
208   XlaOp a;
209   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
210   auto result = SVD(a, 100, 1e-4);
211   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
212 
213   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
214                              ErrorSpec(1e-3, 1e-3));
215 }
216 
XLA_TEST_F(SVDTest,Various_Size_Random_Matrix_256x128)217 XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) {
218   XlaBuilder builder(TestName());
219   Array2D<float> a_val = GenerateRandomMatrix(256, 128);
220   XlaOp a;
221   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
222   auto result = SVD(a, 100, 1e-4);
223   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
224 
225   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
226                              ErrorSpec(1e-3, 1e-3));
227 }
228 
229 // Too slow on the interpreter backend.
XLA_TEST_F(SVDTest,DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_128x512))230 XLA_TEST_F(SVDTest,
231            DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) {
232   XlaBuilder builder(TestName());
233   Array2D<float> a_val = GenerateRandomMatrix(128, 512);
234   XlaOp a;
235   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
236   auto result = SVD(a, 100, 1e-4);
237   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
238 
239   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
240                              ErrorSpec(1e-3, 1e-3));
241 }
242 
243 // Too slow on the interpreter and CPU backends.
XLA_TEST_F(SVDTest,DISABLED_ON_CPU (DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_512x256)))244 XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
245                         Various_Size_Random_Matrix_512x256))) {
246   XlaBuilder builder(TestName());
247   Array2D<float> a_val = GenerateRandomMatrix(512, 256);
248   XlaOp a;
249   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
250   auto result = SVD(a, 100, 1e-4);
251   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
252 
253   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
254                              ErrorSpec(1e-3, 1e-3));
255 }
256 
257 // Too slow on the CPU, GPU and interpreter backends.
XLA_TEST_F(SVDTest,DISABLED_ON_GPU (DISABLED_ON_CPU (DISABLED_ON_INTERPRETER (Various_Size_Random_Matrix_512x512))))258 XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
259                         Various_Size_Random_Matrix_512x512)))) {
260   XlaBuilder builder(TestName());
261   Array2D<float> a_val = GenerateRandomMatrix(512, 512);
262   XlaOp a;
263   auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
264   auto result = SVD(a, 100, 1e-4);
265   GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
266 
267   ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
268                              ErrorSpec(1e-3, 1e-3));
269 }
270 
271 }  // namespace xla
272