• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CLGEMM_H
25 #define ARM_COMPUTE_CLGEMM_H
26 
27 #include "arm_compute/runtime/CL/CLTensor.h"
28 #include "arm_compute/runtime/CL/CLTypes.h"
29 #include "arm_compute/runtime/IFunction.h"
30 #include "arm_compute/runtime/IMemoryManager.h"
31 #include "arm_compute/runtime/IWeightsManager.h"
32 #include "arm_compute/runtime/MemoryGroup.h"
33 
34 #include <memory>
35 
36 namespace arm_compute
37 {
38 class CLCompileContext;
39 class CLGEMMReshapeRHSMatrixKernel;
40 class CLGEMMMatrixMultiplyKernel;
41 class CLGEMMMatrixMultiplyReshapedKernel;
42 class CLGEMMMatrixMultiplyReshapedOnlyRHSKernel;
43 class CLGEMMReshapeLHSMatrixKernel;
44 class ICLTensor;
45 class ITensorInfo;
46 
47 namespace weights_transformations
48 {
49 /** Basic function to manage the reshape weights generated from @ref CLGEMMReshapeRHSMatrixKernel */
50 class CLGEMMReshapeRHSMatrixKernelManaged : public ITransformWeights
51 {
52 public:
53     /** Default constructor */
54     CLGEMMReshapeRHSMatrixKernelManaged();
55     /** Prevent instances of this class from being copied (As this class contains pointers) */
56     CLGEMMReshapeRHSMatrixKernelManaged(const CLGEMMReshapeRHSMatrixKernelManaged &) = delete;
57     /** Default move constructor */
58     CLGEMMReshapeRHSMatrixKernelManaged(CLGEMMReshapeRHSMatrixKernelManaged &&) = default;
59     /** Prevent instances of this class from being copied (As this class contains pointers) */
60     CLGEMMReshapeRHSMatrixKernelManaged &operator=(const CLGEMMReshapeRHSMatrixKernelManaged &) = delete;
61     /** Default move assignment operator */
62     CLGEMMReshapeRHSMatrixKernelManaged &operator=(CLGEMMReshapeRHSMatrixKernelManaged &&) = default;
63     /** Default desctructor */
64     ~CLGEMMReshapeRHSMatrixKernelManaged();
65     //Inherited method override
66     void run() override;
67 
68     //Inherited method override
69     void release() override;
70 
71     //Inherited method override
72     ICLTensor *get_weights() override;
73 
74     //Inherited method override
75     uint32_t uid() override;
76 
77     /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel
78      *
79      * @param[in] input Input tensor. Data types supported: All
80      * @param[in] info  RHS matrix information to be used for reshaping.
81      */
82     void configure(const ICLTensor *input, GEMMRHSMatrixInfo info);
83 
84     /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel
85      *
86      * @param[in] compile_context The compile context to be used.
87      * @param[in] input           Input tensor. Data types supported: All
88      * @param[in] info            RHS matrix information to be used for reshaping.
89      */
90     void configure(const CLCompileContext &compile_context, const ICLTensor *input, GEMMRHSMatrixInfo info);
91 
92 private:
93     static constexpr uint32_t                     _uid{ 0x15 };
94     CLTensor                                      _output{};
95     std::unique_ptr<CLGEMMReshapeRHSMatrixKernel> _kernel;
96 };
97 } // namespace weights_transformations
98 
99 /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels:
100  *
101  *  -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model)
102  *  -# @ref CLGEMMReshapeRHSMatrixKernel (only if either the RESHAPED_V1 or RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method())
103  *  -# @ref CLGEMMMatrixMultiplyKernel (only if either the NATIVE or RESHAPED_V1 is selected by the select_gemm_kernel method())
104  *  -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if RESHAPED_V1 is selected by the select_gemm_kernel method())
105  *  -# @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method())
106  *
107  */
108 class CLGEMM : public IFunction
109 {
110 public:
111     /** Default constructor.
112      *
113      * @param[in] memory_manager  (Optional) Memory manager.
114      * @param[in] weights_manager (Optional) Weights manager.
115      */
116     CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
117     /** Prevent instances of this class from being copied (As this class contains pointers) */
118     CLGEMM(const CLGEMM &) = delete;
119     /** Default move constructor */
120     CLGEMM(CLGEMM &&) = default;
121     /** Prevent instances of this class from being copied (As this class contains pointers) */
122     CLGEMM &operator=(const CLGEMM &) = delete;
123     /** Default move assignment operator */
124     CLGEMM &operator=(CLGEMM &&) = default;
125     /** Default destructor */
126     ~CLGEMM();
127     /** Initialise the kernel's inputs and output
128      *
129      * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C].
130      *
131      * @note All tensors must have the same data type.
132      *
133      * @note Whilst the first input tensor can be a vector, the second input tensor must be at least a matrix
134      *
135      * @param[in]  a         First input tensor  (Matrix or Vector A). Data types supported: F16/F32
136      * @param[in]  b         Second input tensor (Matrix B). Data type supported: same as @p a.
137      * @param[in]  c         Third input tensor  (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a.
138      * @param[out] output    Output tensor. Data type supported: same as @p a
139      * @param[in]  alpha     Weight of the matrix product
140      * @param[in]  beta      Weight of matrix C
141      * @param[in]  gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
142      *                       if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping
143      *                       in case matrix A and matrix B have been already transformed.
144      */
145     void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
146     /** Initialise the kernel's inputs and output
147      *
148      * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C].
149      *
150      * @note All tensors must have the same data type.
151      *
152      * @note Whilst the first input tensor can be a vector, the second input tensor must be at least a matrix
153      *
154      * @param[in]  compile_context The compile context to be used.
155      * @param[in]  a               First input tensor  (Matrix or Vector A). Data types supported: F16/F32
156      * @param[in]  b               Second input tensor (Matrix B). Data type supported: same as @p a.
157      * @param[in]  c               Third input tensor  (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a.
158      * @param[out] output          Output tensor. Data type supported: same as @p a
159      * @param[in]  alpha           Weight of the matrix product
160      * @param[in]  beta            Weight of matrix C
161      * @param[in]  gemm_info       (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
162      *                       if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping
163      *                       in case matrix A and matrix B have been already transformed.
164      */
165     void configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
166     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMM.
167      *
168      * @param[in] a         First input tensor info  (Matrix or Vector A). Data types supported: F16/F32
169      * @param[in] b         Second input tensor info (Matrix B). Data type supported: same as @p a.
170      * @param[in] c         Third input tensor info  (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a.
171      * @param[in] output    Output tensor info. Data type supported: same as @p a
172      * @param[in] alpha     Weight of the matrix product
173      * @param[in] beta      Weight of matrix C
174      * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
175      *                       if the reshape of matrix B should happen only for the first run
176      *
177      * @return a status
178      */
179     static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
180 
181     // Inherited methods overridden:
182     void run() override;
183     void prepare() override;
184 
185 private:
186     static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run);
187 
188     void configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
189     void configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
190     void configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
191     void configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
192                                      const GEMMInfo &gemm_info);
193 
194     static Status validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
195     static Status validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
196     static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
197     static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
198 
199     MemoryGroup                                                                   _memory_group;
200     IWeightsManager                                                              *_weights_manager;
201     std::unique_ptr<CLGEMMMatrixMultiplyKernel>                                   _mm_kernel;
202     std::unique_ptr<CLGEMMReshapeLHSMatrixKernel>                                 _reshape_lhs_kernel;
203     std::unique_ptr<CLGEMMReshapeRHSMatrixKernel>                                 _reshape_rhs_kernel;
204     std::unique_ptr<weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged> _reshape_rhs_kernel_managed;
205     std::unique_ptr<CLGEMMMatrixMultiplyReshapedKernel>                           _mm_reshaped_kernel;
206     std::unique_ptr<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>                    _mm_reshaped_only_rhs_kernel;
207     std::unique_ptr<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>                    _mm_reshaped_only_rhs_fallback_kernel;
208     CLTensor                                                                      _tmp_a;
209     CLTensor                                                                      _tmp_b;
210     const ICLTensor                                                              *_original_b;
211     const ICLTensor                                                              *_lhs;
212     ICLTensor                                                                    *_dst;
213     bool                                                                          _reshape_b_only_on_first_run;
214     bool                                                                          _is_prepared;
215     CLGEMMKernelType                                                              _gemm_kernel_type;
216 };
217 } // namespace arm_compute
218 
219 #endif /* ARM_COMPUTE_CLGEMM_H */
220