• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
18 
19 #include <stddef.h>
20 
21 #include <cstdint>
22 #include <memory>
23 #include <optional>
24 #include <ostream>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29 
30 #include <gmock/gmock.h>
31 #include <gtest/gtest.h>
32 #include "absl/status/status.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/lite/interpreter.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/string_type.h"
37 
38 namespace tflite {
39 
40 // These two functions implement usability printing for TfLiteTensor dimensions
41 // and coordinates. By default dimensions are interpreted depending on the size:
42 // 1:Linear, 2:HW, 3: HWC, 4:BHWC. If there are more than 4 dimensions,
43 // absl::nullopt will be returned.
44 absl::optional<std::string> ShapeToString(TfLiteIntArray* shape);
45 absl::optional<std::string> CoordinateToString(TfLiteIntArray* shape,
46                                                int linear);
47 
48 template <typename TupleMatcher>
49 class TensorEqMatcher {
50  public:
TensorEqMatcher(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)51   TensorEqMatcher(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
52       : tuple_matcher_(tuple_matcher), rhs_(rhs) {}
53 
54   // Make TensorEqMatcher movable only (The copy operations are implicitly
55   // deleted).
56   TensorEqMatcher(TensorEqMatcher&& other) = default;
57   TensorEqMatcher& operator=(TensorEqMatcher&& other) = default;
58 
59   template <typename T>
60   operator testing::Matcher<T>() const {  // NOLINT
61     return testing::Matcher<T>(new Impl(tuple_matcher_, rhs_));
62   }
63 
64   class Impl : public testing::MatcherInterface<TfLiteTensor> {
65    public:
66     typedef ::std::tuple<float, float> InnerMatcherArg;
67 
Impl(const TupleMatcher & tuple_matcher,const TfLiteTensor & rhs)68     Impl(const TupleMatcher& tuple_matcher, const TfLiteTensor& rhs)
69         : mono_tuple_matcher_(
70               testing::SafeMatcherCast<InnerMatcherArg>(tuple_matcher)),
71           rhs_(rhs) {}
72 
73     // Make Impl movable only (The copy operations are implicitly deleted).
74     Impl(Impl&& other) = default;
75     Impl& operator=(Impl&& other) = default;
76 
77     // Define what gtest framework will print for the Expected field.
DescribeTo(std::ostream * os)78     void DescribeTo(std::ostream* os) const override {
79       std::string shape;
80       absl::optional<std::string> result = ShapeToString(rhs_.dims);
81       if (result.has_value()) {
82         shape = std::move(result.value());
83       } else {
84         shape = "[error: unsupported number of dimensions]";
85       }
86       *os << "tensor which has the shape of " << shape
87           << ", where each value and its corresponding expected value ";
88       mono_tuple_matcher_.DescribeTo(os);
89     }
90 
MatchAndExplain(TfLiteTensor lhs,testing::MatchResultListener * listener)91     bool MatchAndExplain(
92         TfLiteTensor lhs,
93         testing::MatchResultListener* listener) const override {
94       // 1. Check that TfLiteTensor data type is supported.
95       // Support for other data types will be added on demand.
96       if (lhs.type != kTfLiteFloat32 || rhs_.type != kTfLiteFloat32) {
97         *listener << "which data type is not float32, which is not currently "
98                      "supported.";
99         return false;
100       }
101 
102       // 2. Check that dimensions' sizes match. Otherwise, we are not able to
103       // compare tensors.
104       if (lhs.dims->size != rhs_.dims->size) {
105         *listener << "which is different from the expected shape of size "
106                   << rhs_.dims->size;
107         return false;
108       }
109       // 3. Check that dimensions' values are equal as well. We are not able to
110       // compare tensors of different shapes, even if the total elements count
111       // matches.
112       bool dims_are_equal = true;
113       for (int i = 0; i < lhs.dims->size; i++) {
114         dims_are_equal &= lhs.dims->data[i] == rhs_.dims->data[i];
115       }
116       if (!dims_are_equal) {
117         std::string shape;
118         absl::optional<std::string> result = ShapeToString(rhs_.dims);
119         if (result.has_value()) {
120           shape = std::move(result.value());
121         } else {
122           shape = "[error: unsupported number of dimensions]";
123         }
124         *listener << "which is different from the expected shape " << shape;
125         return false;
126       }
127 
128       // 4. Proceed to data comparison. Iterate through elements as they lay
129       // flat. If some pair of elements don't match, deduct the coordinate
130       // basing on the dimensions, then return.
131       absl::Span<float> lhs_span(lhs.data.f, lhs.bytes / sizeof(float));
132       absl::Span<float> rhs_span(rhs_.data.f, rhs_.bytes / sizeof(float));
133 
134       auto left = lhs_span.begin();
135       auto right = rhs_span.begin();
136       for (size_t i = 0; i != lhs_span.size(); ++i, ++left, ++right) {
137         if (listener->IsInterested()) {
138           testing::StringMatchResultListener inner_listener;
139           if (!mono_tuple_matcher_.MatchAndExplain({*left, *right},
140                                                    &inner_listener)) {
141             *listener << "where the value pair (";
142             testing::internal::UniversalPrint(*left, listener->stream());
143             *listener << ", ";
144             testing::internal::UniversalPrint(*right, listener->stream());
145             std::string coordinate;
146             absl::optional<std::string> result =
147                 CoordinateToString(lhs.dims, i);
148             if (result.has_value()) {
149               coordinate = std::move(result.value());
150             } else {
151               coordinate = "[error: unsupported number of dimensions]";
152             }
153             *listener << ") with coordinate " << coordinate << " don't match";
154             testing::internal::PrintIfNotEmpty(inner_listener.str(),
155                                                listener->stream());
156             return false;
157           }
158         } else {
159           if (!mono_tuple_matcher_.Matches({*left, *right})) return false;
160         }
161       }
162 
163       return true;
164     }
165 
166    private:
167     const testing::Matcher<InnerMatcherArg> mono_tuple_matcher_;
168     const TfLiteTensor rhs_;
169   };
170 
171  private:
172   const TupleMatcher tuple_matcher_;
173   const TfLiteTensor rhs_;
174 };
175 
176 // Builds interpreter for a model, allocates tensors.
177 absl::Status BuildInterpreter(const Model* model,
178                               std::unique_ptr<Interpreter>* interpreter);
179 
180 // Allocates tensors for a given interpreter.
181 absl::Status AllocateTensors(std::unique_ptr<Interpreter>* interpreter);
182 
183 // Modifies graph with given delegate.
184 absl::Status ModifyGraphWithDelegate(std::unique_ptr<Interpreter>* interpreter,
185                                      TfLiteDelegate* delegate);
186 
187 // Initializes inputs with consequent values of some fixed range.
188 void InitializeInputs(int left, int right,
189                       std::unique_ptr<Interpreter>* interpreter);
190 
191 // Invokes a prebuilt interpreter.
192 absl::Status Invoke(std::unique_ptr<Interpreter>* interpreter);
193 
194 // Usability structure, which is used to pass parameters data to parameterized
195 // tests.
196 struct TestParams {
197   // A gtest name, which will be used for a generated tests.
198   std::string name;
199 
200   // Function, which returns a TFLite model, associated with this test name.
201   std::vector<uint8_t> model;
202 };
203 
204 // Defines how the TestParams should be printed into the command line if
205 // something fails during testing.
206 std::ostream& operator<<(std::ostream& os, const TestParams& param);
207 
208 }  // namespace tflite
209 
210 // Gtest framework uses this function to describe TfLiteTensor if something
211 // fails. TfLiteTensor is defined in global namespace, same should be done for
212 // streaming operator.
213 std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor);
214 
215 // Defines a matcher to compare two TfLiteTensors pointwise using the given
216 // tuple matcher for comparing their values.
217 template <typename TupleMatcherT>
TensorEq(const TupleMatcherT & matcher,const TfLiteTensor & rhs)218 inline tflite::TensorEqMatcher<TupleMatcherT> TensorEq(
219     const TupleMatcherT& matcher, const TfLiteTensor& rhs) {
220   return tflite::TensorEqMatcher<TupleMatcherT>(matcher, rhs);
221 }
222 
223 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
224