• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h"
26 #include "arm_compute/runtime/Tensor.h"
27 #include "arm_compute/runtime/TensorAllocator.h"
28 #include "tests/NEON/Accessor.h"
29 #include "tests/PaddingCalculator.h"
30 #include "tests/datasets/ConvertPolicyDataset.h"
31 #include "tests/datasets/ShapeDatasets.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Macros.h"
34 #include "tests/framework/datasets/Datasets.h"
35 #include "tests/validation/Validation.h"
36 #include "tests/validation/fixtures/ArithmeticOperationsFixture.h"
37 
38 namespace arm_compute
39 {
40 namespace test
41 {
42 namespace validation
43 {
44 namespace
45 {
46 #ifdef __aarch64__
47 constexpr AbsoluteTolerance<float> tolerance_qasymm8(0);   /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
48 #else                                                      //__aarch64__
49 constexpr AbsoluteTolerance<float> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
50 #endif                                                     //__aarch64__
51 constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
52 
53 /** Input data sets **/
54 const auto ArithmeticSubtractionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8),
55                                                                  framework::dataset::make("DataType", DataType::QASYMM8)),
56                                                          framework::dataset::make("DataType", DataType::QASYMM8));
57 
58 const auto ArithmeticSubtractionQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED),
59                                                                        framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
60                                                                framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
61 
62 const auto ArithmeticSubtractionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16),
63                                                                  framework::dataset::make("DataType", DataType::QSYMM16)),
64                                                          framework::dataset::make("DataType", DataType::QSYMM16));
65 
66 const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8),
67                                                             framework::dataset::make("DataType", DataType::U8)),
68                                                     framework::dataset::make("DataType", DataType::U8));
69 
70 const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }),
71                                                              framework::dataset::make("DataType", DataType::S16)),
72                                                      framework::dataset::make("DataType", DataType::S16));
73 
74 const auto ArithmeticSubtractionS32Dataset = combine(combine(framework::dataset::make("DataType", DataType::S32),
75                                                              framework::dataset::make("DataType", DataType::S32)),
76                                                      framework::dataset::make("DataType", DataType::S32));
77 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
78 const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16),
79                                                               framework::dataset::make("DataType", DataType::F16)),
80                                                       framework::dataset::make("DataType", DataType::F16));
81 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
82 const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32),
83                                                               framework::dataset::make("DataType", DataType::F32)),
84                                                       framework::dataset::make("DataType", DataType::F32));
85 
86 const auto ArithmeticSubtractionQuantizationInfoDataset = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(10, 120) }),
87                                                                           framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(20, 110) })),
88                                                                   framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(15, 125) }));
89 const auto ArithmeticSubtractionQuantizationInfoSignedDataset = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(0.5f, 10) }),
90                                                                                 framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(0.5f, 20) })),
91                                                                         framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(0.5f, 50) }));
92 const auto ArithmeticSubtractionQuantizationInfoSymmetric = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(0.3f, 0) }),
93                                                                             framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(0.7f, 0) })),
94                                                                     framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(0.2f, 0) }));
95 const auto InPlaceDataSet    = framework::dataset::make("InPlace", { false, true });
96 const auto OutOfPlaceDataSet = framework::dataset::make("InPlace", { false });
97 } // namespace
98 
99 TEST_SUITE(NEON)
100 TEST_SUITE(ArithmeticSubtraction)
101 
102 template <typename T>
103 using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
104 
105 // *INDENT-OFF*
106 // clang-format off
107 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
108         framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
109                                                  TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
110                                                  TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),      // Invalid data type combination
111                                                  TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),     // Mismatching shapes
112                                                  TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::QASYMM8), // Mismatching types
113                                                  TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Invalid convert policy
114         }),
115         framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
116                                                 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
117                                                 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
118                                                 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
119                                                 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
120                                                 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
121         })),
122         framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
123                                                 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
124                                                 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
125                                                 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
126                                                 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
127                                                 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
128         })),
129         framework::dataset::make("ConvertPolicy",{ ConvertPolicy::WRAP,
130                                                 ConvertPolicy::SATURATE,
131                                                 ConvertPolicy::SATURATE,
132                                                 ConvertPolicy::WRAP,
133                                                 ConvertPolicy::WRAP,
134                                                 ConvertPolicy::WRAP,
135         })),
136         framework::dataset::make("Expected", { true, true, false, false, false, false})),
137         input1_info, input2_info, output_info, policy, expected)
138 {
139     ARM_COMPUTE_EXPECT(bool(NEArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), policy)) == expected, framework::LogLevel::ERRORS);
140 }
141 // clang-format on
142 // *INDENT-ON*
143 
144 TEST_SUITE(InPlaceValidate)
TEST_CASE(SingleTensor,framework::DatasetMode::ALL)145 TEST_CASE(SingleTensor, framework::DatasetMode::ALL)
146 {
147     const auto random_shape       = TensorShape{ 9, 9 };
148     const auto single_tensor_info = TensorInfo{ random_shape, 1, DataType::F32 };
149 
150     Status result = NEArithmeticSubtraction::validate(&single_tensor_info, &single_tensor_info, &single_tensor_info, ConvertPolicy::WRAP);
151     ARM_COMPUTE_EXPECT(bool(result) == true, framework::LogLevel::ERRORS);
152 }
153 
TEST_CASE(ValidBroadCast,framework::DatasetMode::ALL)154 TEST_CASE(ValidBroadCast, framework::DatasetMode::ALL)
155 {
156     const auto larger_shape  = TensorShape{ 27U, 13U, 2U };
157     const auto smaller_shape = TensorShape{ 1U, 13U, 2U };
158 
159     const auto larger_tensor_info  = TensorInfo{ larger_shape, 1, DataType::F32 };
160     const auto smaller_tensor_info = TensorInfo{ smaller_shape, 1, DataType::F32 };
161 
162     Status result = NEArithmeticSubtraction::validate(&larger_tensor_info, &smaller_tensor_info, &larger_tensor_info, ConvertPolicy::WRAP);
163     ARM_COMPUTE_EXPECT(bool(result) == true, framework::LogLevel::ERRORS);
164 }
165 
TEST_CASE(InvalidBroadcastOutput,framework::DatasetMode::ALL)166 TEST_CASE(InvalidBroadcastOutput, framework::DatasetMode::ALL)
167 {
168     const auto larger_shape  = TensorShape{ 27U, 13U, 2U };
169     const auto smaller_shape = TensorShape{ 1U, 13U, 2U };
170 
171     const auto larger_tensor_info  = TensorInfo{ larger_shape, 1, DataType::F32 };
172     const auto smaller_tensor_info = TensorInfo{ smaller_shape, 1, DataType::F32 };
173 
174     Status result = NEArithmeticSubtraction::validate(&larger_tensor_info, &smaller_tensor_info, &smaller_tensor_info, ConvertPolicy::WRAP);
175     ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
176 }
177 
TEST_CASE(InvalidBroadcastBoth,framework::DatasetMode::ALL)178 TEST_CASE(InvalidBroadcastBoth, framework::DatasetMode::ALL)
179 {
180     const auto shape0 = TensorShape{ 9U, 9U };
181     const auto shape1 = TensorShape{ 9U, 1U, 2U };
182 
183     const auto info0 = TensorInfo{ shape0, 1, DataType::F32 };
184     const auto info1 = TensorInfo{ shape1, 1, DataType::F32 };
185 
186     Status result{};
187 
188     result = NEArithmeticSubtraction::validate(&info0, &info1, &info0, ConvertPolicy::WRAP);
189     ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
190 
191     result = NEArithmeticSubtraction::validate(&info0, &info1, &info1, ConvertPolicy::WRAP);
192     ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
193 }
194 TEST_SUITE_END() // InPlaceValidate
195 
TEST_SUITE(U8)196 TEST_SUITE(U8)
197 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset),
198                                                                                                                      framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
199                                                                                                                      OutOfPlaceDataSet))
200 {
201     // Validate output
202     validate(Accessor(_target), _reference);
203 }
204 TEST_SUITE_END() // U8
205 
206 using NEArithmeticSubtractionQASYMM8Fixture                = ArithmeticSubtractionValidationQuantizedFixture<Tensor, Accessor, NEArithmeticSubtraction, uint8_t>;
207 using NEArithmeticSubtractionQASYMM8SignedFixture          = ArithmeticSubtractionValidationQuantizedFixture<Tensor, Accessor, NEArithmeticSubtraction, int8_t>;
208 using NEArithmeticSubtractionQASYMM8SignedBroadcastFixture = ArithmeticSubtractionValidationQuantizedBroadcastFixture<Tensor, Accessor, NEArithmeticSubtraction, int8_t>;
209 using NEArithmeticSubtractionQSYMM16Fixture                = ArithmeticSubtractionValidationQuantizedFixture<Tensor, Accessor, NEArithmeticSubtraction, int16_t>;
210 
211 TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)212 TEST_SUITE(QASYMM8)
213 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQASYMM8Dataset),
214                                                                                                                      framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
215                                                                                                                      ArithmeticSubtractionQuantizationInfoDataset),
216                                                                                                              InPlaceDataSet))
217 {
218     // Validate output
219     validate(Accessor(_target), _reference, tolerance_qasymm8);
220 }
221 TEST_SUITE_END() // QASYMM8
222 
TEST_SUITE(QASYMM8_SIGNED)223 TEST_SUITE(QASYMM8_SIGNED)
224 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8SignedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
225                                                                                                                        datasets::SmallShapes(),
226                                                                                                                        ArithmeticSubtractionQASYMM8SIGNEDDataset),
227                                                                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
228                                                                                                                    ArithmeticSubtractionQuantizationInfoSignedDataset),
229                                                                                                                    InPlaceDataSet))
230 {
231     // Validate output
232     validate(Accessor(_target), _reference, tolerance_qasymm8);
233 }
234 
235 FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionQASYMM8SignedBroadcastFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
236                            datasets::SmallShapesBroadcast(),
237                            ArithmeticSubtractionQASYMM8SIGNEDDataset),
238                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
239                        ArithmeticSubtractionQuantizationInfoSignedDataset),
240                        OutOfPlaceDataSet))
241 {
242     // Validate output
243     validate(Accessor(_target), _reference, tolerance_qasymm8);
244 }
245 TEST_SUITE_END() // QASYMM8_SIGNED
246 
TEST_SUITE(QSYMM16)247 TEST_SUITE(QSYMM16)
248 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQSYMM16Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
249         datasets::SmallShapes(),
250         ArithmeticSubtractionQSYMM16Dataset),
251                                                                                                                      framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
252                                                                                                                      ArithmeticSubtractionQuantizationInfoSymmetric),
253                                                                                                              OutOfPlaceDataSet))
254 {
255     // Validate output
256     validate(Accessor(_target), _reference, tolerance_qsymm16);
257 }
258 TEST_SUITE_END() // QSYMM16
TEST_SUITE_END()259 TEST_SUITE_END() // Quantized
260 
261 TEST_SUITE(S16)
262 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
263                                                                                                                      framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
264                                                                                                                      OutOfPlaceDataSet))
265 {
266     // Validate output
267     validate(Accessor(_target), _reference);
268 }
269 
270 FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
271                                                                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
272                                                                                                                    OutOfPlaceDataSet))
273 {
274     // Validate output
275     validate(Accessor(_target), _reference);
276 }
277 TEST_SUITE_END() // S16
278 
TEST_SUITE(S32)279 TEST_SUITE(S32)
280 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS32Dataset),
281                                                                                                                      framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
282                                                                                                                      OutOfPlaceDataSet))
283 {
284     // Validate output
285     validate(Accessor(_target), _reference);
286 }
287 
288 FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS32Dataset),
289                                                                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
290                                                                                                                    OutOfPlaceDataSet))
291 {
292     // Validate output
293     validate(Accessor(_target), _reference);
294 }
295 TEST_SUITE_END() // S32
296 
TEST_SUITE(Float)297 TEST_SUITE(Float)
298 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
299 TEST_SUITE(F16)
300 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP16Dataset),
301                                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
302                                                                                                             OutOfPlaceDataSet))
303 {
304     // Validate output
305     validate(Accessor(_target), _reference);
306 }
307 TEST_SUITE_END() // F16
308 #endif           /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
309 
TEST_SUITE(F32)310 TEST_SUITE(F32)
311 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP32Dataset),
312                                                                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
313                                                                                                                    InPlaceDataSet))
314 {
315     // Validate output
316     validate(Accessor(_target), _reference);
317 }
318 
319 FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionFP32Dataset),
320                                                                                                                  framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
321                                                                                                                  OutOfPlaceDataSet))
322 {
323     // Validate output
324     validate(Accessor(_target), _reference);
325 }
326 
327 template <typename T>
328 using NEArithmeticSubtractionBroadcastFixture = ArithmeticSubtractionBroadcastValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
329 
330 FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionBroadcastFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapesBroadcast(),
331                        ArithmeticSubtractionFP32Dataset),
332                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
333                        OutOfPlaceDataSet))
334 {
335     // Validate output
336     validate(Accessor(_target), _reference);
337 }
338 
339 FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEArithmeticSubtractionBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapesBroadcast(),
340                        ArithmeticSubtractionFP32Dataset),
341                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
342                        OutOfPlaceDataSet))
343 {
344     // Validate output
345     validate(Accessor(_target), _reference);
346 }
347 TEST_SUITE_END() // F32
348 TEST_SUITE_END() // Float
349 
350 TEST_SUITE_END() // ArithmeticSubtraction
351 TEST_SUITE_END() // NEON
352 } // namespace validation
353 } // namespace test
354 } // namespace arm_compute
355