1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_ 18 19 #include <cstdint> 20 #include <vector> 21 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/schema/schema_generated.h" 25 26 namespace tflite { 27 namespace xnnpack { 28 29 class Conv2DTester { 30 public: 31 Conv2DTester() = default; 32 Conv2DTester(const Conv2DTester&) = delete; 33 Conv2DTester& operator=(const Conv2DTester&) = delete; 34 BatchSize(int32_t batch_size)35 inline Conv2DTester& BatchSize(int32_t batch_size) { 36 EXPECT_GT(batch_size, 0); 37 batch_size_ = batch_size; 38 return *this; 39 } 40 BatchSize()41 inline int32_t BatchSize() const { return batch_size_; } 42 InputChannels(int32_t input_channels)43 inline Conv2DTester& InputChannels(int32_t input_channels) { 44 EXPECT_GT(input_channels, 0); 45 input_channels_ = input_channels; 46 return *this; 47 } 48 InputChannels()49 inline int32_t InputChannels() const { return input_channels_; } 50 OutputChannels(int32_t output_channels)51 inline Conv2DTester& OutputChannels(int32_t output_channels) { 52 EXPECT_GT(output_channels, 0); 53 output_channels_ = output_channels; 54 return *this; 55 } 56 OutputChannels()57 inline int32_t OutputChannels() const { return output_channels_; } 58 InputHeight(int32_t input_height)59 inline Conv2DTester& InputHeight(int32_t input_height) { 60 EXPECT_GT(input_height, 0); 61 input_height_ = input_height; 62 return *this; 63 } 64 InputHeight()65 inline int32_t InputHeight() const { return input_height_; } 66 InputWidth(int32_t input_width)67 inline Conv2DTester& InputWidth(int32_t input_width) { 68 EXPECT_GT(input_width, 0); 69 input_width_ = input_width; 70 return *this; 71 } 72 InputWidth()73 inline int32_t InputWidth() const { return input_width_; } 74 OutputWidth()75 inline int32_t OutputWidth() const { 76 if (Padding() == ::tflite::Padding_SAME) { 77 EXPECT_GE(InputWidth(), 1); 78 return (InputWidth() - 1) / StrideWidth() + 1; 79 } else { 80 EXPECT_GE(InputWidth(), DilatedKernelWidth()); 81 return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth(); 82 } 83 } 84 OutputHeight()85 inline int32_t OutputHeight() const { 86 if (Padding() == ::tflite::Padding_SAME) { 87 EXPECT_GE(InputHeight(), 1); 88 return (InputHeight() - 1) / StrideHeight() + 1; 89 } else { 90 EXPECT_GE(InputHeight(), DilatedKernelHeight()); 91 return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight(); 92 } 93 } 94 KernelHeight(int32_t kernel_height)95 inline Conv2DTester& KernelHeight(int32_t kernel_height) { 96 EXPECT_GT(kernel_height, 0); 97 kernel_height_ = kernel_height; 98 return *this; 99 } 100 KernelHeight()101 inline int32_t KernelHeight() const { return kernel_height_; } 102 KernelWidth(int32_t kernel_width)103 inline Conv2DTester& KernelWidth(int32_t kernel_width) { 104 EXPECT_GT(kernel_width, 0); 105 kernel_width_ = kernel_width; 106 return *this; 107 } 108 KernelWidth()109 inline int32_t KernelWidth() const { return kernel_width_; } 110 StrideHeight(int32_t stride_height)111 inline Conv2DTester& StrideHeight(int32_t stride_height) { 112 EXPECT_GT(stride_height, 0); 113 stride_height_ = stride_height; 114 return *this; 115 } 116 StrideHeight()117 inline int32_t StrideHeight() const { return stride_height_; } 118 StrideWidth(int32_t stride_width)119 inline Conv2DTester& StrideWidth(int32_t stride_width) { 120 EXPECT_GT(stride_width, 0); 121 stride_width_ = stride_width; 122 return *this; 123 } 124 StrideWidth()125 inline int32_t StrideWidth() const { return stride_width_; } 126 DilationHeight(int32_t dilation_height)127 inline Conv2DTester& DilationHeight(int32_t dilation_height) { 128 EXPECT_GT(dilation_height, 0); 129 dilation_height_ = dilation_height; 130 return *this; 131 } 132 DilationHeight()133 inline int32_t DilationHeight() const { return dilation_height_; } 134 DilationWidth(int32_t dilation_width)135 inline Conv2DTester& DilationWidth(int32_t dilation_width) { 136 EXPECT_GT(dilation_width, 0); 137 dilation_width_ = dilation_width; 138 return *this; 139 } 140 DilationWidth()141 inline int32_t DilationWidth() const { return dilation_width_; } 142 DilatedKernelHeight()143 inline int32_t DilatedKernelHeight() const { 144 return (KernelHeight() - 1) * DilationHeight() + 1; 145 } 146 DilatedKernelWidth()147 inline int32_t DilatedKernelWidth() const { 148 return (KernelWidth() - 1) * DilationWidth() + 1; 149 } 150 FP16Weights()151 inline Conv2DTester& FP16Weights() { 152 fp16_weights_ = true; 153 return *this; 154 } 155 FP16Weights()156 inline bool FP16Weights() const { return fp16_weights_; } 157 SparseWeights()158 inline Conv2DTester& SparseWeights() { 159 sparse_weights_ = true; 160 return *this; 161 } 162 SparseWeights()163 inline bool SparseWeights() const { return sparse_weights_; } 164 SamePadding()165 inline Conv2DTester& SamePadding() { 166 padding_ = ::tflite::Padding_SAME; 167 return *this; 168 } 169 ValidPadding()170 inline Conv2DTester& ValidPadding() { 171 padding_ = ::tflite::Padding_VALID; 172 return *this; 173 } 174 ReluActivation()175 inline Conv2DTester& ReluActivation() { 176 activation_ = ::tflite::ActivationFunctionType_RELU; 177 return *this; 178 } 179 Relu6Activation()180 inline Conv2DTester& Relu6Activation() { 181 activation_ = ::tflite::ActivationFunctionType_RELU6; 182 return *this; 183 } 184 ReluMinus1To1Activation()185 inline Conv2DTester& ReluMinus1To1Activation() { 186 activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1; 187 return *this; 188 } 189 TanhActivation()190 inline Conv2DTester& TanhActivation() { 191 activation_ = ::tflite::ActivationFunctionType_TANH; 192 return *this; 193 } 194 SignBitActivation()195 inline Conv2DTester& SignBitActivation() { 196 activation_ = ::tflite::ActivationFunctionType_SIGN_BIT; 197 return *this; 198 } 199 200 void Test(TfLiteDelegate* delegate) const; 201 202 private: 203 std::vector<char> CreateTfLiteModel() const; 204 Padding()205 inline ::tflite::Padding Padding() const { return padding_; } 206 Activation()207 inline ::tflite::ActivationFunctionType Activation() const { 208 return activation_; 209 } 210 211 int32_t batch_size_ = 1; 212 int32_t input_channels_ = 1; 213 int32_t output_channels_ = 1; 214 int32_t input_height_ = 1; 215 int32_t input_width_ = 1; 216 int32_t kernel_height_ = 1; 217 int32_t kernel_width_ = 1; 218 int32_t stride_height_ = 1; 219 int32_t stride_width_ = 1; 220 int32_t dilation_height_ = 1; 221 int32_t dilation_width_ = 1; 222 bool fp16_weights_ = false; 223 bool sparse_weights_ = false; 224 ::tflite::Padding padding_ = ::tflite::Padding_VALID; 225 ::tflite::ActivationFunctionType activation_ = 226 ::tflite::ActivationFunctionType_NONE; 227 }; 228 229 } // namespace xnnpack 230 } // namespace tflite 231 232 #endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_ 233