1 /*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/CL/CLTensor.h"
26 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
27 #include "arm_compute/runtime/CL/functions/CLGEMM.h"
28 #include "tests/CL/CLAccessor.h"
29 #include "tests/CL/Helper.h"
30 #include "tests/PaddingCalculator.h"
31 #include "tests/datasets/LargeGEMMDataset.h"
32 #include "tests/datasets/SmallGEMMDataset.h"
33 #include "tests/datasets/TinyGEMMDataset.h"
34 #include "tests/framework/Asserts.h"
35 #include "tests/framework/Macros.h"
36 #include "tests/framework/datasets/Datasets.h"
37 #include "tests/validation/Validation.h"
38 #include "tests/validation/fixtures/GEMMFixture.h"
39
40 namespace arm_compute
41 {
42 namespace test
43 {
44 namespace validation
45 {
46 namespace
47 {
48 RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
49 constexpr float abs_tolerance_f32(
50 0.0001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for floating point data types in case using relative tolerance fails because of small values */
51 RelativeTolerance<half_float::half> tolerance_f16(half(0.2)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
52 constexpr float tolerance_num = 0.02f; /**< Tolerance number */
53 const auto data_interleave = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
54
55 /** CNN data types */
56 const auto CNNDataTypes = framework::dataset::make("DataType",
57 {
58 DataType::F16,
59 DataType::F32,
60 });
61 } // namespace
62
63 const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
64
65 TEST_SUITE(CL)
66 TEST_SUITE(GEMM)
67
68 template <typename T>
69 using CLGEMMFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T>;
70
71 template <typename T>
72 using CLGEMMOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, false, true>;
73
74 template <typename T>
75 using CLGEMMInputOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, true, true>;
76
77 TEST_SUITE(Float)
TEST_SUITE(FP16)78 TEST_SUITE(FP16)
79 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
80 framework::dataset::make("ReshapeWeights", { true, false })),
81 framework::dataset::make("DataType", DataType::F16)))
82 {
83 // Validate output
84 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
85 }
86 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
87 framework::dataset::make("ReshapeWeights", { true })),
88 framework::dataset::make("DataType", DataType::F16)))
89 {
90 // Validate output
91 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
92 }
93 TEST_SUITE_END()
94
TEST_SUITE(FP32)95 TEST_SUITE(FP32)
96 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
97 framework::dataset::make("ReshapeWeights", { true, false })),
98 framework::dataset::make("DataType", DataType::F32)))
99 {
100 // Validate output
101 validate(CLAccessor(_target), _reference, tolerance_f32);
102 }
103 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
104 framework::dataset::make("ReshapeWeights", { true, false })),
105 framework::dataset::make("DataType", DataType::F32)))
106 {
107 // Validate output
108 validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
109 }
110 TEST_SUITE_END()
TEST_SUITE_END()111 TEST_SUITE_END()
112
113 TEST_SUITE(INPUT_OUTPUT_3D)
114 TEST_SUITE(Float)
115 TEST_SUITE(FP32)
116 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
117 framework::dataset::make("ReshapeWeights", { true, false })),
118 framework::dataset::make("DataType", DataType::F32)))
119 {
120 // Validate output
121 validate(CLAccessor(_target), _reference, tolerance_f32);
122 }
123 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
124 framework::dataset::make("ReshapeWeights", { true, false })),
125 framework::dataset::make("DataType", DataType::F32)))
126 {
127 // Validate output
128 validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
129 }
130 TEST_SUITE_END() // FP32
131
TEST_SUITE(FP16)132 TEST_SUITE(FP16)
133 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
134 framework::dataset::make("ReshapeWeights", { true, false })),
135 framework::dataset::make("DataType", DataType::F16)))
136 {
137 // Validate output
138 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
139 }
140 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
141 framework::dataset::make("ReshapeWeights", { true, false })),
142 framework::dataset::make("DataType", DataType::F16)))
143 {
144 // Validate output
145 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
146 }
147 TEST_SUITE_END() // FP16
148
TEST_SUITE_END()149 TEST_SUITE_END() // Float
150 TEST_SUITE_END() // INPUT_OUTPUT_3D
151
152 TEST_SUITE(OUTPUT_3D)
153 TEST_SUITE(Float)
154 TEST_SUITE(FP32)
155 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
156 framework::dataset::make("ReshapeWeights", { true, false })),
157 framework::dataset::make("DataType", DataType::F32)))
158 {
159 // Validate output
160 validate(CLAccessor(_target), _reference, tolerance_f32);
161 }
162 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
163 framework::dataset::make("ReshapeWeights", { true, false })),
164 framework::dataset::make("DataType", DataType::F32)))
165 {
166 // Validate output
167 validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
168 }
169 TEST_SUITE_END() // FP32
170
TEST_SUITE(FP16)171 TEST_SUITE(FP16)
172 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
173 framework::dataset::make("ReshapeWeights", { true, false })),
174 framework::dataset::make("DataType", DataType::F16)))
175 {
176 // Validate output
177 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
178 }
179 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
180 framework::dataset::make("ReshapeWeights", { true, false })),
181 framework::dataset::make("DataType", DataType::F16)))
182 {
183 // Validate output
184 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
185 }
186 TEST_SUITE_END() // FP16
187
188 TEST_SUITE_END() // Float
189 TEST_SUITE_END() // OUTPUT_3D
190
191 TEST_SUITE_END() // GEMM
192 TEST_SUITE_END() // CL
193 } // namespace validation
194 } // namespace test
195 } // namespace arm_compute
196