1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
18 
19 #include <initializer_list>
20 #include <memory>
21 #include <optional>
22 #include <random>
23 #include <string>
24 
25 #include "absl/base/attributes.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/error_spec.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/test.h"
39 
40 namespace xla {
41 
42 // Utility class for making expectations/assertions related to XLA literals.
43 class LiteralTestUtil {
44  public:
45   // Asserts that the given shapes have the same rank, dimension sizes, and
46   // primitive types.
47   [[nodiscard]] static ::testing::AssertionResult EqualShapes(
48       const Shape& expected, const Shape& actual);
49 
50   // Asserts that the provided shapes are equal as defined in AssertEqualShapes
51   // and that they have the same layout.
52   [[nodiscard]] static ::testing::AssertionResult EqualShapesAndLayouts(
53       const Shape& expected, const Shape& actual);
54 
55   [[nodiscard]] static ::testing::AssertionResult Equal(
56       const LiteralSlice& expected, const LiteralSlice& actual);
57 
58   // Asserts the given literal are (bitwise) equal to given expected values.
59   template <typename NativeT>
60   static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
61 
62   template <typename NativeT>
63   static void ExpectR1Equal(absl::Span<const NativeT> expected,
64                             const LiteralSlice& actual);
65   template <typename NativeT>
66   static void ExpectR2Equal(
67       std::initializer_list<std::initializer_list<NativeT>> expected,
68       const LiteralSlice& actual);
69 
70   template <typename NativeT>
71   static void ExpectR3Equal(
72       std::initializer_list<
73           std::initializer_list<std::initializer_list<NativeT>>>
74           expected,
75       const LiteralSlice& actual);
76 
77   // Asserts the given literal are (bitwise) equal to given array.
78   template <typename NativeT>
79   static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
80                                    const LiteralSlice& actual);
81   template <typename NativeT>
82   static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
83                                    const LiteralSlice& actual);
84   template <typename NativeT>
85   static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
86                                    const LiteralSlice& actual);
87 
88   // Decorates literal_comparison::Near() with an AssertionResult return type.
89   //
90   // See comment on literal_comparison::Near().
91   [[nodiscard]] static ::testing::AssertionResult Near(
92       const LiteralSlice& expected, const LiteralSlice& actual,
93       const ErrorSpec& error_spec,
94       std::optional<bool> detailed_message = std::nullopt);
95 
96   // Asserts the given literal are within the given error bound of the given
97   // expected values. Only supported for floating point values.
98   template <typename NativeT>
99   static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
100                            const ErrorSpec& error);
101 
102   template <typename NativeT>
103   static void ExpectR1Near(absl::Span<const NativeT> expected,
104                            const LiteralSlice& actual, const ErrorSpec& error);
105 
106   template <typename NativeT>
107   static void ExpectR2Near(
108       std::initializer_list<std::initializer_list<NativeT>> expected,
109       const LiteralSlice& actual, const ErrorSpec& error);
110 
111   template <typename NativeT>
112   static void ExpectR3Near(
113       std::initializer_list<
114           std::initializer_list<std::initializer_list<NativeT>>>
115           expected,
116       const LiteralSlice& actual, const ErrorSpec& error);
117 
118   template <typename NativeT>
119   static void ExpectR4Near(
120       std::initializer_list<std::initializer_list<
121           std::initializer_list<std::initializer_list<NativeT>>>>
122           expected,
123       const LiteralSlice& actual, const ErrorSpec& error);
124 
125   // Asserts the given literal are within the given error bound to the given
126   // array. Only supported for floating point values.
127   template <typename NativeT>
128   static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
129                                   const LiteralSlice& actual,
130                                   const ErrorSpec& error);
131 
132   template <typename NativeT>
133   static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
134                                   const LiteralSlice& actual,
135                                   const ErrorSpec& error);
136 
137   template <typename NativeT>
138   static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
139                                   const LiteralSlice& actual,
140                                   const ErrorSpec& error);
141 
142   // If the error spec is given, returns whether the expected and the actual are
143   // within the error bound; otherwise, returns whether they are equal. Tuples
144   // will be compared recursively.
145   [[nodiscard]] static ::testing::AssertionResult NearOrEqual(
146       const LiteralSlice& expected, const LiteralSlice& actual,
147       const std::optional<ErrorSpec>& error);
148 
149  private:
150   LiteralTestUtil(const LiteralTestUtil&) = delete;
151   LiteralTestUtil& operator=(const LiteralTestUtil&) = delete;
152 };
153 
154 template <typename NativeT>
ExpectR0Equal(NativeT expected,const LiteralSlice & actual)155 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
156                                                  const LiteralSlice& actual) {
157   EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
158 }
159 
160 template <typename NativeT>
ExpectR1Equal(absl::Span<const NativeT> expected,const LiteralSlice & actual)161 /* static */ void LiteralTestUtil::ExpectR1Equal(
162     absl::Span<const NativeT> expected, const LiteralSlice& actual) {
163   EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
164 }
165 
166 template <typename NativeT>
ExpectR2Equal(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual)167 /* static */ void LiteralTestUtil::ExpectR2Equal(
168     std::initializer_list<std::initializer_list<NativeT>> expected,
169     const LiteralSlice& actual) {
170   EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
171 }
172 
173 template <typename NativeT>
ExpectR3Equal(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual)174 /* static */ void LiteralTestUtil::ExpectR3Equal(
175     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
176         expected,
177     const LiteralSlice& actual) {
178   EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
179 }
180 
181 template <typename NativeT>
ExpectR2EqualArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual)182 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
183     const Array2D<NativeT>& expected, const LiteralSlice& actual) {
184   EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
185 }
186 
187 template <typename NativeT>
ExpectR3EqualArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual)188 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
189     const Array3D<NativeT>& expected, const LiteralSlice& actual) {
190   EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
191 }
192 
193 template <typename NativeT>
ExpectR4EqualArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual)194 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
195     const Array4D<NativeT>& expected, const LiteralSlice& actual) {
196   EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
197 }
198 
199 template <typename NativeT>
ExpectR0Near(NativeT expected,const LiteralSlice & actual,const ErrorSpec & error)200 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
201                                                 const LiteralSlice& actual,
202                                                 const ErrorSpec& error) {
203   EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
204 }
205 
206 template <typename NativeT>
ExpectR1Near(absl::Span<const NativeT> expected,const LiteralSlice & actual,const ErrorSpec & error)207 /* static */ void LiteralTestUtil::ExpectR1Near(
208     absl::Span<const NativeT> expected, const LiteralSlice& actual,
209     const ErrorSpec& error) {
210   EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
211 }
212 
213 template <typename NativeT>
ExpectR2Near(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual,const ErrorSpec & error)214 /* static */ void LiteralTestUtil::ExpectR2Near(
215     std::initializer_list<std::initializer_list<NativeT>> expected,
216     const LiteralSlice& actual, const ErrorSpec& error) {
217   EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
218 }
219 
220 template <typename NativeT>
ExpectR3Near(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual,const ErrorSpec & error)221 /* static */ void LiteralTestUtil::ExpectR3Near(
222     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
223         expected,
224     const LiteralSlice& actual, const ErrorSpec& error) {
225   EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
226 }
227 
228 template <typename NativeT>
ExpectR4Near(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> expected,const LiteralSlice & actual,const ErrorSpec & error)229 /* static */ void LiteralTestUtil::ExpectR4Near(
230     std::initializer_list<std::initializer_list<
231         std::initializer_list<std::initializer_list<NativeT>>>>
232         expected,
233     const LiteralSlice& actual, const ErrorSpec& error) {
234   EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
235 }
236 
237 template <typename NativeT>
ExpectR2NearArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)238 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
239     const Array2D<NativeT>& expected, const LiteralSlice& actual,
240     const ErrorSpec& error) {
241   EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
242 }
243 
244 template <typename NativeT>
ExpectR3NearArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)245 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
246     const Array3D<NativeT>& expected, const LiteralSlice& actual,
247     const ErrorSpec& error) {
248   EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
249 }
250 
251 template <typename NativeT>
ExpectR4NearArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)252 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
253     const Array4D<NativeT>& expected, const LiteralSlice& actual,
254     const ErrorSpec& error) {
255   EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
256 }
257 
258 }  // namespace xla
259 
260 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
261