1 /*
2 * Copyright (c) 2019-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
26 #include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyNativeKernel.h"
27 #include "tests/CL/CLAccessor.h"
28 #include "tests/CL/Helper.h"
29 #include "tests/framework/Asserts.h"
30 #include "tests/framework/Macros.h"
31 #include "tests/framework/datasets/Datasets.h"
32 #include "tests/validation/Validation.h"
33 #include "tests/validation/fixtures/GEMMLowpFixture.h"
34
35 namespace arm_compute
36 {
37 namespace test
38 {
39 namespace validation
40 {
41 using namespace arm_compute::misc::shape_calculator;
42
43 // Create function for CLGEMMMatrixMultiplyNativeKernel
44 using CLGEMMLowpMatrixMultiplyNative = CLSynthetizeOperator<opencl::kernels::ClGemmLowpMatrixMultiplyNativeKernel>;
45
46 // Fixture for CLGEMMLowpMatrixMultiplyNative
47 using CLGEMMLowpMatrixMultiplyNativeFixture = GEMMLowpMatrixMultiplyNativeValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyNative>;
48
49 // Fixture for CLGEMMMatrixMultiplyNative3D
50 using CLGEMMLowpMatrixMultiplyNative3DFixture = GEMMLowpMatrixMultiplyNative3DValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyNative>;
51
52 namespace
53 {
54 // *INDENT-OFF*
55 // clang-format off
56 /** M, N combinations to test
57 * 1: Special 1x1 case
58 * 2: Special multples of processor size in both dimensions
59 * 3: Non multiples of processor size in both dimensions
60 */
61 const auto m_n_values = zip(
62 framework::dataset::make("M", {1, 16, 37}),
63 framework::dataset::make("N", {1, 16, 51})
64 );
65
66 /** M_W values to test */
67 const auto m_w_values = framework::dataset::make("M_W", 5);
68
69 /** M_H values to test */
70 const auto m_h_values = framework::dataset::make("M_H", 7);
71
72 /** N values to test */
73 const auto n_values = framework::dataset::make("N", 51);
74
75 /** K values to test */
76 const auto k_values = framework::dataset::make("K", 23);
77
78 /** Batch size values to test */
79 const auto b_values = framework::dataset::make("batch_size", 1, 3);
80
81 /** M0 values to test - Precommit */
82 const auto m0_values_precommit = framework::dataset::make("M0", {4, 6});
83
84 /** N0 values to test - Precommit */
85 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
86
87 /** K0 values to test - Precommit */
88 const auto k0_values_precommit = framework::dataset::make("K0", { 16 });
89
90 /** M0 values to test - Nightly */
91 const auto m0_values_nightly = framework::dataset::make("M0", {1, 2, 7});
92
93 /** N0 values to test - Nightly */
94 const auto n0_values_nightly = framework::dataset::make("N0", { 1, 2, 3, 4, 8 });
95
96 /** K0 values to test - Nightly */
97 const auto k0_values_nightly = framework::dataset::make("K0", { 1, 2, 3, 4, 8, 16 });
98 } // namespace
99
100 TEST_SUITE(CL)
TEST_SUITE(GEMMLowpMatrixMultiplyNative)101 TEST_SUITE(GEMMLowpMatrixMultiplyNative)
102 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyNativeFixture, framework::DatasetMode::ALL,
103 combine(combine(combine(combine(combine(m_n_values,
104 k_values),
105 b_values),
106 m0_values_precommit),
107 n0_values_precommit),
108 k0_values_precommit))
109 {
110 // Validate output
111 validate(CLAccessor(_target), _reference);
112 }
113
FIXTURE_DATA_TEST_CASE(RunLarge,CLGEMMLowpMatrixMultiplyNativeFixture,framework::DatasetMode::ALL,combine (combine (combine (combine (combine (m_n_values,k_values),b_values),m0_values_nightly),n0_values_nightly),k0_values_nightly))114 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyNativeFixture, framework::DatasetMode::ALL,
115 combine(combine(combine(combine(combine(m_n_values,
116 k_values),
117 b_values),
118 m0_values_nightly),
119 n0_values_nightly),
120 k0_values_nightly))
121 {
122 // Validate output
123 validate(CLAccessor(_target), _reference);
124 }
125
FIXTURE_DATA_TEST_CASE(RunSmall3D,CLGEMMLowpMatrixMultiplyNative3DFixture,framework::DatasetMode::ALL,combine (combine (combine (combine (combine (combine (combine (m_w_values,m_h_values),n_values),k_values),b_values),m0_values_precommit),n0_values_precommit),k0_values_precommit))126 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMLowpMatrixMultiplyNative3DFixture, framework::DatasetMode::ALL,
127 combine(combine(combine(combine(combine(combine(combine(m_w_values,
128 m_h_values),
129 n_values),
130 k_values),
131 b_values),
132 m0_values_precommit),
133 n0_values_precommit),
134 k0_values_precommit))
135 {
136 // Validate output
137 validate(CLAccessor(_target), _reference);
138 }
139
FIXTURE_DATA_TEST_CASE(RunLarge3D,CLGEMMLowpMatrixMultiplyNative3DFixture,framework::DatasetMode::ALL,combine (combine (combine (combine (combine (combine (combine (m_w_values,m_h_values),n_values),k_values),b_values),m0_values_nightly),n0_values_nightly),k0_values_nightly))140 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMLowpMatrixMultiplyNative3DFixture, framework::DatasetMode::ALL,
141 combine(combine(combine(combine(combine(combine(combine(m_w_values,
142 m_h_values),
143 n_values),
144 k_values),
145 b_values),
146 m0_values_nightly),
147 n0_values_nightly),
148 k0_values_nightly))
149 {
150 // Validate output
151 validate(CLAccessor(_target), _reference);
152 }
153 TEST_SUITE_END() // GEMMLowpMatrixMultiplyNative
154 TEST_SUITE_END() // CL
155 } // namespace validation
156 } // namespace test
157 } // namespace arm_compute
158