1 /*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h"
26 #include "arm_compute/runtime/Tensor.h"
27 #include "arm_compute/runtime/TensorAllocator.h"
28 #include "tests/NEON/Accessor.h"
29 #include "tests/PaddingCalculator.h"
30 #include "tests/datasets/ConvertPolicyDataset.h"
31 #include "tests/datasets/ShapeDatasets.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Macros.h"
34 #include "tests/framework/datasets/Datasets.h"
35 #include "tests/validation/Validation.h"
36 #include "tests/validation/fixtures/ArithmeticOperationsFixture.h"
37
38 namespace arm_compute
39 {
40 namespace test
41 {
42 namespace validation
43 {
44 namespace
45 {
46 #ifdef __aarch64__
47 constexpr AbsoluteTolerance<float> tolerance_qasymm8(0); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
48 #else //__aarch64__
49 constexpr AbsoluteTolerance<float> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
50 #endif //__aarch64__
51 constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
52
53 /** Input data sets **/
54 const auto ArithmeticSubtractionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8),
55 framework::dataset::make("DataType", DataType::QASYMM8)),
56 framework::dataset::make("DataType", DataType::QASYMM8));
57
58 const auto ArithmeticSubtractionQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED),
59 framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
60 framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
61
62 const auto ArithmeticSubtractionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16),
63 framework::dataset::make("DataType", DataType::QSYMM16)),
64 framework::dataset::make("DataType", DataType::QSYMM16));
65
66 const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8),
67 framework::dataset::make("DataType", DataType::U8)),
68 framework::dataset::make("DataType", DataType::U8));
69
70 const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }),
71 framework::dataset::make("DataType", DataType::S16)),
72 framework::dataset::make("DataType", DataType::S16));
73
74 const auto ArithmeticSubtractionS32Dataset = combine(combine(framework::dataset::make("DataType", DataType::S32),
75 framework::dataset::make("DataType", DataType::S32)),
76 framework::dataset::make("DataType", DataType::S32));
77 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
78 const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16),
79 framework::dataset::make("DataType", DataType::F16)),
80 framework::dataset::make("DataType", DataType::F16));
81 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
82 const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32),
83 framework::dataset::make("DataType", DataType::F32)),
84 framework::dataset::make("DataType", DataType::F32));
85
86 const auto ArithmeticSubtractionQuantizationInfoDataset = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(10, 120) }),
87 framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(20, 110) })),
88 framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(15, 125) }));
89 const auto ArithmeticSubtractionQuantizationInfoSignedDataset = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(0.5f, 10) }),
90 framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(0.5f, 20) })),
91 framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(0.5f, 50) }));
92 const auto ArithmeticSubtractionQuantizationInfoSymmetric = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(0.3f, 0) }),
93 framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(0.7f, 0) })),
94 framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(0.2f, 0) }));
95 const auto InPlaceDataSet = framework::dataset::make("InPlace", { false, true });
96 const auto OutOfPlaceDataSet = framework::dataset::make("InPlace", { false });
97 } // namespace
98
99 TEST_SUITE(NEON)
100 TEST_SUITE(ArithmeticSubtraction)
101
102 template <typename T>
103 using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
104
105 // *INDENT-OFF*
106 // clang-format off
107 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
108 framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
109 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
110 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
111 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
112 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::QASYMM8), // Mismatching types
113 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Invalid convert policy
114 }),
115 framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
116 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
117 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
118 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
119 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
120 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
121 })),
122 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
123 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
124 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
125 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
126 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
127 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
128 })),
129 framework::dataset::make("ConvertPolicy",{ ConvertPolicy::WRAP,
130 ConvertPolicy::SATURATE,
131 ConvertPolicy::SATURATE,
132 ConvertPolicy::WRAP,
133 ConvertPolicy::WRAP,
134 ConvertPolicy::WRAP,
135 })),
136 framework::dataset::make("Expected", { true, true, false, false, false, false})),
137 input1_info, input2_info, output_info, policy, expected)
138 {
139 ARM_COMPUTE_EXPECT(bool(NEArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), policy)) == expected, framework::LogLevel::ERRORS);
140 }
141 // clang-format on
142 // *INDENT-ON*
143
144 TEST_SUITE(InPlaceValidate)
TEST_CASE(SingleTensor,framework::DatasetMode::ALL)145 TEST_CASE(SingleTensor, framework::DatasetMode::ALL)
146 {
147 const auto random_shape = TensorShape{ 9, 9 };
148 const auto single_tensor_info = TensorInfo{ random_shape, 1, DataType::F32 };
149
150 Status result = NEArithmeticSubtraction::validate(&single_tensor_info, &single_tensor_info, &single_tensor_info, ConvertPolicy::WRAP);
151 ARM_COMPUTE_EXPECT(bool(result) == true, framework::LogLevel::ERRORS);
152 }
153
TEST_CASE(ValidBroadCast,framework::DatasetMode::ALL)154 TEST_CASE(ValidBroadCast, framework::DatasetMode::ALL)
155 {
156 const auto larger_shape = TensorShape{ 27U, 13U, 2U };
157 const auto smaller_shape = TensorShape{ 1U, 13U, 2U };
158
159 const auto larger_tensor_info = TensorInfo{ larger_shape, 1, DataType::F32 };
160 const auto smaller_tensor_info = TensorInfo{ smaller_shape, 1, DataType::F32 };
161
162 Status result = NEArithmeticSubtraction::validate(&larger_tensor_info, &smaller_tensor_info, &larger_tensor_info, ConvertPolicy::WRAP);
163 ARM_COMPUTE_EXPECT(bool(result) == true, framework::LogLevel::ERRORS);
164 }
165
TEST_CASE(InvalidBroadcastOutput,framework::DatasetMode::ALL)166 TEST_CASE(InvalidBroadcastOutput, framework::DatasetMode::ALL)
167 {
168 const auto larger_shape = TensorShape{ 27U, 13U, 2U };
169 const auto smaller_shape = TensorShape{ 1U, 13U, 2U };
170
171 const auto larger_tensor_info = TensorInfo{ larger_shape, 1, DataType::F32 };
172 const auto smaller_tensor_info = TensorInfo{ smaller_shape, 1, DataType::F32 };
173
174 Status result = NEArithmeticSubtraction::validate(&larger_tensor_info, &smaller_tensor_info, &smaller_tensor_info, ConvertPolicy::WRAP);
175 ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
176 }
177
TEST_CASE(InvalidBroadcastBoth,framework::DatasetMode::ALL)178 TEST_CASE(InvalidBroadcastBoth, framework::DatasetMode::ALL)
179 {
180 const auto shape0 = TensorShape{ 9U, 9U };
181 const auto shape1 = TensorShape{ 9U, 1U, 2U };
182
183 const auto info0 = TensorInfo{ shape0, 1, DataType::F32 };
184 const auto info1 = TensorInfo{ shape1, 1, DataType::F32 };
185
186 Status result{};
187
188 result = NEArithmeticSubtraction::validate(&info0, &info1, &info0, ConvertPolicy::WRAP);
189 ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
190
191 result = NEArithmeticSubtraction::validate(&info0, &info1, &info1, ConvertPolicy::WRAP);
192 ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
193 }
194 TEST_SUITE_END() // InPlaceValidate
195
TEST_SUITE(U8)196 TEST_SUITE(U8)
197 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset),
198 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
199 OutOfPlaceDataSet))
200 {
201 // Validate output
202 validate(Accessor(_target), _reference);
203 }
204 TEST_SUITE_END() // U8
205
206 using NEArithmeticSubtractionQASYMM8Fixture = ArithmeticSubtractionValidationQuantizedFixture<Tensor, Accessor, NEArithmeticSubtraction, uint8_t>;
207 using NEArithmeticSubtractionQASYMM8SignedFixture = ArithmeticSubtractionValidationQuantizedFixture<Tensor, Accessor, NEArithmeticSubtraction, int8_t>;
208 using NEArithmeticSubtractionQASYMM8SignedBroadcastFixture = ArithmeticSubtractionValidationQuantizedBroadcastFixture<Tensor, Accessor, NEArithmeticSubtraction, int8_t>;
209 using NEArithmeticSubtractionQSYMM16Fixture = ArithmeticSubtractionValidationQuantizedFixture<Tensor, Accessor, NEArithmeticSubtraction, int16_t>;
210
211 TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)212 TEST_SUITE(QASYMM8)
213 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQASYMM8Dataset),
214 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
215 ArithmeticSubtractionQuantizationInfoDataset),
216 InPlaceDataSet))
217 {
218 // Validate output
219 validate(Accessor(_target), _reference, tolerance_qasymm8);
220 }
221 TEST_SUITE_END() // QASYMM8
222
TEST_SUITE(QASYMM8_SIGNED)223 TEST_SUITE(QASYMM8_SIGNED)
224 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8SignedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
225 datasets::SmallShapes(),
226 ArithmeticSubtractionQASYMM8SIGNEDDataset),
227 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
228 ArithmeticSubtractionQuantizationInfoSignedDataset),
229 InPlaceDataSet))
230 {
231 // Validate output
232 validate(Accessor(_target), _reference, tolerance_qasymm8);
233 }
234
235 FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionQASYMM8SignedBroadcastFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
236 datasets::SmallShapesBroadcast(),
237 ArithmeticSubtractionQASYMM8SIGNEDDataset),
238 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
239 ArithmeticSubtractionQuantizationInfoSignedDataset),
240 OutOfPlaceDataSet))
241 {
242 // Validate output
243 validate(Accessor(_target), _reference, tolerance_qasymm8);
244 }
245 TEST_SUITE_END() // QASYMM8_SIGNED
246
TEST_SUITE(QSYMM16)247 TEST_SUITE(QSYMM16)
248 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQSYMM16Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
249 datasets::SmallShapes(),
250 ArithmeticSubtractionQSYMM16Dataset),
251 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
252 ArithmeticSubtractionQuantizationInfoSymmetric),
253 OutOfPlaceDataSet))
254 {
255 // Validate output
256 validate(Accessor(_target), _reference, tolerance_qsymm16);
257 }
258 TEST_SUITE_END() // QSYMM16
TEST_SUITE_END()259 TEST_SUITE_END() // Quantized
260
261 TEST_SUITE(S16)
262 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
263 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
264 OutOfPlaceDataSet))
265 {
266 // Validate output
267 validate(Accessor(_target), _reference);
268 }
269
270 FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
271 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
272 OutOfPlaceDataSet))
273 {
274 // Validate output
275 validate(Accessor(_target), _reference);
276 }
277 TEST_SUITE_END() // S16
278
TEST_SUITE(S32)279 TEST_SUITE(S32)
280 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS32Dataset),
281 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
282 OutOfPlaceDataSet))
283 {
284 // Validate output
285 validate(Accessor(_target), _reference);
286 }
287
288 FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS32Dataset),
289 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
290 OutOfPlaceDataSet))
291 {
292 // Validate output
293 validate(Accessor(_target), _reference);
294 }
295 TEST_SUITE_END() // S32
296
TEST_SUITE(Float)297 TEST_SUITE(Float)
298 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
299 TEST_SUITE(F16)
300 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP16Dataset),
301 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
302 OutOfPlaceDataSet))
303 {
304 // Validate output
305 validate(Accessor(_target), _reference);
306 }
307 TEST_SUITE_END() // F16
308 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
309
TEST_SUITE(F32)310 TEST_SUITE(F32)
311 FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP32Dataset),
312 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
313 InPlaceDataSet))
314 {
315 // Validate output
316 validate(Accessor(_target), _reference);
317 }
318
319 FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionFP32Dataset),
320 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
321 OutOfPlaceDataSet))
322 {
323 // Validate output
324 validate(Accessor(_target), _reference);
325 }
326
327 template <typename T>
328 using NEArithmeticSubtractionBroadcastFixture = ArithmeticSubtractionBroadcastValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
329
330 FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionBroadcastFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapesBroadcast(),
331 ArithmeticSubtractionFP32Dataset),
332 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
333 OutOfPlaceDataSet))
334 {
335 // Validate output
336 validate(Accessor(_target), _reference);
337 }
338
339 FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEArithmeticSubtractionBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapesBroadcast(),
340 ArithmeticSubtractionFP32Dataset),
341 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
342 OutOfPlaceDataSet))
343 {
344 // Validate output
345 validate(Accessor(_target), _reference);
346 }
347 TEST_SUITE_END() // F32
348 TEST_SUITE_END() // Float
349
350 TEST_SUITE_END() // ArithmeticSubtraction
351 TEST_SUITE_END() // NEON
352 } // namespace validation
353 } // namespace test
354 } // namespace arm_compute
355