• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #pragma once
25 
26 #include <cstring>
27 #include <memory>
28 
29 #include "arm_gemm_local.hpp"
30 #include "gemm_common.hpp"
31 
32 namespace arm_gemm
33 {
34 enum class GemmMethod
35 {
36     DEFAULT,
37     GEMV_BATCHED,
38     GEMV_PRETRANSPOSED,
39     GEMV_NATIVE_TRANSPOSED,
40     GEMM_NATIVE,
41     GEMM_HYBRID,
42     GEMM_INTERLEAVED,
43     GEMM_INTERLEAVED_2D,
44     QUANTIZE_WRAPPER,
45     QUANTIZE_WRAPPER_2D,
46     GEMM_HYBRID_QUANTIZED,
47     INDIRECT_GEMM,
48     CONVOLUTION_GEMM
49 };
50 
51 struct KernelDescription
52 {
53     GemmMethod  method         = GemmMethod::DEFAULT;
54     std::string name           = "";
55     bool        is_default     = false;
56     uint64_t    cycle_estimate = 0;
57 
KernelDescriptionarm_gemm::KernelDescription58     KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0)
59         : method(m), name(n), is_default(d), cycle_estimate(c)
60     {
61     }
KernelDescriptionarm_gemm::KernelDescription62     KernelDescription() noexcept
63     {
64     }
65 };
66 
67 struct GemmConfig
68 {
69     GemmMethod   method           = GemmMethod::DEFAULT;
70     std::string  filter           = "";
71     unsigned int inner_block_size = 0;
72     unsigned int outer_block_size = 0;
73 
GemmConfigarm_gemm::GemmConfig74     GemmConfig(GemmMethod method)
75         : method(method)
76     {
77     }
GemmConfigarm_gemm::GemmConfig78     GemmConfig()
79     {
80     }
81 };
82 
83 struct Activation
84 {
85     enum class Type
86     {
87         None,
88         ReLU,
89         BoundedReLU
90     };
91 
92     Type  type;
93     float param1;
94     float param2;
95 
Activationarm_gemm::Activation96     Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f)
97         : type(type), param1(p1), param2(p2)
98     {
99     }
100 };
101 
102 struct GemmArgs
103 {
104 public:
105     const CPUInfo    *_ci;
106     unsigned int      _Msize;
107     unsigned int      _Nsize;
108     unsigned int      _Ksize;
109     unsigned int      _Ksections;
110     unsigned int      _nbatches;
111     unsigned int      _nmulti;
112     bool              _indirect_input;
113     Activation        _act;
114     int               _maxthreads;
115     const GemmConfig *_cfg;
116 
GemmArgsarm_gemm::GemmArgs117     GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
118              unsigned int K, unsigned int Ksections, unsigned int nbatches,
119              unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
120              const GemmConfig *cfg = nullptr)
121         : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _cfg(cfg)
122     {
123     }
124 };
125 
126 struct Requantize32
127 {
128 public:
129     const int32_t *bias                     = nullptr;
130     size_t         bias_multi_stride        = 0;
131     int32_t        a_offset                 = 0;
132     int32_t        b_offset                 = 0;
133     int32_t        c_offset                 = 0;
134     bool           per_channel_requant      = false;
135     int32_t        per_layer_left_shift     = 0;
136     int32_t        per_layer_right_shift    = 0;
137     int32_t        per_layer_mul            = 0;
138     const int32_t *per_channel_left_shifts  = nullptr;
139     const int32_t *per_channel_right_shifts = nullptr;
140     const int32_t *per_channel_muls         = nullptr;
141     int32_t        minval                   = 0;
142     int32_t        maxval                   = 0;
143 
144     Requantize32() = default;
145 
146     // Constructor for per-tensor quantization
Requantize32arm_gemm::Requantize32147     Requantize32(const int32_t *bias, size_t bias_multi_stride,
148                  int32_t a_offset, int32_t b_offset, int32_t c_offset,
149                  int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
150         : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
151           per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
152     {
153     }
154 
155     // Constructor for per-channel quantization
Requantize32arm_gemm::Requantize32156     Requantize32(const int32_t *bias, size_t bias_multi_stride,
157                  int32_t a_offset, int32_t b_offset, int32_t c_offset,
158                  const int32_t *requant_left_shifts,
159                  const int32_t *requant_right_shifts,
160                  const int32_t *requant_muls,
161                  int32_t minv, int32_t maxv)
162         : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts),
163           per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv)
164     {
165     }
166 };
167 
168 struct Nothing
169 {
170 };
171 
172 template <typename Top, typename Tret>
173 using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
174 
175 /* Low level API calls.
176  * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
177 
178 /* get_gemm_method(): Given the templated types and provided parameters,
179  * which is the preferred method to implement this GEMM?  */
180 template <typename Top, typename Tret, class OutputStage = Nothing>
181 KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
182 
183 template <typename Top, typename Tret, class OutputStage = Nothing>
184 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
185 
186 template <typename Top, typename Tret, class OutputStage = Nothing>
187 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
188 
189 } // namespace arm_gemm
190