1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <limits>
17 #include <memory>
18 #include <numeric>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/matrix.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/tensor_float_32_utils.h"
34
35 namespace xla {
36 namespace {
37
38 using CholeskyTest = ClientLibraryTestBase;
39
XLA_TEST_F(CholeskyTest,NonPSDInput)40 XLA_TEST_F(CholeskyTest, NonPSDInput) {
41 XlaBuilder builder(TestName());
42
43 Array2D<float> a_vals({
44 {1, 1, 1},
45 {1, 1, 1},
46 {1, 1, 1},
47 });
48
49 XlaOp a;
50 auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
51 Cholesky(a, /*lower=*/true);
52
53 float nan = std::numeric_limits<float>::quiet_NaN();
54 Array2D<float> expected({
55 {nan, nan, nan},
56 {nan, nan, nan},
57 {nan, nan, nan},
58 });
59
60 ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
61 ErrorSpec(1e-4, 1e-4));
62 }
63
XLA_TEST_F(CholeskyTest,NonPSDBatched)64 XLA_TEST_F(CholeskyTest, NonPSDBatched) {
65 XlaBuilder builder(TestName());
66
67 Array3D<float> a_vals({
68 {
69 {10, 0, 0},
70 {1, 20, 0},
71 {1, 1, 30},
72 },
73 {
74 {1, 1, 1},
75 {1, 1, 1},
76 {1, 1, 1},
77 },
78 });
79
80 XlaOp a;
81 auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
82 Cholesky(a, /*lower=*/true);
83
84 float nan = std::numeric_limits<float>::quiet_NaN();
85 Array3D<float> expected({
86 {
87 {3.16227766, 0., 0.},
88 {0.31622777, 4.4609416, 0.},
89 {0.31622777, 0.20175113, 5.46436606},
90 },
91 {
92 {nan, nan, nan},
93 {nan, nan, nan},
94 {nan, nan, nan},
95 },
96 });
97
98 ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
99 ErrorSpec(1e-4, 1e-4));
100 }
101
XLA_TEST_F(CholeskyTest,Lower)102 XLA_TEST_F(CholeskyTest, Lower) {
103 XlaBuilder builder(TestName());
104
105 float nan = std::numeric_limits<float>::quiet_NaN();
106 Array2D<float> a_vals({
107 {4, nan, nan, nan},
108 {6, 45, nan, nan},
109 {8, 54, 146, nan},
110 {10, 63, 166, 310},
111 });
112
113 XlaOp a;
114 auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
115 LowerTriangle(Cholesky(a, /*lower=*/true));
116
117 Array2D<float> expected({
118 {2, 0, 0, 0},
119 {3, 6, 0, 0},
120 {4, 7, 9, 0},
121 {5, 8, 10, 11},
122 });
123
124 ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
125 ErrorSpec(1e-4, 1e-4));
126 }
127
XLA_TEST_F(CholeskyTest,Upper)128 XLA_TEST_F(CholeskyTest, Upper) {
129 XlaBuilder builder(TestName());
130
131 float nan = std::numeric_limits<float>::quiet_NaN();
132 Array2D<float> a_vals({
133 {4, 6, 8, 10},
134 {nan, 45, 54, 63},
135 {nan, nan, 146, 166},
136 {nan, nan, nan, 310},
137 });
138
139 XlaOp a;
140 auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
141 UpperTriangle(Cholesky(a, /*lower=*/false));
142
143 Array2D<float> expected({
144 {2, 3, 4, 5},
145 {0, 6, 7, 8},
146 {0, 0, 9, 10},
147 {0, 0, 0, 11},
148 });
149
150 ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
151 ErrorSpec(1e-4, 1e-4));
152 }
153
XLA_TEST_F(CholeskyTest,Simple2)154 XLA_TEST_F(CholeskyTest, Simple2) {
155 XlaBuilder builder(TestName());
156
157 Array2D<float> a_vals({
158 {16, 24, 8, 12},
159 {24, 61, 82, 48},
160 {8, 82, 456, 106},
161 {12, 48, 106, 62},
162 });
163
164 XlaOp a;
165 auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
166 LowerTriangle(Cholesky(a, /*lower=*/true));
167
168 Array2D<float> expected({{4, 0, 0, 0}, //
169 {6, 5, 0, 0}, //
170 {2, 14, 16, 0}, //
171 {3, 6, 1, 4}});
172
173 ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
174 ErrorSpec(1e-4, 1e-4));
175 }
176
XLA_TEST_F(CholeskyTest,SimpleBatched)177 XLA_TEST_F(CholeskyTest, SimpleBatched) {
178 XlaBuilder builder(TestName());
179
180 Array3D<float> a_vals({
181 {
182 {4, 6, 8, 10},
183 {6, 45, 54, 63},
184 {8, 54, 146, 166},
185 {10, 63, 166, 310},
186 },
187 {
188 {16, 24, 8, 12},
189 {24, 61, 82, 48},
190 {8, 82, 456, 106},
191 {12, 48, 106, 62},
192 },
193 });
194
195 XlaOp a;
196 auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
197 LowerTriangle(Cholesky(a, /*lower=*/true));
198
199 Array3D<float> expected({
200 {
201 {2, 0, 0, 0},
202 {3, 6, 0, 0},
203 {4, 7, 9, 0},
204 {5, 8, 10, 11},
205 },
206 {{4, 0, 0, 0}, //
207 {6, 5, 0, 0}, //
208 {2, 14, 16, 0}, //
209 {3, 6, 1, 4}},
210 });
211
212 ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
213 ErrorSpec(1e-4, 1e-4));
214 }
215
216 using CholeskyTestCase = std::tuple<int64, int64, bool>;
217
218 class RandomCholeskyTest
219 : public ClientLibraryTestBase,
220 public ::testing::WithParamInterface<CholeskyTestCase> {};
221
XLA_TEST_P(RandomCholeskyTest,Real)222 XLA_TEST_P(RandomCholeskyTest, Real) {
223 // Test fails with TensorFloat-32 enabled
224 tensorflow::enable_tensor_float_32_execution(false);
225 XlaBuilder builder(TestName());
226
227 auto test_params = GetParam();
228 std::vector<int64> dimensions = {std::get<0>(test_params),
229 std::get<1>(test_params),
230 std::get<1>(test_params)};
231 bool lower = std::get<2>(test_params);
232 Shape shape = ShapeUtil::MakeShape(F32, dimensions);
233 TF_ASSERT_OK_AND_ASSIGN(
234 auto literal, LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
235
236 auto input = Parameter(&builder, 0, shape, "input");
237 // Form a random positive definite matrix.
238 auto matrix =
239 BatchDot(input, TransposeInMinorDims(input), PrecisionConfig::HIGHEST);
240
241 auto cholesky = Triangle(Cholesky(matrix, lower), lower);
242
243 // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
244 XlaOp verification;
245 if (lower) {
246 verification = BatchDot(cholesky, TransposeInMinorDims(cholesky),
247 PrecisionConfig::HIGHEST);
248 } else {
249 verification = BatchDot(TransposeInMinorDims(cholesky), cholesky,
250 PrecisionConfig::HIGHEST);
251 }
252 auto delta = matrix - verification;
253 Reduce(delta * delta, ConstantR0<float>(&builder, 0.0),
254 CreateScalarAddComputation(F32, &builder), {0, 1, 2});
255
256 TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
257 ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
258 ErrorSpec(1e-4, 1e-4));
259 }
260
XLA_TEST_P(RandomCholeskyTest,Complex)261 XLA_TEST_P(RandomCholeskyTest, Complex) {
262 // Test fails with TensorFloat-32 enabled
263 tensorflow::enable_tensor_float_32_execution(false);
264 XlaBuilder builder(TestName());
265
266 auto test_params = GetParam();
267 std::vector<int64> dimensions = {std::get<0>(test_params),
268 std::get<1>(test_params),
269 std::get<1>(test_params)};
270 bool lower = std::get<2>(test_params);
271 Shape shape = ShapeUtil::MakeShape(F32, dimensions);
272 TF_ASSERT_OK_AND_ASSIGN(
273 auto literal_real,
274 LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
275 TF_ASSERT_OK_AND_ASSIGN(
276 auto literal_imag,
277 LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
278
279 auto input_real = Parameter(&builder, 0, shape, "input_real");
280 auto input_imag = Parameter(&builder, 1, shape, "input_imag");
281 auto input = Complex(input_real, input_imag);
282 // Form a random positive definite matrix.
283 auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)),
284 PrecisionConfig::HIGHEST);
285
286 auto cholesky = Triangle(Cholesky(matrix, lower), lower);
287
288 // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
289 XlaOp verification;
290 if (lower) {
291 verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)),
292 PrecisionConfig::HIGHEST);
293 } else {
294 verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky,
295 PrecisionConfig::HIGHEST);
296 }
297 auto delta = matrix - verification;
298 Reduce(Abs(delta * Conj(delta)), ConstantR0<float>(&builder, 0.0),
299 CreateScalarAddComputation(F32, &builder), {0, 1, 2});
300
301 TF_ASSERT_OK_AND_ASSIGN(auto input_data_real,
302 client_->TransferToServer(literal_real));
303 TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag,
304 client_->TransferToServer(literal_imag));
305 ComputeAndCompareR0<float>(&builder, 0.0,
306 {input_data_real.get(), input_data_imag.get()},
307 ErrorSpec(1e-4, 1e-4));
308 }
309
310 INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
311 ::testing::Values(CholeskyTestCase{1, 1, true},
312 CholeskyTestCase{1, 2, true},
313 CholeskyTestCase{1, 50, true},
314 CholeskyTestCase{1, 50, false},
315 CholeskyTestCase{1, 255, false},
316 CholeskyTestCase{10, 5, true},
317 CholeskyTestCase{5, 10, false},
318 CholeskyTestCase{2, 20, true},
319 CholeskyTestCase{2, 129, true}));
320
321 } // namespace
322 } // namespace xla
323