1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ 18 19 #include <cstdint> 20 #include <vector> 21 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/schema/schema_generated.h" 25 26 namespace tflite { 27 namespace xnnpack { 28 29 class BinaryElementwiseTester { 30 public: 31 BinaryElementwiseTester() = default; 32 BinaryElementwiseTester(const BinaryElementwiseTester&) = delete; 33 BinaryElementwiseTester& operator=(const BinaryElementwiseTester&) = delete; 34 Input1Shape(std::initializer_list<int32_t> shape)35 inline BinaryElementwiseTester& Input1Shape( 36 std::initializer_list<int32_t> shape) { 37 for (auto it = shape.begin(); it != shape.end(); ++it) { 38 EXPECT_GT(*it, 0); 39 } 40 input1_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); 41 return *this; 42 } 43 Input1Shape()44 inline const std::vector<int32_t>& Input1Shape() const { 45 return input1_shape_; 46 } 47 Input2Shape(std::initializer_list<int32_t> shape)48 inline BinaryElementwiseTester& Input2Shape( 49 std::initializer_list<int32_t> shape) { 50 for (auto it = shape.begin(); it != shape.end(); ++it) { 51 EXPECT_GT(*it, 0); 52 } 53 input2_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); 54 return *this; 55 } 56 Input2Shape()57 inline const std::vector<int32_t>& Input2Shape() const { 58 return input2_shape_; 59 } 60 61 std::vector<int32_t> OutputShape() const; 62 Input1Static(bool is_static)63 inline BinaryElementwiseTester& Input1Static(bool is_static) { 64 input1_static_ = is_static; 65 return *this; 66 } 67 Input1Static()68 inline bool Input1Static() const { return input1_static_; } 69 Input2Static(bool is_static)70 inline BinaryElementwiseTester& Input2Static(bool is_static) { 71 input2_static_ = is_static; 72 return *this; 73 } 74 Input2Static()75 inline bool Input2Static() const { return input2_static_; } 76 FP16Weights()77 inline BinaryElementwiseTester& FP16Weights() { 78 fp16_weights_ = true; 79 return *this; 80 } 81 FP16Weights()82 inline bool FP16Weights() const { return fp16_weights_; } 83 INT8Weights()84 inline BinaryElementwiseTester& INT8Weights() { 85 int8_weights_ = true; 86 return *this; 87 } 88 INT8Weights()89 inline bool INT8Weights() const { return int8_weights_; } 90 INT8ChannelWiseWeights()91 inline BinaryElementwiseTester& INT8ChannelWiseWeights() { 92 int8_channel_wise_weights_ = true; 93 return *this; 94 } 95 INT8ChannelWiseWeights()96 inline bool INT8ChannelWiseWeights() const { 97 return int8_channel_wise_weights_; 98 } 99 SparseWeights()100 inline BinaryElementwiseTester& SparseWeights() { 101 sparse_weights_ = true; 102 return *this; 103 } 104 SparseWeights()105 inline bool SparseWeights() const { return sparse_weights_; } 106 ReluActivation()107 inline BinaryElementwiseTester& ReluActivation() { 108 activation_ = ::tflite::ActivationFunctionType_RELU; 109 return *this; 110 } 111 Relu6Activation()112 inline BinaryElementwiseTester& Relu6Activation() { 113 activation_ = ::tflite::ActivationFunctionType_RELU6; 114 return *this; 115 } 116 ReluMinus1To1Activation()117 inline BinaryElementwiseTester& ReluMinus1To1Activation() { 118 activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1; 119 return *this; 120 } 121 TanhActivation()122 inline BinaryElementwiseTester& TanhActivation() { 123 activation_ = ::tflite::ActivationFunctionType_TANH; 124 return *this; 125 } 126 SignBitActivation()127 inline BinaryElementwiseTester& SignBitActivation() { 128 activation_ = ::tflite::ActivationFunctionType_SIGN_BIT; 129 return *this; 130 } 131 132 void Test(tflite::BuiltinOperator binary_op, TfLiteDelegate* delegate) const; 133 134 private: 135 std::vector<char> CreateTfLiteModel(tflite::BuiltinOperator binary_op) const; 136 Activation()137 inline ::tflite::ActivationFunctionType Activation() const { 138 return activation_; 139 } 140 141 static int32_t ComputeSize(const std::vector<int32_t>& shape); 142 143 std::vector<int32_t> input1_shape_; 144 std::vector<int32_t> input2_shape_; 145 bool input1_static_ = false; 146 bool input2_static_ = false; 147 bool fp16_weights_ = false; 148 bool int8_weights_ = false; 149 bool int8_channel_wise_weights_ = false; 150 bool sparse_weights_ = false; 151 ::tflite::ActivationFunctionType activation_ = 152 ::tflite::ActivationFunctionType_NONE; 153 }; 154 155 } // namespace xnnpack 156 } // namespace tflite 157 158 #endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ 159