• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <limits>
17 #include <memory>
18 #include <numeric>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/matrix.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/tensor_float_32_utils.h"
34 
35 namespace xla {
36 namespace {
37 
38 using CholeskyTest = ClientLibraryTestBase;
39 
XLA_TEST_F(CholeskyTest,NonPSDInput)40 XLA_TEST_F(CholeskyTest, NonPSDInput) {
41   XlaBuilder builder(TestName());
42 
43   Array2D<float> a_vals({
44       {1, 1, 1},
45       {1, 1, 1},
46       {1, 1, 1},
47   });
48 
49   XlaOp a;
50   auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
51   Cholesky(a, /*lower=*/true);
52 
53   float nan = std::numeric_limits<float>::quiet_NaN();
54   Array2D<float> expected({
55       {nan, nan, nan},
56       {nan, nan, nan},
57       {nan, nan, nan},
58   });
59 
60   ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
61                              ErrorSpec(1e-4, 1e-4));
62 }
63 
XLA_TEST_F(CholeskyTest,NonPSDBatched)64 XLA_TEST_F(CholeskyTest, NonPSDBatched) {
65   XlaBuilder builder(TestName());
66 
67   Array3D<float> a_vals({
68       {
69           {10, 0, 0},
70           {1, 20, 0},
71           {1, 1, 30},
72       },
73       {
74           {1, 1, 1},
75           {1, 1, 1},
76           {1, 1, 1},
77       },
78   });
79 
80   XlaOp a;
81   auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
82   Cholesky(a, /*lower=*/true);
83 
84   float nan = std::numeric_limits<float>::quiet_NaN();
85   Array3D<float> expected({
86       {
87           {3.16227766, 0., 0.},
88           {0.31622777, 4.4609416, 0.},
89           {0.31622777, 0.20175113, 5.46436606},
90       },
91       {
92           {nan, nan, nan},
93           {nan, nan, nan},
94           {nan, nan, nan},
95       },
96   });
97 
98   ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
99                              ErrorSpec(1e-4, 1e-4));
100 }
101 
XLA_TEST_F(CholeskyTest,Lower)102 XLA_TEST_F(CholeskyTest, Lower) {
103   XlaBuilder builder(TestName());
104 
105   float nan = std::numeric_limits<float>::quiet_NaN();
106   Array2D<float> a_vals({
107       {4, nan, nan, nan},
108       {6, 45, nan, nan},
109       {8, 54, 146, nan},
110       {10, 63, 166, 310},
111   });
112 
113   XlaOp a;
114   auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
115   LowerTriangle(Cholesky(a, /*lower=*/true));
116 
117   Array2D<float> expected({
118       {2, 0, 0, 0},
119       {3, 6, 0, 0},
120       {4, 7, 9, 0},
121       {5, 8, 10, 11},
122   });
123 
124   ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
125                              ErrorSpec(1e-4, 1e-4));
126 }
127 
XLA_TEST_F(CholeskyTest,Upper)128 XLA_TEST_F(CholeskyTest, Upper) {
129   XlaBuilder builder(TestName());
130 
131   float nan = std::numeric_limits<float>::quiet_NaN();
132   Array2D<float> a_vals({
133       {4, 6, 8, 10},
134       {nan, 45, 54, 63},
135       {nan, nan, 146, 166},
136       {nan, nan, nan, 310},
137   });
138 
139   XlaOp a;
140   auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
141   UpperTriangle(Cholesky(a, /*lower=*/false));
142 
143   Array2D<float> expected({
144       {2, 3, 4, 5},
145       {0, 6, 7, 8},
146       {0, 0, 9, 10},
147       {0, 0, 0, 11},
148   });
149 
150   ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
151                              ErrorSpec(1e-4, 1e-4));
152 }
153 
XLA_TEST_F(CholeskyTest,Simple2)154 XLA_TEST_F(CholeskyTest, Simple2) {
155   XlaBuilder builder(TestName());
156 
157   Array2D<float> a_vals({
158       {16, 24, 8, 12},
159       {24, 61, 82, 48},
160       {8, 82, 456, 106},
161       {12, 48, 106, 62},
162   });
163 
164   XlaOp a;
165   auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
166   LowerTriangle(Cholesky(a, /*lower=*/true));
167 
168   Array2D<float> expected({{4, 0, 0, 0},    //
169                            {6, 5, 0, 0},    //
170                            {2, 14, 16, 0},  //
171                            {3, 6, 1, 4}});
172 
173   ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
174                              ErrorSpec(1e-4, 1e-4));
175 }
176 
XLA_TEST_F(CholeskyTest,SimpleBatched)177 XLA_TEST_F(CholeskyTest, SimpleBatched) {
178   XlaBuilder builder(TestName());
179 
180   Array3D<float> a_vals({
181       {
182           {4, 6, 8, 10},
183           {6, 45, 54, 63},
184           {8, 54, 146, 166},
185           {10, 63, 166, 310},
186       },
187       {
188           {16, 24, 8, 12},
189           {24, 61, 82, 48},
190           {8, 82, 456, 106},
191           {12, 48, 106, 62},
192       },
193   });
194 
195   XlaOp a;
196   auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
197   LowerTriangle(Cholesky(a, /*lower=*/true));
198 
199   Array3D<float> expected({
200       {
201           {2, 0, 0, 0},
202           {3, 6, 0, 0},
203           {4, 7, 9, 0},
204           {5, 8, 10, 11},
205       },
206       {{4, 0, 0, 0},    //
207        {6, 5, 0, 0},    //
208        {2, 14, 16, 0},  //
209        {3, 6, 1, 4}},
210   });
211 
212   ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
213                              ErrorSpec(1e-4, 1e-4));
214 }
215 
216 using CholeskyTestCase = std::tuple<int64, int64, bool>;
217 
218 class RandomCholeskyTest
219     : public ClientLibraryTestBase,
220       public ::testing::WithParamInterface<CholeskyTestCase> {};
221 
XLA_TEST_P(RandomCholeskyTest,Real)222 XLA_TEST_P(RandomCholeskyTest, Real) {
223   // Test fails with TensorFloat-32 enabled
224   tensorflow::enable_tensor_float_32_execution(false);
225   XlaBuilder builder(TestName());
226 
227   auto test_params = GetParam();
228   std::vector<int64> dimensions = {std::get<0>(test_params),
229                                    std::get<1>(test_params),
230                                    std::get<1>(test_params)};
231   bool lower = std::get<2>(test_params);
232   Shape shape = ShapeUtil::MakeShape(F32, dimensions);
233   TF_ASSERT_OK_AND_ASSIGN(
234       auto literal, LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
235 
236   auto input = Parameter(&builder, 0, shape, "input");
237   // Form a random positive definite matrix.
238   auto matrix =
239       BatchDot(input, TransposeInMinorDims(input), PrecisionConfig::HIGHEST);
240 
241   auto cholesky = Triangle(Cholesky(matrix, lower), lower);
242 
243   // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
244   XlaOp verification;
245   if (lower) {
246     verification = BatchDot(cholesky, TransposeInMinorDims(cholesky),
247                             PrecisionConfig::HIGHEST);
248   } else {
249     verification = BatchDot(TransposeInMinorDims(cholesky), cholesky,
250                             PrecisionConfig::HIGHEST);
251   }
252   auto delta = matrix - verification;
253   Reduce(delta * delta, ConstantR0<float>(&builder, 0.0),
254          CreateScalarAddComputation(F32, &builder), {0, 1, 2});
255 
256   TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
257   ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
258                              ErrorSpec(1e-4, 1e-4));
259 }
260 
XLA_TEST_P(RandomCholeskyTest,Complex)261 XLA_TEST_P(RandomCholeskyTest, Complex) {
262   // Test fails with TensorFloat-32 enabled
263   tensorflow::enable_tensor_float_32_execution(false);
264   XlaBuilder builder(TestName());
265 
266   auto test_params = GetParam();
267   std::vector<int64> dimensions = {std::get<0>(test_params),
268                                    std::get<1>(test_params),
269                                    std::get<1>(test_params)};
270   bool lower = std::get<2>(test_params);
271   Shape shape = ShapeUtil::MakeShape(F32, dimensions);
272   TF_ASSERT_OK_AND_ASSIGN(
273       auto literal_real,
274       LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
275   TF_ASSERT_OK_AND_ASSIGN(
276       auto literal_imag,
277       LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
278 
279   auto input_real = Parameter(&builder, 0, shape, "input_real");
280   auto input_imag = Parameter(&builder, 1, shape, "input_imag");
281   auto input = Complex(input_real, input_imag);
282   // Form a random positive definite matrix.
283   auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)),
284                          PrecisionConfig::HIGHEST);
285 
286   auto cholesky = Triangle(Cholesky(matrix, lower), lower);
287 
288   // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
289   XlaOp verification;
290   if (lower) {
291     verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)),
292                             PrecisionConfig::HIGHEST);
293   } else {
294     verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky,
295                             PrecisionConfig::HIGHEST);
296   }
297   auto delta = matrix - verification;
298   Reduce(Abs(delta * Conj(delta)), ConstantR0<float>(&builder, 0.0),
299          CreateScalarAddComputation(F32, &builder), {0, 1, 2});
300 
301   TF_ASSERT_OK_AND_ASSIGN(auto input_data_real,
302                           client_->TransferToServer(literal_real));
303   TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag,
304                           client_->TransferToServer(literal_imag));
305   ComputeAndCompareR0<float>(&builder, 0.0,
306                              {input_data_real.get(), input_data_imag.get()},
307                              ErrorSpec(1e-4, 1e-4));
308 }
309 
310 INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
311                          ::testing::Values(CholeskyTestCase{1, 1, true},
312                                            CholeskyTestCase{1, 2, true},
313                                            CholeskyTestCase{1, 50, true},
314                                            CholeskyTestCase{1, 50, false},
315                                            CholeskyTestCase{1, 255, false},
316                                            CholeskyTestCase{10, 5, true},
317                                            CholeskyTestCase{5, 10, false},
318                                            CholeskyTestCase{2, 20, true},
319                                            CholeskyTestCase{2, 129, true}));
320 
321 }  // namespace
322 }  // namespace xla
323