• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_
17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_
18 
19 #include <cstdint>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace xnnpack {
28 
29 class BinaryElementwiseTester {
30  public:
31   BinaryElementwiseTester() = default;
32   BinaryElementwiseTester(const BinaryElementwiseTester&) = delete;
33   BinaryElementwiseTester& operator=(const BinaryElementwiseTester&) = delete;
34 
Input1Shape(std::initializer_list<int32_t> shape)35   inline BinaryElementwiseTester& Input1Shape(
36       std::initializer_list<int32_t> shape) {
37     for (auto it = shape.begin(); it != shape.end(); ++it) {
38       EXPECT_GT(*it, 0);
39     }
40     input1_shape_ = std::vector<int32_t>(shape.begin(), shape.end());
41     return *this;
42   }
43 
Input1Shape()44   inline const std::vector<int32_t>& Input1Shape() const {
45     return input1_shape_;
46   }
47 
Input2Shape(std::initializer_list<int32_t> shape)48   inline BinaryElementwiseTester& Input2Shape(
49       std::initializer_list<int32_t> shape) {
50     for (auto it = shape.begin(); it != shape.end(); ++it) {
51       EXPECT_GT(*it, 0);
52     }
53     input2_shape_ = std::vector<int32_t>(shape.begin(), shape.end());
54     return *this;
55   }
56 
Input2Shape()57   inline const std::vector<int32_t>& Input2Shape() const {
58     return input2_shape_;
59   }
60 
61   std::vector<int32_t> OutputShape() const;
62 
Input1Static(bool is_static)63   inline BinaryElementwiseTester& Input1Static(bool is_static) {
64     input1_static_ = is_static;
65     return *this;
66   }
67 
Input1Static()68   inline bool Input1Static() const { return input1_static_; }
69 
Input2Static(bool is_static)70   inline BinaryElementwiseTester& Input2Static(bool is_static) {
71     input2_static_ = is_static;
72     return *this;
73   }
74 
Input2Static()75   inline bool Input2Static() const { return input2_static_; }
76 
FP16Weights()77   inline BinaryElementwiseTester& FP16Weights() {
78     fp16_weights_ = true;
79     return *this;
80   }
81 
FP16Weights()82   inline bool FP16Weights() const { return fp16_weights_; }
83 
SparseWeights()84   inline BinaryElementwiseTester& SparseWeights() {
85     sparse_weights_ = true;
86     return *this;
87   }
88 
SparseWeights()89   inline bool SparseWeights() const { return sparse_weights_; }
90 
ReluActivation()91   inline BinaryElementwiseTester& ReluActivation() {
92     activation_ = ::tflite::ActivationFunctionType_RELU;
93     return *this;
94   }
95 
Relu6Activation()96   inline BinaryElementwiseTester& Relu6Activation() {
97     activation_ = ::tflite::ActivationFunctionType_RELU6;
98     return *this;
99   }
100 
ReluMinus1To1Activation()101   inline BinaryElementwiseTester& ReluMinus1To1Activation() {
102     activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
103     return *this;
104   }
105 
TanhActivation()106   inline BinaryElementwiseTester& TanhActivation() {
107     activation_ = ::tflite::ActivationFunctionType_TANH;
108     return *this;
109   }
110 
SignBitActivation()111   inline BinaryElementwiseTester& SignBitActivation() {
112     activation_ = ::tflite::ActivationFunctionType_SIGN_BIT;
113     return *this;
114   }
115 
116   void Test(tflite::BuiltinOperator binary_op, TfLiteDelegate* delegate) const;
117 
118  private:
119   std::vector<char> CreateTfLiteModel(tflite::BuiltinOperator binary_op) const;
120 
Activation()121   inline ::tflite::ActivationFunctionType Activation() const {
122     return activation_;
123   }
124 
125   static int32_t ComputeSize(const std::vector<int32_t>& shape);
126 
127   std::vector<int32_t> input1_shape_;
128   std::vector<int32_t> input2_shape_;
129   bool input1_static_ = false;
130   bool input2_static_ = false;
131   bool fp16_weights_ = false;
132   bool sparse_weights_ = false;
133   ::tflite::ActivationFunctionType activation_ =
134       ::tflite::ActivationFunctionType_NONE;
135 };
136 
137 }  // namespace xnnpack
138 }  // namespace tflite
139 
140 #endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_
141