1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_ 18 19 #include <cstdint> 20 #include <vector> 21 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" 25 #include "tensorflow/lite/schema/schema_generated.h" 26 27 namespace tflite { 28 namespace xnnpack { 29 30 class Conv2DTester { 31 public: 32 Conv2DTester() = default; 33 Conv2DTester(const Conv2DTester&) = delete; 34 Conv2DTester& operator=(const Conv2DTester&) = delete; 35 BatchSize(int32_t batch_size)36 inline Conv2DTester& BatchSize(int32_t batch_size) { 37 EXPECT_GT(batch_size, 0); 38 batch_size_ = batch_size; 39 return *this; 40 } 41 BatchSize()42 inline int32_t BatchSize() const { return batch_size_; } 43 InputChannels(int32_t input_channels)44 inline Conv2DTester& InputChannels(int32_t input_channels) { 45 EXPECT_GT(input_channels, 0); 46 input_channels_ = input_channels; 47 return *this; 48 } 49 InputChannels()50 inline int32_t InputChannels() const { return input_channels_; } 51 OutputChannels(int32_t output_channels)52 inline Conv2DTester& OutputChannels(int32_t output_channels) { 53 EXPECT_GT(output_channels, 0); 54 output_channels_ = output_channels; 55 return *this; 56 } 57 OutputChannels()58 inline int32_t OutputChannels() const { return output_channels_; } 59 Groups(int32_t groups)60 inline Conv2DTester& Groups(int32_t groups) { 61 EXPECT_EQ(InputChannels() % groups, 0); 62 EXPECT_EQ(OutputChannels() % groups, 0); 63 groups_ = groups; 64 return *this; 65 } 66 Groups()67 inline int32_t Groups() const { return groups_; } 68 KernelInputChannels()69 inline int32_t KernelInputChannels() const { 70 return input_channels_ / groups_; 71 } 72 InputHeight(int32_t input_height)73 inline Conv2DTester& InputHeight(int32_t input_height) { 74 EXPECT_GT(input_height, 0); 75 input_height_ = input_height; 76 return *this; 77 } 78 InputHeight()79 inline int32_t InputHeight() const { return input_height_; } 80 InputWidth(int32_t input_width)81 inline Conv2DTester& InputWidth(int32_t input_width) { 82 EXPECT_GT(input_width, 0); 83 input_width_ = input_width; 84 return *this; 85 } 86 InputWidth()87 inline int32_t InputWidth() const { return input_width_; } 88 OutputWidth()89 inline int32_t OutputWidth() const { 90 if (Padding() == ::tflite::Padding_SAME) { 91 EXPECT_GE(InputWidth(), 1); 92 return (InputWidth() - 1) / StrideWidth() + 1; 93 } else { 94 EXPECT_GE(InputWidth(), DilatedKernelWidth()); 95 return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth(); 96 } 97 } 98 OutputHeight()99 inline int32_t OutputHeight() const { 100 if (Padding() == ::tflite::Padding_SAME) { 101 EXPECT_GE(InputHeight(), 1); 102 return (InputHeight() - 1) / StrideHeight() + 1; 103 } else { 104 EXPECT_GE(InputHeight(), DilatedKernelHeight()); 105 return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight(); 106 } 107 } 108 KernelHeight(int32_t kernel_height)109 inline Conv2DTester& KernelHeight(int32_t kernel_height) { 110 EXPECT_GT(kernel_height, 0); 111 kernel_height_ = kernel_height; 112 return *this; 113 } 114 KernelHeight()115 inline int32_t KernelHeight() const { return kernel_height_; } 116 KernelWidth(int32_t kernel_width)117 inline Conv2DTester& KernelWidth(int32_t kernel_width) { 118 EXPECT_GT(kernel_width, 0); 119 kernel_width_ = kernel_width; 120 return *this; 121 } 122 KernelWidth()123 inline int32_t KernelWidth() const { return kernel_width_; } 124 StrideHeight(int32_t stride_height)125 inline Conv2DTester& StrideHeight(int32_t stride_height) { 126 EXPECT_GT(stride_height, 0); 127 stride_height_ = stride_height; 128 return *this; 129 } 130 StrideHeight()131 inline int32_t StrideHeight() const { return stride_height_; } 132 StrideWidth(int32_t stride_width)133 inline Conv2DTester& StrideWidth(int32_t stride_width) { 134 EXPECT_GT(stride_width, 0); 135 stride_width_ = stride_width; 136 return *this; 137 } 138 StrideWidth()139 inline int32_t StrideWidth() const { return stride_width_; } 140 DilationHeight(int32_t dilation_height)141 inline Conv2DTester& DilationHeight(int32_t dilation_height) { 142 EXPECT_GT(dilation_height, 0); 143 dilation_height_ = dilation_height; 144 return *this; 145 } 146 DilationHeight()147 inline int32_t DilationHeight() const { return dilation_height_; } 148 DilationWidth(int32_t dilation_width)149 inline Conv2DTester& DilationWidth(int32_t dilation_width) { 150 EXPECT_GT(dilation_width, 0); 151 dilation_width_ = dilation_width; 152 return *this; 153 } 154 DilationWidth()155 inline int32_t DilationWidth() const { return dilation_width_; } 156 DilatedKernelHeight()157 inline int32_t DilatedKernelHeight() const { 158 return (KernelHeight() - 1) * DilationHeight() + 1; 159 } 160 DilatedKernelWidth()161 inline int32_t DilatedKernelWidth() const { 162 return (KernelWidth() - 1) * DilationWidth() + 1; 163 } 164 FP16Weights()165 inline Conv2DTester& FP16Weights() { 166 fp16_weights_ = true; 167 return *this; 168 } 169 FP16Weights()170 inline bool FP16Weights() const { return fp16_weights_; } 171 INT8Weights()172 inline Conv2DTester& INT8Weights() { 173 int8_weights_ = true; 174 return *this; 175 } 176 INT8Weights()177 inline bool INT8Weights() const { return int8_weights_; } 178 INT8ChannelWiseWeights()179 inline Conv2DTester& INT8ChannelWiseWeights() { 180 int8_channel_wise_weights_ = true; 181 return *this; 182 } 183 INT8ChannelWiseWeights()184 inline bool INT8ChannelWiseWeights() const { 185 return int8_channel_wise_weights_; 186 } 187 SparseWeights()188 inline Conv2DTester& SparseWeights() { 189 sparse_weights_ = true; 190 return *this; 191 } 192 SparseWeights()193 inline bool SparseWeights() const { return sparse_weights_; } 194 SamePadding()195 inline Conv2DTester& SamePadding() { 196 padding_ = ::tflite::Padding_SAME; 197 return *this; 198 } 199 ValidPadding()200 inline Conv2DTester& ValidPadding() { 201 padding_ = ::tflite::Padding_VALID; 202 return *this; 203 } 204 ReluActivation()205 inline Conv2DTester& ReluActivation() { 206 activation_ = ::tflite::ActivationFunctionType_RELU; 207 return *this; 208 } 209 Relu6Activation()210 inline Conv2DTester& Relu6Activation() { 211 activation_ = ::tflite::ActivationFunctionType_RELU6; 212 return *this; 213 } 214 ReluMinus1To1Activation()215 inline Conv2DTester& ReluMinus1To1Activation() { 216 activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1; 217 return *this; 218 } 219 TanhActivation()220 inline Conv2DTester& TanhActivation() { 221 activation_ = ::tflite::ActivationFunctionType_TANH; 222 return *this; 223 } 224 SignBitActivation()225 inline Conv2DTester& SignBitActivation() { 226 activation_ = ::tflite::ActivationFunctionType_SIGN_BIT; 227 return *this; 228 } 229 WeightsCache(TfLiteXNNPackDelegateWeightsCache * weights_cache)230 inline Conv2DTester& WeightsCache( 231 TfLiteXNNPackDelegateWeightsCache* weights_cache) { 232 weights_cache_ = weights_cache; 233 return *this; 234 } 235 236 void Test(TfLiteDelegate* delegate) const; 237 238 std::vector<char> CreateTfLiteModel() const; 239 240 private: Padding()241 inline ::tflite::Padding Padding() const { return padding_; } 242 Activation()243 inline ::tflite::ActivationFunctionType Activation() const { 244 return activation_; 245 } 246 247 int32_t batch_size_ = 1; 248 int32_t input_channels_ = 1; 249 int32_t output_channels_ = 1; 250 int32_t groups_ = 1; 251 int32_t input_height_ = 1; 252 int32_t input_width_ = 1; 253 int32_t kernel_height_ = 1; 254 int32_t kernel_width_ = 1; 255 int32_t stride_height_ = 1; 256 int32_t stride_width_ = 1; 257 int32_t dilation_height_ = 1; 258 int32_t dilation_width_ = 1; 259 bool fp16_weights_ = false; 260 bool int8_weights_ = false; 261 bool int8_channel_wise_weights_ = false; 262 bool sparse_weights_ = false; 263 ::tflite::Padding padding_ = ::tflite::Padding_VALID; 264 ::tflite::ActivationFunctionType activation_ = 265 ::tflite::ActivationFunctionType_NONE; 266 TfLiteXNNPackDelegateWeightsCache* weights_cache_ = nullptr; 267 }; 268 269 } // namespace xnnpack 270 } // namespace tflite 271 272 #endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_CONV_2D_TESTER_H_ 273