• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_
17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_
18 
19 #include <cstdint>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 namespace xnnpack {
28 
29 class BinaryElementwiseTester {
30  public:
31   BinaryElementwiseTester() = default;
32   BinaryElementwiseTester(const BinaryElementwiseTester&) = delete;
33   BinaryElementwiseTester& operator=(const BinaryElementwiseTester&) = delete;
34 
Input1Shape(std::initializer_list<int32_t> shape)35   inline BinaryElementwiseTester& Input1Shape(
36       std::initializer_list<int32_t> shape) {
37     for (auto it = shape.begin(); it != shape.end(); ++it) {
38       EXPECT_GT(*it, 0);
39     }
40     input1_shape_ = std::vector<int32_t>(shape.begin(), shape.end());
41     return *this;
42   }
43 
Input1Shape()44   inline const std::vector<int32_t>& Input1Shape() const {
45     return input1_shape_;
46   }
47 
Input2Shape(std::initializer_list<int32_t> shape)48   inline BinaryElementwiseTester& Input2Shape(
49       std::initializer_list<int32_t> shape) {
50     for (auto it = shape.begin(); it != shape.end(); ++it) {
51       EXPECT_GT(*it, 0);
52     }
53     input2_shape_ = std::vector<int32_t>(shape.begin(), shape.end());
54     return *this;
55   }
56 
Input2Shape()57   inline const std::vector<int32_t>& Input2Shape() const {
58     return input2_shape_;
59   }
60 
61   std::vector<int32_t> OutputShape() const;
62 
Input1Static(bool is_static)63   inline BinaryElementwiseTester& Input1Static(bool is_static) {
64     input1_static_ = is_static;
65     return *this;
66   }
67 
Input1Static()68   inline bool Input1Static() const { return input1_static_; }
69 
Input2Static(bool is_static)70   inline BinaryElementwiseTester& Input2Static(bool is_static) {
71     input2_static_ = is_static;
72     return *this;
73   }
74 
Input2Static()75   inline bool Input2Static() const { return input2_static_; }
76 
FP16Weights()77   inline BinaryElementwiseTester& FP16Weights() {
78     fp16_weights_ = true;
79     return *this;
80   }
81 
FP16Weights()82   inline bool FP16Weights() const { return fp16_weights_; }
83 
INT8Weights()84   inline BinaryElementwiseTester& INT8Weights() {
85     int8_weights_ = true;
86     return *this;
87   }
88 
INT8Weights()89   inline bool INT8Weights() const { return int8_weights_; }
90 
INT8ChannelWiseWeights()91   inline BinaryElementwiseTester& INT8ChannelWiseWeights() {
92     int8_channel_wise_weights_ = true;
93     return *this;
94   }
95 
INT8ChannelWiseWeights()96   inline bool INT8ChannelWiseWeights() const {
97     return int8_channel_wise_weights_;
98   }
99 
SparseWeights()100   inline BinaryElementwiseTester& SparseWeights() {
101     sparse_weights_ = true;
102     return *this;
103   }
104 
SparseWeights()105   inline bool SparseWeights() const { return sparse_weights_; }
106 
ReluActivation()107   inline BinaryElementwiseTester& ReluActivation() {
108     activation_ = ::tflite::ActivationFunctionType_RELU;
109     return *this;
110   }
111 
Relu6Activation()112   inline BinaryElementwiseTester& Relu6Activation() {
113     activation_ = ::tflite::ActivationFunctionType_RELU6;
114     return *this;
115   }
116 
ReluMinus1To1Activation()117   inline BinaryElementwiseTester& ReluMinus1To1Activation() {
118     activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
119     return *this;
120   }
121 
TanhActivation()122   inline BinaryElementwiseTester& TanhActivation() {
123     activation_ = ::tflite::ActivationFunctionType_TANH;
124     return *this;
125   }
126 
SignBitActivation()127   inline BinaryElementwiseTester& SignBitActivation() {
128     activation_ = ::tflite::ActivationFunctionType_SIGN_BIT;
129     return *this;
130   }
131 
132   void Test(tflite::BuiltinOperator binary_op, TfLiteDelegate* delegate) const;
133 
134  private:
135   std::vector<char> CreateTfLiteModel(tflite::BuiltinOperator binary_op) const;
136 
Activation()137   inline ::tflite::ActivationFunctionType Activation() const {
138     return activation_;
139   }
140 
141   static int32_t ComputeSize(const std::vector<int32_t>& shape);
142 
143   std::vector<int32_t> input1_shape_;
144   std::vector<int32_t> input2_shape_;
145   bool input1_static_ = false;
146   bool input2_static_ = false;
147   bool fp16_weights_ = false;
148   bool int8_weights_ = false;
149   bool int8_channel_wise_weights_ = false;
150   bool sparse_weights_ = false;
151   ::tflite::ActivationFunctionType activation_ =
152       ::tflite::ActivationFunctionType_NONE;
153 };
154 
155 }  // namespace xnnpack
156 }  // namespace tflite
157 
158 #endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_
159