• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
26 #include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyNativeKernel.h"
27 #include "tests/CL/CLAccessor.h"
28 #include "tests/CL/Helper.h"
29 #include "tests/framework/Asserts.h"
30 #include "tests/framework/Macros.h"
31 #include "tests/framework/datasets/Datasets.h"
32 #include "tests/validation/Validation.h"
33 #include "tests/validation/fixtures/GEMMLowpFixture.h"
34 
35 namespace arm_compute
36 {
37 namespace test
38 {
39 namespace validation
40 {
41 using namespace arm_compute::misc::shape_calculator;
42 
43 // Create function for CLGEMMMatrixMultiplyNativeKernel
44 using CLGEMMLowpMatrixMultiplyNative = CLSynthetizeOperator<opencl::kernels::ClGemmLowpMatrixMultiplyNativeKernel>;
45 
46 // Fixture for CLGEMMLowpMatrixMultiplyNative
47 using CLGEMMLowpMatrixMultiplyNativeFixture = GEMMLowpMatrixMultiplyNativeValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyNative>;
48 
49 // Fixture for CLGEMMMatrixMultiplyNative3D
50 using CLGEMMLowpMatrixMultiplyNative3DFixture = GEMMLowpMatrixMultiplyNative3DValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyNative>;
51 
52 namespace
53 {
54 // *INDENT-OFF*
55 // clang-format off
56 /** M, N combinations to test
57  *  1: Special 1x1 case
58  *  2: Special multples of processor size in both dimensions
59  *  3: Non multiples of processor size in both dimensions
60 */
61 const auto m_n_values = zip(
62     framework::dataset::make("M", {1, 16, 37}),
63     framework::dataset::make("N", {1, 16, 51})
64     );
65 
66 /** M_W values to test */
67 const auto m_w_values = framework::dataset::make("M_W", 5);
68 
69 /** M_H values to test */
70 const auto m_h_values = framework::dataset::make("M_H", 7);
71 
72 /** N values to test */
73 const auto n_values = framework::dataset::make("N", 51);
74 
75 /** K values to test */
76 const auto k_values = framework::dataset::make("K", 23);
77 
78 /** Batch size values to test */
79 const auto b_values = framework::dataset::make("batch_size", 1, 3);
80 
81 /** M0 values to test - Precommit */
82 const auto m0_values_precommit = framework::dataset::make("M0", {4, 6});
83 
84 /** N0 values to test - Precommit */
85 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
86 
87 /** K0 values to test - Precommit */
88 const auto k0_values_precommit = framework::dataset::make("K0", { 16 });
89 
90 /** M0 values to test - Nightly */
91 const auto m0_values_nightly = framework::dataset::make("M0", {1, 2, 7});
92 
93 /** N0 values to test - Nightly */
94 const auto n0_values_nightly = framework::dataset::make("N0", { 1, 2, 3, 4, 8 });
95 
96 /** K0 values to test - Nightly */
97 const auto k0_values_nightly = framework::dataset::make("K0", { 1, 2, 3, 4, 8, 16 });
98 } // namespace
99 
100 TEST_SUITE(CL)
TEST_SUITE(GEMMLowpMatrixMultiplyNative)101 TEST_SUITE(GEMMLowpMatrixMultiplyNative)
102 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyNativeFixture, framework::DatasetMode::ALL,
103                 combine(combine(combine(combine(combine(m_n_values,
104                                                                 k_values),
105                                                                 b_values),
106                                                                 m0_values_precommit),
107                                                                 n0_values_precommit),
108                                                                 k0_values_precommit))
109 {
110     // Validate output
111     validate(CLAccessor(_target), _reference);
112 }
113 
FIXTURE_DATA_TEST_CASE(RunLarge,CLGEMMLowpMatrixMultiplyNativeFixture,framework::DatasetMode::ALL,combine (combine (combine (combine (combine (m_n_values,k_values),b_values),m0_values_nightly),n0_values_nightly),k0_values_nightly))114 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyNativeFixture, framework::DatasetMode::ALL,
115                 combine(combine(combine(combine(combine(m_n_values,
116                                                                 k_values),
117                                                                 b_values),
118                                                                 m0_values_nightly),
119                                                                 n0_values_nightly),
120                                                                 k0_values_nightly))
121 {
122     // Validate output
123     validate(CLAccessor(_target), _reference);
124 }
125 
FIXTURE_DATA_TEST_CASE(RunSmall3D,CLGEMMLowpMatrixMultiplyNative3DFixture,framework::DatasetMode::ALL,combine (combine (combine (combine (combine (combine (combine (m_w_values,m_h_values),n_values),k_values),b_values),m0_values_precommit),n0_values_precommit),k0_values_precommit))126 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMLowpMatrixMultiplyNative3DFixture, framework::DatasetMode::ALL,
127                 combine(combine(combine(combine(combine(combine(combine(m_w_values,
128                                                                         m_h_values),
129                                                                         n_values),
130                                                                         k_values),
131                                                                         b_values),
132                                                                         m0_values_precommit),
133                                                                         n0_values_precommit),
134                                                                         k0_values_precommit))
135 {
136     // Validate output
137     validate(CLAccessor(_target), _reference);
138 }
139 
FIXTURE_DATA_TEST_CASE(RunLarge3D,CLGEMMLowpMatrixMultiplyNative3DFixture,framework::DatasetMode::ALL,combine (combine (combine (combine (combine (combine (combine (m_w_values,m_h_values),n_values),k_values),b_values),m0_values_nightly),n0_values_nightly),k0_values_nightly))140 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMLowpMatrixMultiplyNative3DFixture, framework::DatasetMode::ALL,
141                 combine(combine(combine(combine(combine(combine(combine(m_w_values,
142                                                                         m_h_values),
143                                                                         n_values),
144                                                                         k_values),
145                                                                         b_values),
146                                                                         m0_values_nightly),
147                                                                         n0_values_nightly),
148                                                                         k0_values_nightly))
149 {
150     // Validate output
151     validate(CLAccessor(_target), _reference);
152 }
153 TEST_SUITE_END() // GEMMLowpMatrixMultiplyNative
154 TEST_SUITE_END() // CL
155 } // namespace validation
156 } // namespace test
157 } // namespace arm_compute
158