• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_
17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_
18 
19 #include <cstdint>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace xnnpack {
28 
29 class Conv2DTester {
30  public:
31   Conv2DTester() = default;
32   Conv2DTester(const Conv2DTester&) = delete;
33   Conv2DTester& operator=(const Conv2DTester&) = delete;
34 
BatchSize(int32_t batch_size)35   inline Conv2DTester& BatchSize(int32_t batch_size) {
36     EXPECT_GT(batch_size, 0);
37     batch_size_ = batch_size;
38     return *this;
39   }
40 
BatchSize()41   inline int32_t BatchSize() const { return batch_size_; }
42 
InputChannels(int32_t input_channels)43   inline Conv2DTester& InputChannels(int32_t input_channels) {
44     EXPECT_GT(input_channels, 0);
45     input_channels_ = input_channels;
46     return *this;
47   }
48 
InputChannels()49   inline int32_t InputChannels() const { return input_channels_; }
50 
OutputChannels(int32_t output_channels)51   inline Conv2DTester& OutputChannels(int32_t output_channels) {
52     EXPECT_GT(output_channels, 0);
53     output_channels_ = output_channels;
54     return *this;
55   }
56 
OutputChannels()57   inline int32_t OutputChannels() const { return output_channels_; }
58 
InputHeight(int32_t input_height)59   inline Conv2DTester& InputHeight(int32_t input_height) {
60     EXPECT_GT(input_height, 0);
61     input_height_ = input_height;
62     return *this;
63   }
64 
InputHeight()65   inline int32_t InputHeight() const { return input_height_; }
66 
InputWidth(int32_t input_width)67   inline Conv2DTester& InputWidth(int32_t input_width) {
68     EXPECT_GT(input_width, 0);
69     input_width_ = input_width;
70     return *this;
71   }
72 
InputWidth()73   inline int32_t InputWidth() const { return input_width_; }
74 
OutputWidth()75   inline int32_t OutputWidth() const {
76     if (Padding() == ::tflite::Padding_SAME) {
77       EXPECT_GE(InputWidth(), 1);
78       return (InputWidth() - 1) / StrideWidth() + 1;
79     } else {
80       EXPECT_GE(InputWidth(), DilatedKernelWidth());
81       return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth();
82     }
83   }
84 
OutputHeight()85   inline int32_t OutputHeight() const {
86     if (Padding() == ::tflite::Padding_SAME) {
87       EXPECT_GE(InputHeight(), 1);
88       return (InputHeight() - 1) / StrideHeight() + 1;
89     } else {
90       EXPECT_GE(InputHeight(), DilatedKernelHeight());
91       return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight();
92     }
93   }
94 
KernelHeight(int32_t kernel_height)95   inline Conv2DTester& KernelHeight(int32_t kernel_height) {
96     EXPECT_GT(kernel_height, 0);
97     kernel_height_ = kernel_height;
98     return *this;
99   }
100 
KernelHeight()101   inline int32_t KernelHeight() const { return kernel_height_; }
102 
KernelWidth(int32_t kernel_width)103   inline Conv2DTester& KernelWidth(int32_t kernel_width) {
104     EXPECT_GT(kernel_width, 0);
105     kernel_width_ = kernel_width;
106     return *this;
107   }
108 
KernelWidth()109   inline int32_t KernelWidth() const { return kernel_width_; }
110 
StrideHeight(int32_t stride_height)111   inline Conv2DTester& StrideHeight(int32_t stride_height) {
112     EXPECT_GT(stride_height, 0);
113     stride_height_ = stride_height;
114     return *this;
115   }
116 
StrideHeight()117   inline int32_t StrideHeight() const { return stride_height_; }
118 
StrideWidth(int32_t stride_width)119   inline Conv2DTester& StrideWidth(int32_t stride_width) {
120     EXPECT_GT(stride_width, 0);
121     stride_width_ = stride_width;
122     return *this;
123   }
124 
StrideWidth()125   inline int32_t StrideWidth() const { return stride_width_; }
126 
DilationHeight(int32_t dilation_height)127   inline Conv2DTester& DilationHeight(int32_t dilation_height) {
128     EXPECT_GT(dilation_height, 0);
129     dilation_height_ = dilation_height;
130     return *this;
131   }
132 
DilationHeight()133   inline int32_t DilationHeight() const { return dilation_height_; }
134 
DilationWidth(int32_t dilation_width)135   inline Conv2DTester& DilationWidth(int32_t dilation_width) {
136     EXPECT_GT(dilation_width, 0);
137     dilation_width_ = dilation_width;
138     return *this;
139   }
140 
DilationWidth()141   inline int32_t DilationWidth() const { return dilation_width_; }
142 
DilatedKernelHeight()143   inline int32_t DilatedKernelHeight() const {
144     return (KernelHeight() - 1) * DilationHeight() + 1;
145   }
146 
DilatedKernelWidth()147   inline int32_t DilatedKernelWidth() const {
148     return (KernelWidth() - 1) * DilationWidth() + 1;
149   }
150 
FP16Weights()151   inline Conv2DTester& FP16Weights() {
152     fp16_weights_ = true;
153     return *this;
154   }
155 
FP16Weights()156   inline bool FP16Weights() const { return fp16_weights_; }
157 
SparseWeights()158   inline Conv2DTester& SparseWeights() {
159     sparse_weights_ = true;
160     return *this;
161   }
162 
SparseWeights()163   inline bool SparseWeights() const { return sparse_weights_; }
164 
SamePadding()165   inline Conv2DTester& SamePadding() {
166     padding_ = ::tflite::Padding_SAME;
167     return *this;
168   }
169 
ValidPadding()170   inline Conv2DTester& ValidPadding() {
171     padding_ = ::tflite::Padding_VALID;
172     return *this;
173   }
174 
ReluActivation()175   inline Conv2DTester& ReluActivation() {
176     activation_ = ::tflite::ActivationFunctionType_RELU;
177     return *this;
178   }
179 
Relu6Activation()180   inline Conv2DTester& Relu6Activation() {
181     activation_ = ::tflite::ActivationFunctionType_RELU6;
182     return *this;
183   }
184 
ReluMinus1To1Activation()185   inline Conv2DTester& ReluMinus1To1Activation() {
186     activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
187     return *this;
188   }
189 
TanhActivation()190   inline Conv2DTester& TanhActivation() {
191     activation_ = ::tflite::ActivationFunctionType_TANH;
192     return *this;
193   }
194 
SignBitActivation()195   inline Conv2DTester& SignBitActivation() {
196     activation_ = ::tflite::ActivationFunctionType_SIGN_BIT;
197     return *this;
198   }
199 
200   void Test(TfLiteDelegate* delegate) const;
201 
202  private:
203   std::vector<char> CreateTfLiteModel() const;
204 
Padding()205   inline ::tflite::Padding Padding() const { return padding_; }
206 
Activation()207   inline ::tflite::ActivationFunctionType Activation() const {
208     return activation_;
209   }
210 
211   int32_t batch_size_ = 1;
212   int32_t input_channels_ = 1;
213   int32_t output_channels_ = 1;
214   int32_t input_height_ = 1;
215   int32_t input_width_ = 1;
216   int32_t kernel_height_ = 1;
217   int32_t kernel_width_ = 1;
218   int32_t stride_height_ = 1;
219   int32_t stride_width_ = 1;
220   int32_t dilation_height_ = 1;
221   int32_t dilation_width_ = 1;
222   bool fp16_weights_ = false;
223   bool sparse_weights_ = false;
224   ::tflite::Padding padding_ = ::tflite::Padding_VALID;
225   ::tflite::ActivationFunctionType activation_ =
226       ::tflite::ActivationFunctionType_NONE;
227 };
228 
229 }  // namespace xnnpack
230 }  // namespace tflite
231 
232 #endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_
233