1"""Generates the specialized gemm functions.""" 2 3import mul_Nx8_Mx8_neon 4import qnt_Nx8_neon 5import zip_Nx8_neon 6 7_QUANTIZED_8BIT = 'quantized_8bit' 8_FULL_32BIT = 'full_32bit' 9_FULL_FLOAT = 'full_float' 10 11 12class Error(Exception): 13 """Module level error.""" 14 15 16class ConfigurationError(Error): 17 """Runtime configuration error.""" 18 19 20def GenerateCommonTempsCountersAndConsts(emitter, rows): 21 emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3') 22 emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3') 23 emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8') 24 emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3') 25 emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size', 26 '(padded_k + 16) * 3') 27 emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size', 28 '(padded_k + 16) * n') 29 emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs') 30 emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs') 31 emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch') 32 emitter.EmitDeclare( 33 'std::int32_t*', 'zipped_lhs_3_offsets', 34 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3)') 35 if rows is not 0: 36 emitter.EmitDeclare( 37 'std::int32_t*', 'zipped_lhs_%d_offsets' % rows, 38 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows) 39 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs', 40 'scratch + zipped_chunk_size') 41 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs') 42 emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride', 43 'result_stride * 3') 44 emitter.EmitNewline() 45 46 47def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows): 48 """Generates all the boilerplate variables for the q8 gemm function.""" 49 GenerateCommonTempsCountersAndConsts(emitter, rows) 50 emitter.EmitDeclare('const std::int32_t', 'const_offset', 51 'lhs_offset * rhs_offset * k + result_offset') 52 emitter.EmitDeclare('const std::int32_t', 'rounding_offset', 53 '(1 << (shift - 1))') 54 emitter.EmitDeclare('std::int32_t*', 'temp_result', 55 'reinterpret_cast<std::int32_t*>(' 56 'scratch + zipped_chunk_size + zipped_rhs_size)') 57 emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result') 58 emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result') 59 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes', 60 '((n * 4 + 7) / 8) * 8') 61 emitter.EmitNewline() 62 63 64def GenerateFullTempsCountersAndConsts(emitter, result_type, rows): 65 """Generates all the boilerplate variables for the int32 and float gemms.""" 66 GenerateCommonTempsCountersAndConsts(emitter, rows) 67 emitter.EmitDeclare('const std::int32_t', 'const_offset', 68 'lhs_offset * rhs_offset * k') 69 emitter.EmitDeclare(result_type, 'result_chunk', 'result') 70 emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result') 71 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes', 72 'result_stride * 4') 73 emitter.EmitNewline() 74 75 76def ZipName(rows, leftovers, aligned): 77 return zip_Nx8_neon.BuildName(rows, leftovers, aligned) 78 79 80def GenerateZipRhs(emitter, aligned, cols, leftovers): 81 """Emits the code responsible for zipping the rhs matrix.""" 82 emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)') 83 emitter.EmitCall( 84 ZipName(3, leftovers, aligned), 85 ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0]) 86 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') 87 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size') 88 emitter.EmitCloseBracket() 89 90 if cols is not 0: 91 emitter.EmitCall( 92 ZipName(cols, leftovers, aligned), 93 ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0]) 94 emitter.EmitNewline() 95 96 97def MulName(result_type, lhs_add, rhs_add, rows, cols): 98 return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols) 99 100 101def GetMulParams(result_type): 102 params = ['zipped_lhs', 'zipped_rhs_chunk', 'padded_k', 'mul_result_chunk', 103 'mul_result_chunk_stride_bytes'] 104 if result_type is 'float': 105 params.append('result_scale') 106 return params 107 108 109def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned, 110 rows, cols, leftovers): 111 """Emits code responsible for multiplication of one horizontal lhs strip.""" 112 emitter.EmitCall( 113 ZipName(rows, leftovers, aligned), 114 ['lhs_chunk', 'k', 'k', 'zipped_lhs', 'rhs_offset', 'const_offset']) 115 emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs') 116 emitter.EmitAssign('mul_result_chunk', result) 117 118 emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)') 119 120 emitter.EmitCall( 121 MulName(result_type, lhs_add, rhs_add, rows, 3), 122 GetMulParams(result_type)) 123 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size') 124 emitter.EmitAssignIncrement('mul_result_chunk', 3) 125 126 emitter.EmitCloseBracket() 127 128 if cols is not 0: 129 emitter.EmitCall( 130 MulName(result_type, lhs_add, rhs_add, rows, cols), 131 GetMulParams(result_type)) 132 133 134def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers): 135 """Emits code for all lhs strips & leftover rows. Quantize after mul code.""" 136 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)') 137 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3, 138 cols, leftovers) 139 emitter.EmitCall( 140 qnt_Nx8_neon.BuildMultiQuantizeName(aligned, 3), 141 ['temp_result', 'n', 'mul_result_chunk_stride_bytes', 142 'zipped_lhs_3_offsets', 'result_chunk', 'result_stride', 143 'multiplicative_offset', 'rounding_offset', '-shift']) 144 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size') 145 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride') 146 emitter.EmitCloseBracket() 147 emitter.EmitNewline() 148 149 if rows is not 0: 150 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows, 151 cols, leftovers) 152 emitter.EmitCall( 153 qnt_Nx8_neon.BuildMultiQuantizeName(aligned, rows), 154 ['temp_result', 'n', 'mul_result_chunk_stride_bytes', 155 'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride', 156 'multiplicative_offset', 'rounding_offset', '-shift']) 157 158 159def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers): 160 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)') 161 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3, 162 cols, leftovers) 163 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size') 164 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride') 165 emitter.EmitCloseBracket() 166 emitter.EmitNewline() 167 168 if rows is not 0: 169 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 170 rows, cols, leftovers) 171 172 173def BuildName(output_type, aligned, rows, cols, leftover): 174 name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover) 175 if aligned: 176 name += '_aligned' 177 return name 178 179 180def GetCommonGemmParameters(): 181 return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'], 182 ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'm'], 183 ['std::int32_t', 'n'], ['std::int32_t', 'k'], 184 ['std::int32_t', 'lhs_offset'], ['std::int32_t', 'rhs_offset']] 185 186 187def GetGemmParameters(output_type, extra_params=None): 188 """Prepares a (type, parameter) array for the gemm functions.""" 189 if extra_params is None: 190 extra_params = [] 191 params = GetCommonGemmParameters() 192 if output_type is _QUANTIZED_8BIT: 193 params += [['std::int32_t', 'result_offset'], 194 ['std::int32_t', 'multiplicative_offset'], 195 ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']] 196 elif output_type is _FULL_32BIT: 197 params += [['std::int32_t*', 'result']] 198 elif output_type is _FULL_FLOAT: 199 params += [['float', 'result_scale'], ['float*', 'result']] 200 else: 201 raise ConfigurationError('Unsupported output type: %s' % output_type) 202 return params + extra_params 203 204 205def GetStridedGemmParameters(output_type): 206 return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']]) 207 208 209def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers): 210 """Build one gemm function for given row, col, and depth leftovers.""" 211 emitter.EmitFunctionBeginA( 212 BuildName(output_type, aligned, rows, cols, leftovers), 213 GetStridedGemmParameters(output_type), 'void') 214 215 emitter.EmitAssert('m %% 3 == %d' % rows) 216 emitter.EmitAssert('n %% 3 == %d' % cols) 217 emitter.EmitAssert('k %% 8 == %d' % leftovers) 218 219 if output_type is _QUANTIZED_8BIT: 220 GenerateQuantized8BitTempsCountersAndConsts(emitter, rows) 221 GenerateZipRhs(emitter, aligned, cols, leftovers) 222 GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers) 223 elif output_type is _FULL_32BIT: 224 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows) 225 GenerateZipRhs(emitter, aligned, cols, leftovers) 226 GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers) 227 elif output_type is _FULL_FLOAT: 228 GenerateFullTempsCountersAndConsts(emitter, 'float*', rows) 229 GenerateZipRhs(emitter, aligned, cols, leftovers) 230 GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers) 231 else: 232 raise ConfigurationError('Unknown output type: %s' % output_type) 233 234 emitter.EmitFunctionEnd() 235 236 237def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers): 238 emitter.EmitCall( 239 emitter.Scope('internal', 240 BuildName(output_type, aligned, m_mod, n_mod, leftovers)), 241 [p for (unused_t, p) in GetStridedGemmParameters(output_type)]) 242 243 244def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod): 245 """Third level of main switch, choose optimized version on depth leftover.""" 246 emitter.EmitSwitch('k % 8') 247 248 for leftovers in range(0, 8): 249 emitter.EmitCase(leftovers) 250 emitter.PushIndent() 251 GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers) 252 emitter.EmitBreak() 253 emitter.PopIndent() 254 255 emitter.EmitSwitchEnd() 256 257 258def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod): 259 """Second level of main switch, choose optimized version on cols leftover.""" 260 emitter.EmitSwitch('n % 3') 261 262 for n_mod in range(0, 3): 263 emitter.EmitCase(n_mod) 264 emitter.PushIndent() 265 GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod) 266 emitter.EmitBreak() 267 emitter.PopIndent() 268 269 emitter.EmitSwitchEnd() 270 271 272def GenerateGemmSwitch1(emitter, output_type, aligned): 273 """First level of main switch, choose optimized version on rows leftover.""" 274 emitter.EmitSwitch('m % 3') 275 276 for m_mod in range(0, 3): 277 emitter.EmitCase(m_mod) 278 emitter.PushIndent() 279 GenerateGemmSwitch2(emitter, output_type, aligned, m_mod) 280 emitter.EmitBreak() 281 emitter.PopIndent() 282 283 emitter.EmitSwitchEnd() 284 285 286def BuildMainGemmName(output_type): 287 if output_type is _QUANTIZED_8BIT: 288 return 'gemm_q8' 289 elif output_type is _FULL_32BIT: 290 return 'gemm_i32' 291 elif output_type is _FULL_FLOAT: 292 return 'gemm_f' 293 else: 294 raise ConfigurationError('Unsupported output type: %s' % output_type) 295 296 297def BuildStridedMainGemmName(output_type): 298 return BuildMainGemmName(output_type) + '_strided' 299 300 301def GenerateMainGemmFunction(emitter, output_type): 302 """Emit high level gemm function that switches between optimized versions.""" 303 emitter.EmitFunctionBeginA( 304 BuildStridedMainGemmName(output_type), 305 GetStridedGemmParameters(output_type), 'void') 306 307 emitter.EmitDeclare('const bool', 'lhs_aligned', 308 '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)') 309 emitter.EmitDeclare('const bool', 'rhs_aligned', 310 '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)') 311 emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)') 312 313 if output_type is _QUANTIZED_8BIT: 314 emitter.EmitDeclare('const bool', 'result_aligned', 315 '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)') 316 emitter.EmitDeclare('const bool', 'result_stride_aligned', 317 '((result_stride % 8) == 0)') 318 emitter.EmitDeclare('const bool', 'aligned', 319 'lhs_aligned && rhs_aligned && result_aligned ' 320 '&& k_aligned && result_stride_aligned') 321 else: 322 emitter.EmitDeclare('const bool', 'aligned', 323 'lhs_aligned && rhs_aligned && k_aligned') 324 325 emitter.EmitIf('aligned') 326 GenerateGemmSwitch1(emitter, output_type, True) 327 emitter.EmitElse() 328 GenerateGemmSwitch1(emitter, output_type, False) 329 emitter.EmitEndif() 330 emitter.EmitFunctionEnd() 331 332 333def GenerateWrapperGemmFunction(emitter, output_type): 334 emitter.EmitFunctionBeginA( 335 BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void') 336 emitter.EmitCall( 337 BuildStridedMainGemmName(output_type), 338 [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n']) 339 emitter.EmitFunctionEnd() 340 341 342def GenerateInternalFunctions(emitter): 343 """Generate all the functions hidden in the internal namespace.""" 344 for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: 345 for aligned in [True, False]: 346 for rows in range(0, 3): 347 for cols in range(0, 3): 348 for leftover in range(0, 8): 349 GenerateGemm(emitter, output_type, aligned, rows, cols, leftover) 350 emitter.EmitNewline() 351 352 353def GeneratePublicFunctions(emitter): 354 for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: 355 GenerateMainGemmFunction(emitter, output_type) 356 emitter.EmitNewline() 357 358 GenerateWrapperGemmFunction(emitter, output_type) 359 emitter.EmitNewline() 360