1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_FULLY_CONNECTED_TESTER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_FULLY_CONNECTED_TESTER_H_ 18 19 #include <cstdint> 20 #include <vector> 21 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" 25 #include "tensorflow/lite/interpreter.h" 26 #include "tensorflow/lite/schema/schema_generated.h" 27 28 namespace tflite { 29 namespace xnnpack { 30 31 class QuantizedFullyConnectedTester { 32 public: 33 QuantizedFullyConnectedTester() = default; 34 QuantizedFullyConnectedTester(const QuantizedFullyConnectedTester&) = delete; 35 QuantizedFullyConnectedTester& operator=( 36 const QuantizedFullyConnectedTester&) = delete; 37 InputShape(std::initializer_list<int32_t> shape)38 inline QuantizedFullyConnectedTester& InputShape( 39 std::initializer_list<int32_t> shape) { 40 for (auto it = shape.begin(); it != shape.end(); ++it) { 41 EXPECT_GT(*it, 0); 42 } 43 input_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); 44 input_size_ = ComputeSize(input_shape_); 45 return *this; 46 } 47 InputShape()48 inline const std::vector<int32_t>& InputShape() const { return input_shape_; } 49 InputSize()50 inline int32_t InputSize() const { return input_size_; } 51 InputChannels(int32_t input_channels)52 inline QuantizedFullyConnectedTester& InputChannels(int32_t input_channels) { 53 EXPECT_GT(input_channels, 0); 54 input_channels_ = input_channels; 55 return *this; 56 } 57 InputChannels()58 inline int32_t InputChannels() const { return input_channels_; } 59 OutputChannels(int32_t output_channels)60 inline QuantizedFullyConnectedTester& OutputChannels( 61 int32_t output_channels) { 62 EXPECT_GT(output_channels, 0); 63 output_channels_ = output_channels; 64 return *this; 65 } 66 OutputChannels()67 inline int32_t OutputChannels() const { return output_channels_; } 68 69 std::vector<int32_t> OutputShape() const; 70 InputZeroPoint(int32_t input_zero_point)71 inline QuantizedFullyConnectedTester& InputZeroPoint( 72 int32_t input_zero_point) { 73 input_zero_point_ = input_zero_point; 74 return *this; 75 } 76 InputZeroPoint()77 inline int32_t InputZeroPoint() const { return input_zero_point_; } 78 FilterZeroPoint(int32_t filter_zero_point)79 inline QuantizedFullyConnectedTester& FilterZeroPoint( 80 int32_t filter_zero_point) { 81 filter_zero_point_ = filter_zero_point; 82 return *this; 83 } 84 FilterZeroPoint()85 inline int32_t FilterZeroPoint() const { return filter_zero_point_; } 86 OutputZeroPoint(int32_t output_zero_point)87 inline QuantizedFullyConnectedTester& OutputZeroPoint( 88 int32_t output_zero_point) { 89 output_zero_point_ = output_zero_point; 90 return *this; 91 } 92 OutputZeroPoint()93 inline int32_t OutputZeroPoint() const { return output_zero_point_; } 94 InputScale(float input_scale)95 inline QuantizedFullyConnectedTester& InputScale(float input_scale) { 96 input_scale_ = input_scale; 97 return *this; 98 } 99 InputScale()100 inline float InputScale() const { return input_scale_; } 101 FilterScale(float filter_scale)102 inline QuantizedFullyConnectedTester& FilterScale(float filter_scale) { 103 filter_scale_ = filter_scale; 104 return *this; 105 } 106 FilterScale()107 inline float FilterScale() const { return filter_scale_; } 108 OutputScale(float output_scale)109 inline QuantizedFullyConnectedTester& OutputScale(float output_scale) { 110 output_scale_ = output_scale; 111 return *this; 112 } 113 OutputScale()114 inline float OutputScale() const { return output_scale_; } 115 KeepDims(bool keep_dims)116 inline QuantizedFullyConnectedTester& KeepDims(bool keep_dims) { 117 keep_dims_ = keep_dims; 118 return *this; 119 } 120 KeepDims()121 inline bool KeepDims() const { return keep_dims_; } 122 Unsigned()123 inline bool Unsigned() const { return filter_zero_point_ != 0; } 124 NoBias()125 inline QuantizedFullyConnectedTester& NoBias() { 126 has_bias_ = false; 127 return *this; 128 } 129 WithBias()130 inline QuantizedFullyConnectedTester& WithBias() { 131 has_bias_ = true; 132 return *this; 133 } 134 ReluActivation()135 inline QuantizedFullyConnectedTester& ReluActivation() { 136 activation_ = ::tflite::ActivationFunctionType_RELU; 137 return *this; 138 } 139 Relu6Activation()140 inline QuantizedFullyConnectedTester& Relu6Activation() { 141 activation_ = ::tflite::ActivationFunctionType_RELU6; 142 return *this; 143 } 144 ReluMinus1To1Activation()145 inline QuantizedFullyConnectedTester& ReluMinus1To1Activation() { 146 activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1; 147 return *this; 148 } 149 WeightsCache(TfLiteXNNPackDelegateWeightsCache * weights_cache)150 inline QuantizedFullyConnectedTester& WeightsCache( 151 TfLiteXNNPackDelegateWeightsCache* weights_cache) { 152 weights_cache_ = weights_cache; 153 return *this; 154 } 155 156 template <class T> 157 void Test(Interpreter* delegate_interpreter, 158 Interpreter* default_interpreter) const; 159 160 void Test(TfLiteDelegate* delegate) const; 161 162 private: 163 std::vector<char> CreateTfLiteModel() const; 164 HasBias()165 inline bool HasBias() const { return has_bias_; } 166 Activation()167 inline ::tflite::ActivationFunctionType Activation() const { 168 return activation_; 169 } 170 171 static int32_t ComputeSize(const std::vector<int32_t>& shape); 172 173 std::vector<int32_t> input_shape_; 174 int32_t input_size_ = 1; 175 int32_t input_channels_ = 1; 176 int32_t output_channels_ = 1; 177 int32_t input_zero_point_ = 0; 178 int32_t filter_zero_point_ = 0; 179 int32_t output_zero_point_ = 0; 180 float input_scale_ = 0.8f; 181 float filter_scale_ = 0.75f; 182 float output_scale_ = 1.5f; 183 bool keep_dims_ = false; 184 bool has_bias_ = true; 185 ::tflite::ActivationFunctionType activation_ = 186 ::tflite::ActivationFunctionType_NONE; 187 TfLiteXNNPackDelegateWeightsCache* weights_cache_ = nullptr; 188 }; 189 190 } // namespace xnnpack 191 } // namespace tflite 192 193 #endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_FULLY_CONNECTED_TESTER_H_ 194