• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/CL/CLTensor.h"
26 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
27 #include "arm_compute/runtime/CL/functions/CLGEMM.h"
28 #include "tests/CL/CLAccessor.h"
29 #include "tests/CL/Helper.h"
30 #include "tests/PaddingCalculator.h"
31 #include "tests/datasets/LargeGEMMDataset.h"
32 #include "tests/datasets/SmallGEMMDataset.h"
33 #include "tests/datasets/TinyGEMMDataset.h"
34 #include "tests/framework/Asserts.h"
35 #include "tests/framework/Macros.h"
36 #include "tests/framework/datasets/Datasets.h"
37 #include "tests/validation/Validation.h"
38 #include "tests/validation/fixtures/GEMMFixture.h"
39 
40 namespace arm_compute
41 {
42 namespace test
43 {
44 namespace validation
45 {
46 namespace
47 {
48 RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
49 constexpr float          abs_tolerance_f32(
50     0.0001f);                                                 /**< Absolute tolerance value for comparing reference's output against implementation's output for floating point data types in case using relative tolerance fails because of small values */
51 RelativeTolerance<half_float::half> tolerance_f16(half(0.2)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
52 constexpr float                     tolerance_num   = 0.02f;  /**< Tolerance number */
53 const auto                          data_interleave = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
54 
55 /** CNN data types */
56 const auto CNNDataTypes = framework::dataset::make("DataType",
57 {
58     DataType::F16,
59     DataType::F32,
60 });
61 } // namespace
62 
63 const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
64 
65 TEST_SUITE(CL)
66 TEST_SUITE(GEMM)
67 
68 template <typename T>
69 using CLGEMMFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T>;
70 
71 template <typename T>
72 using CLGEMMOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, false, true>;
73 
74 template <typename T>
75 using CLGEMMInputOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, true, true>;
76 
77 TEST_SUITE(Float)
TEST_SUITE(FP16)78 TEST_SUITE(FP16)
79 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
80                                                                                                          framework::dataset::make("ReshapeWeights", { true, false })),
81                                                                                                  framework::dataset::make("DataType", DataType::F16)))
82 {
83     // Validate output
84     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
85 }
86 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
87                                                                                                        framework::dataset::make("ReshapeWeights", { true })),
88                                                                                                framework::dataset::make("DataType", DataType::F16)))
89 {
90     // Validate output
91     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
92 }
93 TEST_SUITE_END()
94 
TEST_SUITE(FP32)95 TEST_SUITE(FP32)
96 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
97                                                                                                           framework::dataset::make("ReshapeWeights", { true, false })),
98                                                                                                   framework::dataset::make("DataType", DataType::F32)))
99 {
100     // Validate output
101     validate(CLAccessor(_target), _reference, tolerance_f32);
102 }
103 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
104                                                                                                         framework::dataset::make("ReshapeWeights", { true, false })),
105                                                                                                 framework::dataset::make("DataType", DataType::F32)))
106 {
107     // Validate output
108     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
109 }
110 TEST_SUITE_END()
TEST_SUITE_END()111 TEST_SUITE_END()
112 
113 TEST_SUITE(INPUT_OUTPUT_3D)
114 TEST_SUITE(Float)
115 TEST_SUITE(FP32)
116 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
117                                                                                                                        framework::dataset::make("ReshapeWeights", { true, false })),
118                                                                                                                framework::dataset::make("DataType", DataType::F32)))
119 {
120     // Validate output
121     validate(CLAccessor(_target), _reference, tolerance_f32);
122 }
123 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
124                                                                                                                      framework::dataset::make("ReshapeWeights", { true, false })),
125                                                                                                              framework::dataset::make("DataType", DataType::F32)))
126 {
127     // Validate output
128     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
129 }
130 TEST_SUITE_END() // FP32
131 
TEST_SUITE(FP16)132 TEST_SUITE(FP16)
133 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
134                                                                                                                       framework::dataset::make("ReshapeWeights", { true, false })),
135                                                                                                               framework::dataset::make("DataType", DataType::F16)))
136 {
137     // Validate output
138     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
139 }
140 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
141                                                                                                                     framework::dataset::make("ReshapeWeights", { true, false })),
142                                                                                                             framework::dataset::make("DataType", DataType::F16)))
143 {
144     // Validate output
145     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
146 }
147 TEST_SUITE_END() // FP16
148 
TEST_SUITE_END()149 TEST_SUITE_END() // Float
150 TEST_SUITE_END() // INPUT_OUTPUT_3D
151 
152 TEST_SUITE(OUTPUT_3D)
153 TEST_SUITE(Float)
154 TEST_SUITE(FP32)
155 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
156                                                                                                                   framework::dataset::make("ReshapeWeights", { true, false })),
157                                                                                                           framework::dataset::make("DataType", DataType::F32)))
158 {
159     // Validate output
160     validate(CLAccessor(_target), _reference, tolerance_f32);
161 }
162 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
163                                                                                                                 framework::dataset::make("ReshapeWeights", { true, false })),
164                                                                                                         framework::dataset::make("DataType", DataType::F32)))
165 {
166     // Validate output
167     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
168 }
169 TEST_SUITE_END() // FP32
170 
TEST_SUITE(FP16)171 TEST_SUITE(FP16)
172 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
173                                                                                                                  framework::dataset::make("ReshapeWeights", { true, false })),
174                                                                                                          framework::dataset::make("DataType", DataType::F16)))
175 {
176     // Validate output
177     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
178 }
179 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
180                                                                                                                framework::dataset::make("ReshapeWeights", { true, false })),
181                                                                                                        framework::dataset::make("DataType", DataType::F16)))
182 {
183     // Validate output
184     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
185 }
186 TEST_SUITE_END() // FP16
187 
188 TEST_SUITE_END() // Float
189 TEST_SUITE_END() // OUTPUT_3D
190 
191 TEST_SUITE_END() // GEMM
192 TEST_SUITE_END() // CL
193 } // namespace validation
194 } // namespace test
195 } // namespace arm_compute
196