1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_ 18 19 #include <cstdint> 20 #include <vector> 21 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/schema/schema_generated.h" 25 26 namespace tflite { 27 namespace xnnpack { 28 29 class DepthwiseConv2DTester { 30 public: 31 DepthwiseConv2DTester() = default; 32 DepthwiseConv2DTester(const DepthwiseConv2DTester&) = delete; 33 DepthwiseConv2DTester& operator=(const DepthwiseConv2DTester&) = delete; 34 BatchSize(int32_t batch_size)35 inline DepthwiseConv2DTester& BatchSize(int32_t batch_size) { 36 EXPECT_GT(batch_size, 0); 37 batch_size_ = batch_size; 38 return *this; 39 } 40 BatchSize()41 inline int32_t BatchSize() const { return batch_size_; } 42 InputChannels(int32_t input_channels)43 inline DepthwiseConv2DTester& InputChannels(int32_t input_channels) { 44 EXPECT_GT(input_channels, 0); 45 input_channels_ = input_channels; 46 return *this; 47 } 48 InputChannels()49 inline int32_t InputChannels() const { return input_channels_; } 50 DepthMultiplier(int32_t depth_multiplier)51 inline DepthwiseConv2DTester& DepthMultiplier(int32_t depth_multiplier) { 52 EXPECT_GT(depth_multiplier, 0); 53 depth_multiplier_ = depth_multiplier; 54 return *this; 55 } 56 DepthMultiplier()57 inline int32_t DepthMultiplier() const { return depth_multiplier_; } 58 OutputChannels()59 inline int32_t OutputChannels() const { 60 return DepthMultiplier() * InputChannels(); 61 } 62 InputHeight(int32_t input_height)63 inline DepthwiseConv2DTester& InputHeight(int32_t input_height) { 64 EXPECT_GT(input_height, 0); 65 input_height_ = input_height; 66 return *this; 67 } 68 InputHeight()69 inline int32_t InputHeight() const { return input_height_; } 70 InputWidth(int32_t input_width)71 inline DepthwiseConv2DTester& InputWidth(int32_t input_width) { 72 EXPECT_GT(input_width, 0); 73 input_width_ = input_width; 74 return *this; 75 } 76 InputWidth()77 inline int32_t InputWidth() const { return input_width_; } 78 OutputWidth()79 inline int32_t OutputWidth() const { 80 if (Padding() == ::tflite::Padding_SAME) { 81 EXPECT_GE(InputWidth(), 1); 82 return (InputWidth() - 1) / StrideWidth() + 1; 83 } else { 84 EXPECT_GE(InputWidth(), DilatedKernelWidth()); 85 return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth(); 86 } 87 } 88 OutputHeight()89 inline int32_t OutputHeight() const { 90 if (Padding() == ::tflite::Padding_SAME) { 91 EXPECT_GE(InputHeight(), 1); 92 return (InputHeight() - 1) / StrideHeight() + 1; 93 } else { 94 EXPECT_GE(InputHeight(), DilatedKernelHeight()); 95 return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight(); 96 } 97 } 98 KernelHeight(int32_t kernel_height)99 inline DepthwiseConv2DTester& KernelHeight(int32_t kernel_height) { 100 EXPECT_GT(kernel_height, 0); 101 kernel_height_ = kernel_height; 102 return *this; 103 } 104 KernelHeight()105 inline int32_t KernelHeight() const { return kernel_height_; } 106 KernelWidth(int32_t kernel_width)107 inline DepthwiseConv2DTester& KernelWidth(int32_t kernel_width) { 108 EXPECT_GT(kernel_width, 0); 109 kernel_width_ = kernel_width; 110 return *this; 111 } 112 KernelWidth()113 inline int32_t KernelWidth() const { return kernel_width_; } 114 StrideHeight(int32_t stride_height)115 inline DepthwiseConv2DTester& StrideHeight(int32_t stride_height) { 116 EXPECT_GT(stride_height, 0); 117 stride_height_ = stride_height; 118 return *this; 119 } 120 StrideHeight()121 inline int32_t StrideHeight() const { return stride_height_; } 122 StrideWidth(int32_t stride_width)123 inline DepthwiseConv2DTester& StrideWidth(int32_t stride_width) { 124 EXPECT_GT(stride_width, 0); 125 stride_width_ = stride_width; 126 return *this; 127 } 128 StrideWidth()129 inline int32_t StrideWidth() const { return stride_width_; } 130 DilationHeight(int32_t dilation_height)131 inline DepthwiseConv2DTester& DilationHeight(int32_t dilation_height) { 132 EXPECT_GT(dilation_height, 0); 133 dilation_height_ = dilation_height; 134 return *this; 135 } 136 DilationHeight()137 inline int32_t DilationHeight() const { return dilation_height_; } 138 DilationWidth(int32_t dilation_width)139 inline DepthwiseConv2DTester& DilationWidth(int32_t dilation_width) { 140 EXPECT_GT(dilation_width, 0); 141 dilation_width_ = dilation_width; 142 return *this; 143 } 144 DilationWidth()145 inline int32_t DilationWidth() const { return dilation_width_; } 146 DilatedKernelHeight()147 inline int32_t DilatedKernelHeight() const { 148 return (KernelHeight() - 1) * DilationHeight() + 1; 149 } 150 DilatedKernelWidth()151 inline int32_t DilatedKernelWidth() const { 152 return (KernelWidth() - 1) * DilationWidth() + 1; 153 } 154 FP16Weights()155 inline DepthwiseConv2DTester& FP16Weights() { 156 fp16_weights_ = true; 157 return *this; 158 } 159 FP16Weights()160 inline bool FP16Weights() const { return fp16_weights_; } 161 SparseWeights()162 inline DepthwiseConv2DTester& SparseWeights() { 163 sparse_weights_ = true; 164 return *this; 165 } 166 SparseWeights()167 inline bool SparseWeights() const { return sparse_weights_; } 168 SamePadding()169 inline DepthwiseConv2DTester& SamePadding() { 170 padding_ = ::tflite::Padding_SAME; 171 return *this; 172 } 173 ValidPadding()174 inline DepthwiseConv2DTester& ValidPadding() { 175 padding_ = ::tflite::Padding_VALID; 176 return *this; 177 } 178 ReluActivation()179 inline DepthwiseConv2DTester& ReluActivation() { 180 activation_ = ::tflite::ActivationFunctionType_RELU; 181 return *this; 182 } 183 Relu6Activation()184 inline DepthwiseConv2DTester& Relu6Activation() { 185 activation_ = ::tflite::ActivationFunctionType_RELU6; 186 return *this; 187 } 188 ReluMinus1To1Activation()189 inline DepthwiseConv2DTester& ReluMinus1To1Activation() { 190 activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1; 191 return *this; 192 } 193 TanhActivation()194 inline DepthwiseConv2DTester& TanhActivation() { 195 activation_ = ::tflite::ActivationFunctionType_TANH; 196 return *this; 197 } 198 SignBitActivation()199 inline DepthwiseConv2DTester& SignBitActivation() { 200 activation_ = ::tflite::ActivationFunctionType_SIGN_BIT; 201 return *this; 202 } 203 204 void Test(TfLiteDelegate* delegate) const; 205 206 private: 207 std::vector<char> CreateTfLiteModel() const; 208 Padding()209 inline ::tflite::Padding Padding() const { return padding_; } 210 Activation()211 inline ::tflite::ActivationFunctionType Activation() const { 212 return activation_; 213 } 214 215 int32_t batch_size_ = 1; 216 int32_t input_channels_ = 1; 217 int32_t depth_multiplier_ = 1; 218 int32_t input_height_ = 1; 219 int32_t input_width_ = 1; 220 int32_t kernel_height_ = 1; 221 int32_t kernel_width_ = 1; 222 int32_t stride_height_ = 1; 223 int32_t stride_width_ = 1; 224 int32_t dilation_height_ = 1; 225 int32_t dilation_width_ = 1; 226 bool fp16_weights_ = false; 227 bool sparse_weights_ = false; 228 ::tflite::Padding padding_ = ::tflite::Padding_VALID; 229 ::tflite::ActivationFunctionType activation_ = 230 ::tflite::ActivationFunctionType_NONE; 231 }; 232 233 } // namespace xnnpack 234 } // namespace tflite 235 236 #endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_ 237