1 /* 2 * Copyright (c) 2018-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #pragma once 25 26 #include <cstring> 27 #include <memory> 28 29 #include "arm_gemm_local.hpp" 30 #include "gemm_common.hpp" 31 32 namespace arm_gemm 33 { 34 enum class GemmMethod 35 { 36 DEFAULT, 37 GEMV_BATCHED, 38 GEMV_PRETRANSPOSED, 39 GEMV_NATIVE_TRANSPOSED, 40 GEMM_NATIVE, 41 GEMM_HYBRID, 42 GEMM_INTERLEAVED, 43 GEMM_INTERLEAVED_2D, 44 QUANTIZE_WRAPPER, 45 QUANTIZE_WRAPPER_2D, 46 GEMM_HYBRID_QUANTIZED, 47 INDIRECT_GEMM, 48 CONVOLUTION_GEMM 49 }; 50 51 struct KernelDescription 52 { 53 GemmMethod method = GemmMethod::DEFAULT; 54 std::string name = ""; 55 bool is_default = false; 56 uint64_t cycle_estimate = 0; 57 KernelDescriptionarm_gemm::KernelDescription58 KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0) 59 : method(m), name(n), is_default(d), cycle_estimate(c) 60 { 61 } KernelDescriptionarm_gemm::KernelDescription62 KernelDescription() noexcept 63 { 64 } 65 }; 66 67 struct GemmConfig 68 { 69 GemmMethod method = GemmMethod::DEFAULT; 70 std::string filter = ""; 71 unsigned int inner_block_size = 0; 72 unsigned int outer_block_size = 0; 73 GemmConfigarm_gemm::GemmConfig74 GemmConfig(GemmMethod method) 75 : method(method) 76 { 77 } GemmConfigarm_gemm::GemmConfig78 GemmConfig() 79 { 80 } 81 }; 82 83 struct Activation 84 { 85 enum class Type 86 { 87 None, 88 ReLU, 89 BoundedReLU 90 }; 91 92 Type type; 93 float param1; 94 float param2; 95 Activationarm_gemm::Activation96 Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) 97 : type(type), param1(p1), param2(p2) 98 { 99 } 100 }; 101 102 struct GemmArgs 103 { 104 public: 105 const CPUInfo *_ci; 106 unsigned int _Msize; 107 unsigned int _Nsize; 108 unsigned int _Ksize; 109 unsigned int _Ksections; 110 unsigned int _nbatches; 111 unsigned int _nmulti; 112 bool _indirect_input; 113 Activation _act; 114 int _maxthreads; 115 const GemmConfig *_cfg; 116 GemmArgsarm_gemm::GemmArgs117 GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N, 118 unsigned int K, unsigned int Ksections, unsigned int nbatches, 119 unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads, 120 const GemmConfig *cfg = nullptr) 121 : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _cfg(cfg) 122 { 123 } 124 }; 125 126 struct Requantize32 127 { 128 public: 129 const int32_t *bias = nullptr; 130 size_t bias_multi_stride = 0; 131 int32_t a_offset = 0; 132 int32_t b_offset = 0; 133 int32_t c_offset = 0; 134 bool per_channel_requant = false; 135 int32_t per_layer_left_shift = 0; 136 int32_t per_layer_right_shift = 0; 137 int32_t per_layer_mul = 0; 138 const int32_t *per_channel_left_shifts = nullptr; 139 const int32_t *per_channel_right_shifts = nullptr; 140 const int32_t *per_channel_muls = nullptr; 141 int32_t minval = 0; 142 int32_t maxval = 0; 143 144 Requantize32() = default; 145 146 // Constructor for per-tensor quantization Requantize32arm_gemm::Requantize32147 Requantize32(const int32_t *bias, size_t bias_multi_stride, 148 int32_t a_offset, int32_t b_offset, int32_t c_offset, 149 int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv) 150 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)), 151 per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv) 152 { 153 } 154 155 // Constructor for per-channel quantization Requantize32arm_gemm::Requantize32156 Requantize32(const int32_t *bias, size_t bias_multi_stride, 157 int32_t a_offset, int32_t b_offset, int32_t c_offset, 158 const int32_t *requant_left_shifts, 159 const int32_t *requant_right_shifts, 160 const int32_t *requant_muls, 161 int32_t minv, int32_t maxv) 162 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts), 163 per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv) 164 { 165 } 166 }; 167 168 struct Nothing 169 { 170 }; 171 172 template <typename Top, typename Tret> 173 using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>; 174 175 /* Low level API calls. 176 * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */ 177 178 /* get_gemm_method(): Given the templated types and provided parameters, 179 * which is the preferred method to implement this GEMM? */ 180 template <typename Top, typename Tret, class OutputStage = Nothing> 181 KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {}); 182 183 template <typename Top, typename Tret, class OutputStage = Nothing> 184 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {}); 185 186 template <typename Top, typename Tret, class OutputStage = Nothing> 187 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); 188 189 } // namespace arm_gemm 190