1 /*
2 * Copyright (c) 2018-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/core/KernelDescriptors.h"
25 #include "arm_compute/core/Types.h"
26 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
27 #include "arm_compute/runtime/CL/CLTensor.h"
28 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
29 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
30 #include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
31 #include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
32 #include "tests/CL/CLAccessor.h"
33 #include "tests/CL/Helper.h"
34 #include "tests/PaddingCalculator.h"
35 #include "tests/datasets/ShapeDatasets.h"
36 #include "tests/framework/Asserts.h"
37 #include "tests/framework/Macros.h"
38 #include "tests/framework/datasets/Datasets.h"
39 #include "tests/validation/Validation.h"
40 #include "tests/validation/fixtures/GEMMFixture.h"
41
42 namespace arm_compute
43 {
44 namespace test
45 {
46 namespace validation
47 {
48 using namespace arm_compute::misc::shape_calculator;
49
50 // Create function for CLGEMMReshapeLHSMatrixKernel
51 using CLGEMMReshapeLHSMatrix = CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
52
53 // Create function for CLGEMMReshapeRHSMatrixKernel
54 using CLGEMMReshapeRHSMatrix = CLSynthetizeFunction<CLGEMMReshapeRHSMatrixKernel>;
55
56 // Create function for CLGEMMMatrixMultiplyReshapedKernel
57 using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
58
59 // Fixture for CLGEMMMatrixMultiplyReshaped
60 template <typename T>
61 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
62
63 // Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
64 template <typename T>
65 using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
66 GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
67
68 // Fixture for CLGEMMMatrixMultiplyReshaped3D
69 template <typename T>
70 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
71
72 // Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
73 template <typename T>
74 using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
75 GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
76
77 namespace
78 {
79 // *INDENT-OFF*
80 // clang-format off
81 RelativeTolerance<float> rel_tolerance_f32(0.001f);
82 constexpr float abs_tolerance_f32(0.0001f);
83
84 RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
85 constexpr float abs_tolerance_f16_mixed_precision(0.01f);
86
87 RelativeTolerance<float> rel_tolerance_f16(0.001f);
88 constexpr float abs_tolerance_f16(0.01f);
89
90 /** M values to test */
91 const auto m_values = framework::dataset::make("M", 17);
92
93 /** M_W values to test */
94 const auto m_w_values = framework::dataset::make("M_W", 5);
95
96 /** M_H values to test */
97 const auto m_h_values = framework::dataset::make("M_H", 7);
98
99 /** N values to test */
100 const auto n_values = framework::dataset::make("N", 21);
101
102 /** K values to test */
103 const auto k_values = framework::dataset::make("K", 13);
104
105 /** Batch size values to test */
106 const auto b_values = framework::dataset::make("batch_size", 2, 3);
107
108 /** Activation values to test */
109 const auto act_values = framework::dataset::make("Activation",
110 {
111 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
112 });
113
114 /** Alpha values to test - Precommit */
115 const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
116
117 /** Beta values to test - Precommit */
118 const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
119
120 /** M0 values to test - Precommit */
121 const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
122
123 /** N0 values to test - Precommit */
124 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
125
126 /** K0 values to test - Precommit */
127 const auto k0_values_precommit = framework::dataset::make("K0", { 4 });
128
129 /** V0 values to test - Precommit */
130 const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
131
132 /** H0 values to test - Precommit */
133 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
134
135 /** Alpha values to test - Nightly */
136 const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
137
138 /** Beta values to test - Nightly */
139 const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
140
141 /** M0 values to test - Nightly */
142 const auto m0_values_nightly = framework::dataset::make("M0", { 8 });
143
144 /** N0 values to test - Nightly */
145 const auto n0_values_nightly = framework::dataset::make("N0", { 8 });
146
147 /** K0 values to test - Nightly */
148 const auto k0_values_nightly = framework::dataset::make("K0", { 4 });
149
150 /** N0 values to test with export to OpenCL image object - Nightly */
151 const auto n0_export_to_cl_image_values_nightly = framework::dataset::make("N0", { 4, 8, 16 });
152
153 /** K0 values to test with export to OpenCL image object - Nightly */
154 const auto k0_export_to_cl_image_values_nightly = framework::dataset::make("K0", { 4, 8, 16 });
155
156 /** V0 values to test - Nightly */
157 const auto v0_values_nightly = framework::dataset::make("V0", 1, 3);
158
159 /** H0 values to test - Nightly */
160 const auto h0_values_nightly = framework::dataset::make("H0", 1, 3);
161
162 /** Interleave values to test with LHS matrix */
163 const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, false });
164
165 /** Interleave values to test with RHS matrix */
166 const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
167
168 /** Broadcast bias from vector to matrix */
169 const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
170
171 /** LHS transposed values */
172 const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
173
174 } // namespace
175
176 TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyReshaped)177 TEST_SUITE(GEMMMatrixMultiplyReshaped)
178
179 // *INDENT-OFF*
180 // clang-format off
181 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
182 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // OK
183 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK
184 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8), // Data type not supported
185 TensorInfo(TensorShape(10U, 5U, 2U), 1, DataType::F32), // Incorrect dimension bias
186 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // Mismatching shapes
187 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, do not broadcast bias
188 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, wider accummulation
189 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, RHS 4,4,2
190
191 }),
192 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
193 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
194 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8),
195 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
196 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
197 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
198 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
199 TensorInfo(TensorShape(128U, 3U, 2U), 1, DataType::F16),
200
201 })),
202 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(21U), 1, DataType::F32),
203 TensorInfo(TensorShape(21U), 1, DataType::F16),
204 TensorInfo(TensorShape(21U), 1, DataType::QASYMM8),
205 TensorInfo(TensorShape(21U), 1, DataType::F32),
206 TensorInfo(TensorShape(21U), 1, DataType::F32),
207 TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
208 TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
209 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
210
211 })),
212 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
213 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
214 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::QASYMM8),
215 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
216 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
217 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
218 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
219 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
220
221 })),
222 framework::dataset::make("LHSMInfo",{
223 GEMMLHSMatrixInfo(4,4,1,false,true),
224 GEMMLHSMatrixInfo(4,4,1,false,true),
225 GEMMLHSMatrixInfo(4,4,1,false,true),
226 GEMMLHSMatrixInfo(4,2,4,false,false),
227 GEMMLHSMatrixInfo(4,2,4,false,false),
228 GEMMLHSMatrixInfo(4,4,1,false,true),
229 GEMMLHSMatrixInfo(4,4,1,false,true),
230 GEMMLHSMatrixInfo(4,4,1,false,true),
231
232 })),
233 framework::dataset::make("RHSMInfo",{
234 GEMMRHSMatrixInfo(4,4,1,true,true,false),
235 GEMMRHSMatrixInfo(4,4,1,true,true,false),
236 GEMMRHSMatrixInfo(4,4,1,true,true,false),
237 GEMMRHSMatrixInfo(2,2,1,true,false,false),
238 GEMMRHSMatrixInfo(2,2,1,true,false,false),
239 GEMMRHSMatrixInfo(4,4,1,true,true,false),
240 GEMMRHSMatrixInfo(4,4,1,true,true,false),
241 GEMMRHSMatrixInfo(4,4,2,true,false,false),
242
243
244 })),
245
246
247 framework::dataset::make("GEMMInfo",{
248 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
249 21 /**<N Number of RHS columns*/,
250 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
251 false /**< reinterpret the input as 3D */,
252 true /**< Flag used to broadcast the bias addition */,
253 false /**< wider accumm */,
254 false /**< has pad y */,
255 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
256 1 /**< Multiplication factor for the width of the 1xW transposed block */,
257 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
258 GEMMLHSMatrixInfo(4,4,1,false,true),
259 GEMMRHSMatrixInfo(4,4,1,true,true,false),
260 0 /**< Offset to be added to each element of the matrix A */,
261 0 /**< Offset to be added to each element of the matrix B */),
262
263 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
264 21 /**<N Number of RHS columns*/,
265 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
266 false /**< reinterpret the input as 3D */,
267 true /**< Flag used to broadcast the bias addition */,
268 false /**< wider accumm */,
269 false /**< has pad y */,
270 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
271 1 /**< Multiplication factor for the width of the 1xW transposed block */,
272 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
273 GEMMLHSMatrixInfo(4,4,1,false,true),
274 GEMMRHSMatrixInfo(4,4,1,true,true,false),
275 0 /**< Offset to be added to each element of the matrix A */,
276 0 /**< Offset to be added to each element of the matrix B */),
277 GEMMKernelInfo(),
278 GEMMKernelInfo(),
279 GEMMKernelInfo(),
280
281 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
282 21 /**<N Number of RHS columns*/,
283 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
284 false /**< reinterpret the input as 3D */,
285 false /**< Flag used to broadcast the bias addition */,
286 false /**< wider accumm */,
287 false /**< has pad y */,
288 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
289 1 /**< Multiplication factor for the width of the 1xW transposed block */,
290 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
291 GEMMLHSMatrixInfo(4,4,1,false,true),
292 GEMMRHSMatrixInfo(4,4,1,true,true,false),
293 0 /**< Offset to be added to each element of the matrix A */,
294 0 /**< Offset to be added to each element of the matrix B */),
295
296
297 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
298 21 /**<N Number of RHS columns*/,
299 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
300 false /**< reinterpret the input as 3D */,
301 false /**< Flag used to broadcast the bias addition */,
302 true /**< wider accumm */,
303 true /**< has pad y */,
304 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
305 1 /**< Multiplication factor for the width of the 1xW transposed block */,
306 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
307 GEMMLHSMatrixInfo(4,4,1,false,true),
308 GEMMRHSMatrixInfo(4,4,1,true,true,false),
309 0 /**< Offset to be added to each element of the matrix A */,
310 0 /**< Offset to be added to each element of the matrix B */),
311
312 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
313 21 /**<N Number of RHS columns*/,
314 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
315 false /**< reinterpret the input as 3D */,
316 false /**< Flag used to broadcast the bias addition */,
317 false /**< wider accumm */,
318 false /**< has pad y */,
319 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
320 1 /**< Multiplication factor for the width of the 1xW transposed block */,
321 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
322 GEMMLHSMatrixInfo(4,4,1,false,true),
323 GEMMRHSMatrixInfo(4,4,2,true,false,false),
324 0 /**< Offset to be added to each element of the matrix A */,
325 0 /**< Offset to be added to each element of the matrix B */),
326 })),
327 framework::dataset::make("Expected", { true, true, false, false, false, true, true,true})),
328 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
329 {
330 ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
331 &input1_info.clone()->set_is_resizable(true),
332 &input2_info.clone()->set_is_resizable(true),
333 &output_info.clone()->set_is_resizable(true),1.f,1.f,
334 lhs_info,
335 rhs_info,
336 gemm_info)) == expected, framework::LogLevel::ERRORS);
337 }
338 TEST_SUITE(Float)
TEST_SUITE(FP32)339 TEST_SUITE(FP32)
340
341 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
342 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
343 m_values,
344 n_values),
345 k_values),
346 b_values),
347 m0_values_precommit),
348 n0_values_precommit),
349 k0_values_precommit),
350 v0_values_precommit),
351 h0_values_precommit),
352 i_values_lhs),
353 i_values_rhs),
354 framework::dataset::make("export_to_cl_image_rhs", false)),
355 framework::dataset::make("DataType", DataType::F32)),
356 a_values_precommit),
357 beta_values_precommit),
358 broadcast_bias_values),
359 lhs_transpose_values),
360 act_values))
361 {
362 // Validate output
363 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
364 }
365
366 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::DISABLED,
367 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
368 m_values,
369 n_values),
370 k_values),
371 b_values),
372 m0_values_nightly),
373 n0_values_nightly),
374 k0_values_nightly),
375 v0_values_nightly),
376 h0_values_nightly),
377 i_values_lhs),
378 i_values_rhs),
379 framework::dataset::make("export_to_cl_image_rhs", false)),
380 framework::dataset::make("DataType", DataType::F32)),
381 a_values_nightly),
382 beta_values_nightly),
383 broadcast_bias_values),
384 lhs_transpose_values),
385 act_values))
386 {
387 // Validate output
388 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
389 }
390
391 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
392 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
393 m_w_values,
394 m_h_values),
395 n_values),
396 k_values),
397 b_values),
398 m0_values_precommit),
399 n0_values_precommit),
400 k0_values_precommit),
401 v0_values_precommit),
402 h0_values_precommit),
403 i_values_lhs),
404 i_values_rhs),
405 framework::dataset::make("export_to_cl_image_rhs", false)),
406 framework::dataset::make("DataType", DataType::F32)),
407 a_values_precommit),
408 beta_values_precommit),
409 lhs_transpose_values),
410 act_values))
411 {
412 // Validate output
413 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
414 }
415
416 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::DISABLED,
417 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
418 m_w_values,
419 m_h_values),
420 n_values),
421 k_values),
422 b_values),
423 m0_values_nightly),
424 n0_values_nightly),
425 k0_values_nightly),
426 v0_values_nightly),
427 h0_values_nightly),
428 i_values_lhs),
429 i_values_rhs),
430 framework::dataset::make("export_to_cl_image_rhs", false)),
431 framework::dataset::make("DataType", DataType::F32)),
432 a_values_nightly),
433 beta_values_nightly),
434 lhs_transpose_values),
435 act_values))
436 {
437 // Validate output
438 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
439 }
440 TEST_SUITE(ExportToCLImage)
441 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
442 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
443 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
444 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
445 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // Incorrect k0
446 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // Incorrect n0
447
448 }),
449 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
450 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
451 TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F32),
452 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
453 TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F32),
454
455 })),
456 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F32),
457 TensorInfo(TensorShape(64U), 1, DataType::F32),
458 TensorInfo(TensorShape(64U), 1, DataType::F32),
459 TensorInfo(TensorShape(64U), 1, DataType::F32),
460 TensorInfo(TensorShape(64U), 1, DataType::F32),
461
462 })),
463 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
464 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
465 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
466 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
467 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
468 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
469
470 })),
471 framework::dataset::make("LHSMInfo",{
472 GEMMLHSMatrixInfo(4, 4, 1, false, true),
473 GEMMLHSMatrixInfo(4, 8, 1, false, true),
474 GEMMLHSMatrixInfo(4, 4, 1, false, true),
475 GEMMLHSMatrixInfo(4, 2, 1, false, false),
476 GEMMLHSMatrixInfo(4, 4, 1, false, false),
477
478 })),
479 framework::dataset::make("RHSMInfo",{
480 GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
481 GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
482 GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
483 GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
484 GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
485 })),
486 framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
487 64 /**<N Number of RHS columns*/,
488 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
489 false /**< reinterpret the input as 3D */,
490 true /**< Flag used to broadcast the bias addition */,
491 false /**< wider accumm */,
492 false /**< has pad y */,
493 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
494 1 /**< Multiplication factor for the width of the 1xW transposed block */,
495 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
496 GEMMLHSMatrixInfo(),
497 GEMMRHSMatrixInfo(),
498 0 /**< Offset to be added to each element of the matrix A */,
499 0 /**< Offset to be added to each element of the matrix B */),
500 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
501 64 /**<N Number of RHS columns*/,
502 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
503 false /**< reinterpret the input as 3D */,
504 true /**< Flag used to broadcast the bias addition */,
505 false /**< wider accumm */,
506 false /**< has pad y */,
507 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
508 1 /**< Multiplication factor for the width of the 1xW transposed block */,
509 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
510 GEMMLHSMatrixInfo(),
511 GEMMRHSMatrixInfo(),
512 0 /**< Offset to be added to each element of the matrix A */,
513 0 /**< Offset to be added to each element of the matrix B */),
514 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
515 64 /**<N Number of RHS columns*/,
516 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
517 false /**< reinterpret the input as 3D */,
518 true /**< Flag used to broadcast the bias addition */,
519 false /**< wider accumm */,
520 false /**< has pad y */,
521 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
522 1 /**< Multiplication factor for the width of the 1xW transposed block */,
523 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
524 GEMMLHSMatrixInfo(),
525 GEMMRHSMatrixInfo(),
526 0 /**< Offset to be added to each element of the matrix A */,
527 0 /**< Offset to be added to each element of the matrix B */),
528
529 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
530 64 /**<N Number of RHS columns*/,
531 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
532 false /**< reinterpret the input as 3D */,
533 true /**< Flag used to broadcast the bias addition */,
534 false /**< wider accumm */,
535 false /**< has pad y */,
536 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
537 1 /**< Multiplication factor for the width of the 1xW transposed block */,
538 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
539 GEMMLHSMatrixInfo(),
540 GEMMRHSMatrixInfo(),
541 0 /**< Offset to be added to each element of the matrix A */,
542 0 /**< Offset to be added to each element of the matrix B */),
543 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
544 64 /**<N Number of RHS columns*/,
545 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
546 false /**< reinterpret the input as 3D */,
547 true /**< Flag used to broadcast the bias addition */,
548 false /**< wider accumm */,
549 false /**< has pad y */,
550 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
551 1 /**< Multiplication factor for the width of the 1xW transposed block */,
552 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
553 GEMMLHSMatrixInfo(),
554 GEMMRHSMatrixInfo(),
555 0 /**< Offset to be added to each element of the matrix A */,
556 0 /**< Offset to be added to each element of the matrix B */)
557 })),
558 framework::dataset::make("Expected", { true,
559 true,
560 true,
561 false,
562 false})),
563 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
564 {
565 ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
566 &input1_info.clone()->set_is_resizable(true),
567 &input2_info.clone()->set_is_resizable(true),
568 &output_info.clone()->set_is_resizable(true),1.f,1.f,
569 lhs_info,
570 rhs_info,
571 gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
572 }
573
574 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
575 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
576 m_values,
577 n_values),
578 k_values),
579 b_values),
580 m0_values_precommit),
581 n0_values_precommit),
582 k0_values_precommit),
583 v0_values_precommit),
584 h0_values_precommit),
585 i_values_lhs),
586 i_values_rhs),
587 framework::dataset::make("export_to_cl_image_rhs", true)),
588 framework::dataset::make("DataType", DataType::F32)),
589 a_values_precommit),
590 beta_values_precommit),
591 broadcast_bias_values),
592 lhs_transpose_values),
593 act_values))
594 {
595 // Validate output only if validate() is successful
596 if(validate_result)
597 {
598 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
599 }
600 else
601 {
602 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
603 framework::ARM_COMPUTE_PRINT_INFO();
604 }
605
606 }
607
608 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
609 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
610 m_values,
611 n_values),
612 k_values),
613 b_values),
614 m0_values_nightly),
615 n0_export_to_cl_image_values_nightly),
616 k0_export_to_cl_image_values_nightly),
617 v0_values_nightly),
618 h0_values_nightly),
619 i_values_lhs),
620 i_values_rhs),
621 framework::dataset::make("export_to_cl_image_rhs", true)),
622 framework::dataset::make("DataType", DataType::F32)),
623 a_values_nightly),
624 beta_values_nightly),
625 broadcast_bias_values),
626 lhs_transpose_values),
627 act_values))
628 {
629 // Validate output only if validate() is successful
630 if(validate_result)
631 {
632 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
633 }
634 else
635 {
636 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
637 framework::ARM_COMPUTE_PRINT_INFO();
638 }
639 }
640
641 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
642 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
643 m_w_values,
644 m_h_values),
645 n_values),
646 k_values),
647 b_values),
648 m0_values_precommit),
649 n0_values_precommit),
650 k0_values_precommit),
651 v0_values_precommit),
652 h0_values_precommit),
653 i_values_lhs),
654 i_values_rhs),
655 framework::dataset::make("export_to_cl_image_rhs", true)),
656 framework::dataset::make("DataType", DataType::F32)),
657 a_values_precommit),
658 beta_values_precommit),
659 lhs_transpose_values),
660 act_values))
661 {
662 // Validate output only if validate() is successful
663 if(validate_result)
664 {
665 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
666 }
667 else
668 {
669 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
670 framework::ARM_COMPUTE_PRINT_INFO();
671 }
672 }
673
674 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
675 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
676 m_w_values,
677 m_h_values),
678 n_values),
679 k_values),
680 b_values),
681 m0_values_nightly),
682 n0_export_to_cl_image_values_nightly),
683 k0_export_to_cl_image_values_nightly),
684 v0_values_nightly),
685 h0_values_nightly),
686 i_values_lhs),
687 i_values_rhs),
688 framework::dataset::make("export_to_cl_image_rhs", true)),
689 framework::dataset::make("DataType", DataType::F32)),
690 a_values_nightly),
691 beta_values_nightly),
692 lhs_transpose_values),
693 act_values))
694 {
695 // Validate output only if validate() is successful
696 if(validate_result)
697 {
698 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
699 }
700 else
701 {
702 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
703 framework::ARM_COMPUTE_PRINT_INFO();
704 }
705 }
706 TEST_SUITE_END() // ExportToCLImage
TEST_SUITE_END()707 TEST_SUITE_END() // FP32
708
709 TEST_SUITE(FP16)
710
711 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
712 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
713 m_values,
714 n_values),
715 k_values),
716 b_values),
717 m0_values_precommit),
718 n0_values_precommit),
719 k0_values_precommit),
720 v0_values_precommit),
721 h0_values_precommit),
722 i_values_lhs),
723 i_values_rhs),
724 framework::dataset::make("export_to_cl_image_rhs", false)),
725 framework::dataset::make("DataType", DataType::F16)),
726 a_values_precommit),
727 beta_values_precommit),
728 broadcast_bias_values),
729 lhs_transpose_values),
730 act_values))
731 {
732 // Validate output
733 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
734 }
735
736 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::DISABLED,
737 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
738 m_values,
739 n_values),
740 k_values),
741 b_values),
742 m0_values_nightly),
743 n0_values_nightly),
744 k0_values_nightly),
745 v0_values_nightly),
746 h0_values_nightly),
747 i_values_lhs),
748 i_values_rhs),
749 framework::dataset::make("export_to_cl_image_rhs", false)),
750 framework::dataset::make("DataType", DataType::F16)),
751 a_values_nightly),
752 beta_values_nightly),
753 broadcast_bias_values),
754 lhs_transpose_values),
755 act_values))
756 {
757 // Validate output
758 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
759 }
760
761 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
762 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
763 m_w_values,
764 m_h_values),
765 n_values),
766 k_values),
767 b_values),
768 m0_values_precommit),
769 n0_values_precommit),
770 k0_values_precommit),
771 v0_values_precommit),
772 h0_values_precommit),
773 i_values_lhs),
774 i_values_rhs),
775 framework::dataset::make("export_to_cl_image_rhs", false)),
776 framework::dataset::make("DataType", DataType::F16)),
777 a_values_precommit),
778 beta_values_precommit),
779 lhs_transpose_values),
780 act_values))
781 {
782 // Validate output
783 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
784 }
785
786 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::DISABLED,
787 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
788 m_w_values,
789 m_h_values),
790 n_values),
791 k_values),
792 b_values),
793 m0_values_nightly),
794 n0_values_nightly),
795 k0_values_nightly),
796 v0_values_nightly),
797 h0_values_nightly),
798 i_values_lhs),
799 i_values_rhs),
800 framework::dataset::make("export_to_cl_image_rhs", false)),
801 framework::dataset::make("DataType", DataType::F16)),
802 a_values_nightly),
803 beta_values_nightly),
804 lhs_transpose_values),
805 act_values))
806 {
807 // Validate output
808 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
809 }
810
811 TEST_SUITE(ExportToCLImage)
812 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
813 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
814 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
815 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
816 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect k0
817 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect n0
818
819 }),
820 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
821 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
822 TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F16),
823 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
824 TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F16),
825
826 })),
827 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F16),
828 TensorInfo(TensorShape(64U), 1, DataType::F16),
829 TensorInfo(TensorShape(64U), 1, DataType::F16),
830 TensorInfo(TensorShape(64U), 1, DataType::F16),
831 TensorInfo(TensorShape(64U), 1, DataType::F16),
832
833 })),
834 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
835 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
836 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
837 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
838 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
839 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
840
841 })),
842 framework::dataset::make("LHSMInfo",{
843 GEMMLHSMatrixInfo(4, 4, 1, false, true),
844 GEMMLHSMatrixInfo(4, 8, 1, false, true),
845 GEMMLHSMatrixInfo(4, 4, 1, false, true),
846 GEMMLHSMatrixInfo(4, 2, 1, false, false),
847 GEMMLHSMatrixInfo(4, 4, 1, false, false),
848
849 })),
850 framework::dataset::make("RHSMInfo",{
851 GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
852 GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
853 GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
854 GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
855 GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
856 })),
857 framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
858 64 /**<N Number of RHS columns*/,
859 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
860 false /**< reinterpret the input as 3D */,
861 true /**< Flag used to broadcast the bias addition */,
862 false /**< wider accumm */,
863 false /**< has pad y */,
864 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
865 1 /**< Multiplication factor for the width of the 1xW transposed block */,
866 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
867 GEMMLHSMatrixInfo(),
868 GEMMRHSMatrixInfo(),
869 0 /**< Offset to be added to each element of the matrix A */,
870 0 /**< Offset to be added to each element of the matrix B */),
871 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
872 64 /**<N Number of RHS columns*/,
873 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
874 false /**< reinterpret the input as 3D */,
875 true /**< Flag used to broadcast the bias addition */,
876 false /**< wider accumm */,
877 false /**< has pad y */,
878 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
879 1 /**< Multiplication factor for the width of the 1xW transposed block */,
880 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
881 GEMMLHSMatrixInfo(),
882 GEMMRHSMatrixInfo(),
883 0 /**< Offset to be added to each element of the matrix A */,
884 0 /**< Offset to be added to each element of the matrix B */),
885 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
886 64 /**<N Number of RHS columns*/,
887 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
888 false /**< reinterpret the input as 3D */,
889 true /**< Flag used to broadcast the bias addition */,
890 false /**< wider accumm */,
891 false /**< has pad y */,
892 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
893 1 /**< Multiplication factor for the width of the 1xW transposed block */,
894 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
895 GEMMLHSMatrixInfo(),
896 GEMMRHSMatrixInfo(),
897 0 /**< Offset to be added to each element of the matrix A */,
898 0 /**< Offset to be added to each element of the matrix B */),
899
900 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
901 64 /**<N Number of RHS columns*/,
902 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
903 false /**< reinterpret the input as 3D */,
904 true /**< Flag used to broadcast the bias addition */,
905 false /**< wider accumm */,
906 false /**< has pad y */,
907 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
908 1 /**< Multiplication factor for the width of the 1xW transposed block */,
909 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
910 GEMMLHSMatrixInfo(),
911 GEMMRHSMatrixInfo(),
912 0 /**< Offset to be added to each element of the matrix A */,
913 0 /**< Offset to be added to each element of the matrix B */),
914 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
915 64 /**<N Number of RHS columns*/,
916 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
917 false /**< reinterpret the input as 3D */,
918 true /**< Flag used to broadcast the bias addition */,
919 false /**< wider accumm */,
920 false /**< has pad y */,
921 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
922 1 /**< Multiplication factor for the width of the 1xW transposed block */,
923 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
924 GEMMLHSMatrixInfo(),
925 GEMMRHSMatrixInfo(),
926 0 /**< Offset to be added to each element of the matrix A */,
927 0 /**< Offset to be added to each element of the matrix B */)
928 })),
929 framework::dataset::make("Expected", { true,
930 true,
931 true,
932 false,
933 false})),
934 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
935 {
936 ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
937 &input1_info.clone()->set_is_resizable(true),
938 &input2_info.clone()->set_is_resizable(true),
939 &output_info.clone()->set_is_resizable(true),1.f,1.f,
940 lhs_info,
941 rhs_info,
942 gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
943 }
944
945 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
946 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
947 m_values,
948 n_values),
949 k_values),
950 b_values),
951 m0_values_precommit),
952 n0_values_precommit),
953 k0_values_precommit),
954 v0_values_precommit),
955 h0_values_precommit),
956 i_values_lhs),
957 i_values_rhs),
958 framework::dataset::make("export_to_cl_image_rhs", true)),
959 framework::dataset::make("DataType", DataType::F16)),
960 a_values_precommit),
961 beta_values_precommit),
962 broadcast_bias_values),
963 lhs_transpose_values),
964 act_values))
965 {
966 // Validate output only if validate() is successful
967 if(validate_result)
968 {
969 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
970 }
971 else
972 {
973 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
974 framework::ARM_COMPUTE_PRINT_INFO();
975 }
976
977 }
978
979 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
980 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
981 m_values,
982 n_values),
983 k_values),
984 b_values),
985 m0_values_nightly),
986 n0_export_to_cl_image_values_nightly),
987 k0_export_to_cl_image_values_nightly),
988 v0_values_nightly),
989 h0_values_nightly),
990 i_values_lhs),
991 i_values_rhs),
992 framework::dataset::make("export_to_cl_image_rhs", true)),
993 framework::dataset::make("DataType", DataType::F16)),
994 a_values_nightly),
995 beta_values_nightly),
996 broadcast_bias_values),
997 lhs_transpose_values),
998 act_values))
999 {
1000 // Validate output only if validate() is successful
1001 if(validate_result)
1002 {
1003 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1004 }
1005 else
1006 {
1007 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1008 framework::ARM_COMPUTE_PRINT_INFO();
1009 }
1010 }
1011
1012 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
1013 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1014 m_w_values,
1015 m_h_values),
1016 n_values),
1017 k_values),
1018 b_values),
1019 m0_values_precommit),
1020 n0_values_precommit),
1021 k0_values_precommit),
1022 v0_values_precommit),
1023 h0_values_precommit),
1024 i_values_lhs),
1025 i_values_rhs),
1026 framework::dataset::make("export_to_cl_image_rhs", true)),
1027 framework::dataset::make("DataType", DataType::F16)),
1028 a_values_precommit),
1029 beta_values_precommit),
1030 lhs_transpose_values),
1031 act_values))
1032 {
1033 // Validate output only if validate() is successful
1034 if(validate_result)
1035 {
1036 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1037 }
1038 else
1039 {
1040 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1041 framework::ARM_COMPUTE_PRINT_INFO();
1042 }
1043 }
1044
1045 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
1046 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1047 m_w_values,
1048 m_h_values),
1049 n_values),
1050 k_values),
1051 b_values),
1052 m0_values_nightly),
1053 n0_export_to_cl_image_values_nightly),
1054 k0_export_to_cl_image_values_nightly),
1055 v0_values_nightly),
1056 h0_values_nightly),
1057 i_values_lhs),
1058 i_values_rhs),
1059 framework::dataset::make("export_to_cl_image_rhs", true)),
1060 framework::dataset::make("DataType", DataType::F16)),
1061 a_values_nightly),
1062 beta_values_nightly),
1063 lhs_transpose_values),
1064 act_values))
1065 {
1066 // Validate output only if validate() is successful
1067 if(validate_result)
1068 {
1069 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1070 }
1071 else
1072 {
1073 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1074 framework::ARM_COMPUTE_PRINT_INFO();
1075 }
1076 }
1077 TEST_SUITE_END() // ExportToCLImage
TEST_SUITE_END()1078 TEST_SUITE_END() // FP16
1079
1080 TEST_SUITE(MixedPrecision)
1081
1082 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1083 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1084 m_values,
1085 n_values),
1086 k_values),
1087 b_values),
1088 m0_values_precommit),
1089 n0_values_precommit),
1090 k0_values_precommit),
1091 v0_values_precommit),
1092 h0_values_precommit),
1093 i_values_lhs),
1094 i_values_rhs),
1095 framework::dataset::make("export_to_cl_image_rhs", false)),
1096 framework::dataset::make("DataType", DataType::F16)),
1097 a_values_precommit),
1098 beta_values_precommit),
1099 broadcast_bias_values),
1100 lhs_transpose_values),
1101 act_values))
1102 {
1103 // Validate output
1104 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1105 }
1106
1107 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1108 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1109 m_values,
1110 n_values),
1111 k_values),
1112 b_values),
1113 m0_values_nightly),
1114 n0_values_nightly),
1115 k0_values_nightly),
1116 v0_values_nightly),
1117 h0_values_nightly),
1118 i_values_lhs),
1119 i_values_rhs),
1120 framework::dataset::make("export_to_cl_image_rhs", false)),
1121 framework::dataset::make("DataType", DataType::F16)),
1122 a_values_nightly),
1123 beta_values_nightly),
1124 broadcast_bias_values),
1125 lhs_transpose_values),
1126 act_values))
1127 {
1128 // Validate output
1129 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1130 }
1131
1132 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1133 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1134 m_w_values,
1135 m_h_values),
1136 n_values),
1137 k_values),
1138 b_values),
1139 m0_values_precommit),
1140 n0_values_precommit),
1141 k0_values_precommit),
1142 v0_values_precommit),
1143 h0_values_precommit),
1144 i_values_lhs),
1145 i_values_rhs),
1146 framework::dataset::make("export_to_cl_image_rhs", false)),
1147 framework::dataset::make("DataType", DataType::F16)),
1148 a_values_precommit),
1149 beta_values_precommit),
1150 lhs_transpose_values),
1151 act_values))
1152 {
1153 // Validate output
1154 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1155 }
1156
1157 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1158 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1159 m_w_values,
1160 m_h_values),
1161 n_values),
1162 k_values),
1163 b_values),
1164 m0_values_nightly),
1165 n0_values_nightly),
1166 k0_values_nightly),
1167 v0_values_nightly),
1168 h0_values_nightly),
1169 i_values_lhs),
1170 i_values_rhs),
1171 framework::dataset::make("export_to_cl_image_rhs", false)),
1172 framework::dataset::make("DataType", DataType::F16)),
1173 a_values_nightly),
1174 beta_values_nightly),
1175 lhs_transpose_values),
1176 act_values))
1177 {
1178 // Validate output
1179 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1180 }
1181 TEST_SUITE_END() // MixedPrecision
1182 TEST_SUITE_END() // Float
1183 TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
1184 TEST_SUITE_END() // CL
1185 } // namespace validation
1186 } // namespace test
1187 } // namespace arm_compute
1188