• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_FULLY_CONNECTED_TESTER_H_
17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_FULLY_CONNECTED_TESTER_H_
18 
19 #include <cstdint>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace tflite {
29 namespace xnnpack {
30 
31 class QuantizedFullyConnectedTester {
32  public:
33   QuantizedFullyConnectedTester() = default;
34   QuantizedFullyConnectedTester(const QuantizedFullyConnectedTester&) = delete;
35   QuantizedFullyConnectedTester& operator=(
36       const QuantizedFullyConnectedTester&) = delete;
37 
InputShape(std::initializer_list<int32_t> shape)38   inline QuantizedFullyConnectedTester& InputShape(
39       std::initializer_list<int32_t> shape) {
40     for (auto it = shape.begin(); it != shape.end(); ++it) {
41       EXPECT_GT(*it, 0);
42     }
43     input_shape_ = std::vector<int32_t>(shape.begin(), shape.end());
44     input_size_ = ComputeSize(input_shape_);
45     return *this;
46   }
47 
InputShape()48   inline const std::vector<int32_t>& InputShape() const { return input_shape_; }
49 
InputSize()50   inline int32_t InputSize() const { return input_size_; }
51 
InputChannels(int32_t input_channels)52   inline QuantizedFullyConnectedTester& InputChannels(int32_t input_channels) {
53     EXPECT_GT(input_channels, 0);
54     input_channels_ = input_channels;
55     return *this;
56   }
57 
InputChannels()58   inline int32_t InputChannels() const { return input_channels_; }
59 
OutputChannels(int32_t output_channels)60   inline QuantizedFullyConnectedTester& OutputChannels(
61       int32_t output_channels) {
62     EXPECT_GT(output_channels, 0);
63     output_channels_ = output_channels;
64     return *this;
65   }
66 
OutputChannels()67   inline int32_t OutputChannels() const { return output_channels_; }
68 
69   std::vector<int32_t> OutputShape() const;
70 
InputZeroPoint(int32_t input_zero_point)71   inline QuantizedFullyConnectedTester& InputZeroPoint(
72       int32_t input_zero_point) {
73     input_zero_point_ = input_zero_point;
74     return *this;
75   }
76 
InputZeroPoint()77   inline int32_t InputZeroPoint() const { return input_zero_point_; }
78 
FilterZeroPoint(int32_t filter_zero_point)79   inline QuantizedFullyConnectedTester& FilterZeroPoint(
80       int32_t filter_zero_point) {
81     filter_zero_point_ = filter_zero_point;
82     return *this;
83   }
84 
FilterZeroPoint()85   inline int32_t FilterZeroPoint() const { return filter_zero_point_; }
86 
OutputZeroPoint(int32_t output_zero_point)87   inline QuantizedFullyConnectedTester& OutputZeroPoint(
88       int32_t output_zero_point) {
89     output_zero_point_ = output_zero_point;
90     return *this;
91   }
92 
OutputZeroPoint()93   inline int32_t OutputZeroPoint() const { return output_zero_point_; }
94 
InputScale(float input_scale)95   inline QuantizedFullyConnectedTester& InputScale(float input_scale) {
96     input_scale_ = input_scale;
97     return *this;
98   }
99 
InputScale()100   inline float InputScale() const { return input_scale_; }
101 
FilterScale(float filter_scale)102   inline QuantizedFullyConnectedTester& FilterScale(float filter_scale) {
103     filter_scale_ = filter_scale;
104     return *this;
105   }
106 
FilterScale()107   inline float FilterScale() const { return filter_scale_; }
108 
OutputScale(float output_scale)109   inline QuantizedFullyConnectedTester& OutputScale(float output_scale) {
110     output_scale_ = output_scale;
111     return *this;
112   }
113 
OutputScale()114   inline float OutputScale() const { return output_scale_; }
115 
KeepDims(bool keep_dims)116   inline QuantizedFullyConnectedTester& KeepDims(bool keep_dims) {
117     keep_dims_ = keep_dims;
118     return *this;
119   }
120 
KeepDims()121   inline bool KeepDims() const { return keep_dims_; }
122 
Unsigned()123   inline bool Unsigned() const { return filter_zero_point_ != 0; }
124 
NoBias()125   inline QuantizedFullyConnectedTester& NoBias() {
126     has_bias_ = false;
127     return *this;
128   }
129 
WithBias()130   inline QuantizedFullyConnectedTester& WithBias() {
131     has_bias_ = true;
132     return *this;
133   }
134 
ReluActivation()135   inline QuantizedFullyConnectedTester& ReluActivation() {
136     activation_ = ::tflite::ActivationFunctionType_RELU;
137     return *this;
138   }
139 
Relu6Activation()140   inline QuantizedFullyConnectedTester& Relu6Activation() {
141     activation_ = ::tflite::ActivationFunctionType_RELU6;
142     return *this;
143   }
144 
ReluMinus1To1Activation()145   inline QuantizedFullyConnectedTester& ReluMinus1To1Activation() {
146     activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
147     return *this;
148   }
149 
WeightsCache(TfLiteXNNPackDelegateWeightsCache * weights_cache)150   inline QuantizedFullyConnectedTester& WeightsCache(
151       TfLiteXNNPackDelegateWeightsCache* weights_cache) {
152     weights_cache_ = weights_cache;
153     return *this;
154   }
155 
156   template <class T>
157   void Test(Interpreter* delegate_interpreter,
158             Interpreter* default_interpreter) const;
159 
160   void Test(TfLiteDelegate* delegate) const;
161 
162  private:
163   std::vector<char> CreateTfLiteModel() const;
164 
HasBias()165   inline bool HasBias() const { return has_bias_; }
166 
Activation()167   inline ::tflite::ActivationFunctionType Activation() const {
168     return activation_;
169   }
170 
171   static int32_t ComputeSize(const std::vector<int32_t>& shape);
172 
173   std::vector<int32_t> input_shape_;
174   int32_t input_size_ = 1;
175   int32_t input_channels_ = 1;
176   int32_t output_channels_ = 1;
177   int32_t input_zero_point_ = 0;
178   int32_t filter_zero_point_ = 0;
179   int32_t output_zero_point_ = 0;
180   float input_scale_ = 0.8f;
181   float filter_scale_ = 0.75f;
182   float output_scale_ = 1.5f;
183   bool keep_dims_ = false;
184   bool has_bias_ = true;
185   ::tflite::ActivationFunctionType activation_ =
186       ::tflite::ActivationFunctionType_NONE;
187   TfLiteXNNPackDelegateWeightsCache* weights_cache_ = nullptr;
188 };
189 
190 }  // namespace xnnpack
191 }  // namespace tflite
192 
193 #endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_FULLY_CONNECTED_TESTER_H_
194