1 /*
2 * Copyright (c) 2017-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/CL/CLTensor.h"
26 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
27 #include "arm_compute/runtime/CL/functions/CLElementwiseOperations.h"
28 #include "tests/CL/CLAccessor.h"
29 #include "tests/PaddingCalculator.h"
30 #include "tests/datasets/ConvertPolicyDataset.h"
31 #include "tests/datasets/ShapeDatasets.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Macros.h"
34 #include "tests/framework/datasets/Datasets.h"
35 #include "tests/validation/Validation.h"
36 #include "tests/validation/fixtures/ArithmeticOperationsFixture.h"
37
38 namespace arm_compute
39 {
40 namespace test
41 {
42 namespace validation
43 {
44 /** Synced with tests/validation/dynamic_fusion/gpu/cl/Sub.cpp from the dynamic fusion interface.
45 * Please check there for any differences in the coverage
46 */
47 namespace
48 {
49 /** Input data sets **/
50 const auto EmptyActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
51 { ActivationLayerInfo() });
52 const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
53 {
54 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 0.75f, 0.25f),
55 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 0.75f, 0.25f)
56 });
57 const auto InPlaceDataSet = framework::dataset::make("InPlace", { false, true });
58 const auto OutOfPlaceDataSet = framework::dataset::make("InPlace", { false });
59 } // namespace
60
61 TEST_SUITE(CL)
TEST_SUITE(ArithmeticSubtraction)62 TEST_SUITE(ArithmeticSubtraction)
63
64 // *INDENT-OFF*
65 // clang-format off
66 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
67 framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
68 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
69 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
70 }),
71 framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
72 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
73 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
74 })),
75 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
76 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
77 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
78 })),
79 framework::dataset::make("Expected", { true, false, false})),
80 input1_info, input2_info, output_info, expected)
81 {
82 ARM_COMPUTE_EXPECT(bool(CLArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS);
83 }
84 // clang-format on
85 // *INDENT-ON*
86
87 TEST_SUITE(InPlaceValidate)
TEST_CASE(SingleTensor,framework::DatasetMode::ALL)88 TEST_CASE(SingleTensor, framework::DatasetMode::ALL)
89 {
90 const auto random_shape = TensorShape{ 9, 9 };
91 const auto single_tensor_info = TensorInfo{ random_shape, 1, DataType::F32 };
92
93 Status result = CLArithmeticSubtraction::validate(&single_tensor_info, &single_tensor_info, &single_tensor_info, ConvertPolicy::WRAP);
94 ARM_COMPUTE_EXPECT(bool(result) == true, framework::LogLevel::ERRORS);
95 }
96
TEST_CASE(ValidBroadCast,framework::DatasetMode::ALL)97 TEST_CASE(ValidBroadCast, framework::DatasetMode::ALL)
98 {
99 const auto larger_shape = TensorShape{ 27U, 13U, 2U };
100 const auto smaller_shape = TensorShape{ 1U, 13U, 2U };
101
102 const auto larger_tensor_info = TensorInfo{ larger_shape, 1, DataType::F32 };
103 const auto smaller_tensor_info = TensorInfo{ smaller_shape, 1, DataType::F32 };
104
105 Status result = CLArithmeticSubtraction::validate(&larger_tensor_info, &smaller_tensor_info, &larger_tensor_info, ConvertPolicy::WRAP);
106 ARM_COMPUTE_EXPECT(bool(result) == true, framework::LogLevel::ERRORS);
107 }
108
TEST_CASE(InvalidBroadcastOutput,framework::DatasetMode::ALL)109 TEST_CASE(InvalidBroadcastOutput, framework::DatasetMode::ALL)
110 {
111 const auto larger_shape = TensorShape{ 27U, 13U, 2U };
112 const auto smaller_shape = TensorShape{ 1U, 13U, 2U };
113
114 const auto larger_tensor_info = TensorInfo{ larger_shape, 1, DataType::F32 };
115 const auto smaller_tensor_info = TensorInfo{ smaller_shape, 1, DataType::F32 };
116
117 Status result = CLArithmeticSubtraction::validate(&larger_tensor_info, &smaller_tensor_info, &smaller_tensor_info, ConvertPolicy::WRAP);
118 ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
119 }
120
TEST_CASE(InvalidBroadcastBoth,framework::DatasetMode::ALL)121 TEST_CASE(InvalidBroadcastBoth, framework::DatasetMode::ALL)
122 {
123 const auto shape0 = TensorShape{ 9U, 9U };
124 const auto shape1 = TensorShape{ 9U, 1U, 2U };
125
126 const auto info0 = TensorInfo{ shape0, 1, DataType::F32 };
127 const auto info1 = TensorInfo{ shape1, 1, DataType::F32 };
128
129 Status result{};
130
131 result = CLArithmeticSubtraction::validate(&info0, &info1, &info0, ConvertPolicy::WRAP);
132 ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
133
134 result = CLArithmeticSubtraction::validate(&info0, &info1, &info1, ConvertPolicy::WRAP);
135 ARM_COMPUTE_EXPECT(bool(result) == false, framework::LogLevel::ERRORS);
136 }
137 TEST_SUITE_END() // InPlaceValidate
138
139 template <typename T>
140 using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
141
142 TEST_SUITE(Integer)
TEST_SUITE(U8)143 TEST_SUITE(U8)
144 FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
145 DataType::U8)),
146 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
147 OutOfPlaceDataSet))
148 {
149 // Validate output
150 validate(CLAccessor(_target), _reference);
151 }
152 TEST_SUITE_END() // U8
153
TEST_SUITE(S16)154 TEST_SUITE(S16)
155 FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
156 DataType::S16)),
157 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
158 OutOfPlaceDataSet))
159 {
160 // Validate output
161 validate(CLAccessor(_target), _reference);
162 }
163
164 FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
165 DataType::S16)),
166 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
167 OutOfPlaceDataSet))
168 {
169 // Validate output
170 validate(CLAccessor(_target), _reference);
171 }
172 TEST_SUITE_END() // S16
173 TEST_SUITE_END() // Integer
174
175 template <typename T>
176 using CLArithmeticSubtractionQuantizedFixture = ArithmeticSubtractionValidationQuantizedFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
177
178 TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)179 TEST_SUITE(QASYMM8)
180 FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
181 framework::dataset::make("DataType", DataType::QASYMM8)),
182 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
183 framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
184 framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
185 framework::dataset::make("OutQInfo", { QuantizationInfo(1.f / 255.f, 5) })),
186 OutOfPlaceDataSet))
187 {
188 // Validate output
189 validate(CLAccessor(_target), _reference);
190 }
191 FIXTURE_DATA_TEST_CASE(RunTinyInPlace, CLArithmeticSubtractionQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::TinyShapes(),
192 framework::dataset::make("DataType", DataType::QASYMM8)),
193 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
194 framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
195 framework::dataset::make("Src1QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
196 framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 255.f, 20) })),
197 InPlaceDataSet))
198 {
199 // Validate output
200 validate(CLAccessor(_target), _reference);
201 }
202 TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)203 TEST_SUITE(QASYMM8_SIGNED)
204 FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
205 framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
206 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
207 framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 10) })),
208 framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
209 framework::dataset::make("OutQInfo", { QuantizationInfo(1.f / 255.f, 5) })),
210 OutOfPlaceDataSet))
211 {
212 // Validate output
213 validate(CLAccessor(_target), _reference);
214 }
215 TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE(QSYMM16)216 TEST_SUITE(QSYMM16)
217 FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
218 framework::dataset::make("DataType", DataType::QSYMM16)),
219 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
220 framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
221 framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
222 framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })),
223 OutOfPlaceDataSet))
224 {
225 // Validate output
226 validate(CLAccessor(_target), _reference);
227 }
228 TEST_SUITE_END() // QSYMM16
229 TEST_SUITE_END() // Quantized
230
231 template <typename T>
232 using CLArithmeticSubtractionFloatFixture = ArithmeticSubtractionValidationFloatFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
233
234 TEST_SUITE(Float)
TEST_SUITE(FP16)235 TEST_SUITE(FP16)
236 FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
237 DataType::F16)),
238 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
239 EmptyActivationFunctionsDataset),
240 OutOfPlaceDataSet))
241 {
242 // Validate output
243 validate(CLAccessor(_target), _reference);
244 }
245 FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapes(),
246 framework::dataset::make("DataType", DataType::F16)),
247 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
248 ActivationFunctionsDataset),
249 InPlaceDataSet))
250 {
251 // Validate output
252 validate(CLAccessor(_target), _reference);
253 }
254 TEST_SUITE_END() // FP16
255
TEST_SUITE(FP32)256 TEST_SUITE(FP32)
257 FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
258 framework::dataset::make("DataType", DataType::F32)),
259 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
260 EmptyActivationFunctionsDataset),
261 OutOfPlaceDataSet))
262 {
263 // Validate output
264 validate(CLAccessor(_target), _reference);
265 }
266 FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapes(),
267 framework::dataset::make("DataType", DataType::F32)),
268 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
269 ActivationFunctionsDataset),
270 InPlaceDataSet))
271 {
272 // Validate output
273 validate(CLAccessor(_target), _reference);
274 }
275
276 FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFloatFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
277 framework::dataset::make("DataType", DataType::F32)),
278 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
279 EmptyActivationFunctionsDataset),
280 OutOfPlaceDataSet))
281 {
282 // Validate output
283 validate(CLAccessor(_target), _reference);
284 }
285
286 template <typename T>
287 using CLArithmeticSubtractionBroadcastFloatFixture = ArithmeticSubtractionBroadcastValidationFloatFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
288
289 FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticSubtractionBroadcastFloatFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapesBroadcast(),
290 framework::dataset::make("DataType", DataType::F32)),
291 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
292 EmptyActivationFunctionsDataset),
293 OutOfPlaceDataSet))
294 {
295 // Validate output
296 validate(CLAccessor(_target), _reference);
297 }
298 FIXTURE_DATA_TEST_CASE(RunTinyBroadcastInplace, CLArithmeticSubtractionBroadcastFloatFixture<float>, framework::DatasetMode::PRECOMMIT,
299 combine(combine(combine(combine(datasets::TinyShapesBroadcastInplace(),
300 framework::dataset::make("DataType", DataType::F32)),
301 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
302 EmptyActivationFunctionsDataset),
303 InPlaceDataSet))
304 {
305 // Validate output
306 validate(CLAccessor(_target), _reference);
307 }
308 FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticSubtractionBroadcastFloatFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapesBroadcast(),
309 framework::dataset::make("DataType", DataType::F32)),
310 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
311 ActivationFunctionsDataset),
312 OutOfPlaceDataSet))
313 {
314 // Validate output
315 validate(CLAccessor(_target), _reference);
316 }
317
318 FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, CLArithmeticSubtractionBroadcastFloatFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapesBroadcast(),
319 framework::dataset::make("DataType", DataType::F32)),
320 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
321 EmptyActivationFunctionsDataset),
322 OutOfPlaceDataSet))
323 {
324 // Validate output
325 validate(CLAccessor(_target), _reference);
326 }
327 TEST_SUITE_END() // FP32
328 TEST_SUITE_END() // Float
329
330 TEST_SUITE_END() // ArithmeticSubtraction
331 TEST_SUITE_END() // CL
332 } // namespace validation
333 } // namespace test
334 } // namespace arm_compute
335