1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
18
19 #include <initializer_list>
20 #include <memory>
21 #include <random>
22 #include <string>
23
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/array2d.h"
27 #include "tensorflow/compiler/xla/array3d.h"
28 #include "tensorflow/compiler/xla/array4d.h"
29 #include "tensorflow/compiler/xla/error_spec.h"
30 #include "tensorflow/compiler/xla/literal.h"
31 #include "tensorflow/compiler/xla/literal_util.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace xla {
42
43 // Utility class for making expectations/assertions related to XLA literals.
44 class LiteralTestUtil {
45 public:
46 // Asserts that the given shapes have the same rank, dimension sizes, and
47 // primitive types.
48 static ::testing::AssertionResult EqualShapes(
49 const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
50
51 // Asserts that the provided shapes are equal as defined in AssertEqualShapes
52 // and that they have the same layout.
53 static ::testing::AssertionResult EqualShapesAndLayouts(
54 const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
55
56 static ::testing::AssertionResult Equal(const LiteralSlice& expected,
57 const LiteralSlice& actual)
58 TF_MUST_USE_RESULT;
59
60 // Asserts the given literal are (bitwise) equal to given expected values.
61 template <typename NativeT>
62 static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
63
64 template <typename NativeT>
65 static void ExpectR1Equal(absl::Span<const NativeT> expected,
66 const LiteralSlice& actual);
67 template <typename NativeT>
68 static void ExpectR2Equal(
69 std::initializer_list<std::initializer_list<NativeT>> expected,
70 const LiteralSlice& actual);
71
72 template <typename NativeT>
73 static void ExpectR3Equal(
74 std::initializer_list<
75 std::initializer_list<std::initializer_list<NativeT>>>
76 expected,
77 const LiteralSlice& actual);
78
79 // Asserts the given literal are (bitwise) equal to given array.
80 template <typename NativeT>
81 static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
82 const LiteralSlice& actual);
83 template <typename NativeT>
84 static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
85 const LiteralSlice& actual);
86 template <typename NativeT>
87 static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
88 const LiteralSlice& actual);
89
90 // Decorates literal_comparison::Near() with an AssertionResult return type.
91 //
92 // See comment on literal_comparison::Near().
93 static ::testing::AssertionResult Near(
94 const LiteralSlice& expected, const LiteralSlice& actual,
95 const ErrorSpec& error_spec,
96 absl::optional<bool> detailed_message = absl::nullopt) TF_MUST_USE_RESULT;
97
98 // Asserts the given literal are within the given error bound of the given
99 // expected values. Only supported for floating point values.
100 template <typename NativeT>
101 static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
102 const ErrorSpec& error);
103
104 template <typename NativeT>
105 static void ExpectR1Near(absl::Span<const NativeT> expected,
106 const LiteralSlice& actual, const ErrorSpec& error);
107
108 template <typename NativeT>
109 static void ExpectR2Near(
110 std::initializer_list<std::initializer_list<NativeT>> expected,
111 const LiteralSlice& actual, const ErrorSpec& error);
112
113 template <typename NativeT>
114 static void ExpectR3Near(
115 std::initializer_list<
116 std::initializer_list<std::initializer_list<NativeT>>>
117 expected,
118 const LiteralSlice& actual, const ErrorSpec& error);
119
120 template <typename NativeT>
121 static void ExpectR4Near(
122 std::initializer_list<std::initializer_list<
123 std::initializer_list<std::initializer_list<NativeT>>>>
124 expected,
125 const LiteralSlice& actual, const ErrorSpec& error);
126
127 // Asserts the given literal are within the given error bound to the given
128 // array. Only supported for floating point values.
129 template <typename NativeT>
130 static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
131 const LiteralSlice& actual,
132 const ErrorSpec& error);
133
134 template <typename NativeT>
135 static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
136 const LiteralSlice& actual,
137 const ErrorSpec& error);
138
139 template <typename NativeT>
140 static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
141 const LiteralSlice& actual,
142 const ErrorSpec& error);
143
144 // If the error spec is given, returns whether the expected and the actual are
145 // within the error bound; otherwise, returns whether they are equal. Tuples
146 // will be compared recursively.
147 static ::testing::AssertionResult NearOrEqual(
148 const LiteralSlice& expected, const LiteralSlice& actual,
149 const absl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
150
151 private:
152 TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
153 };
154
155 template <typename NativeT>
ExpectR0Equal(NativeT expected,const LiteralSlice & actual)156 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
157 const LiteralSlice& actual) {
158 EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
159 }
160
161 template <typename NativeT>
ExpectR1Equal(absl::Span<const NativeT> expected,const LiteralSlice & actual)162 /* static */ void LiteralTestUtil::ExpectR1Equal(
163 absl::Span<const NativeT> expected, const LiteralSlice& actual) {
164 EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
165 }
166
167 template <typename NativeT>
ExpectR2Equal(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual)168 /* static */ void LiteralTestUtil::ExpectR2Equal(
169 std::initializer_list<std::initializer_list<NativeT>> expected,
170 const LiteralSlice& actual) {
171 EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
172 }
173
174 template <typename NativeT>
ExpectR3Equal(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual)175 /* static */ void LiteralTestUtil::ExpectR3Equal(
176 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
177 expected,
178 const LiteralSlice& actual) {
179 EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
180 }
181
182 template <typename NativeT>
ExpectR2EqualArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual)183 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
184 const Array2D<NativeT>& expected, const LiteralSlice& actual) {
185 EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
186 }
187
188 template <typename NativeT>
ExpectR3EqualArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual)189 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
190 const Array3D<NativeT>& expected, const LiteralSlice& actual) {
191 EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
192 }
193
194 template <typename NativeT>
ExpectR4EqualArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual)195 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
196 const Array4D<NativeT>& expected, const LiteralSlice& actual) {
197 EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
198 }
199
200 template <typename NativeT>
ExpectR0Near(NativeT expected,const LiteralSlice & actual,const ErrorSpec & error)201 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
202 const LiteralSlice& actual,
203 const ErrorSpec& error) {
204 EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
205 }
206
207 template <typename NativeT>
ExpectR1Near(absl::Span<const NativeT> expected,const LiteralSlice & actual,const ErrorSpec & error)208 /* static */ void LiteralTestUtil::ExpectR1Near(
209 absl::Span<const NativeT> expected, const LiteralSlice& actual,
210 const ErrorSpec& error) {
211 EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
212 }
213
214 template <typename NativeT>
ExpectR2Near(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual,const ErrorSpec & error)215 /* static */ void LiteralTestUtil::ExpectR2Near(
216 std::initializer_list<std::initializer_list<NativeT>> expected,
217 const LiteralSlice& actual, const ErrorSpec& error) {
218 EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
219 }
220
221 template <typename NativeT>
ExpectR3Near(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual,const ErrorSpec & error)222 /* static */ void LiteralTestUtil::ExpectR3Near(
223 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
224 expected,
225 const LiteralSlice& actual, const ErrorSpec& error) {
226 EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
227 }
228
229 template <typename NativeT>
ExpectR4Near(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> expected,const LiteralSlice & actual,const ErrorSpec & error)230 /* static */ void LiteralTestUtil::ExpectR4Near(
231 std::initializer_list<std::initializer_list<
232 std::initializer_list<std::initializer_list<NativeT>>>>
233 expected,
234 const LiteralSlice& actual, const ErrorSpec& error) {
235 EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
236 }
237
238 template <typename NativeT>
ExpectR2NearArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)239 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
240 const Array2D<NativeT>& expected, const LiteralSlice& actual,
241 const ErrorSpec& error) {
242 EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
243 }
244
245 template <typename NativeT>
ExpectR3NearArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)246 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
247 const Array3D<NativeT>& expected, const LiteralSlice& actual,
248 const ErrorSpec& error) {
249 EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
250 }
251
252 template <typename NativeT>
ExpectR4NearArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)253 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
254 const Array4D<NativeT>& expected, const LiteralSlice& actual,
255 const ErrorSpec& error) {
256 EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
257 }
258
259 } // namespace xla
260
261 #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
262