1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/logdet.h"
17
18 #include "tensorflow/compiler/xla/array2d.h"
19 #include "tensorflow/compiler/xla/array3d.h"
20 #include "tensorflow/compiler/xla/client/lib/matrix.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29
30 namespace {
31
32 using LogDetTest = xla::ClientLibraryTestBase;
33
XLA_TEST_F(LogDetTest,Simple)34 XLA_TEST_F(LogDetTest, Simple) {
35 xla::XlaBuilder builder(TestName());
36
37 xla::Array2D<float> a_vals({
38 {4, 6, 8, 10},
39 {6, 45, 54, 63},
40 {8, 54, 146, 166},
41 {10, 63, 166, 310},
42 });
43
44 xla::XlaOp a;
45 auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
46 xla::SignAndLogDet slogdet = xla::SLogDet(a);
47 xla::XlaOp logdet = xla::LogDet(a);
48 xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
49 xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
50 xla::LiteralUtil::CreateR0<float>(1.f),
51 xla::LiteralUtil::CreateR0<float>(14.1601f),
52 xla::LiteralUtil::CreateR0<float>(14.1601f));
53 ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
54 xla::ErrorSpec(1e-4));
55 }
56
XLA_TEST_F(LogDetTest,SimpleTriangle)57 XLA_TEST_F(LogDetTest, SimpleTriangle) {
58 xla::XlaBuilder builder(TestName());
59
60 xla::Array2D<float> a_vals({
61 {4, 6, 8, 10},
62 {4, -39, 62, 73},
63 {0, 0, -146, 166},
64 {4, 6, 8, 320},
65 });
66
67 xla::XlaOp a;
68 auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
69 xla::SignAndLogDet slogdet = xla::SLogDet(a);
70 xla::XlaOp logdet = xla::LogDet(a);
71 xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
72 xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
73 xla::LiteralUtil::CreateR0<float>(1.f),
74 xla::LiteralUtil::CreateR0<float>(15.9131355f),
75 xla::LiteralUtil::CreateR0<float>(15.9131355f));
76
77 ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
78 xla::ErrorSpec(1e-4));
79 }
80
XLA_TEST_F(LogDetTest,SimpleBatched)81 XLA_TEST_F(LogDetTest, SimpleBatched) {
82 xla::XlaBuilder builder(TestName());
83
84 xla::Array3D<float> a_vals({
85 {
86 {4, 6, 8, 10},
87 {6, 45, 54, 63},
88 {8, 54, 146, 166},
89 {10, 63, 166, 310},
90 },
91 {
92 {16, 24, 8, 12},
93 {24, 61, 82, 48},
94 {8, 82, 456, 106},
95 {12, 48, 106, 62},
96 },
97 {{2, 2, 3, 4}, {4, 5, 6, 7}, {7, 8, 9, 8}, {10, 11, 12, 13}},
98 {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
99 });
100
101 xla::XlaOp a;
102 auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
103 xla::SignAndLogDet slogdet = xla::SLogDet(a);
104 xla::XlaOp logdet = xla::LogDet(a);
105 xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
106 xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
107 xla::LiteralUtil::CreateR1<float>({1.f, 1.f, -1.f, 0.f}),
108 xla::LiteralUtil::CreateR1<float>(
109 {14.1601f, 14.3092f, 2.4849f,
110 -std::numeric_limits<float>::infinity()}),
111 xla::LiteralUtil::CreateR1<float>(
112 {14.1601f, 14.3092f, std::numeric_limits<float>::quiet_NaN(),
113 -std::numeric_limits<float>::infinity()}));
114
115 ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
116 xla::ErrorSpec(1e-4));
117 }
118
XLA_TEST_F(LogDetTest,LogdetOfLargerMatricesBatched)119 XLA_TEST_F(LogDetTest, LogdetOfLargerMatricesBatched) {
120 xla::XlaBuilder builder(TestName());
121
122 xla::Array<float> a_vals = {
123 {{7.2393, 1.1413, 4.1883, -4.8272, 3.2831, -0.0568, -2.4776},
124 {0.4347, 3.4095, 1.6259, -4.7100, 1.5942, 1.4217, -2.8009},
125 {3.6964, 0.4882, 6.5276, -1.2128, 1.3851, 0.7417, -3.8515},
126 {-3.7986, -5.1188, -1.9410, 14.0205, -5.4515, 3.1831, 5.1488},
127 {1.5621, 3.0426, 1.4819, -4.5938, 10.1397, 4.9312, -2.8351},
128 {-1.5436, -0.0287, -0.1139, 4.4499, 2.5894, 6.1216, 2.7201},
129 {-3.7241, -2.7670, -3.8162, 4.5961, -1.7251, -0.4190, 8.6562}},
130
131 {{3.3789, -2.3607, -1.2471, 2.1503, 0.6062, -0.6057, 1.7748},
132 {-1.8670, 11.0947, 0.1229, 0.0599, 3.1714, -4.7941, -4.5442},
133 {-0.6905, -0.0829, 5.2156, 2.9528, 2.6200, 6.1638, 1.8652},
134 {3.0521, 2.2174, 0.7444, 10.7268, 0.6443, -2.7732, 1.6840},
135 {1.8479, 3.0821, 4.5671, 2.9254, 6.1338, 5.2066, 2.3662},
136 {-0.0360, -5.5341, 5.9687, -0.3297, 2.1174, 13.0016, 4.0118},
137 {0.4380, -4.6683, 3.1548, 0.0924, 0.7176, 6.4679, 6.1819}},
138
139 {{10.0487, 4.0350, -0.8471, -1.2887, -0.8172, -3.3698, 1.3191},
140 {4.8678, 4.6081, 0.8419, -0.2454, -3.2599, -1.2386, 2.4070},
141 {1.4877, 0.8362, 2.6077, 1.1782, -0.1116, 1.7130, -1.1883},
142 {-0.9245, -0.7435, -0.9456, 2.5936, 1.9887, -0.1324, -0.1453},
143 {0.2918, -0.5301, -0.8775, 1.0478, 8.9262, 2.4731, -0.4393},
144 {-3.5759, -1.5619, 2.4410, 1.3046, 4.2678, 7.3587, -4.0935},
145 {-1.1187, 0.9150, -1.8253, 0.0390, -2.5684, -4.0778, 4.1447}}};
146
147 xla::XlaOp a;
148 auto a_data = CreateParameter<float>(a_vals, 0, "a", &builder, &a);
149 xla::SignAndLogDet slogdet = xla::SLogDet(a);
150 xla::XlaOp logdet = xla::LogDet(a);
151 xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
152 xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
153 xla::LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f}),
154 xla::LiteralUtil::CreateR1<float>({8.93788053, 6.77846303, 7.4852403}),
155 xla::LiteralUtil::CreateR1<float>({8.93788053, 6.77846303, 7.4852403}));
156
157 ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
158 xla::ErrorSpec(1e-4));
159 }
160
161 } // namespace
162