1 /* 2 * Copyright (c) 2018-2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #pragma once 25 26 #include <cstring> 27 #include <memory> 28 #include <vector> 29 30 #include "arm_gemm_local.hpp" 31 #include "gemm_common.hpp" 32 33 namespace arm_gemm 34 { 35 enum class GemmMethod 36 { 37 DEFAULT, 38 GEMV_BATCHED, 39 GEMV_PRETRANSPOSED, 40 GEMV_NATIVE_TRANSPOSED, 41 GEMM_NATIVE, 42 GEMM_HYBRID, 43 GEMM_INTERLEAVED, 44 GEMM_INTERLEAVED_2D, 45 QUANTIZE_WRAPPER, 46 QUANTIZE_WRAPPER_2D, 47 GEMM_HYBRID_QUANTIZED 48 }; 49 50 enum class WeightFormat 51 { 52 UNSPECIFIED = 0x1, 53 ANY = 0x2, 54 OHWI = 0x100100, 55 OHWIo2 = 0x100200, 56 OHWIo4 = 0x100400, 57 OHWIo8 = 0x100800, 58 OHWIo16 = 0x101000, 59 OHWIo32 = 0x102000, 60 OHWIo64 = 0x104000, 61 OHWIo128 = 0x108000, 62 OHWIo4i2 = 0x200400, 63 OHWIo4i2_bf16 = 0x200410, 64 OHWIo8i2 = 0x200800, 65 OHWIo8i2_bf16 = 0x200810, 66 OHWIo16i2 = 0x201000, 67 OHWIo16i2_bf16 = 0x201010, 68 OHWIo32i2 = 0x202000, 69 OHWIo32i2_bf16 = 0x202010, 70 OHWIo64i2 = 0x204000, 71 OHWIo64i2_bf16 = 0x204010, 72 OHWIo4i4 = 0x400400, 73 OHWIo4i4_bf16 = 0x400410, 74 OHWIo8i4 = 0x400800, 75 OHWIo8i4_bf16 = 0x400810, 76 OHWIo16i4 = 0x401000, 77 OHWIo16i4_bf16 = 0x401010, 78 OHWIo32i4 = 0x402000, 79 OHWIo32i4_bf16 = 0x402010, 80 OHWIo64i4 = 0x404000, 81 OHWIo64i4_bf16 = 0x404010, 82 OHWIo2i8 = 0x800200, 83 OHWIo4i8 = 0x800400, 84 OHWIo8i8 = 0x800800, 85 OHWIo16i8 = 0x801000, 86 OHWIo32i8 = 0x802000, 87 OHWIo64i8 = 0x804000 88 }; 89 90 struct KernelDescription 91 { 92 GemmMethod method = GemmMethod::DEFAULT; 93 std::string name = ""; 94 bool is_default = false; 95 uint64_t cycle_estimate = 0; 96 KernelDescriptionarm_gemm::KernelDescription97 KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0) 98 : method(m), name(n), is_default(d), cycle_estimate(c) 99 { 100 } KernelDescriptionarm_gemm::KernelDescription101 KernelDescription() noexcept 102 { 103 } 104 }; 105 106 struct GemmConfig 107 { 108 GemmMethod method = GemmMethod::DEFAULT; 109 std::string filter = ""; 110 unsigned int inner_block_size = 0; 111 unsigned int outer_block_size = 0; 112 WeightFormat weight_format = WeightFormat::ANY; 113 GemmConfigarm_gemm::GemmConfig114 GemmConfig(GemmMethod method) 115 : method(method) 116 { 117 } GemmConfigarm_gemm::GemmConfig118 GemmConfig() 119 { 120 } 121 }; 122 123 struct Activation 124 { 125 enum class Type 126 { 127 None, 128 ReLU, 129 BoundedReLU 130 }; 131 132 Type type; 133 float param1; 134 float param2; 135 Activationarm_gemm::Activation136 Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) 137 : type(type), param1(p1), param2(p2) 138 { 139 } 140 }; 141 142 struct GemmArgs 143 { 144 public: 145 const CPUInfo *_ci; 146 unsigned int _Msize; // num of tiles 147 unsigned int _Nsize; // output channels 148 unsigned int _Ksize; // input channels 149 unsigned int _Ksections; 150 unsigned int _nbatches; 151 unsigned int _nmulti; // n_gemms to be performed 152 bool _indirect_input; 153 Activation _act; 154 int _maxthreads; 155 bool _fixed_format; 156 bool _fast_mode; 157 const GemmConfig *_cfg; 158 GemmArgsarm_gemm::GemmArgs159 GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N, 160 unsigned int K, unsigned int Ksections, unsigned int nbatches, 161 unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads, 162 bool fixed_format = false, bool fast_mode = false, const GemmConfig *cfg = nullptr) 163 : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), 164 _fixed_format(fixed_format), _fast_mode(fast_mode), _cfg(cfg) 165 { 166 } 167 }; 168 169 struct Requantize32 170 { 171 public: 172 const int32_t *bias = nullptr; 173 size_t bias_multi_stride = 0; 174 int32_t a_offset = 0; 175 int32_t b_offset = 0; 176 int32_t c_offset = 0; 177 bool per_channel_requant = false; 178 int32_t per_layer_left_shift = 0; 179 int32_t per_layer_right_shift = 0; 180 int32_t per_layer_mul = 0; 181 const int32_t *per_channel_left_shifts = nullptr; 182 const int32_t *per_channel_right_shifts = nullptr; 183 const int32_t *per_channel_muls = nullptr; 184 int32_t minval = 0; 185 int32_t maxval = 0; 186 187 Requantize32() = default; 188 189 // Constructor for per-tensor quantization Requantize32arm_gemm::Requantize32190 Requantize32(const int32_t *bias, size_t bias_multi_stride, 191 int32_t a_offset, int32_t b_offset, int32_t c_offset, 192 int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv) 193 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)), 194 per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv) 195 { 196 } 197 198 // Constructor for per-channel quantization Requantize32arm_gemm::Requantize32199 Requantize32(const int32_t *bias, size_t bias_multi_stride, 200 int32_t a_offset, int32_t b_offset, int32_t c_offset, 201 const int32_t *requant_left_shifts, 202 const int32_t *requant_right_shifts, 203 const int32_t *requant_muls, 204 int32_t minv, int32_t maxv) 205 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts), 206 per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv) 207 { 208 } 209 }; 210 211 struct Nothing 212 { 213 }; 214 215 template <typename Top, typename Tret> 216 using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>; 217 218 /* Low level API calls. 219 * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */ 220 221 /* get_gemm_method(): Given the templated types and provided parameters, 222 * which is the preferred method to implement this GEMM? */ 223 template <typename Top, typename Tret, class OutputStage = Nothing> 224 KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {}); 225 226 template <typename Top, typename Tret, class OutputStage = Nothing> 227 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {}); 228 229 template <typename Top, typename Tret, class OutputStage = Nothing> 230 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); 231 232 template <typename Top, typename Tret, class OutputStage = Nothing> 233 bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); 234 235 } // namespace arm_gemm 236