1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests that our utility functions for dealing with literals are correctly
17 // implemented.
18
19 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
20
21 #include <vector>
22
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace xla {
31 namespace {
32
TEST(LiteralTestUtilTest,ComparesEqualTuplesEqual)33 TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
34 Literal literal = LiteralUtil::MakeTupleFromSlices({
35 LiteralUtil::CreateR0<int32>(42),
36 LiteralUtil::CreateR0<int32>(64),
37 });
38 EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
39 }
40
TEST(LiteralTestUtilTest,ComparesUnequalTuplesUnequal)41 TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
42 // Implementation note: we have to use a death test here, because you can't
43 // un-fail an assertion failure. The CHECK-failure is death, so we can make a
44 // death assertion.
45 auto unequal_things_are_equal = [] {
46 Literal lhs = LiteralUtil::MakeTupleFromSlices({
47 LiteralUtil::CreateR0<int32>(42),
48 LiteralUtil::CreateR0<int32>(64),
49 });
50 Literal rhs = LiteralUtil::MakeTupleFromSlices({
51 LiteralUtil::CreateR0<int32>(64),
52 LiteralUtil::CreateR0<int32>(42),
53 });
54 CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
55 };
56 ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
57 }
58
TEST(LiteralTestUtilTest,ExpectNearFailurePlacesResultsInTemporaryDirectory)59 TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
60 auto dummy_lambda = [] {
61 auto two = LiteralUtil::CreateR0<float>(2);
62 auto four = LiteralUtil::CreateR0<float>(4);
63 ErrorSpec error(0.001);
64 CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
65 };
66
67 tensorflow::Env* env = tensorflow::Env::Default();
68 string pattern =
69 tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/tempfile-*");
70 std::vector<string> files;
71 TF_CHECK_OK(env->GetMatchingPaths(pattern, &files));
72 for (const auto& f : files) {
73 TF_CHECK_OK(env->DeleteFile(f)) << f;
74 }
75
76 ASSERT_DEATH(dummy_lambda(), "two is not near four");
77
78 // Now check we wrote temporary files to the temporary directory that we can
79 // read.
80 std::vector<string> results;
81 TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
82
83 LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]";
84 EXPECT_EQ(3, results.size());
85 for (const string& result : results) {
86 LiteralProto literal_proto;
87 TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
88 &literal_proto));
89 Literal literal =
90 Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
91 if (result.find("expected") != string::npos) {
92 EXPECT_EQ("f32[] 2", literal.ToString());
93 } else if (result.find("actual") != string::npos) {
94 EXPECT_EQ("f32[] 4", literal.ToString());
95 } else if (result.find("mismatches") != string::npos) {
96 EXPECT_EQ("pred[] true", literal.ToString());
97 } else {
98 FAIL() << "unknown file in temporary directory: " << result;
99 }
100 }
101 }
102
TEST(LiteralTestUtilTest,NotEqualHasValuesInMessage)103 TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
104 auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
105 auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
106 ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
107 EXPECT_THAT(result.message(),
108 ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}"));
109 EXPECT_THAT(result.message(),
110 ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}"));
111 }
112
TEST(LiteralTestUtilTest,NearComparatorR1)113 TEST(LiteralTestUtilTest, NearComparatorR1) {
114 auto a = LiteralUtil::CreateR1<float>(
115 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
116 auto b = LiteralUtil::CreateR1<float>(
117 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
118 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
119 }
120
TEST(LiteralTestUtilTest,NearComparatorR1Nan)121 TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
122 auto a = LiteralUtil::CreateR1<float>(
123 {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
124 auto b = LiteralUtil::CreateR1<float>(
125 {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
126 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
127 }
128
TEST(LiteralTestUtil,NearComparatorDifferentLengths)129 TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
130 auto a = LiteralUtil::CreateR1<float>(
131 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
132 auto b =
133 LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
134 EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
135 EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
136 }
137
138 } // namespace
139 } // namespace xla
140