• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
25 #define ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "tests/AssetsLibrary.h"
30 #include "tests/Globals.h"
31 #include "tests/IAccessor.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Fixture.h"
34 #include "tests/validation/Helpers.h"
35 #include "tests/validation/reference/ActivationLayer.h"
36 #include "tests/validation/reference/ArithmeticOperations.h"
37 
38 namespace arm_compute
39 {
40 namespace test
41 {
42 namespace validation
43 {
44 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45 class ArithmeticOperationGenericFixture : public framework::Fixture
46 {
47 public:
48     template <typename...>
setup(reference::ArithmeticOperation op,const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,ActivationLayerInfo act_info,bool in_place)49     void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1,
50                DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
51                QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool in_place)
52     {
53         _op        = op;
54         _act_info  = act_info;
55         _in_place  = in_place;
56         _target    = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
57         _reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
58     }
59 
60 protected:
61     template <typename U>
fill(U && tensor,int i)62     void fill(U &&tensor, int i)
63     {
64         library->fill_tensor_uniform(tensor, i);
65     }
66 
compute_target(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out)67     TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
68                               QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
69     {
70         // Create tensors
71         TensorType  ref_src1   = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0);
72         TensorType  ref_src2   = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1);
73         TensorType  dst        = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out);
74         TensorType *dst_to_use = _in_place ? &ref_src1 : &dst;
75 
76         // Create and configure function
77         FunctionType arith_op;
78         arith_op.configure(&ref_src1, &ref_src2, dst_to_use, convert_policy, _act_info);
79 
80         ARM_COMPUTE_EXPECT(ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
81         ARM_COMPUTE_EXPECT(ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
82         ARM_COMPUTE_EXPECT(dst_to_use->info()->is_resizable(), framework::LogLevel::ERRORS);
83 
84         // Allocate tensors
85         ref_src1.allocator()->allocate();
86         ref_src2.allocator()->allocate();
87         dst_to_use->allocator()->allocate();
88 
89         ARM_COMPUTE_EXPECT(!ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
90         ARM_COMPUTE_EXPECT(!ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
91         ARM_COMPUTE_EXPECT(!dst_to_use->info()->is_resizable(), framework::LogLevel::ERRORS);
92 
93         // Fill tensors
94         fill(AccessorType(ref_src1), 0);
95         fill(AccessorType(ref_src2), 1);
96 
97         // Compute function
98         arith_op.run();
99 
100         if(_in_place)
101         {
102             return ref_src1;
103         }
104         return dst;
105     }
106 
compute_reference(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out)107     SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1,
108                                       DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
109                                       QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
110     {
111         // current in-place implementation only supports same metadata of input and output tensors.
112         // By ignoring output quantization information here, we can make test cases implementation much simpler.
113         QuantizationInfo output_qinfo = _in_place ? qinfo0 : qinfo_out;
114 
115         // Create reference
116         SimpleTensor<T> ref_src1{ shape0, data_type0, 1, qinfo0 };
117         SimpleTensor<T> ref_src2{ shape1, data_type1, 1, qinfo1 };
118         SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, output_qinfo };
119 
120         // Fill reference
121         fill(ref_src1, 0);
122         fill(ref_src2, 1);
123 
124         auto result = reference::arithmetic_operation<T>(_op, ref_src1, ref_src2, ref_dst, convert_policy);
125         return _act_info.enabled() ? reference::activation_layer(result, _act_info, output_qinfo) : result;
126     }
127 
128     TensorType                     _target{};
129     SimpleTensor<T>                _reference{};
130     reference::ArithmeticOperation _op{ reference::ArithmeticOperation::ADD };
131     ActivationLayerInfo            _act_info{};
132     bool                           _in_place{};
133 };
134 
135 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
136 class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
137 {
138 public:
139     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy)140     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
141     {
142         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
143                                                                                             output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
144     }
145 };
146 
147 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
148 class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
149 {
150 public:
151     template <typename...>
setup(const TensorShape & shape,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy)152     void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
153     {
154         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
155                                                                                             output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
156     }
157 };
158 
159 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
160 class ArithmeticAdditionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
161 {
162 public:
163     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info)164     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
165     {
166         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
167                                                                                             output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
168     }
169 };
170 
171 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
172 class ArithmeticAdditionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
173 {
174 public:
175     template <typename...>
setup(const TensorShape & shape,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info)176     void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
177     {
178         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
179                                                                                             output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
180     }
181 };
182 
183 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
184 class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
185 {
186 public:
187     template <typename...>
setup(const TensorShape & shape,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out)188     void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
189                QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
190 
191     {
192         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
193                                                                                             output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
194     }
195 };
196 
197 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
198 class ArithmeticAdditionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
199 {
200 public:
201     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out)202     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
203                ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
204     {
205         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1,
206                                                                                             data_type0, data_type1, output_data_type, convert_policy,
207                                                                                             qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
208     }
209 };
210 
211 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
212 class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
213 {
214 public:
215     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,bool in_place)216     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
217     {
218         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
219                                                                                             data_type0, data_type1, output_data_type, convert_policy,
220                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
221     }
222 };
223 
224 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
225 class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
226 {
227 public:
228     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info,bool in_place)229     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info,
230                bool in_place)
231     {
232         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
233                                                                                             data_type0, data_type1, output_data_type, convert_policy,
234                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place);
235     }
236 };
237 
238 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
239 class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
240 {
241 public:
242     template <typename...>
setup(const TensorShape & shape,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,bool in_place)243     void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
244     {
245         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
246                                                                                             data_type0, data_type1, output_data_type, convert_policy,
247                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
248     }
249 };
250 
251 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
252 class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
253 {
254 public:
255     template <typename...>
setup(const TensorShape & shape,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info,bool in_place)256     void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place)
257     {
258         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
259                                                                                             data_type0, data_type1, output_data_type, convert_policy,
260                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place);
261     }
262 };
263 
264 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
265 class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
266 {
267 public:
268     template <typename...>
setup(const TensorShape & shape,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,bool in_place)269     void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
270                QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place)
271 
272     {
273         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
274                                                                                             data_type0, data_type1, output_data_type,
275                                                                                             convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place);
276     }
277 };
278 
279 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
280 class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
281 {
282 public:
283     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type0,DataType data_type1,DataType output_data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,bool in_place)284     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
285                ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place)
286     {
287         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
288                                                                                             data_type0, data_type1, output_data_type, convert_policy,
289                                                                                             qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place);
290     }
291 };
292 } // namespace validation
293 } // namespace test
294 } // namespace arm_compute
295 #endif /* ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE */
296