• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYNATIVEKERNEL_H
25 #define ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYNATIVEKERNEL_H
26 
27 #include "src/core/CL/ICLKernel.h"
28 
29 namespace arm_compute
30 {
31 class ICLTensor;
32 
33 /** OpenCL kernel to multiply matrices with QASYMM8/QASYMM8_SIGNED data type */
34 class CLGEMMLowpMatrixMultiplyNativeKernel : public ICLKernel
35 {
36 public:
37     /** Default Constructor */
38     CLGEMMLowpMatrixMultiplyNativeKernel();
39     /** Prevent instances of this class from being copied (As this class contains pointers) */
40     CLGEMMLowpMatrixMultiplyNativeKernel(const CLGEMMLowpMatrixMultiplyNativeKernel &) = delete;
41     /** Prevent instances of this class from being copied (As this class contains pointers) */
42     CLGEMMLowpMatrixMultiplyNativeKernel &operator=(const CLGEMMLowpMatrixMultiplyNativeKernel &) = delete;
43     /** Allow instances of this class to be moved */
44     CLGEMMLowpMatrixMultiplyNativeKernel(CLGEMMLowpMatrixMultiplyNativeKernel &&) = default;
45     /** Allow instances of this class to be moved */
46     CLGEMMLowpMatrixMultiplyNativeKernel &operator=(CLGEMMLowpMatrixMultiplyNativeKernel &&) = default;
47     /** Initialise the kernel's input and output.
48      *
49      * @param[in]  input0    Input tensor containing the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
50      * @param[in]  input1    Input tensor containing the RHS matrix. Data type supported: same as @p input0
51      * @param[out] output    Output tensor to store the result of matrix multiplication. Data type supported: S32
52      * @param[in]  lhs_info  LHS matrix information used to retrieve the number of rows to be processed by each thread
53      *                       lhs_info.m0: 2,3,4,5,6,7,8
54      *                       lhs_info.k0: 2,3,4,8,16
55      * @param[in]  rhs_info  RHS matrix information used to retrieve the number of columns to be processed by each thread
56      *                       rhs_info.n0: 2,3,4,8,16
57      *                       rhs_info.k0: same as lhs_info.k0
58      * @param[in]  gemm_info GEMM information used to retrieve the original dimensions of the input matrices
59      */
60     void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info);
61     /** Initialise the kernel's input and output.
62      *
63      * @param[in]  compile_context The compile context to be used.
64      * @param[in]  input0          Input tensor containing the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
65      * @param[in]  input1          Input tensor containing the RHS matrix. Data type supported: same as @p input0
66      * @param[out] output          Output tensor to store the result of matrix multiplication. Data type supported: S32
67      * @param[in]  lhs_info        LHS matrix information used to retrieve the number of rows to be processed by each thread
68      *                             lhs_info.m0: 2,3,4,5,6,7,8
69      *                             lhs_info.k0: 2,3,4,8,16
70      * @param[in]  rhs_info        RHS matrix information used to retrieve the number of columns to be processed by each thread
71      *                             rhs_info.n0: 2,3,4,8,16
72      *                             rhs_info.k0: same as lhs_info.k0
73      * @param[in]  gemm_info       GEMM information used to retrieve the original dimensions of the input matrices
74      */
75     void configure(const CLCompileContext &compile_context, const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
76                    const GEMMReshapeInfo &gemm_info);
77     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyNativeKernel
78      *
79      * @param[in] input0    Input tensor info for the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
80      * @param[in] input1    Input tensor info for the RHS matrix. Data type supported: same as @p input0
81      * @param[in] output    Output tensor info. Data type supported: S32
82      * @param[in] lhs_info  LHS matrix information used to retrieve the number of rows to be processed by each thread
83      *                      lhs_info.m0: 2,3,4,5,6,7,8
84      *                      lhs_info.k0: 2,3,4,8,16
85      * @param[in] rhs_info  RHS matrix information used to retrieve the number of columns to be processed by each thread
86      *                      rhs_info.n0: 2,3,4,8,16
87      *                      rhs_info.k0: same as lhs_info.k0
88      * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices
89      *
90      * @return a status
91      */
92     static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
93                            const GEMMReshapeInfo &gemm_info);
94 
95     // Inherited methods overridden:
96     void run(const Window &window, cl::CommandQueue &queue) override;
97 
98 private:
99     const ICLTensor *_input0;
100     const ICLTensor *_input1;
101     ICLTensor       *_output;
102     bool             _slide_matrix_b;
103     bool             _reinterpret_input_as_3d;
104     bool             _reinterpret_output_as_3d;
105     bool             _use_dummy_work_items;
106 };
107 } // namespace arm_compute
108 #endif /*ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYNATIVEKERNEL_H*/
109