1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests that our utility functions for dealing with literals are correctly
17 // implemented.
18
19 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
20
21 #include <vector>
22
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/path.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace xla {
32 namespace {
33
TEST(LiteralTestUtilTest,ComparesEqualTuplesEqual)34 TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
35 Literal literal = LiteralUtil::MakeTupleFromSlices({
36 LiteralUtil::CreateR0<int32>(42),
37 LiteralUtil::CreateR0<int32>(64),
38 });
39 EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
40 }
41
TEST(LiteralTestUtilTest,ComparesEqualComplex64TuplesEqual)42 TEST(LiteralTestUtilTest, ComparesEqualComplex64TuplesEqual) {
43 Literal literal = LiteralUtil::MakeTupleFromSlices({
44 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
45 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
46 });
47 EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
48 }
49
TEST(LiteralTestUtilTest,ComparesEqualComplex128TuplesEqual)50 TEST(LiteralTestUtilTest, ComparesEqualComplex128TuplesEqual) {
51 Literal literal = LiteralUtil::MakeTupleFromSlices({
52 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
53 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
54 });
55 EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
56 }
57
TEST(LiteralTestUtilTest,ComparesUnequalComplex64TuplesUnequal)58 TEST(LiteralTestUtilTest, ComparesUnequalComplex64TuplesUnequal) {
59 Literal literal0 = LiteralUtil::MakeTupleFromSlices({
60 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
61 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
62 });
63 Literal literal1 = LiteralUtil::MakeTupleFromSlices({
64 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
65 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
66 });
67 Literal literal2 = LiteralUtil::MakeTupleFromSlices({
68 LiteralUtil::CreateR0<complex64>({42.42, 64.0}),
69 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
70 });
71 Literal literal3 = LiteralUtil::MakeTupleFromSlices({
72 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
73 LiteralUtil::CreateR0<complex64>({64.0, 42.42}),
74 });
75 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
76 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
77 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
78 EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
79 }
80
TEST(LiteralTestUtilTest,ComparesUnequalComplex128TuplesUnequal)81 TEST(LiteralTestUtilTest, ComparesUnequalComplex128TuplesUnequal) {
82 Literal literal0 = LiteralUtil::MakeTupleFromSlices({
83 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
84 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
85 });
86 Literal literal1 = LiteralUtil::MakeTupleFromSlices({
87 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
88 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
89 });
90 Literal literal2 = LiteralUtil::MakeTupleFromSlices({
91 LiteralUtil::CreateR0<complex128>({42.42, 64.0}),
92 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
93 });
94 Literal literal3 = LiteralUtil::MakeTupleFromSlices({
95 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
96 LiteralUtil::CreateR0<complex128>({64.0, 42.42}),
97 });
98 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
99 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
100 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
101 EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
102 }
103
TEST(LiteralTestUtilTest,ComparesUnequalTuplesUnequal)104 TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
105 // Implementation note: we have to use a death test here, because you can't
106 // un-fail an assertion failure. The CHECK-failure is death, so we can make a
107 // death assertion.
108 auto unequal_things_are_equal = [] {
109 Literal lhs = LiteralUtil::MakeTupleFromSlices({
110 LiteralUtil::CreateR0<int32>(42),
111 LiteralUtil::CreateR0<int32>(64),
112 });
113 Literal rhs = LiteralUtil::MakeTupleFromSlices({
114 LiteralUtil::CreateR0<int32>(64),
115 LiteralUtil::CreateR0<int32>(42),
116 });
117 CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
118 };
119 ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
120 }
121
TEST(LiteralTestUtilTest,ExpectNearFailurePlacesResultsInTemporaryDirectory)122 TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
123 auto dummy_lambda = [] {
124 auto two = LiteralUtil::CreateR0<float>(2);
125 auto four = LiteralUtil::CreateR0<float>(4);
126 ErrorSpec error(0.001);
127 CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
128 };
129
130 tensorflow::Env* env = tensorflow::Env::Default();
131
132 string outdir;
133 if (!tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) {
134 outdir = tensorflow::testing::TmpDir();
135 }
136 string pattern = tensorflow::io::JoinPath(outdir, "tempfile-*.pb");
137 std::vector<string> files;
138 TF_CHECK_OK(env->GetMatchingPaths(pattern, &files));
139 for (const auto& f : files) {
140 TF_CHECK_OK(env->DeleteFile(f)) << f;
141 }
142
143 ASSERT_DEATH(dummy_lambda(), "two is not near four");
144
145 // Now check we wrote temporary files to the temporary directory that we can
146 // read.
147 std::vector<string> results;
148 TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
149
150 LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]";
151 EXPECT_EQ(3, results.size());
152 for (const string& result : results) {
153 LiteralProto literal_proto;
154 TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
155 &literal_proto));
156 Literal literal =
157 Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
158 if (result.find("expected") != string::npos) {
159 EXPECT_EQ("f32[] 2", literal.ToString());
160 } else if (result.find("actual") != string::npos) {
161 EXPECT_EQ("f32[] 4", literal.ToString());
162 } else if (result.find("mismatches") != string::npos) {
163 EXPECT_EQ("pred[] true", literal.ToString());
164 } else {
165 FAIL() << "unknown file in temporary directory: " << result;
166 }
167 }
168 }
169
TEST(LiteralTestUtilTest,NotEqualHasValuesInMessage)170 TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
171 auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
172 auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
173 ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
174 EXPECT_THAT(result.message(),
175 ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}"));
176 EXPECT_THAT(result.message(),
177 ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}"));
178 }
179
TEST(LiteralTestUtilTest,NearComparatorR1)180 TEST(LiteralTestUtilTest, NearComparatorR1) {
181 auto a = LiteralUtil::CreateR1<float>(
182 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
183 auto b = LiteralUtil::CreateR1<float>(
184 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
185 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
186 }
187
TEST(LiteralTestUtilTest,NearComparatorR1Complex64)188 TEST(LiteralTestUtilTest, NearComparatorR1Complex64) {
189 auto a = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
190 {0.1, 1.1},
191 {0.2, 1.2},
192 {0.3, 1.3},
193 {0.4, 1.4},
194 {0.5, 1.5},
195 {0.6, 1.6},
196 {0.7, 1.7},
197 {0.8, 1.8}});
198 auto b = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
199 {0.1, 1.1},
200 {0.2, 1.2},
201 {0.3, 1.3},
202 {0.4, 1.4},
203 {0.5, 1.5},
204 {0.6, 1.6},
205 {0.7, 1.7},
206 {0.8, 1.8}});
207 auto c = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
208 {0.1, 1.1},
209 {0.2, 1.2},
210 {0.3, 1.3},
211 {0.4, 1.4},
212 {0.5, 1.5},
213 {0.6, 1.6},
214 {0.7, 1.7},
215 {0.9, 1.8}});
216 auto d = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
217 {0.1, 1.1},
218 {0.2, 1.2},
219 {0.3, 1.3},
220 {0.4, 1.4},
221 {0.5, 1.5},
222 {0.6, 1.6},
223 {0.7, 1.7},
224 {0.8, 1.9}});
225 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
226 EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
227 EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
228 EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
229 }
230
TEST(LiteralTestUtilTest,NearComparatorR1Complex128)231 TEST(LiteralTestUtilTest, NearComparatorR1Complex128) {
232 auto a = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
233 {0.1, 1.1},
234 {0.2, 1.2},
235 {0.3, 1.3},
236 {0.4, 1.4},
237 {0.5, 1.5},
238 {0.6, 1.6},
239 {0.7, 1.7},
240 {0.8, 1.8}});
241 auto b = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
242 {0.1, 1.1},
243 {0.2, 1.2},
244 {0.3, 1.3},
245 {0.4, 1.4},
246 {0.5, 1.5},
247 {0.6, 1.6},
248 {0.7, 1.7},
249 {0.8, 1.8}});
250 auto c = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
251 {0.1, 1.1},
252 {0.2, 1.2},
253 {0.3, 1.3},
254 {0.4, 1.4},
255 {0.5, 1.5},
256 {0.6, 1.6},
257 {0.7, 1.7},
258 {0.9, 1.8}});
259 auto d = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
260 {0.1, 1.1},
261 {0.2, 1.2},
262 {0.3, 1.3},
263 {0.4, 1.4},
264 {0.5, 1.5},
265 {0.6, 1.6},
266 {0.7, 1.7},
267 {0.8, 1.9}});
268 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
269 EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
270 EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
271 EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
272 }
273
TEST(LiteralTestUtilTest,NearComparatorR1Nan)274 TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
275 auto a = LiteralUtil::CreateR1<float>(
276 {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
277 auto b = LiteralUtil::CreateR1<float>(
278 {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
279 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
280 }
281
TEST(LiteralTestUtil,NearComparatorDifferentLengths)282 TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
283 auto a = LiteralUtil::CreateR1<float>(
284 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
285 auto b =
286 LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
287 EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
288 EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
289 }
290
TEST(LiteralTestUtilTest,ExpectNearDoubleOutsideFloatValueRange)291 TEST(LiteralTestUtilTest, ExpectNearDoubleOutsideFloatValueRange) {
292 auto two_times_float_max =
293 LiteralUtil::CreateR0<double>(2.0 * std::numeric_limits<float>::max());
294 ErrorSpec error(0.001);
295 EXPECT_TRUE(
296 LiteralTestUtil::Near(two_times_float_max, two_times_float_max, error));
297 }
298
299 } // namespace
300 } // namespace xla
301