• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #pragma once
25 
26 #include <cstring>
27 #include <memory>
28 #include <vector>
29 
30 #include "arm_gemm_local.hpp"
31 #include "gemm_common.hpp"
32 
33 namespace arm_gemm
34 {
35 enum class GemmMethod
36 {
37     DEFAULT,
38     GEMV_BATCHED,
39     GEMV_PRETRANSPOSED,
40     GEMV_NATIVE_TRANSPOSED,
41     GEMM_NATIVE,
42     GEMM_HYBRID,
43     GEMM_INTERLEAVED,
44     GEMM_INTERLEAVED_2D,
45     QUANTIZE_WRAPPER,
46     QUANTIZE_WRAPPER_2D,
47     GEMM_HYBRID_QUANTIZED
48 };
49 
50 enum class WeightFormat
51 {
52     UNSPECIFIED    = 0x1,
53     ANY            = 0x2,
54     OHWI           = 0x100100,
55     OHWIo2         = 0x100200,
56     OHWIo4         = 0x100400,
57     OHWIo8         = 0x100800,
58     OHWIo16        = 0x101000,
59     OHWIo32        = 0x102000,
60     OHWIo64        = 0x104000,
61     OHWIo128       = 0x108000,
62     OHWIo4i2       = 0x200400,
63     OHWIo4i2_bf16  = 0x200410,
64     OHWIo8i2       = 0x200800,
65     OHWIo8i2_bf16  = 0x200810,
66     OHWIo16i2      = 0x201000,
67     OHWIo16i2_bf16 = 0x201010,
68     OHWIo32i2      = 0x202000,
69     OHWIo32i2_bf16 = 0x202010,
70     OHWIo64i2      = 0x204000,
71     OHWIo64i2_bf16 = 0x204010,
72     OHWIo4i4       = 0x400400,
73     OHWIo4i4_bf16  = 0x400410,
74     OHWIo8i4       = 0x400800,
75     OHWIo8i4_bf16  = 0x400810,
76     OHWIo16i4      = 0x401000,
77     OHWIo16i4_bf16 = 0x401010,
78     OHWIo32i4      = 0x402000,
79     OHWIo32i4_bf16 = 0x402010,
80     OHWIo64i4      = 0x404000,
81     OHWIo64i4_bf16 = 0x404010,
82     OHWIo2i8       = 0x800200,
83     OHWIo4i8       = 0x800400,
84     OHWIo8i8       = 0x800800,
85     OHWIo16i8      = 0x801000,
86     OHWIo32i8      = 0x802000,
87     OHWIo64i8      = 0x804000
88 };
89 
90 struct KernelDescription
91 {
92     GemmMethod  method         = GemmMethod::DEFAULT;
93     std::string name           = "";
94     bool        is_default     = false;
95     uint64_t    cycle_estimate = 0;
96 
KernelDescriptionarm_gemm::KernelDescription97     KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0)
98         : method(m), name(n), is_default(d), cycle_estimate(c)
99     {
100     }
KernelDescriptionarm_gemm::KernelDescription101     KernelDescription() noexcept
102     {
103     }
104 };
105 
106 struct GemmConfig
107 {
108     GemmMethod   method           = GemmMethod::DEFAULT;
109     std::string  filter           = "";
110     unsigned int inner_block_size = 0;
111     unsigned int outer_block_size = 0;
112     WeightFormat weight_format    = WeightFormat::ANY;
113 
GemmConfigarm_gemm::GemmConfig114     GemmConfig(GemmMethod method)
115         : method(method)
116     {
117     }
GemmConfigarm_gemm::GemmConfig118     GemmConfig()
119     {
120     }
121 };
122 
123 struct Activation
124 {
125     enum class Type
126     {
127         None,
128         ReLU,
129         BoundedReLU
130     };
131 
132     Type  type;
133     float param1;
134     float param2;
135 
Activationarm_gemm::Activation136     Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f)
137         : type(type), param1(p1), param2(p2)
138     {
139     }
140 };
141 
142 struct GemmArgs
143 {
144 public:
145     const CPUInfo    *_ci;
146     unsigned int      _Msize; // num of tiles
147     unsigned int      _Nsize; // output channels
148     unsigned int      _Ksize; // input channels
149     unsigned int      _Ksections;
150     unsigned int      _nbatches;
151     unsigned int      _nmulti; // n_gemms to be performed
152     bool              _indirect_input;
153     Activation        _act;
154     int               _maxthreads;
155     bool              _fixed_format;
156     bool              _fast_mode;
157     const GemmConfig *_cfg;
158 
GemmArgsarm_gemm::GemmArgs159     GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
160              unsigned int K, unsigned int Ksections, unsigned int nbatches,
161              unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
162              bool fixed_format = false, bool fast_mode = false, const GemmConfig *cfg = nullptr)
163         : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads),
164           _fixed_format(fixed_format), _fast_mode(fast_mode), _cfg(cfg)
165     {
166     }
167 };
168 
169 struct Requantize32
170 {
171 public:
172     const int32_t *bias                     = nullptr;
173     size_t         bias_multi_stride        = 0;
174     int32_t        a_offset                 = 0;
175     int32_t        b_offset                 = 0;
176     int32_t        c_offset                 = 0;
177     bool           per_channel_requant      = false;
178     int32_t        per_layer_left_shift     = 0;
179     int32_t        per_layer_right_shift    = 0;
180     int32_t        per_layer_mul            = 0;
181     const int32_t *per_channel_left_shifts  = nullptr;
182     const int32_t *per_channel_right_shifts = nullptr;
183     const int32_t *per_channel_muls         = nullptr;
184     int32_t        minval                   = 0;
185     int32_t        maxval                   = 0;
186 
187     Requantize32() = default;
188 
189     // Constructor for per-tensor quantization
Requantize32arm_gemm::Requantize32190     Requantize32(const int32_t *bias, size_t bias_multi_stride,
191                  int32_t a_offset, int32_t b_offset, int32_t c_offset,
192                  int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
193         : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
194           per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
195     {
196     }
197 
198     // Constructor for per-channel quantization
Requantize32arm_gemm::Requantize32199     Requantize32(const int32_t *bias, size_t bias_multi_stride,
200                  int32_t a_offset, int32_t b_offset, int32_t c_offset,
201                  const int32_t *requant_left_shifts,
202                  const int32_t *requant_right_shifts,
203                  const int32_t *requant_muls,
204                  int32_t minv, int32_t maxv)
205         : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts),
206           per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv)
207     {
208     }
209 };
210 
211 struct Nothing
212 {
213 };
214 
215 template <typename Top, typename Tret>
216 using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
217 
218 /* Low level API calls.
219  * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
220 
221 /* get_gemm_method(): Given the templated types and provided parameters,
222  * which is the preferred method to implement this GEMM?  */
223 template <typename Top, typename Tret, class OutputStage = Nothing>
224 KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
225 
226 template <typename Top, typename Tret, class OutputStage = Nothing>
227 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
228 
229 template <typename Top, typename Tret, class OutputStage = Nothing>
230 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
231 
232 template <typename Top, typename Tret, class OutputStage = Nothing>
233 bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
234 
235 } // namespace arm_gemm
236