1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ 18 19 #include <cstdint> 20 #include <vector> 21 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/schema/schema_generated.h" 25 26 namespace tflite { 27 namespace xnnpack { 28 29 class BinaryElementwiseTester { 30 public: 31 BinaryElementwiseTester() = default; 32 BinaryElementwiseTester(const BinaryElementwiseTester&) = delete; 33 BinaryElementwiseTester& operator=(const BinaryElementwiseTester&) = delete; 34 Input1Shape(std::initializer_list<int32_t> shape)35 inline BinaryElementwiseTester& Input1Shape( 36 std::initializer_list<int32_t> shape) { 37 for (auto it = shape.begin(); it != shape.end(); ++it) { 38 EXPECT_GT(*it, 0); 39 } 40 input1_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); 41 return *this; 42 } 43 Input1Shape()44 inline const std::vector<int32_t>& Input1Shape() const { 45 return input1_shape_; 46 } 47 Input2Shape(std::initializer_list<int32_t> shape)48 inline BinaryElementwiseTester& Input2Shape( 49 std::initializer_list<int32_t> shape) { 50 for (auto it = shape.begin(); it != shape.end(); ++it) { 51 EXPECT_GT(*it, 0); 52 } 53 input2_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); 54 return *this; 55 } 56 Input2Shape()57 inline const std::vector<int32_t>& Input2Shape() const { 58 return input2_shape_; 59 } 60 61 std::vector<int32_t> OutputShape() const; 62 Input1Static(bool is_static)63 inline BinaryElementwiseTester& Input1Static(bool is_static) { 64 input1_static_ = is_static; 65 return *this; 66 } 67 Input1Static()68 inline bool Input1Static() const { return input1_static_; } 69 Input2Static(bool is_static)70 inline BinaryElementwiseTester& Input2Static(bool is_static) { 71 input2_static_ = is_static; 72 return *this; 73 } 74 Input2Static()75 inline bool Input2Static() const { return input2_static_; } 76 FP16Weights()77 inline BinaryElementwiseTester& FP16Weights() { 78 fp16_weights_ = true; 79 return *this; 80 } 81 FP16Weights()82 inline bool FP16Weights() const { return fp16_weights_; } 83 SparseWeights()84 inline BinaryElementwiseTester& SparseWeights() { 85 sparse_weights_ = true; 86 return *this; 87 } 88 SparseWeights()89 inline bool SparseWeights() const { return sparse_weights_; } 90 ReluActivation()91 inline BinaryElementwiseTester& ReluActivation() { 92 activation_ = ::tflite::ActivationFunctionType_RELU; 93 return *this; 94 } 95 Relu6Activation()96 inline BinaryElementwiseTester& Relu6Activation() { 97 activation_ = ::tflite::ActivationFunctionType_RELU6; 98 return *this; 99 } 100 ReluMinus1To1Activation()101 inline BinaryElementwiseTester& ReluMinus1To1Activation() { 102 activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1; 103 return *this; 104 } 105 TanhActivation()106 inline BinaryElementwiseTester& TanhActivation() { 107 activation_ = ::tflite::ActivationFunctionType_TANH; 108 return *this; 109 } 110 SignBitActivation()111 inline BinaryElementwiseTester& SignBitActivation() { 112 activation_ = ::tflite::ActivationFunctionType_SIGN_BIT; 113 return *this; 114 } 115 116 void Test(tflite::BuiltinOperator binary_op, TfLiteDelegate* delegate) const; 117 118 private: 119 std::vector<char> CreateTfLiteModel(tflite::BuiltinOperator binary_op) const; 120 Activation()121 inline ::tflite::ActivationFunctionType Activation() const { 122 return activation_; 123 } 124 125 static int32_t ComputeSize(const std::vector<int32_t>& shape); 126 127 std::vector<int32_t> input1_shape_; 128 std::vector<int32_t> input2_shape_; 129 bool input1_static_ = false; 130 bool input2_static_ = false; 131 bool fp16_weights_ = false; 132 bool sparse_weights_ = false; 133 ::tflite::ActivationFunctionType activation_ = 134 ::tflite::ActivationFunctionType_NONE; 135 }; 136 137 } // namespace xnnpack 138 } // namespace tflite 139 140 #endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ 141