1"""Generates the specialized gemv functions.""" 2 3import mul_1x8_Mx8_neon 4import mul_Nx8_Mx8_neon 5import qnt_Nx8_neon 6import zip_Nx8_neon 7 8_QUANTIZED_8BIT = 'quantized_8bit' 9_FULL_32BIT = 'full_32bit' 10_FULL_FLOAT = 'full_float' 11 12 13class Error(Exception): 14 """Module level error.""" 15 16 17class ConfigurationError(Error): 18 """Runtime configuration error.""" 19 20 21def GenerateCommonTempsCountersAndConsts(emitter): 22 """Generates common gemv boilerplate variables.""" 23 emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 8') 24 emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8') 25 emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 4') 26 emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size', 27 '(padded_k + 16) * 4') 28 emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs') 29 emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch') 30 emitter.EmitDeclare('std::int32_t*', 'zipped_lhs_offsets', 31 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k)') 32 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_1', 33 'scratch + padded_k + 16') 34 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_2', 35 'zipped_rhs_1 + zipped_chunk_size') 36 emitter.EmitNewline() 37 38 39def GenerateQuantized8BitTempsCountersAndConsts(emitter): 40 """Generates all the boilerplate variables for the q8 gemm function.""" 41 GenerateCommonTempsCountersAndConsts(emitter) 42 emitter.EmitDeclare('const std::int32_t', 'const_offset', 43 'lhs_offset * rhs_offset * k + result_offset') 44 emitter.EmitDeclare('const std::int32_t', 'rounding_offset', 45 '(1 << (shift - 1))') 46 emitter.EmitDeclare('std::int32_t*', 'temp_result', 47 'reinterpret_cast<std::int32_t*>(' 48 'zipped_rhs_2 + zipped_chunk_size)') 49 emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result') 50 emitter.EmitNewline() 51 52 53def GenerateFullTempsCountersAndConsts(emitter, result_type): 54 """Generates all the boilerplate variables for the int32 and float gemms.""" 55 GenerateCommonTempsCountersAndConsts(emitter) 56 emitter.EmitDeclare('const std::int32_t', 'const_offset', 57 'lhs_offset * rhs_offset * k') 58 emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result') 59 emitter.EmitNewline() 60 61 62def GenerateZipVector(emitter, aligned, leftovers): 63 emitter.EmitCall( 64 zip_Nx8_neon.BuildName(1, leftovers, aligned), 65 ['lhs', 'k', 'k', 'zipped_lhs', 'rhs_offset', 0]) 66 67 68def GetMul2Params(result_type): 69 params = ['zipped_lhs', 'zipped_rhs_1', 'zipped_rhs_2', 'padded_k', 70 'mul_result_chunk'] 71 if result_type is 'float': 72 params.append('result_scale') 73 return params 74 75 76def GetMulParams(result_type): 77 params = ['zipped_lhs', 'zipped_rhs_1', 'padded_k', 'mul_result_chunk', 0] 78 if result_type is 'float': 79 params.append('result_scale') 80 return params 81 82 83def GenerateMulCols(emitter, result_type, lhs_add, rhs_add, aligned, cols, 84 leftovers): 85 """Emits code responsible for multiplication of one horizontal lhs strip.""" 86 emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)') 87 emitter.EmitCall( 88 zip_Nx8_neon.BuildName(4, leftovers, aligned), 89 ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset']) 90 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') 91 92 emitter.EmitCall( 93 zip_Nx8_neon.BuildName(4, leftovers, aligned), 94 ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset']) 95 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') 96 97 emitter.EmitCall( 98 mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 8), 99 GetMul2Params(result_type)) 100 101 emitter.EmitAssignIncrement('mul_result_chunk', 8) 102 emitter.EmitCloseBracket() 103 104 if cols > 4: 105 emitter.EmitCall( 106 zip_Nx8_neon.BuildName(4, leftovers, aligned), 107 ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset']) 108 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') 109 110 emitter.EmitCall( 111 zip_Nx8_neon.BuildName(cols - 4, leftovers, aligned), 112 ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset']) 113 114 emitter.EmitCall( 115 mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, cols), 116 GetMul2Params(result_type)) 117 elif cols > 0: 118 emitter.EmitCall( 119 zip_Nx8_neon.BuildName(cols, leftovers, aligned), 120 ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset']) 121 122 emitter.EmitCall( 123 mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 1, cols), 124 GetMulParams(result_type)) 125 126 127def GenerateQuantized8BitMul(emitter, aligned, cols, leftovers): 128 """Emits code for all lhs strips & leftover rows. Quantize after mul code.""" 129 GenerateMulCols(emitter, 'int32', False, True, aligned, cols, leftovers) 130 emitter.EmitCall( 131 qnt_Nx8_neon.BuildName(1, cols, aligned), 132 ['temp_result', 'n', 0, 'zipped_lhs_offsets', 'result', 0, 133 'multiplicative_offset', 'rounding_offset', '-shift']) 134 135 136def GenerateFullMul(emitter, result_type, aligned, cols, leftovers): 137 GenerateMulCols(emitter, result_type, True, True, aligned, cols, leftovers) 138 139 140def BuildName(output_type, aligned, cols, leftover): 141 name = BuildMainGemvName(output_type) + '_%d_%d' % (cols, leftover) 142 if aligned: 143 name += '_aligned' 144 return name 145 146 147def GetCommonGemvParameters(): 148 return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'], 149 ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'n'], 150 ['std::int32_t', 'k'], ['std::int32_t', 'lhs_offset'], 151 ['std::int32_t', 'rhs_offset']] 152 153 154def GetGemvParameters(output_type): 155 """Prepares a (type, parameter) array for the gemm functions.""" 156 params = GetCommonGemvParameters() 157 if output_type is _QUANTIZED_8BIT: 158 params += [['std::int32_t', 'result_offset'], 159 ['std::int32_t', 'multiplicative_offset'], 160 ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']] 161 elif output_type is _FULL_32BIT: 162 params += [['std::int32_t*', 'result']] 163 elif output_type is _FULL_FLOAT: 164 params += [['float', 'result_scale'], ['float*', 'result']] 165 else: 166 raise ConfigurationError('Unsupported output type: %s' % output_type) 167 return params 168 169 170def GenerateGemv(emitter, output_type, aligned, cols, leftovers): 171 """Build one gemm function for given col, and depth leftovers.""" 172 emitter.EmitFunctionBeginA( 173 BuildName(output_type, aligned, cols, leftovers), 174 GetGemvParameters(output_type), 'void') 175 176 emitter.EmitAssert('n %% 8 == %d' % cols) 177 emitter.EmitAssert('k %% 8 == %d' % leftovers) 178 179 if output_type is _QUANTIZED_8BIT: 180 GenerateQuantized8BitTempsCountersAndConsts(emitter) 181 GenerateZipVector(emitter, aligned, leftovers) 182 GenerateQuantized8BitMul(emitter, aligned, cols, leftovers) 183 elif output_type is _FULL_32BIT: 184 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*') 185 GenerateZipVector(emitter, aligned, leftovers) 186 GenerateFullMul(emitter, 'int32', aligned, cols, leftovers) 187 elif output_type is _FULL_FLOAT: 188 GenerateFullTempsCountersAndConsts(emitter, 'float*') 189 GenerateZipVector(emitter, aligned, leftovers) 190 GenerateFullMul(emitter, 'float', aligned, cols, leftovers) 191 else: 192 raise ConfigurationError('Unknown output type: %s' % output_type) 193 194 emitter.EmitFunctionEnd() 195 196 197def GenerateGemvCall(emitter, output_type, aligned, m_mod, leftovers): 198 emitter.EmitCall( 199 emitter.Scope('internal', 200 BuildName(output_type, aligned, m_mod, leftovers)), 201 [p for (unused_t, p) in GetGemvParameters(output_type)]) 202 203 204def GenerateGemvSwitch2(emitter, output_type, aligned, n_mod): 205 """Second level of main switch, choose optimized version on depth leftover.""" 206 emitter.EmitSwitch('k % 8') 207 208 for leftovers in range(0, 8): 209 emitter.EmitCase(leftovers) 210 emitter.PushIndent() 211 GenerateGemvCall(emitter, output_type, aligned, n_mod, leftovers) 212 emitter.EmitBreak() 213 emitter.PopIndent() 214 215 emitter.EmitSwitchEnd() 216 217 218def GenerateGemvSwitch1(emitter, output_type, aligned): 219 """First level of main switch, choose optimized version on cols leftover.""" 220 emitter.EmitSwitch('n % 8') 221 222 for n_mod in range(0, 8): 223 emitter.EmitCase(n_mod) 224 emitter.PushIndent() 225 GenerateGemvSwitch2(emitter, output_type, aligned, n_mod) 226 emitter.EmitBreak() 227 emitter.PopIndent() 228 229 emitter.EmitSwitchEnd() 230 231 232def BuildMainGemvName(output_type): 233 if output_type is _QUANTIZED_8BIT: 234 return 'gemv_q8' 235 elif output_type is _FULL_32BIT: 236 return 'gemv_i32' 237 elif output_type is _FULL_FLOAT: 238 return 'gemv_f' 239 else: 240 raise ConfigurationError('Unsupported output type: %s' % output_type) 241 242 243def GenerateMainGemvFunction(emitter, output_type): 244 """Emit high level gemv function that switches between optimized versions.""" 245 emitter.EmitFunctionBeginA( 246 BuildMainGemvName(output_type), GetGemvParameters(output_type), 'void') 247 248 emitter.EmitDeclare('const bool', 'lhs_aligned', 249 '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)') 250 emitter.EmitDeclare('const bool', 'rhs_aligned', 251 '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)') 252 emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)') 253 254 if output_type is _QUANTIZED_8BIT: 255 emitter.EmitDeclare('const bool', 'result_aligned', 256 '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)') 257 emitter.EmitDeclare('const bool', 'aligned', 258 'lhs_aligned && rhs_aligned && result_aligned ' 259 '&& k_aligned') 260 else: 261 emitter.EmitDeclare('const bool', 'aligned', 262 'lhs_aligned && rhs_aligned && k_aligned') 263 264 emitter.EmitIf('aligned') 265 GenerateGemvSwitch1(emitter, output_type, True) 266 emitter.EmitElse() 267 GenerateGemvSwitch1(emitter, output_type, False) 268 emitter.EmitEndif() 269 emitter.EmitFunctionEnd() 270 271 272def GenerateInternalFunctions(emitter): 273 """Generate all the functions hidden in the internal namespace.""" 274 for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: 275 for aligned in [True, False]: 276 for cols in range(0, 8): 277 for leftover in range(0, 8): 278 GenerateGemv(emitter, output_type, aligned, cols, leftover) 279 emitter.EmitNewline() 280 281 282def GeneratePublicFunctions(emitter): 283 for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: 284 GenerateMainGemvFunction(emitter, output_type) 285 emitter.EmitNewline() 286