1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
18
19 #include <stddef.h>
20
21 #include <cstdint>
22 #include <memory>
23 #include <optional>
24 #include <ostream>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29
30 #include <gmock/gmock.h>
31 #include <gtest/gtest.h>
32 #include "absl/status/status.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/lite/interpreter.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/string_type.h"
37
38 namespace tflite {
39
40 // These two functions implement usability printing for TfLiteTensor dimensions
41 // and coordinates. By default dimensions are interpreted depending on the size:
42 // 1:Linear, 2:HW, 3: HWC, 4:BHWC. If there are more than 4 dimensions,
43 // absl::nullopt will be returned.
44 absl::optional<std::string> ShapeToString(TfLiteIntArray* shape);
45 absl::optional<std::string> CoordinateToString(TfLiteIntArray* shape,
46 int linear);
47
48 template <typename TupleMatcher>
49 class TensorEqMatcher {
50 public:
TensorEqMatcher(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)51 TensorEqMatcher(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
52 : tuple_matcher_(tuple_matcher), rhs_(rhs) {}
53
54 // Make TensorEqMatcher movable only (The copy operations are implicitly
55 // deleted).
56 TensorEqMatcher(TensorEqMatcher&& other) = default;
57 TensorEqMatcher& operator=(TensorEqMatcher&& other) = default;
58
59 template <typename T>
60 operator testing::Matcher<T>() const { // NOLINT
61 return testing::Matcher<T>(new Impl(tuple_matcher_, rhs_));
62 }
63
64 class Impl : public testing::MatcherInterface<TfLiteTensor> {
65 public:
66 typedef ::std::tuple<float, float> InnerMatcherArg;
67
Impl(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)68 Impl(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
69 : mono_tuple_matcher_(
70 testing::SafeMatcherCast<InnerMatcherArg>(tuple_matcher)),
71 rhs_(rhs) {}
72
73 // Make Impl movable only (The copy operations are implicitly deleted).
74 Impl(Impl&& other) = default;
75 Impl& operator=(Impl&& other) = default;
76
77 // Define what gtest framework will print for the Expected field.
DescribeTo(std::ostream * os)78 void DescribeTo(std::ostream* os) const override {
79 std::string shape;
80 absl::optional<std::string> result = ShapeToString(rhs_.dims);
81 if (result.has_value()) {
82 shape = std::move(result.value());
83 } else {
84 shape = "[error: unsupported number of dimensions]";
85 }
86 *os << "tensor which has the shape of " << shape
87 << ", where each value and its corresponding expected value ";
88 mono_tuple_matcher_.DescribeTo(os);
89 }
90
MatchAndExplain(TfLiteTensor lhs,testing::MatchResultListener * listener)91 bool MatchAndExplain(
92 TfLiteTensor lhs,
93 testing::MatchResultListener* listener) const override {
94 // 1. Check that TfLiteTensor data type is supported.
95 // Support for other data types will be added on demand.
96 if (lhs.type != kTfLiteFloat32 || rhs_.type != kTfLiteFloat32) {
97 *listener << "which data type is not float32, which is not currently "
98 "supported.";
99 return false;
100 }
101
102 // 2. Check that dimensions' sizes match. Otherwise, we are not able to
103 // compare tensors.
104 if (lhs.dims->size != rhs_.dims->size) {
105 *listener << "which is different from the expected shape of size "
106 << rhs_.dims->size;
107 return false;
108 }
109 // 3. Check that dimensions' values are equal as well. We are not able to
110 // compare tensors of different shapes, even if the total elements count
111 // matches.
112 bool dims_are_equal = true;
113 for (int i = 0; i < lhs.dims->size; i++) {
114 dims_are_equal &= lhs.dims->data[i] == rhs_.dims->data[i];
115 }
116 if (!dims_are_equal) {
117 std::string shape;
118 absl::optional<std::string> result = ShapeToString(rhs_.dims);
119 if (result.has_value()) {
120 shape = std::move(result.value());
121 } else {
122 shape = "[error: unsupported number of dimensions]";
123 }
124 *listener << "which is different from the expected shape " << shape;
125 return false;
126 }
127
128 // 4. Proceed to data comparison. Iterate through elements as they lay
129 // flat. If some pair of elements don't match, deduct the coordinate
130 // basing on the dimensions, then return.
131 absl::Span<float> lhs_span(lhs.data.f, lhs.bytes / sizeof(float));
132 absl::Span<float> rhs_span(rhs_.data.f, rhs_.bytes / sizeof(float));
133
134 auto left = lhs_span.begin();
135 auto right = rhs_span.begin();
136 for (size_t i = 0; i != lhs_span.size(); ++i, ++left, ++right) {
137 if (listener->IsInterested()) {
138 testing::StringMatchResultListener inner_listener;
139 if (!mono_tuple_matcher_.MatchAndExplain({*left, *right},
140 &inner_listener)) {
141 *listener << "where the value pair (";
142 testing::internal::UniversalPrint(*left, listener->stream());
143 *listener << ", ";
144 testing::internal::UniversalPrint(*right, listener->stream());
145 std::string coordinate;
146 absl::optional<std::string> result =
147 CoordinateToString(lhs.dims, i);
148 if (result.has_value()) {
149 coordinate = std::move(result.value());
150 } else {
151 coordinate = "[error: unsupported number of dimensions]";
152 }
153 *listener << ") with coordinate " << coordinate << " don't match";
154 testing::internal::PrintIfNotEmpty(inner_listener.str(),
155 listener->stream());
156 return false;
157 }
158 } else {
159 if (!mono_tuple_matcher_.Matches({*left, *right})) return false;
160 }
161 }
162
163 return true;
164 }
165
166 private:
167 const testing::Matcher<InnerMatcherArg> mono_tuple_matcher_;
168 const TfLiteTensor rhs_;
169 };
170
171 private:
172 const TupleMatcher tuple_matcher_;
173 const TfLiteTensor rhs_;
174 };
175
176 // Builds interpreter for a model, allocates tensors.
177 absl::Status BuildInterpreter(const Model* model,
178 std::unique_ptr<Interpreter>* interpreter);
179
180 // Allocates tensors for a given interpreter.
181 absl::Status AllocateTensors(std::unique_ptr<Interpreter>* interpreter);
182
183 // Modifies graph with given delegate.
184 absl::Status ModifyGraphWithDelegate(std::unique_ptr<Interpreter>* interpreter,
185 TfLiteDelegate* delegate);
186
187 // Initializes inputs with consequent values of some fixed range.
188 void InitializeInputs(int left, int right,
189 std::unique_ptr<Interpreter>* interpreter);
190
191 // Invokes a prebuilt interpreter.
192 absl::Status Invoke(std::unique_ptr<Interpreter>* interpreter);
193
194 // Usability structure, which is used to pass parameters data to parameterized
195 // tests.
196 struct TestParams {
197 // A gtest name, which will be used for a generated tests.
198 std::string name;
199
200 // Function, which returns a TFLite model, associated with this test name.
201 std::vector<uint8_t> model;
202 };
203
204 // Defines how the TestParams should be printed into the command line if
205 // something fails during testing.
206 std::ostream& operator<<(std::ostream& os, const TestParams& param);
207
208 } // namespace tflite
209
210 // Gtest framework uses this function to describe TfLiteTensor if something
211 // fails. TfLiteTensor is defined in global namespace, same should be done for
212 // streaming operator.
213 std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor);
214
215 // Defines a matcher to compare two TfLiteTensors pointwise using the given
216 // tuple matcher for comparing their values.
217 template <typename TupleMatcherT>
TensorEq(const TupleMatcherT & matcher,const TfLiteTensor & rhs)218 inline tflite::TensorEqMatcher<TupleMatcherT> TensorEq(
219 const TupleMatcherT& matcher, const TfLiteTensor& rhs) {
220 return tflite::TensorEqMatcher<TupleMatcherT>(matcher, rhs);
221 }
222
223 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
224