• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
25 #define ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
26 
27 #include "arm_compute/core/Utils.h"
28 #include "arm_compute/core/Validate.h"
29 #include "arm_gemm_compute_iface.hpp"
30 #include "src/core/NEON/INEKernel.h"
31 
32 #include "gemm_common.hpp"
33 
34 namespace arm_compute
35 {
36 class ITensor;
37 
38 /** This class is a wrapper for the assembly kernels.
39   *
40   * Some kernels were written in assembly and highly optimised for specific CPUs like A53 or A55.
41   * This class works as a wrapper for these assembly kernels. The arm compute library creates an instance
42   * of NEGEMMAssemblyWrapperKernel and other auxiliary data structures to execute a single assembly kernel
43   * in the context of an NEFunctions.
44   *
45   * The type T is the type of the actual kernel implemented in assembly which is of type
46   *         template<typename To, typename Tr> class GemmCommon
47   *
48   *
49   */
50 template <typename TypeInput, typename TypeOutput>
51 class NEGEMMAssemblyWrapperKernel final : public INEKernel
52 {
53 public:
54     /** Constructor
55      */
NEGEMMAssemblyWrapperKernel()56     NEGEMMAssemblyWrapperKernel()
57         : _kernel(nullptr), _name("NEGEMMAssemblyWrapperKernel")
58     {
59     }
60 
61     NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &)  = delete;
62     NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &&) = default;
63     NEGEMMAssemblyWrapperKernel &operator=(NEGEMMAssemblyWrapperKernel &) = delete;
64 
name()65     const char *name() const override
66     {
67         return _name.c_str();
68     }
69 
run(const Window & window,const ThreadInfo & info)70     void run(const Window &window, const ThreadInfo &info) override
71     {
72         ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel)));
73         ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
74 
75         auto win = arm_gemm::to_ndcoord(window);
76 
77         arm_gemm::ndcoord_t thread_locator{};
78 
79         _kernel->execute(win, thread_locator, info.thread_id);
80     }
81 
82     // Inherited methods overridden:
run_nd(const Window & window,const ThreadInfo & info,const Window & thread_locator)83     void run_nd(const Window &window, const ThreadInfo &info, const Window &thread_locator) override
84     {
85         ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel)));
86         ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
87 
88         //convert between arm_compute and arm_gemm types
89         auto ndc_win = arm_gemm::to_ndcoord(window);
90         auto ndc_tlc = arm_gemm::to_ndcoord(thread_locator);
91 
92         _kernel->execute(ndc_win, ndc_tlc, info.thread_id);
93     }
94 
95     /** Initialise the kernel's input and output.
96      *
97      * @param[in] kernel      Pointer to an assembly kernel implementation.
98      * @param[in] num_threads Number of concurrent threads which will execute the kernel.
99      */
configure(arm_gemm::GemmCommon<TypeInput,TypeOutput> * kernel,std::string kernel_name_tag)100     void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag)
101     {
102         ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
103         _kernel = kernel;
104 
105         Window win = to_window(kernel->get_window_size());
106 
107         INEKernel::configure(win);
108 
109         if(!kernel_name_tag.empty())
110         {
111             _name += "/" + kernel_name_tag;
112         }
113     }
114 
115 private:
116     arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel;
117     std::string _name;
118 };
119 } // namespace arm_compute
120 #endif /* ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H */
121