1"""Generates the whole gemm header. 2 3""" 4 5import cc_emitter 6import mul_Nx8_Mx8_neon 7import neon_emitter 8import qnt_Nx8_neon 9import zip_Nx8_neon 10 11_HEADER_COPYRIGHT = """// Copyright 2015 Google Inc. All Rights Reserved. 12// 13// Licensed under the Apache License, Version 2.0 (the "License"); 14// you may not use this file except in compliance with the License. 15// You may obtain a copy of the License at 16// 17// http://www.apache.org/licenses/LICENSE-2.0 18// 19// Unless required by applicable law or agreed to in writing, software 20// distributed under the License is distributed on an "AS IS" BASIS, 21// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22// See the License for the specific language governing permissions and 23// limitations under the License. 24// 25// single_thread_gemm.h: programatically generated GEMM library header. 26""" 27 28_QUANTIZED_8BIT = 'quantized_8bit' 29_FULL_32BIT = 'full_32bit' 30_FULL_FLOAT = 'full_float' 31 32 33class Error(Exception): 34 """Module level error.""" 35 36 37class ConfigurationError(Error): 38 """Runtime configuration error.""" 39 40 41def GenerateCommonTempsCountersAndConsts(emitter, rows): 42 emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3') 43 emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3') 44 emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8') 45 emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3') 46 emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size', 47 '(padded_k + 16) * 3') 48 emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size', 49 '(padded_k + 16) * n') 50 emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs') 51 emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs') 52 emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch') 53 emitter.EmitDeclare( 54 'std::int32_t*', 'zipped_lhs_3_offsets', 55 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3)') 56 if rows is not 0: 57 emitter.EmitDeclare( 58 'std::int32_t*', 'zipped_lhs_%d_offsets' % rows, 59 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows) 60 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs', 61 'scratch + zipped_chunk_size') 62 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs') 63 emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride', 64 'result_stride * 3') 65 emitter.EmitNewline() 66 67 68def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows): 69 """Generates all the boilerplate variables for the q8 gemm function.""" 70 GenerateCommonTempsCountersAndConsts(emitter, rows) 71 emitter.EmitDeclare('const std::int32_t', 'const_offset', 72 'lhs_offset * rhs_offset * k + result_offset') 73 emitter.EmitDeclare('const std::int32_t', 'rounding_offset', 74 '(1 << (shift - 1))') 75 emitter.EmitDeclare('std::int32_t*', 'temp_result', 76 'reinterpret_cast<std::int32_t*>(' 77 'scratch + zipped_chunk_size + zipped_rhs_size)') 78 emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result') 79 emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result') 80 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes', 81 '((n * 4 + 7) / 8) * 8') 82 emitter.EmitNewline() 83 84 85def GenerateFullTempsCountersAndConsts(emitter, result_type, rows): 86 """Generates all the boilerplate variables for the int32 and float gemms.""" 87 GenerateCommonTempsCountersAndConsts(emitter, rows) 88 emitter.EmitDeclare('const std::int32_t', 'const_offset', 89 'lhs_offset * rhs_offset * k') 90 emitter.EmitDeclare(result_type, 'result_chunk', 'result') 91 emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result') 92 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes', 93 'result_stride * 4') 94 emitter.EmitNewline() 95 96 97def ZipName(rows, leftovers, aligned): 98 return zip_Nx8_neon.BuildName(rows, leftovers, aligned) 99 100 101def GenerateZipRhs(emitter, aligned, cols, leftovers): 102 """Emits the code responsible for zipping the rhs matrix.""" 103 emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)') 104 emitter.EmitCall( 105 ZipName(3, leftovers, aligned), 106 ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0]) 107 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') 108 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size') 109 emitter.EmitCloseBracket() 110 111 if cols is not 0: 112 emitter.EmitCall( 113 ZipName(cols, leftovers, aligned), 114 ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0]) 115 emitter.EmitNewline() 116 117 118def MulName(result_type, lhs_add, rhs_add, rows, cols): 119 return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols) 120 121 122def GetMulParams(result_type): 123 params = ['zipped_lhs', 'zipped_rhs_chunk', 'padded_k', 'mul_result_chunk', 124 'mul_result_chunk_stride_bytes'] 125 if result_type is 'float': 126 params.append('result_scale') 127 return params 128 129 130def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned, 131 rows, cols, leftovers): 132 """Emits code responsible for multiplication of one horizontal lhs strip.""" 133 emitter.EmitCall( 134 ZipName(rows, leftovers, aligned), 135 ['lhs_chunk', 'k', 'k', 'zipped_lhs', 'rhs_offset', 'const_offset']) 136 emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs') 137 emitter.EmitAssign('mul_result_chunk', result) 138 139 emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)') 140 141 emitter.EmitCall( 142 MulName(result_type, lhs_add, rhs_add, rows, 3), 143 GetMulParams(result_type)) 144 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size') 145 emitter.EmitAssignIncrement('mul_result_chunk', 3) 146 147 emitter.EmitCloseBracket() 148 149 if cols is not 0: 150 emitter.EmitCall( 151 MulName(result_type, lhs_add, rhs_add, rows, cols), 152 GetMulParams(result_type)) 153 154 155def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers): 156 """Emits code for all lhs strips & leftover rows. Quantize after mul code.""" 157 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)') 158 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3, 159 cols, leftovers) 160 emitter.EmitCall( 161 BuildMultiQuantizeName(aligned, 3), 162 ['temp_result', 'n', 'mul_result_chunk_stride_bytes', 163 'zipped_lhs_3_offsets', 'result_chunk', 'result_stride', 164 'multiplicative_offset', 'rounding_offset', '-shift']) 165 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size') 166 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride') 167 emitter.EmitCloseBracket() 168 emitter.EmitNewline() 169 170 if rows is not 0: 171 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows, 172 cols, leftovers) 173 emitter.EmitCall( 174 BuildMultiQuantizeName(aligned, rows), 175 ['temp_result', 'n', 'mul_result_chunk_stride_bytes', 176 'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride', 177 'multiplicative_offset', 'rounding_offset', '-shift']) 178 179 180def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers): 181 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)') 182 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3, 183 cols, leftovers) 184 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size') 185 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride') 186 emitter.EmitCloseBracket() 187 emitter.EmitNewline() 188 189 if rows is not 0: 190 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 191 rows, cols, leftovers) 192 193 194def BuildName(output_type, aligned, rows, cols, leftover): 195 name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover) 196 if aligned: 197 name += '_aligned' 198 return name 199 200 201def GetCommonGemmParameters(): 202 return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'], 203 ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'm'], 204 ['std::int32_t', 'n'], ['std::int32_t', 'k'], 205 ['std::int32_t', 'lhs_offset'], ['std::int32_t', 'rhs_offset']] 206 207 208def GetGemmParameters(output_type, extra_params=None): 209 """Prepares a (type, parameter) array for the gemm functions.""" 210 if extra_params is None: 211 extra_params = [] 212 params = GetCommonGemmParameters() 213 if output_type is _QUANTIZED_8BIT: 214 params += [['std::int32_t', 'result_offset'], 215 ['std::int32_t', 'multiplicative_offset'], 216 ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']] 217 elif output_type is _FULL_32BIT: 218 params += [['std::int32_t*', 'result']] 219 elif output_type is _FULL_FLOAT: 220 params += [['float', 'result_scale'], ['float*', 'result']] 221 else: 222 raise ConfigurationError('Unsupported output type: %s' % output_type) 223 return params + extra_params 224 225 226def GetStridedGemmParameters(output_type): 227 return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']]) 228 229 230def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers): 231 """Build one gemm function for given row, col, and depth leftovers.""" 232 emitter.EmitFunctionBeginA( 233 BuildName(output_type, aligned, rows, cols, leftovers), 234 GetStridedGemmParameters(output_type), 'void') 235 236 emitter.EmitAssert('m %% 3 == %d' % rows) 237 emitter.EmitAssert('n %% 3 == %d' % cols) 238 emitter.EmitAssert('k %% 8 == %d' % leftovers) 239 240 if output_type is _QUANTIZED_8BIT: 241 GenerateQuantized8BitTempsCountersAndConsts(emitter, rows) 242 GenerateZipRhs(emitter, aligned, cols, leftovers) 243 GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers) 244 elif output_type is _FULL_32BIT: 245 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows) 246 GenerateZipRhs(emitter, aligned, cols, leftovers) 247 GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers) 248 elif output_type is _FULL_FLOAT: 249 GenerateFullTempsCountersAndConsts(emitter, 'float*', rows) 250 GenerateZipRhs(emitter, aligned, cols, leftovers) 251 GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers) 252 else: 253 raise ConfigurationError('Unknown output type: %s' % output_type) 254 255 emitter.EmitFunctionEnd() 256 257 258def BuildMultiQuantizeName(aligned, rows): 259 name = 'multi_qnt_%dx8' % rows 260 if aligned: 261 name = '%s_aligned' % name 262 return name 263 264 265def GenerateMultiQuantize(emitter, aligned, rows): 266 """Emit main quantization code that switches between optimized versions.""" 267 name = BuildMultiQuantizeName(aligned, rows) 268 emitter.EmitFunctionBeginA( 269 name, 270 [['const std::int32_t*', 'source'], ['std::int32_t', 'count'], 271 ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'], 272 ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'], 273 ['std::int32_t', 'multiplicative_offset'], 274 ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void') 275 emitter.EmitSwitch('count % 8') 276 277 for leftovers in range(0, 8): 278 emitter.EmitCase(leftovers) 279 emitter.PushIndent() 280 emitter.EmitCall( 281 qnt_Nx8_neon.BuildName(rows, leftovers, aligned), 282 ['source', 'count', 'stride', 'offsets', 'destination', 283 'destination_stride', 'multiplicative_offset', 'rounding_offset', 284 'shift']) 285 emitter.EmitBreak() 286 emitter.PopIndent() 287 288 emitter.EmitSwitchEnd() 289 emitter.EmitFunctionEnd() 290 291 292def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers): 293 emitter.EmitCall( 294 emitter.Scope('internal', 295 BuildName(output_type, aligned, m_mod, n_mod, leftovers)), 296 [p for (unused_t, p) in GetStridedGemmParameters(output_type)]) 297 298 299def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod): 300 """Third level of main switch, choose optimized version on depth leftover.""" 301 emitter.EmitSwitch('k % 8') 302 303 for leftovers in range(0, 8): 304 emitter.EmitCase(leftovers) 305 emitter.PushIndent() 306 GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers) 307 emitter.EmitBreak() 308 emitter.PopIndent() 309 310 emitter.EmitSwitchEnd() 311 312 313def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod): 314 """Second level of main switch, choose optimized version on cols leftover.""" 315 emitter.EmitSwitch('n % 3') 316 317 for n_mod in range(0, 3): 318 emitter.EmitCase(n_mod) 319 emitter.PushIndent() 320 GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod) 321 emitter.EmitBreak() 322 emitter.PopIndent() 323 324 emitter.EmitSwitchEnd() 325 326 327def GenerateGemmSwitch1(emitter, output_type, aligned): 328 """First level of main switch, choose optimized version on rows leftover.""" 329 emitter.EmitSwitch('m % 3') 330 331 for m_mod in range(0, 3): 332 emitter.EmitCase(m_mod) 333 emitter.PushIndent() 334 GenerateGemmSwitch2(emitter, output_type, aligned, m_mod) 335 emitter.EmitBreak() 336 emitter.PopIndent() 337 338 emitter.EmitSwitchEnd() 339 340 341def BuildMainGemmName(output_type): 342 if output_type is _QUANTIZED_8BIT: 343 return 'gemm_q8' 344 elif output_type is _FULL_32BIT: 345 return 'gemm_i32' 346 elif output_type is _FULL_FLOAT: 347 return 'gemm_f' 348 else: 349 raise ConfigurationError('Unsupported output type: %s' % output_type) 350 351 352def BuildStridedMainGemmName(output_type): 353 return BuildMainGemmName(output_type) + '_strided' 354 355 356def GenerateMainGemmFunction(emitter, output_type): 357 """Emit high level gemm function that switches between optimized versions.""" 358 emitter.EmitFunctionBeginA( 359 BuildStridedMainGemmName(output_type), 360 GetStridedGemmParameters(output_type), 'void') 361 362 emitter.EmitDeclare('const bool', 'lhs_aligned', 363 '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)') 364 emitter.EmitDeclare('const bool', 'rhs_aligned', 365 '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)') 366 emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)') 367 368 if output_type is _QUANTIZED_8BIT: 369 emitter.EmitDeclare('const bool', 'result_aligned', 370 '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)') 371 emitter.EmitDeclare('const bool', 'result_stride_aligned', 372 '((result_stride % 8) == 0)') 373 emitter.EmitDeclare('const bool', 'aligned', 374 'lhs_aligned && rhs_aligned && result_aligned ' 375 '&& k_aligned && result_stride_aligned') 376 else: 377 emitter.EmitDeclare('const bool', 'aligned', 378 'lhs_aligned && rhs_aligned && k_aligned') 379 380 emitter.EmitIf('aligned') 381 GenerateGemmSwitch1(emitter, output_type, True) 382 emitter.EmitElse() 383 GenerateGemmSwitch1(emitter, output_type, False) 384 emitter.EmitEndif() 385 emitter.EmitFunctionEnd() 386 387 388def GenerateWrapperGemmFunction(emitter, output_type): 389 emitter.EmitFunctionBeginA( 390 BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void') 391 emitter.EmitCall( 392 BuildStridedMainGemmName(output_type), 393 [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n']) 394 emitter.EmitFunctionEnd() 395 396 397def GenerateInternalFunctions(emitter): 398 """Generate all the functions hidden in the internal namespace.""" 399 zip_Nx8_neon.GenerateFunctions(neon_emitter.NeonEmitter()) 400 emitter.EmitNewline() 401 402 mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', False, 403 True) 404 emitter.EmitNewline() 405 406 mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, 407 True) 408 emitter.EmitNewline() 409 410 mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'float', True, 411 True) 412 emitter.EmitNewline() 413 414 qnt_Nx8_neon.GenerateFunctions(neon_emitter.NeonEmitter()) 415 emitter.EmitNewline() 416 417 for aligned in [True, False]: 418 for rows in range(1, 4): 419 GenerateMultiQuantize(emitter, aligned, rows) 420 emitter.EmitNewline() 421 422 for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: 423 for aligned in [True, False]: 424 for rows in range(0, 3): 425 for cols in range(0, 3): 426 for leftover in range(0, 8): 427 GenerateGemm(emitter, output_type, aligned, rows, cols, leftover) 428 emitter.EmitNewline() 429 430 431def Main(): 432 """Generate the single threaded meta gemm library.""" 433 emitter = cc_emitter.CCEmitter() 434 435 emitter.EmitCodeNoSemicolon(_HEADER_COPYRIGHT) 436 emitter.EmitHeaderBegin('gemmlowp_meta_single_thread_gemm') 437 438 emitter.EmitPreprocessor1('ifdef', 'GEMMLOWP_NEON_32') 439 emitter.EmitNewline() 440 441 emitter.EmitInclude('<cassert>') 442 emitter.EmitNewline() 443 444 emitter.EmitNamespaceBegin('gemmlowp') 445 emitter.EmitNamespaceBegin('meta') 446 emitter.EmitNamespaceBegin('internal') 447 emitter.EmitNewline() 448 449 GenerateInternalFunctions(emitter) 450 451 emitter.EmitNamespaceEnd() 452 emitter.EmitNewline() 453 454 GenerateMainGemmFunction(emitter, _QUANTIZED_8BIT) 455 emitter.EmitNewline() 456 GenerateMainGemmFunction(emitter, _FULL_32BIT) 457 emitter.EmitNewline() 458 GenerateMainGemmFunction(emitter, _FULL_FLOAT) 459 emitter.EmitNewline() 460 GenerateWrapperGemmFunction(emitter, _QUANTIZED_8BIT) 461 emitter.EmitNewline() 462 GenerateWrapperGemmFunction(emitter, _FULL_32BIT) 463 emitter.EmitNewline() 464 GenerateWrapperGemmFunction(emitter, _FULL_FLOAT) 465 emitter.EmitNewline() 466 467 emitter.EmitNamespaceEnd() 468 emitter.EmitNamespaceEnd() 469 emitter.EmitNewline() 470 471 emitter.EmitPreprocessor('else') 472 emitter.EmitPreprocessor1('warning', 473 '"Meta gemm fast-path requires GEMMLOWP_NEON_32!"') 474 emitter.EmitPreprocessor('endif') 475 emitter.EmitNewline() 476 477 emitter.EmitHeaderEnd() 478 479 480if __name__ == '__main__': 481 Main() 482