• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/KernelDescriptors.h"
25 #include "arm_compute/core/Types.h"
26 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
27 #include "arm_compute/runtime/CL/CLTensor.h"
28 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
29 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
30 #include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
31 #include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
32 #include "tests/CL/CLAccessor.h"
33 #include "tests/CL/Helper.h"
34 #include "tests/PaddingCalculator.h"
35 #include "tests/datasets/ShapeDatasets.h"
36 #include "tests/framework/Asserts.h"
37 #include "tests/framework/Macros.h"
38 #include "tests/framework/datasets/Datasets.h"
39 #include "tests/validation/Validation.h"
40 #include "tests/validation/fixtures/GEMMFixture.h"
41 
42 namespace arm_compute
43 {
44 namespace test
45 {
46 namespace validation
47 {
48 using namespace arm_compute::misc::shape_calculator;
49 
50 // Create function for CLGEMMReshapeLHSMatrixKernel
51 using CLGEMMReshapeLHSMatrix = CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
52 
53 // Create function for CLGEMMReshapeRHSMatrixKernel
54 using CLGEMMReshapeRHSMatrix = CLSynthetizeFunction<CLGEMMReshapeRHSMatrixKernel>;
55 
56 // Create function for CLGEMMMatrixMultiplyReshapedKernel
57 using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
58 
59 // Fixture for CLGEMMMatrixMultiplyReshaped
60 template <typename T>
61 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
62 
63 // Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
64 template <typename T>
65 using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
66     GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
67 
68 // Fixture for CLGEMMMatrixMultiplyReshaped3D
69 template <typename T>
70 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
71 
72 // Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
73 template <typename T>
74 using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
75     GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
76 
77 namespace
78 {
79 // *INDENT-OFF*
80 // clang-format off
81 RelativeTolerance<float> rel_tolerance_f32(0.001f);
82 constexpr float          abs_tolerance_f32(0.0001f);
83 
84 RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
85 constexpr float          abs_tolerance_f16_mixed_precision(0.01f);
86 
87 RelativeTolerance<float> rel_tolerance_f16(0.001f);
88 constexpr float          abs_tolerance_f16(0.01f);
89 
90 /** M values to test */
91 const auto m_values = framework::dataset::make("M", 17);
92 
93 /** M_W values to test */
94 const auto m_w_values = framework::dataset::make("M_W", 5);
95 
96 /** M_H values to test */
97 const auto m_h_values = framework::dataset::make("M_H", 7);
98 
99 /** N values to test */
100 const auto n_values = framework::dataset::make("N", 21);
101 
102 /** K values to test */
103 const auto k_values = framework::dataset::make("K", 13);
104 
105 /** Batch size values to test */
106 const auto b_values = framework::dataset::make("batch_size", 2, 3);
107 
108 /** Activation values to test */
109 const auto act_values = framework::dataset::make("Activation",
110 {
111     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
112 });
113 
114 /** Alpha values to test - Precommit */
115 const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
116 
117 /** Beta values to test - Precommit */
118 const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
119 
120 /** M0 values to test - Precommit */
121 const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
122 
123 /** N0 values to test - Precommit */
124 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
125 
126 /** K0 values to test - Precommit */
127 const auto k0_values_precommit = framework::dataset::make("K0", { 4 });
128 
129 /** V0 values to test - Precommit */
130 const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
131 
132 /** H0 values to test - Precommit */
133 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
134 
135 /** Alpha values to test - Nightly */
136 const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
137 
138 /** Beta values to test - Nightly */
139 const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
140 
141 /** M0 values to test - Nightly */
142 const auto m0_values_nightly = framework::dataset::make("M0", { 8 });
143 
144 /** N0 values to test - Nightly */
145 const auto n0_values_nightly = framework::dataset::make("N0", { 8 });
146 
147 /** K0 values to test - Nightly */
148 const auto k0_values_nightly = framework::dataset::make("K0", { 4 });
149 
150 /** N0 values to test with export to OpenCL image object - Nightly */
151 const auto n0_export_to_cl_image_values_nightly = framework::dataset::make("N0", { 4, 8, 16 });
152 
153 /** K0 values to test with export to OpenCL image object - Nightly */
154 const auto k0_export_to_cl_image_values_nightly = framework::dataset::make("K0", { 4, 8, 16 });
155 
156 /** V0 values to test - Nightly */
157 const auto v0_values_nightly = framework::dataset::make("V0", 1, 3);
158 
159 /** H0 values to test - Nightly */
160 const auto h0_values_nightly = framework::dataset::make("H0", 1, 3);
161 
162 /** Interleave values to test with LHS matrix */
163 const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, false });
164 
165 /** Interleave values to test with RHS matrix */
166 const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
167 
168 /** Broadcast bias from vector to matrix */
169 const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
170 
171 /** LHS transposed values */
172 const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
173 
174 } // namespace
175 
176 TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyReshaped)177 TEST_SUITE(GEMMMatrixMultiplyReshaped)
178 
179 // *INDENT-OFF*
180 // clang-format off
181 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
182                framework::dataset::make("Input0Info", { TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32),      // OK
183                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK
184                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8),  // Data type not supported
185                                                         TensorInfo(TensorShape(10U, 5U, 2U), 1, DataType::F32),      // Incorrect dimension bias
186                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32),      // Mismatching shapes
187                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK, do not broadcast bias
188                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK, wider accummulation
189                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK, RHS 4,4,2
190 
191                                                       }),
192                framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
193                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
194                                                        TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8),
195                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
196                                                        TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
197                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
198                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
199                                                        TensorInfo(TensorShape(128U, 3U, 2U), 1, DataType::F16),
200 
201                       })),
202                framework::dataset::make("Input2Info", { TensorInfo(TensorShape(21U), 1, DataType::F32),
203                                                         TensorInfo(TensorShape(21U), 1, DataType::F16),
204                                                         TensorInfo(TensorShape(21U), 1, DataType::QASYMM8),
205                                                         TensorInfo(TensorShape(21U), 1, DataType::F32),
206                                                         TensorInfo(TensorShape(21U), 1, DataType::F32),
207                                                         TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
208                                                         TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
209                                                         TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
210 
211                                                       })),
212                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
213                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
214                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::QASYMM8),
215                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
216                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
217                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
218                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
219                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
220 
221                            })),
222                framework::dataset::make("LHSMInfo",{
223                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
224                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
225                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
226                                                           GEMMLHSMatrixInfo(4,2,4,false,false),
227                                                           GEMMLHSMatrixInfo(4,2,4,false,false),
228                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
229                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
230                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
231 
232                                 })),
233                framework::dataset::make("RHSMInfo",{
234                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
235                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
236                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
237                                                           GEMMRHSMatrixInfo(2,2,1,true,false,false),
238                                                           GEMMRHSMatrixInfo(2,2,1,true,false,false),
239                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
240                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
241                                                           GEMMRHSMatrixInfo(4,4,2,true,false,false),
242 
243 
244                            })),
245 
246 
247                framework::dataset::make("GEMMInfo",{
248                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
249                                                                             21 /**<N Number of RHS columns*/,
250                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
251                                                                      false /**< reinterpret the input as 3D */,
252                                                                      true  /**< Flag used to broadcast the bias addition */,
253                                                                      false /**< wider accumm */,
254                                                                      false /**< has pad y */,
255                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
256                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
257                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
258                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
259                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
260                                                                      0  /**< Offset to be added to each element of the matrix A */,
261                                                                      0 /**< Offset to be added to each element of the matrix B */),
262 
263                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
264                                                                             21 /**<N Number of RHS columns*/,
265                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
266                                                                      false /**< reinterpret the input as 3D */,
267                                                                      true  /**< Flag used to broadcast the bias addition */,
268                                                                      false /**< wider accumm */,
269                                                                      false /**< has pad y */,
270                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
271                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
272                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
273                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
274                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
275                                                                      0  /**< Offset to be added to each element of the matrix A */,
276                                                                      0 /**< Offset to be added to each element of the matrix B */),
277                                                             GEMMKernelInfo(),
278                                                             GEMMKernelInfo(),
279                                                             GEMMKernelInfo(),
280 
281                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
282                                                                             21 /**<N Number of RHS columns*/,
283                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
284                                                                      false /**< reinterpret the input as 3D */,
285                                                                      false  /**< Flag used to broadcast the bias addition */,
286                                                                      false /**< wider accumm */,
287                                                                      false /**< has pad y */,
288                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
289                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
290                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
291                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
292                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
293                                                                      0  /**< Offset to be added to each element of the matrix A */,
294                                                                      0 /**< Offset to be added to each element of the matrix B */),
295 
296 
297                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
298                                                                             21 /**<N Number of RHS columns*/,
299                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
300                                                                      false /**< reinterpret the input as 3D */,
301                                                                      false  /**< Flag used to broadcast the bias addition */,
302                                                                      true /**< wider accumm */,
303                                                                      true /**< has pad y */,
304                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
305                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
306                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
307                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
308                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
309                                                                      0  /**< Offset to be added to each element of the matrix A */,
310                                                                      0 /**< Offset to be added to each element of the matrix B */),
311 
312                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
313                                                                             21 /**<N Number of RHS columns*/,
314                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
315                                                                      false /**< reinterpret the input as 3D */,
316                                                                      false  /**< Flag used to broadcast the bias addition */,
317                                                                      false /**< wider accumm */,
318                                                                      false /**< has pad y */,
319                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
320                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
321                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
322                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
323                                                                      GEMMRHSMatrixInfo(4,4,2,true,false,false),
324                                                                      0  /**< Offset to be added to each element of the matrix A */,
325                                                                      0 /**< Offset to be added to each element of the matrix B */),
326                                                     })),
327                framework::dataset::make("Expected", { true, true, false, false, false, true, true,true})),
328                     input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
329 {
330    ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
331                                                           &input1_info.clone()->set_is_resizable(true),
332                                                           &input2_info.clone()->set_is_resizable(true),
333                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
334                                                           lhs_info,
335                                                           rhs_info,
336                                                           gemm_info)) == expected, framework::LogLevel::ERRORS);
337 }
338 TEST_SUITE(Float)
TEST_SUITE(FP32)339 TEST_SUITE(FP32)
340 
341 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
342                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
343                                                                    m_values,
344                                                                    n_values),
345                                                                    k_values),
346                                                                    b_values),
347                                                                    m0_values_precommit),
348                                                                    n0_values_precommit),
349                                                                    k0_values_precommit),
350                                                                    v0_values_precommit),
351                                                                    h0_values_precommit),
352                                                                    i_values_lhs),
353                                                                    i_values_rhs),
354                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
355                                                                    framework::dataset::make("DataType", DataType::F32)),
356                                                                    a_values_precommit),
357                                                                    beta_values_precommit),
358                                                                    broadcast_bias_values),
359                                                                    lhs_transpose_values),
360                                                                    act_values))
361 {
362     // Validate output
363     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
364 }
365 
366 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::DISABLED,
367                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
368                                                                    m_values,
369                                                                    n_values),
370                                                                    k_values),
371                                                                    b_values),
372                                                                    m0_values_nightly),
373                                                                    n0_values_nightly),
374                                                                    k0_values_nightly),
375                                                                    v0_values_nightly),
376                                                                    h0_values_nightly),
377                                                                    i_values_lhs),
378                                                                    i_values_rhs),
379                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
380                                                                    framework::dataset::make("DataType", DataType::F32)),
381                                                                    a_values_nightly),
382                                                                    beta_values_nightly),
383                                                                    broadcast_bias_values),
384                                                                    lhs_transpose_values),
385                                                                    act_values))
386 {
387     // Validate output
388     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
389 }
390 
391 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
392                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
393                                                                    m_w_values,
394                                                                    m_h_values),
395                                                                    n_values),
396                                                                    k_values),
397                                                                    b_values),
398                                                                    m0_values_precommit),
399                                                                    n0_values_precommit),
400                                                                    k0_values_precommit),
401                                                                    v0_values_precommit),
402                                                                    h0_values_precommit),
403                                                                    i_values_lhs),
404                                                                    i_values_rhs),
405                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
406                                                                    framework::dataset::make("DataType", DataType::F32)),
407                                                                    a_values_precommit),
408                                                                    beta_values_precommit),
409                                                                    lhs_transpose_values),
410                                                                    act_values))
411 {
412     // Validate output
413     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
414 }
415 
416 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::DISABLED,
417                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
418                                                                    m_w_values,
419                                                                    m_h_values),
420                                                                    n_values),
421                                                                    k_values),
422                                                                    b_values),
423                                                                    m0_values_nightly),
424                                                                    n0_values_nightly),
425                                                                    k0_values_nightly),
426                                                                    v0_values_nightly),
427                                                                    h0_values_nightly),
428                                                                    i_values_lhs),
429                                                                    i_values_rhs),
430                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
431                                                                    framework::dataset::make("DataType", DataType::F32)),
432                                                                    a_values_nightly),
433                                                                    beta_values_nightly),
434                                                                    lhs_transpose_values),
435                                                                    act_values))
436 {
437     // Validate output
438     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
439 }
440 TEST_SUITE(ExportToCLImage)
441 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
442                framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
443                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
444                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
445                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // Incorrect k0
446                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // Incorrect n0
447 
448                                                       }),
449                framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
450                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
451                                                        TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F32),
452                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
453                                                        TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F32),
454 
455                       })),
456                framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F32),
457                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
458                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
459                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
460                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
461 
462                                                       })),
463                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
464                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
465                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
466                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
467                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
468                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
469 
470                            })),
471                framework::dataset::make("LHSMInfo",{
472                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
473                                                           GEMMLHSMatrixInfo(4, 8, 1, false, true),
474                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
475                                                           GEMMLHSMatrixInfo(4, 2, 1, false, false),
476                                                           GEMMLHSMatrixInfo(4, 4, 1, false, false),
477 
478                                 })),
479                framework::dataset::make("RHSMInfo",{
480                                                           GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
481                                                           GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
482                                                           GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
483                                                           GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
484                                                           GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
485                            })),
486                framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
487                                                                     64 /**<N Number of RHS columns*/,
488                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
489                                                              false /**< reinterpret the input as 3D */,
490                                                              true  /**< Flag used to broadcast the bias addition */,
491                                                              false /**< wider accumm */,
492                                                              false /**< has pad y */,
493                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
494                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
495                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
496                                                              GEMMLHSMatrixInfo(),
497                                                              GEMMRHSMatrixInfo(),
498                                                              0  /**< Offset to be added to each element of the matrix A */,
499                                                              0 /**< Offset to be added to each element of the matrix B */),
500                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
501                                                                     64 /**<N Number of RHS columns*/,
502                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
503                                                              false /**< reinterpret the input as 3D */,
504                                                              true  /**< Flag used to broadcast the bias addition */,
505                                                              false /**< wider accumm */,
506                                                              false /**< has pad y */,
507                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
508                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
509                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
510                                                              GEMMLHSMatrixInfo(),
511                                                              GEMMRHSMatrixInfo(),
512                                                              0  /**< Offset to be added to each element of the matrix A */,
513                                                              0 /**< Offset to be added to each element of the matrix B */),
514                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
515                                                                     64 /**<N Number of RHS columns*/,
516                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
517                                                              false /**< reinterpret the input as 3D */,
518                                                              true  /**< Flag used to broadcast the bias addition */,
519                                                              false /**< wider accumm */,
520                                                              false /**< has pad y */,
521                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
522                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
523                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
524                                                              GEMMLHSMatrixInfo(),
525                                                              GEMMRHSMatrixInfo(),
526                                                              0  /**< Offset to be added to each element of the matrix A */,
527                                                              0 /**< Offset to be added to each element of the matrix B */),
528 
529                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
530                                                                     64 /**<N Number of RHS columns*/,
531                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
532                                                              false /**< reinterpret the input as 3D */,
533                                                              true  /**< Flag used to broadcast the bias addition */,
534                                                              false /**< wider accumm */,
535                                                              false /**< has pad y */,
536                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
537                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
538                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
539                                                              GEMMLHSMatrixInfo(),
540                                                              GEMMRHSMatrixInfo(),
541                                                              0  /**< Offset to be added to each element of the matrix A */,
542                                                              0 /**< Offset to be added to each element of the matrix B */),
543                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
544                                                                     64 /**<N Number of RHS columns*/,
545                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
546                                                              false /**< reinterpret the input as 3D */,
547                                                              true  /**< Flag used to broadcast the bias addition */,
548                                                              false /**< wider accumm */,
549                                                              false /**< has pad y */,
550                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
551                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
552                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
553                                                              GEMMLHSMatrixInfo(),
554                                                              GEMMRHSMatrixInfo(),
555                                                              0  /**< Offset to be added to each element of the matrix A */,
556                                                              0 /**< Offset to be added to each element of the matrix B */)
557                                                     })),
558                framework::dataset::make("Expected", { true,
559                                                       true,
560                                                       true,
561                                                       false,
562                                                       false})),
563                     input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
564 {
565    ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
566                                                           &input1_info.clone()->set_is_resizable(true),
567                                                           &input2_info.clone()->set_is_resizable(true),
568                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
569                                                           lhs_info,
570                                                           rhs_info,
571                                                           gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
572 }
573 
574 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
575                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
576                                                                    m_values,
577                                                                    n_values),
578                                                                    k_values),
579                                                                    b_values),
580                                                                    m0_values_precommit),
581                                                                    n0_values_precommit),
582                                                                    k0_values_precommit),
583                                                                    v0_values_precommit),
584                                                                    h0_values_precommit),
585                                                                    i_values_lhs),
586                                                                    i_values_rhs),
587                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
588                                                                    framework::dataset::make("DataType", DataType::F32)),
589                                                                    a_values_precommit),
590                                                                    beta_values_precommit),
591                                                                    broadcast_bias_values),
592                                                                    lhs_transpose_values),
593                                                                    act_values))
594 {
595      // Validate output only if validate() is successful
596     if(validate_result)
597     {
598         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
599     }
600     else
601     {
602         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
603         framework::ARM_COMPUTE_PRINT_INFO();
604     }
605 
606 }
607 
608 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
609                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
610                                                                    m_values,
611                                                                    n_values),
612                                                                    k_values),
613                                                                    b_values),
614                                                                    m0_values_nightly),
615                                                                    n0_export_to_cl_image_values_nightly),
616                                                                    k0_export_to_cl_image_values_nightly),
617                                                                    v0_values_nightly),
618                                                                    h0_values_nightly),
619                                                                    i_values_lhs),
620                                                                    i_values_rhs),
621                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
622                                                                    framework::dataset::make("DataType", DataType::F32)),
623                                                                    a_values_nightly),
624                                                                    beta_values_nightly),
625                                                                    broadcast_bias_values),
626                                                                    lhs_transpose_values),
627                                                                    act_values))
628 {
629      // Validate output only if validate() is successful
630     if(validate_result)
631     {
632         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
633     }
634     else
635     {
636         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
637         framework::ARM_COMPUTE_PRINT_INFO();
638     }
639 }
640 
641 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
642                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
643                                                                    m_w_values,
644                                                                    m_h_values),
645                                                                    n_values),
646                                                                    k_values),
647                                                                    b_values),
648                                                                    m0_values_precommit),
649                                                                    n0_values_precommit),
650                                                                    k0_values_precommit),
651                                                                    v0_values_precommit),
652                                                                    h0_values_precommit),
653                                                                    i_values_lhs),
654                                                                    i_values_rhs),
655                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
656                                                                    framework::dataset::make("DataType", DataType::F32)),
657                                                                    a_values_precommit),
658                                                                    beta_values_precommit),
659                                                                    lhs_transpose_values),
660                                                                    act_values))
661 {
662      // Validate output only if validate() is successful
663     if(validate_result)
664     {
665         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
666     }
667     else
668     {
669         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
670         framework::ARM_COMPUTE_PRINT_INFO();
671     }
672 }
673 
674 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
675                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
676                                                                    m_w_values,
677                                                                    m_h_values),
678                                                                    n_values),
679                                                                    k_values),
680                                                                    b_values),
681                                                                    m0_values_nightly),
682                                                                    n0_export_to_cl_image_values_nightly),
683                                                                    k0_export_to_cl_image_values_nightly),
684                                                                    v0_values_nightly),
685                                                                    h0_values_nightly),
686                                                                    i_values_lhs),
687                                                                    i_values_rhs),
688                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
689                                                                    framework::dataset::make("DataType", DataType::F32)),
690                                                                    a_values_nightly),
691                                                                    beta_values_nightly),
692                                                                    lhs_transpose_values),
693                                                                    act_values))
694 {
695     // Validate output only if validate() is successful
696     if(validate_result)
697     {
698         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
699     }
700     else
701     {
702         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
703         framework::ARM_COMPUTE_PRINT_INFO();
704     }
705 }
706 TEST_SUITE_END() // ExportToCLImage
TEST_SUITE_END()707 TEST_SUITE_END() // FP32
708 
709 TEST_SUITE(FP16)
710 
711 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
712                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
713                                                                    m_values,
714                                                                    n_values),
715                                                                    k_values),
716                                                                    b_values),
717                                                                    m0_values_precommit),
718                                                                    n0_values_precommit),
719                                                                    k0_values_precommit),
720                                                                    v0_values_precommit),
721                                                                    h0_values_precommit),
722                                                                    i_values_lhs),
723                                                                    i_values_rhs),
724                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
725                                                                    framework::dataset::make("DataType", DataType::F16)),
726                                                                    a_values_precommit),
727                                                                    beta_values_precommit),
728                                                                    broadcast_bias_values),
729                                                                    lhs_transpose_values),
730                                                                    act_values))
731 {
732     // Validate output
733     validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
734 }
735 
736 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::DISABLED,
737                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
738                                                                    m_values,
739                                                                    n_values),
740                                                                    k_values),
741                                                                    b_values),
742                                                                    m0_values_nightly),
743                                                                    n0_values_nightly),
744                                                                    k0_values_nightly),
745                                                                    v0_values_nightly),
746                                                                    h0_values_nightly),
747                                                                    i_values_lhs),
748                                                                    i_values_rhs),
749                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
750                                                                    framework::dataset::make("DataType", DataType::F16)),
751                                                                    a_values_nightly),
752                                                                    beta_values_nightly),
753                                                                    broadcast_bias_values),
754                                                                    lhs_transpose_values),
755                                                                    act_values))
756 {
757     // Validate output
758     validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
759 }
760 
761 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
762                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
763                                                                    m_w_values,
764                                                                    m_h_values),
765                                                                    n_values),
766                                                                    k_values),
767                                                                    b_values),
768                                                                    m0_values_precommit),
769                                                                    n0_values_precommit),
770                                                                    k0_values_precommit),
771                                                                    v0_values_precommit),
772                                                                    h0_values_precommit),
773                                                                    i_values_lhs),
774                                                                    i_values_rhs),
775                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
776                                                                    framework::dataset::make("DataType", DataType::F16)),
777                                                                    a_values_precommit),
778                                                                    beta_values_precommit),
779                                                                    lhs_transpose_values),
780                                                                    act_values))
781 {
782     // Validate output
783     validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
784 }
785 
786 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::DISABLED,
787                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
788                                                                    m_w_values,
789                                                                    m_h_values),
790                                                                    n_values),
791                                                                    k_values),
792                                                                    b_values),
793                                                                    m0_values_nightly),
794                                                                    n0_values_nightly),
795                                                                    k0_values_nightly),
796                                                                    v0_values_nightly),
797                                                                    h0_values_nightly),
798                                                                    i_values_lhs),
799                                                                    i_values_rhs),
800                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
801                                                                    framework::dataset::make("DataType", DataType::F16)),
802                                                                    a_values_nightly),
803                                                                    beta_values_nightly),
804                                                                    lhs_transpose_values),
805                                                                    act_values))
806 {
807     // Validate output
808     validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
809 }
810 
811 TEST_SUITE(ExportToCLImage)
812 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
813                framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
814                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
815                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
816                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // Incorrect k0
817                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // Incorrect n0
818 
819                                                       }),
820                framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
821                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
822                                                        TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F16),
823                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
824                                                        TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F16),
825 
826                       })),
827                framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F16),
828                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
829                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
830                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
831                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
832 
833                                                       })),
834                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
835                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
836                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
837                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
838                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
839                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
840 
841                            })),
842                framework::dataset::make("LHSMInfo",{
843                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
844                                                           GEMMLHSMatrixInfo(4, 8, 1, false, true),
845                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
846                                                           GEMMLHSMatrixInfo(4, 2, 1, false, false),
847                                                           GEMMLHSMatrixInfo(4, 4, 1, false, false),
848 
849                                 })),
850                framework::dataset::make("RHSMInfo",{
851                                                           GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
852                                                           GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
853                                                           GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
854                                                           GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
855                                                           GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
856                            })),
857                framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
858                                                                     64 /**<N Number of RHS columns*/,
859                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
860                                                              false /**< reinterpret the input as 3D */,
861                                                              true  /**< Flag used to broadcast the bias addition */,
862                                                              false /**< wider accumm */,
863                                                              false /**< has pad y */,
864                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
865                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
866                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
867                                                              GEMMLHSMatrixInfo(),
868                                                              GEMMRHSMatrixInfo(),
869                                                              0  /**< Offset to be added to each element of the matrix A */,
870                                                              0 /**< Offset to be added to each element of the matrix B */),
871                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
872                                                                     64 /**<N Number of RHS columns*/,
873                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
874                                                              false /**< reinterpret the input as 3D */,
875                                                              true  /**< Flag used to broadcast the bias addition */,
876                                                              false /**< wider accumm */,
877                                                              false /**< has pad y */,
878                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
879                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
880                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
881                                                              GEMMLHSMatrixInfo(),
882                                                              GEMMRHSMatrixInfo(),
883                                                              0  /**< Offset to be added to each element of the matrix A */,
884                                                              0 /**< Offset to be added to each element of the matrix B */),
885                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
886                                                                     64 /**<N Number of RHS columns*/,
887                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
888                                                              false /**< reinterpret the input as 3D */,
889                                                              true  /**< Flag used to broadcast the bias addition */,
890                                                              false /**< wider accumm */,
891                                                              false /**< has pad y */,
892                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
893                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
894                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
895                                                              GEMMLHSMatrixInfo(),
896                                                              GEMMRHSMatrixInfo(),
897                                                              0  /**< Offset to be added to each element of the matrix A */,
898                                                              0 /**< Offset to be added to each element of the matrix B */),
899 
900                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
901                                                                     64 /**<N Number of RHS columns*/,
902                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
903                                                              false /**< reinterpret the input as 3D */,
904                                                              true  /**< Flag used to broadcast the bias addition */,
905                                                              false /**< wider accumm */,
906                                                              false /**< has pad y */,
907                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
908                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
909                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
910                                                              GEMMLHSMatrixInfo(),
911                                                              GEMMRHSMatrixInfo(),
912                                                              0  /**< Offset to be added to each element of the matrix A */,
913                                                              0 /**< Offset to be added to each element of the matrix B */),
914                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
915                                                                     64 /**<N Number of RHS columns*/,
916                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
917                                                              false /**< reinterpret the input as 3D */,
918                                                              true  /**< Flag used to broadcast the bias addition */,
919                                                              false /**< wider accumm */,
920                                                              false /**< has pad y */,
921                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
922                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
923                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
924                                                              GEMMLHSMatrixInfo(),
925                                                              GEMMRHSMatrixInfo(),
926                                                              0  /**< Offset to be added to each element of the matrix A */,
927                                                              0 /**< Offset to be added to each element of the matrix B */)
928                                                     })),
929                framework::dataset::make("Expected", { true,
930                                                       true,
931                                                       true,
932                                                       false,
933                                                       false})),
934                     input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
935 {
936    ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
937                                                           &input1_info.clone()->set_is_resizable(true),
938                                                           &input2_info.clone()->set_is_resizable(true),
939                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
940                                                           lhs_info,
941                                                           rhs_info,
942                                                           gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
943 }
944 
945 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
946                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
947                                                                    m_values,
948                                                                    n_values),
949                                                                    k_values),
950                                                                    b_values),
951                                                                    m0_values_precommit),
952                                                                    n0_values_precommit),
953                                                                    k0_values_precommit),
954                                                                    v0_values_precommit),
955                                                                    h0_values_precommit),
956                                                                    i_values_lhs),
957                                                                    i_values_rhs),
958                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
959                                                                    framework::dataset::make("DataType", DataType::F16)),
960                                                                    a_values_precommit),
961                                                                    beta_values_precommit),
962                                                                    broadcast_bias_values),
963                                                                    lhs_transpose_values),
964                                                                    act_values))
965 {
966     // Validate output only if validate() is successful
967     if(validate_result)
968     {
969         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
970     }
971     else
972     {
973         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
974         framework::ARM_COMPUTE_PRINT_INFO();
975     }
976 
977 }
978 
979 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
980                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
981                                                                    m_values,
982                                                                    n_values),
983                                                                    k_values),
984                                                                    b_values),
985                                                                    m0_values_nightly),
986                                                                    n0_export_to_cl_image_values_nightly),
987                                                                    k0_export_to_cl_image_values_nightly),
988                                                                    v0_values_nightly),
989                                                                    h0_values_nightly),
990                                                                    i_values_lhs),
991                                                                    i_values_rhs),
992                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
993                                                                    framework::dataset::make("DataType", DataType::F16)),
994                                                                    a_values_nightly),
995                                                                    beta_values_nightly),
996                                                                    broadcast_bias_values),
997                                                                    lhs_transpose_values),
998                                                                    act_values))
999 {
1000     // Validate output only if validate() is successful
1001     if(validate_result)
1002     {
1003         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1004     }
1005     else
1006     {
1007         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1008         framework::ARM_COMPUTE_PRINT_INFO();
1009     }
1010 }
1011 
1012 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
1013                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1014                                                                    m_w_values,
1015                                                                    m_h_values),
1016                                                                    n_values),
1017                                                                    k_values),
1018                                                                    b_values),
1019                                                                    m0_values_precommit),
1020                                                                    n0_values_precommit),
1021                                                                    k0_values_precommit),
1022                                                                    v0_values_precommit),
1023                                                                    h0_values_precommit),
1024                                                                    i_values_lhs),
1025                                                                    i_values_rhs),
1026                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1027                                                                    framework::dataset::make("DataType", DataType::F16)),
1028                                                                    a_values_precommit),
1029                                                                    beta_values_precommit),
1030                                                                    lhs_transpose_values),
1031                                                                    act_values))
1032 {
1033     // Validate output only if validate() is successful
1034     if(validate_result)
1035     {
1036         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1037     }
1038     else
1039     {
1040         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1041         framework::ARM_COMPUTE_PRINT_INFO();
1042     }
1043 }
1044 
1045 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
1046                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1047                                                                    m_w_values,
1048                                                                    m_h_values),
1049                                                                    n_values),
1050                                                                    k_values),
1051                                                                    b_values),
1052                                                                    m0_values_nightly),
1053                                                                    n0_export_to_cl_image_values_nightly),
1054                                                                    k0_export_to_cl_image_values_nightly),
1055                                                                    v0_values_nightly),
1056                                                                    h0_values_nightly),
1057                                                                    i_values_lhs),
1058                                                                    i_values_rhs),
1059                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1060                                                                    framework::dataset::make("DataType", DataType::F16)),
1061                                                                    a_values_nightly),
1062                                                                    beta_values_nightly),
1063                                                                    lhs_transpose_values),
1064                                                                    act_values))
1065 {
1066     // Validate output only if validate() is successful
1067     if(validate_result)
1068     {
1069         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1070     }
1071     else
1072     {
1073         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1074         framework::ARM_COMPUTE_PRINT_INFO();
1075     }
1076 }
1077 TEST_SUITE_END() // ExportToCLImage
TEST_SUITE_END()1078 TEST_SUITE_END() // FP16
1079 
1080 TEST_SUITE(MixedPrecision)
1081 
1082 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1083                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1084                                                                    m_values,
1085                                                                    n_values),
1086                                                                    k_values),
1087                                                                    b_values),
1088                                                                    m0_values_precommit),
1089                                                                    n0_values_precommit),
1090                                                                    k0_values_precommit),
1091                                                                    v0_values_precommit),
1092                                                                    h0_values_precommit),
1093                                                                    i_values_lhs),
1094                                                                    i_values_rhs),
1095                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1096                                                                    framework::dataset::make("DataType", DataType::F16)),
1097                                                                    a_values_precommit),
1098                                                                    beta_values_precommit),
1099                                                                    broadcast_bias_values),
1100                                                                    lhs_transpose_values),
1101                                                                    act_values))
1102 {
1103     // Validate output
1104     validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1105 }
1106 
1107 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1108                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1109                                                                    m_values,
1110                                                                    n_values),
1111                                                                    k_values),
1112                                                                    b_values),
1113                                                                    m0_values_nightly),
1114                                                                    n0_values_nightly),
1115                                                                    k0_values_nightly),
1116                                                                    v0_values_nightly),
1117                                                                    h0_values_nightly),
1118                                                                    i_values_lhs),
1119                                                                    i_values_rhs),
1120                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1121                                                                    framework::dataset::make("DataType", DataType::F16)),
1122                                                                    a_values_nightly),
1123                                                                    beta_values_nightly),
1124                                                                    broadcast_bias_values),
1125                                                                    lhs_transpose_values),
1126                                                                    act_values))
1127 {
1128     // Validate output
1129     validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1130 }
1131 
1132 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1133                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1134                                                                    m_w_values,
1135                                                                    m_h_values),
1136                                                                    n_values),
1137                                                                    k_values),
1138                                                                    b_values),
1139                                                                    m0_values_precommit),
1140                                                                    n0_values_precommit),
1141                                                                    k0_values_precommit),
1142                                                                    v0_values_precommit),
1143                                                                    h0_values_precommit),
1144                                                                    i_values_lhs),
1145                                                                    i_values_rhs),
1146                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1147                                                                    framework::dataset::make("DataType", DataType::F16)),
1148                                                                    a_values_precommit),
1149                                                                    beta_values_precommit),
1150                                                                    lhs_transpose_values),
1151                                                                    act_values))
1152 {
1153     // Validate output
1154     validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1155 }
1156 
1157 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1158                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1159                                                                    m_w_values,
1160                                                                    m_h_values),
1161                                                                    n_values),
1162                                                                    k_values),
1163                                                                    b_values),
1164                                                                    m0_values_nightly),
1165                                                                    n0_values_nightly),
1166                                                                    k0_values_nightly),
1167                                                                    v0_values_nightly),
1168                                                                    h0_values_nightly),
1169                                                                    i_values_lhs),
1170                                                                    i_values_rhs),
1171                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1172                                                                    framework::dataset::make("DataType", DataType::F16)),
1173                                                                    a_values_nightly),
1174                                                                    beta_values_nightly),
1175                                                                    lhs_transpose_values),
1176                                                                    act_values))
1177 {
1178     // Validate output
1179     validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1180 }
1181 TEST_SUITE_END() // MixedPrecision
1182 TEST_SUITE_END() // Float
1183 TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
1184 TEST_SUITE_END() // CL
1185 } // namespace validation
1186 } // namespace test
1187 } // namespace arm_compute
1188