• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDKERNEL_H
25 #define ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDKERNEL_H
26 
27 #include "src/core/CL/ICLKernel.h"
28 
29 namespace arm_compute
30 {
31 class ICLTensor;
32 
33 /** OpenCL kernel to multiply matrices when both the input matrices LHS (input0) and RHS (input1) have been reshaped
34  *
35  * @note The input matrices @p input0 and @p input1 must be reshaped through @ref CLGEMMReshapeLHSMatrixKernel and  @ref CLGEMMReshapeRHSMatrixKernel
36  */
37 class CLGEMMLowpMatrixMultiplyReshapedKernel : public ICLKernel
38 {
39 public:
40     /** Default Constructor */
41     CLGEMMLowpMatrixMultiplyReshapedKernel();
42     /** Prevent instances of this class from being copied (As this class contains pointers) */
43     CLGEMMLowpMatrixMultiplyReshapedKernel(const CLGEMMLowpMatrixMultiplyReshapedKernel &) = delete;
44     /** Prevent instances of this class from being copied (As this class contains pointers) */
45     CLGEMMLowpMatrixMultiplyReshapedKernel &operator=(const CLGEMMLowpMatrixMultiplyReshapedKernel &) = delete;
46     /** Allow instances of this class to be moved */
47     CLGEMMLowpMatrixMultiplyReshapedKernel(CLGEMMLowpMatrixMultiplyReshapedKernel &&) = default;
48     /** Allow instances of this class to be moved */
49     CLGEMMLowpMatrixMultiplyReshapedKernel &operator=(CLGEMMLowpMatrixMultiplyReshapedKernel &&) = default;
50     /** Initialise the kernel's input and output.
51      *
52      * @param[in]  input0    Input tensor containing the LHS reshaped matrix. Data type supported: QASYMM8/QASYMM8_SIGNED. The number of dimensions for the LHS matrix must be less or equal than 4.
53      * @param[in]  input1    Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
54      * @param[out] output    Output tensor to store the result of matrix multiplication. Data type supported: S32
55      * @param[in]  lhs_info  LHS matrix information used for reshaping the input0 tensor.  Only the following values are supported:
56      *                       lhs_info.m0: 2,3,4,5,6,7,8
57      *                       lhs_info.k0: 2,3,4,8,16
58      *                       lhs_info.transpose: false
59      * @param[in]  rhs_info  RHS matrix information used for reshaping the input1 tensor.  Only the following values are supported:
60      *                       rhs_info.n0: 2,3,4,8,16
61      *                       rhs_info.k0: same as lhs_info.k0
62      *                       rhs_info.transpose: true
63      * @param[in]  gemm_info GEMM information used to retrieve the original dimensions of the input matrices
64      *
65      * @note lhs_info.k0 must be equal to rhs_info.k0
66      */
67     void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info);
68     /** Initialise the kernel's input and output.
69      *
70      * @param[in]  compile_context The compile context to be used.
71      * @param[in]  input0          Input tensor containing the LHS reshaped matrix. Data type supported: QASYMM8/QASYMM8_SIGNED. The number of dimensions for the LHS matrix must be less or equal than 4.
72      * @param[in]  input1          Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
73      * @param[out] output          Output tensor to store the result of matrix multiplication. Data type supported: S32
74      * @param[in]  lhs_info        LHS matrix information used for reshaping the input0 tensor.  Only the following values are supported:
75      *                             lhs_info.m0: 2,3,4,5,6,7,8
76      *                             lhs_info.k0: 2,3,4,8,16
77      *                             lhs_info.transpose: false
78      * @param[in]  rhs_info        RHS matrix information used for reshaping the input1 tensor.  Only the following values are supported:
79      *                             rhs_info.n0: 2,3,4,8,16
80      *                             rhs_info.k0: same as lhs_info.k0
81      *                             rhs_info.transpose: true
82      * @param[in]  gemm_info       GEMM information used to retrieve the original dimensions of the input matrices
83      *
84      * @note lhs_info.k0 must be equal to rhs_info.k0
85      */
86     void configure(const CLCompileContext &compile_context, const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
87                    const GEMMReshapeInfo &gemm_info);
88     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyReshapedKernel
89      *
90      * @param[in] input0    Input tensor info containing the LHS reshaped matrix. Data type supported: QASYMM8/QASYMM8_SIGNED. The number of dimensions for the LHS matrix must be less or equal than 4.
91      * @param[in] input1    Input tensor info containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
92      * @param[in] output    Output tensor info. Data type supported: S32
93      * @param[in] lhs_info  LHS matrix information used for reshaping the input0 tensor.  Only the following values are supported:
94      *                      lhs_info.m0: 2,3,4,5,6,7,8
95      *                      lhs_info.k0: 2,3,4,8,16
96      *                      lhs_info.transpose: false
97      * @param[in] rhs_info  RHS matrix information used for reshaping the input1 tensor.  Only the following values are supported:
98      *                      rhs_info.n0: 2,3,4,8,16
99      *                      rhs_info.k0: 2,3,4,8,16
100      *                      rhs_info.transpose: true
101      * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices
102      *
103      * @note lhs_info.k0 must be equal to rhs_info.k0
104      *
105      * @return a status
106      */
107     static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
108                            const GEMMReshapeInfo &gemm_info);
109 
110     // Inherited methods overridden:
111     void run(const Window &window, cl::CommandQueue &queue) override;
112 
113 private:
114     const ICLTensor *_input0;
115     const ICLTensor *_input1;
116     ICLTensor       *_output;
117     bool             _slide_matrix_b;
118     bool             _reinterpret_output_as_3d;
119     unsigned int     _k;
120     bool             _use_dummy_work_items;
121 };
122 } // namespace arm_compute
123 #endif /*ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDKERNEL_H*/
124