• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CL_GEMM_H
25 #define ARM_COMPUTE_CL_GEMM_H
26 
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/runtime/CL/CLTensor.h"
29 #include "arm_compute/runtime/CL/CLTypes.h"
30 
31 #include "src/gpu/cl/ClCompileContext.h"
32 #include "src/gpu/cl/IClKernel.h"
33 #include "src/gpu/cl/IClOperator.h"
34 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h"
35 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h"
36 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h"
37 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
38 #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h"
39 #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
40 
41 #include <memory>
42 
43 namespace arm_compute
44 {
45 namespace opencl
46 {
47 /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels:
48  *
49  *  -# @ref kernels::ClGemmReshapeLhsMatrixKernel (only if the RESHAPED is selected by the heuristic model)
50  *  -# @ref kernels::ClGemmReshapeRhsMatrixKernel (only if either the RESHAPED or RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method())
51  *  -# @ref kernels::ClGemmMatrixMultiplyNativeKernel (only if NATIVE is selected by the select_gemm_kernel method())
52  *  -# @ref kernels::ClGemmMatrixMultiplyReshapedKernel (only if RESHAPED is selected by the select_gemm_kernel method())
53  *  -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method())
54  *  -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel (only if RESHAPED_ONLY_RHS_MMUL is selected by the select_gemm_kernel method())
55  */
56 class ClGemm : public IClOperator
57 {
58 public:
59     /** Constructor */
60     ClGemm();
61     /** Initialise the kernel's inputs and output
62      *
63      * Valid data layouts:
64      * - All
65      *
66      * Valid data type configurations:
67      * |src0         |src1        |src2      |dst            |
68      * |:------------|:-----------|:---------|:--------------|
69      * |F32          |F32         |F32       |F32            |
70      * |F16          |F16         |F16       |F16            |
71      *
72      * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C].
73      *
74      * @note All tensors must have the same data type.
75      *
76      * @note Whilst the first input tensor can be a vector, the second input tensor must be at least a matrix
77      *
78      * @note Batched GEMM only allows RHS tensor's rank to be <= 3
79      * @note Batched GEMM only supports broadcasting cases where RHS rank < LHS rank but not the other way around
80      *
81      * @param[in]  compile_context The compile context to be used.
82      * @param[in]  a               First input tensor  (Matrix or Vector A). Data types supported: F16/F32
83      * @param[in]  b               Second input tensor (Matrix B). Data type supported: same as @p a.
84      * @param[in]  c               Third input tensor  (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a.
85      * @param[out] output          Output tensor. Data type supported: same as @p a
86      * @param[in]  alpha           Weight of the matrix product
87      * @param[in]  beta            Weight of matrix C
88      * @param[in]  gemm_info       (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
89      *                             if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping
90      *                             in case matrix A and matrix B have been already transformed.
91      */
92     void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
93     /** Static function to check if given info will lead to a valid configuration
94      *
95      * Similar to ClGemm::configure()
96      *
97      * @return a status
98      */
99     static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
100 
101     // Inherited methods overridden:
102     void run(ITensorPack &tensors) override;
103     void prepare(ITensorPack &constants) override;
104     experimental::MemoryRequirements workspace() const override;
105 
106 private:
107     void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
108     void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
109     void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
110     void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
111 
112     static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
113     static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
114     static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
115     static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
116 
117 private:
118     enum AuxTensorIdx
119     {
120         LhsReshape = 0,
121         RhsReshape,
122         Count
123     };
124 
125 private:
126     std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel>                  _reshape_lhs_kernel;
127     std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel>                  _reshape_rhs_kernel;
128     std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel>              _mm_native_kernel;
129     std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel>            _mm_reshaped_kernel;
130     std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel>     _mm_reshaped_only_rhs_kernel;
131     std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel> _mm_reshaped_only_rhs_mmul_kernel;
132     TensorInfo                                                              _tmp_a;
133     TensorInfo                                                              _tmp_b;
134     bool                                                                    _reshape_b_only_on_first_run;
135     CLGEMMKernelType                                                        _gemm_kernel_type;
136     bool                                                                    _is_prepared;
137     experimental::MemoryRequirements                                        _aux_mem{};
138 };
139 } // namespace opencl
140 } // namespace arm_compute
141 #endif /* ARM_COMPUTE_CLGEMM_H */
142