• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_
17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_
18 
19 #include <cstdint>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace xnnpack {
28 
29 class DepthwiseConv2DTester {
30  public:
31   DepthwiseConv2DTester() = default;
32   DepthwiseConv2DTester(const DepthwiseConv2DTester&) = delete;
33   DepthwiseConv2DTester& operator=(const DepthwiseConv2DTester&) = delete;
34 
BatchSize(int32_t batch_size)35   inline DepthwiseConv2DTester& BatchSize(int32_t batch_size) {
36     EXPECT_GT(batch_size, 0);
37     batch_size_ = batch_size;
38     return *this;
39   }
40 
BatchSize()41   inline int32_t BatchSize() const { return batch_size_; }
42 
InputChannels(int32_t input_channels)43   inline DepthwiseConv2DTester& InputChannels(int32_t input_channels) {
44     EXPECT_GT(input_channels, 0);
45     input_channels_ = input_channels;
46     return *this;
47   }
48 
InputChannels()49   inline int32_t InputChannels() const { return input_channels_; }
50 
DepthMultiplier(int32_t depth_multiplier)51   inline DepthwiseConv2DTester& DepthMultiplier(int32_t depth_multiplier) {
52     EXPECT_GT(depth_multiplier, 0);
53     depth_multiplier_ = depth_multiplier;
54     return *this;
55   }
56 
DepthMultiplier()57   inline int32_t DepthMultiplier() const { return depth_multiplier_; }
58 
OutputChannels()59   inline int32_t OutputChannels() const {
60     return DepthMultiplier() * InputChannels();
61   }
62 
InputHeight(int32_t input_height)63   inline DepthwiseConv2DTester& InputHeight(int32_t input_height) {
64     EXPECT_GT(input_height, 0);
65     input_height_ = input_height;
66     return *this;
67   }
68 
InputHeight()69   inline int32_t InputHeight() const { return input_height_; }
70 
InputWidth(int32_t input_width)71   inline DepthwiseConv2DTester& InputWidth(int32_t input_width) {
72     EXPECT_GT(input_width, 0);
73     input_width_ = input_width;
74     return *this;
75   }
76 
InputWidth()77   inline int32_t InputWidth() const { return input_width_; }
78 
OutputWidth()79   inline int32_t OutputWidth() const {
80     if (Padding() == ::tflite::Padding_SAME) {
81       EXPECT_GE(InputWidth(), 1);
82       return (InputWidth() - 1) / StrideWidth() + 1;
83     } else {
84       EXPECT_GE(InputWidth(), DilatedKernelWidth());
85       return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth();
86     }
87   }
88 
OutputHeight()89   inline int32_t OutputHeight() const {
90     if (Padding() == ::tflite::Padding_SAME) {
91       EXPECT_GE(InputHeight(), 1);
92       return (InputHeight() - 1) / StrideHeight() + 1;
93     } else {
94       EXPECT_GE(InputHeight(), DilatedKernelHeight());
95       return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight();
96     }
97   }
98 
KernelHeight(int32_t kernel_height)99   inline DepthwiseConv2DTester& KernelHeight(int32_t kernel_height) {
100     EXPECT_GT(kernel_height, 0);
101     kernel_height_ = kernel_height;
102     return *this;
103   }
104 
KernelHeight()105   inline int32_t KernelHeight() const { return kernel_height_; }
106 
KernelWidth(int32_t kernel_width)107   inline DepthwiseConv2DTester& KernelWidth(int32_t kernel_width) {
108     EXPECT_GT(kernel_width, 0);
109     kernel_width_ = kernel_width;
110     return *this;
111   }
112 
KernelWidth()113   inline int32_t KernelWidth() const { return kernel_width_; }
114 
StrideHeight(int32_t stride_height)115   inline DepthwiseConv2DTester& StrideHeight(int32_t stride_height) {
116     EXPECT_GT(stride_height, 0);
117     stride_height_ = stride_height;
118     return *this;
119   }
120 
StrideHeight()121   inline int32_t StrideHeight() const { return stride_height_; }
122 
StrideWidth(int32_t stride_width)123   inline DepthwiseConv2DTester& StrideWidth(int32_t stride_width) {
124     EXPECT_GT(stride_width, 0);
125     stride_width_ = stride_width;
126     return *this;
127   }
128 
StrideWidth()129   inline int32_t StrideWidth() const { return stride_width_; }
130 
DilationHeight(int32_t dilation_height)131   inline DepthwiseConv2DTester& DilationHeight(int32_t dilation_height) {
132     EXPECT_GT(dilation_height, 0);
133     dilation_height_ = dilation_height;
134     return *this;
135   }
136 
DilationHeight()137   inline int32_t DilationHeight() const { return dilation_height_; }
138 
DilationWidth(int32_t dilation_width)139   inline DepthwiseConv2DTester& DilationWidth(int32_t dilation_width) {
140     EXPECT_GT(dilation_width, 0);
141     dilation_width_ = dilation_width;
142     return *this;
143   }
144 
DilationWidth()145   inline int32_t DilationWidth() const { return dilation_width_; }
146 
DilatedKernelHeight()147   inline int32_t DilatedKernelHeight() const {
148     return (KernelHeight() - 1) * DilationHeight() + 1;
149   }
150 
DilatedKernelWidth()151   inline int32_t DilatedKernelWidth() const {
152     return (KernelWidth() - 1) * DilationWidth() + 1;
153   }
154 
FP16Weights()155   inline DepthwiseConv2DTester& FP16Weights() {
156     fp16_weights_ = true;
157     return *this;
158   }
159 
FP16Weights()160   inline bool FP16Weights() const { return fp16_weights_; }
161 
SparseWeights()162   inline DepthwiseConv2DTester& SparseWeights() {
163     sparse_weights_ = true;
164     return *this;
165   }
166 
SparseWeights()167   inline bool SparseWeights() const { return sparse_weights_; }
168 
SamePadding()169   inline DepthwiseConv2DTester& SamePadding() {
170     padding_ = ::tflite::Padding_SAME;
171     return *this;
172   }
173 
ValidPadding()174   inline DepthwiseConv2DTester& ValidPadding() {
175     padding_ = ::tflite::Padding_VALID;
176     return *this;
177   }
178 
ReluActivation()179   inline DepthwiseConv2DTester& ReluActivation() {
180     activation_ = ::tflite::ActivationFunctionType_RELU;
181     return *this;
182   }
183 
Relu6Activation()184   inline DepthwiseConv2DTester& Relu6Activation() {
185     activation_ = ::tflite::ActivationFunctionType_RELU6;
186     return *this;
187   }
188 
ReluMinus1To1Activation()189   inline DepthwiseConv2DTester& ReluMinus1To1Activation() {
190     activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
191     return *this;
192   }
193 
TanhActivation()194   inline DepthwiseConv2DTester& TanhActivation() {
195     activation_ = ::tflite::ActivationFunctionType_TANH;
196     return *this;
197   }
198 
SignBitActivation()199   inline DepthwiseConv2DTester& SignBitActivation() {
200     activation_ = ::tflite::ActivationFunctionType_SIGN_BIT;
201     return *this;
202   }
203 
204   void Test(TfLiteDelegate* delegate) const;
205 
206  private:
207   std::vector<char> CreateTfLiteModel() const;
208 
Padding()209   inline ::tflite::Padding Padding() const { return padding_; }
210 
Activation()211   inline ::tflite::ActivationFunctionType Activation() const {
212     return activation_;
213   }
214 
215   int32_t batch_size_ = 1;
216   int32_t input_channels_ = 1;
217   int32_t depth_multiplier_ = 1;
218   int32_t input_height_ = 1;
219   int32_t input_width_ = 1;
220   int32_t kernel_height_ = 1;
221   int32_t kernel_width_ = 1;
222   int32_t stride_height_ = 1;
223   int32_t stride_width_ = 1;
224   int32_t dilation_height_ = 1;
225   int32_t dilation_width_ = 1;
226   bool fp16_weights_ = false;
227   bool sparse_weights_ = false;
228   ::tflite::Padding padding_ = ::tflite::Padding_VALID;
229   ::tflite::ActivationFunctionType activation_ =
230       ::tflite::ActivationFunctionType_NONE;
231 };
232 
233 }  // namespace xnnpack
234 }  // namespace tflite
235 
236 #endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_
237