• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_
17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_
18 
19 #include <cstdint>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26 
27 namespace tflite {
28 namespace xnnpack {
29 
30 class Conv2DTester {
31  public:
32   Conv2DTester() = default;
33   Conv2DTester(const Conv2DTester&) = delete;
34   Conv2DTester& operator=(const Conv2DTester&) = delete;
35 
BatchSize(int32_t batch_size)36   inline Conv2DTester& BatchSize(int32_t batch_size) {
37     EXPECT_GT(batch_size, 0);
38     batch_size_ = batch_size;
39     return *this;
40   }
41 
BatchSize()42   inline int32_t BatchSize() const { return batch_size_; }
43 
InputChannels(int32_t input_channels)44   inline Conv2DTester& InputChannels(int32_t input_channels) {
45     EXPECT_GT(input_channels, 0);
46     input_channels_ = input_channels;
47     return *this;
48   }
49 
InputChannels()50   inline int32_t InputChannels() const { return input_channels_; }
51 
OutputChannels(int32_t output_channels)52   inline Conv2DTester& OutputChannels(int32_t output_channels) {
53     EXPECT_GT(output_channels, 0);
54     output_channels_ = output_channels;
55     return *this;
56   }
57 
OutputChannels()58   inline int32_t OutputChannels() const { return output_channels_; }
59 
Groups(int32_t groups)60   inline Conv2DTester& Groups(int32_t groups) {
61     EXPECT_EQ(InputChannels() % groups, 0);
62     EXPECT_EQ(OutputChannels() % groups, 0);
63     groups_ = groups;
64     return *this;
65   }
66 
Groups()67   inline int32_t Groups() const { return groups_; }
68 
KernelInputChannels()69   inline int32_t KernelInputChannels() const {
70     return input_channels_ / groups_;
71   }
72 
InputHeight(int32_t input_height)73   inline Conv2DTester& InputHeight(int32_t input_height) {
74     EXPECT_GT(input_height, 0);
75     input_height_ = input_height;
76     return *this;
77   }
78 
InputHeight()79   inline int32_t InputHeight() const { return input_height_; }
80 
InputWidth(int32_t input_width)81   inline Conv2DTester& InputWidth(int32_t input_width) {
82     EXPECT_GT(input_width, 0);
83     input_width_ = input_width;
84     return *this;
85   }
86 
InputWidth()87   inline int32_t InputWidth() const { return input_width_; }
88 
OutputWidth()89   inline int32_t OutputWidth() const {
90     if (Padding() == ::tflite::Padding_SAME) {
91       EXPECT_GE(InputWidth(), 1);
92       return (InputWidth() - 1) / StrideWidth() + 1;
93     } else {
94       EXPECT_GE(InputWidth(), DilatedKernelWidth());
95       return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth();
96     }
97   }
98 
OutputHeight()99   inline int32_t OutputHeight() const {
100     if (Padding() == ::tflite::Padding_SAME) {
101       EXPECT_GE(InputHeight(), 1);
102       return (InputHeight() - 1) / StrideHeight() + 1;
103     } else {
104       EXPECT_GE(InputHeight(), DilatedKernelHeight());
105       return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight();
106     }
107   }
108 
KernelHeight(int32_t kernel_height)109   inline Conv2DTester& KernelHeight(int32_t kernel_height) {
110     EXPECT_GT(kernel_height, 0);
111     kernel_height_ = kernel_height;
112     return *this;
113   }
114 
KernelHeight()115   inline int32_t KernelHeight() const { return kernel_height_; }
116 
KernelWidth(int32_t kernel_width)117   inline Conv2DTester& KernelWidth(int32_t kernel_width) {
118     EXPECT_GT(kernel_width, 0);
119     kernel_width_ = kernel_width;
120     return *this;
121   }
122 
KernelWidth()123   inline int32_t KernelWidth() const { return kernel_width_; }
124 
StrideHeight(int32_t stride_height)125   inline Conv2DTester& StrideHeight(int32_t stride_height) {
126     EXPECT_GT(stride_height, 0);
127     stride_height_ = stride_height;
128     return *this;
129   }
130 
StrideHeight()131   inline int32_t StrideHeight() const { return stride_height_; }
132 
StrideWidth(int32_t stride_width)133   inline Conv2DTester& StrideWidth(int32_t stride_width) {
134     EXPECT_GT(stride_width, 0);
135     stride_width_ = stride_width;
136     return *this;
137   }
138 
StrideWidth()139   inline int32_t StrideWidth() const { return stride_width_; }
140 
DilationHeight(int32_t dilation_height)141   inline Conv2DTester& DilationHeight(int32_t dilation_height) {
142     EXPECT_GT(dilation_height, 0);
143     dilation_height_ = dilation_height;
144     return *this;
145   }
146 
DilationHeight()147   inline int32_t DilationHeight() const { return dilation_height_; }
148 
DilationWidth(int32_t dilation_width)149   inline Conv2DTester& DilationWidth(int32_t dilation_width) {
150     EXPECT_GT(dilation_width, 0);
151     dilation_width_ = dilation_width;
152     return *this;
153   }
154 
DilationWidth()155   inline int32_t DilationWidth() const { return dilation_width_; }
156 
DilatedKernelHeight()157   inline int32_t DilatedKernelHeight() const {
158     return (KernelHeight() - 1) * DilationHeight() + 1;
159   }
160 
DilatedKernelWidth()161   inline int32_t DilatedKernelWidth() const {
162     return (KernelWidth() - 1) * DilationWidth() + 1;
163   }
164 
FP16Weights()165   inline Conv2DTester& FP16Weights() {
166     fp16_weights_ = true;
167     return *this;
168   }
169 
FP16Weights()170   inline bool FP16Weights() const { return fp16_weights_; }
171 
INT8Weights()172   inline Conv2DTester& INT8Weights() {
173     int8_weights_ = true;
174     return *this;
175   }
176 
INT8Weights()177   inline bool INT8Weights() const { return int8_weights_; }
178 
INT8ChannelWiseWeights()179   inline Conv2DTester& INT8ChannelWiseWeights() {
180     int8_channel_wise_weights_ = true;
181     return *this;
182   }
183 
INT8ChannelWiseWeights()184   inline bool INT8ChannelWiseWeights() const {
185     return int8_channel_wise_weights_;
186   }
187 
SparseWeights()188   inline Conv2DTester& SparseWeights() {
189     sparse_weights_ = true;
190     return *this;
191   }
192 
SparseWeights()193   inline bool SparseWeights() const { return sparse_weights_; }
194 
SamePadding()195   inline Conv2DTester& SamePadding() {
196     padding_ = ::tflite::Padding_SAME;
197     return *this;
198   }
199 
ValidPadding()200   inline Conv2DTester& ValidPadding() {
201     padding_ = ::tflite::Padding_VALID;
202     return *this;
203   }
204 
ReluActivation()205   inline Conv2DTester& ReluActivation() {
206     activation_ = ::tflite::ActivationFunctionType_RELU;
207     return *this;
208   }
209 
Relu6Activation()210   inline Conv2DTester& Relu6Activation() {
211     activation_ = ::tflite::ActivationFunctionType_RELU6;
212     return *this;
213   }
214 
ReluMinus1To1Activation()215   inline Conv2DTester& ReluMinus1To1Activation() {
216     activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
217     return *this;
218   }
219 
TanhActivation()220   inline Conv2DTester& TanhActivation() {
221     activation_ = ::tflite::ActivationFunctionType_TANH;
222     return *this;
223   }
224 
SignBitActivation()225   inline Conv2DTester& SignBitActivation() {
226     activation_ = ::tflite::ActivationFunctionType_SIGN_BIT;
227     return *this;
228   }
229 
WeightsCache(TfLiteXNNPackDelegateWeightsCache * weights_cache)230   inline Conv2DTester& WeightsCache(
231       TfLiteXNNPackDelegateWeightsCache* weights_cache) {
232     weights_cache_ = weights_cache;
233     return *this;
234   }
235 
236   void Test(TfLiteDelegate* delegate) const;
237 
238   std::vector<char> CreateTfLiteModel() const;
239 
240  private:
Padding()241   inline ::tflite::Padding Padding() const { return padding_; }
242 
Activation()243   inline ::tflite::ActivationFunctionType Activation() const {
244     return activation_;
245   }
246 
247   int32_t batch_size_ = 1;
248   int32_t input_channels_ = 1;
249   int32_t output_channels_ = 1;
250   int32_t groups_ = 1;
251   int32_t input_height_ = 1;
252   int32_t input_width_ = 1;
253   int32_t kernel_height_ = 1;
254   int32_t kernel_width_ = 1;
255   int32_t stride_height_ = 1;
256   int32_t stride_width_ = 1;
257   int32_t dilation_height_ = 1;
258   int32_t dilation_width_ = 1;
259   bool fp16_weights_ = false;
260   bool int8_weights_ = false;
261   bool int8_channel_wise_weights_ = false;
262   bool sparse_weights_ = false;
263   ::tflite::Padding padding_ = ::tflite::Padding_VALID;
264   ::tflite::ActivationFunctionType activation_ =
265       ::tflite::ActivationFunctionType_NONE;
266   TfLiteXNNPackDelegateWeightsCache* weights_cache_ = nullptr;
267 };
268 
269 }  // namespace xnnpack
270 }  // namespace tflite
271 
272 #endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_
273