1# Copyright 2016 The Gemmlowp Authors. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""".""" 15 16import common 17 18 19def _ReadParams(emitter, registers, input_address, elements, min_register): 20 registers_count = (elements + 3) / 4 21 registers = [ 22 registers.QuadRegister(min_register) 23 for unused_i in range(registers_count) 24 ] 25 emitter.EmitVLoadAE(registers_count * 4, 32, registers, input_address, 64) 26 return registers 27 28 29def _Duplicate(emitter, registers, rows, values): 30 """Populate a grid of registers duplicating provided values.""" 31 duplicated = [] 32 for i in range(rows): 33 if i is rows - 1: 34 duplicated.append(values[0]) 35 else: 36 duplicated.append(registers.QuadRegister()) 37 38 emitter.EmitVDup('32', duplicated[i], 39 emitter.Lane(32, values[i / 4], i % 4)) 40 41 return duplicated 42 43 44def _DuplicateGeneralRegister(emitter, registers, value, min_register): 45 register = registers.QuadRegister(min_register) 46 emitter.EmitVDup('32', register, value) 47 return register 48 49 50class _StaticQuantizationUInt8Transformation(object): 51 """Calculate quantized values and cast back to uint8.""" 52 53 def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs): 54 """Load parameters and prepare duplicated registers.""" 55 emitter.EmitNewline() 56 emitter.EmitComment('StaticQuantization::Prepare') 57 58 lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4) 59 self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4) 60 self.multiplicative_offset = _DuplicateGeneralRegister( 61 emitter, registers, 62 registers.MapParameter('multiplicative_offset', 63 'params.kernel.multiplicative_offset'), 4) 64 self.rounding_offset = _DuplicateGeneralRegister( 65 emitter, registers, 66 registers.MapParameter('rounding_offset', 67 'params.kernel.rounding_offset'), 4) 68 self.shift = _DuplicateGeneralRegister( 69 emitter, registers, 70 registers.MapParameter('shift', 'params.kernel.shift'), 4) 71 self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset) 72 73 def Transform(self, emitter, registers, data, unused_kernel_m, 74 unused_kernel_n): 75 """Quantize the data.""" 76 emitter.EmitNewline() 77 emitter.EmitComment('StaticQuantization::Transform') 78 79 for (row, lhs_offset) in zip(data, self.lhs_offsets): 80 for row_register in row: 81 emitter.EmitVAdd('s32', row_register, row_register, lhs_offset) 82 83 for row in data: 84 for (row_register, rhs_offset_register) in zip(row, self.rhs_offsets): 85 emitter.EmitVAdd('s32', row_register, row_register, rhs_offset_register) 86 87 for row in data: 88 for row_register in row: 89 emitter.EmitVMul('i32', row_register, row_register, 90 self.multiplicative_offset) 91 92 for row in data: 93 for row_register in row: 94 emitter.EmitVAdd('i32', row_register, row_register, 95 self.rounding_offset) 96 97 for row in data: 98 for row_register in row: 99 emitter.EmitVShl('s32', row_register, row_register, self.shift) 100 101 if len(data[0]) is 1: 102 for row in data: 103 emitter.EmitVQmovn('s32', row[0], row[0]) 104 105 for row in data: 106 emitter.EmitVQmovun('s16', row[0], row[0]) 107 108 return data 109 elif len(data[0]) is 2: 110 results = [] 111 for row in data: 112 emitter.EmitVQmovn2('s32', row[0], row[0], row[1]) 113 registers.FreeRegister(row[1]) 114 results.append([row[0]]) 115 116 for row in results: 117 emitter.EmitVQmovun('s16', row[0], row[0]) 118 119 return results 120 else: 121 assert False 122 123 def Type(self): 124 return 8 125 126 127class _StaticQuantizationInt32Transformation(object): 128 """.""" 129 130 def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs): 131 emitter.EmitNewline() 132 emitter.EmitComment('StaticQuantizationInt32::Prepare') 133 134 lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4) 135 self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4) 136 self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset) 137 138 def Transform(self, emitter, unused_registers, data, unused_kernel_m, 139 unused_kernel_n): 140 """Quantize data and output as int32.""" 141 emitter.EmitNewline() 142 emitter.EmitComment('StaticQuantizationInt32::Transform') 143 144 for (row, lhs_offset) in zip(data, self.lhs_offsets): 145 for row_register in row: 146 emitter.EmitVAdd('s32', row_register, row_register, lhs_offset) 147 148 for row in data: 149 for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets): 150 emitter.EmitVAdd('s32', row_register, row_register, 151 rhs_offsets_register) 152 153 return data 154 155 def Type(self): 156 return 32 157 158 159class _StaticQuantizationFloatTransformation(object): 160 """.""" 161 162 def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs): 163 emitter.EmitNewline() 164 emitter.EmitComment('StaticQuantizationFloat::Prepare') 165 166 lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4) 167 self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4) 168 self.scale = _DuplicateGeneralRegister( 169 emitter, registers, 170 registers.MapParameter('scale', 'params.kernel.scale'), 4) 171 self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset) 172 173 def Transform(self, emitter, unused_registers, data, unused_kernel_m, 174 unused_kernel_n): 175 """Quantize data and output as float.""" 176 emitter.EmitNewline() 177 emitter.EmitComment('StaticQuantizationFloat::Transform') 178 179 for (row, lhs_offset) in zip(data, self.lhs_offsets): 180 for row_register in row: 181 emitter.EmitVAdd('s32', row_register, row_register, lhs_offset) 182 183 for row in data: 184 for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets): 185 emitter.EmitVAdd('s32', row_register, row_register, 186 rhs_offsets_register) 187 188 for row in data: 189 for row_register in row: 190 emitter.EmitVCvt('f32', 's32', row_register, row_register) 191 192 for row in data: 193 for row_register in row: 194 emitter.EmitVMul('f32', row_register, row_register, self.scale) 195 196 return data 197 198 def Type(self): 199 return 32 200 201 202class _RowMajorOutput(object): 203 """Output data in row major layout.""" 204 205 def Prepare(self, emitter, registers, kernel_m, unused_kernel_n, 206 unused_data_type): 207 """Prepare strided load addresses.""" 208 emitter.EmitNewline() 209 emitter.EmitComment('RowMajorOutput::Prepare') 210 211 stride = registers.MapParameter('stride', 'params.output_stream.stride') 212 213 self.outputs = [] 214 self.outputs.append(registers.MapOutputParameter('result')) 215 216 for unused_i in range(kernel_m - 1): 217 register = registers.GeneralRegister() 218 emitter.EmitAdd(register, self.outputs[-1], stride) 219 self.outputs.append(register) 220 221 def Output(self, emitter, unused_registers, data, data_type, unused_kernel_m, 222 kernel_n): 223 emitter.EmitNewline() 224 emitter.EmitComment('RowMajorOutput::Output') 225 226 for (datum, output) in zip(data, self.outputs): 227 emitter.EmitVStoreAE(data_type, kernel_n, datum, output, None) 228 229 230def _GenerateAndClearAggregators(emitter, registers, count): 231 """Prepare aggregators and emit aggregator clear code.""" 232 emitter.EmitNewline() 233 emitter.EmitComment('Clear aggregators.') 234 aggregators = [registers.QuadRegister() for unused_i in range(count)] 235 for i in range(count): 236 if i < 3: 237 emitter.EmitVMov('i32', aggregators[i], emitter.ImmediateConstant(0)) 238 else: 239 emitter.EmitVMov('i32', aggregators[i], aggregators[i - 3]) 240 return aggregators 241 242 243def _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, 244 count): 245 """Emit inner loop for 3 rows x 3 cols multiplication.""" 246 emitter.EmitNewline() 247 emitter.EmitComment('3x3 lanes loop.') 248 emitter.EmitNumericalLabel(1) 249 emitter.EmitNewline() 250 251 lhs_load = [registers.DoubleRegister() for unused_i in range(3)] 252 rhs_load = [registers.DoubleRegister() for unused_i in range(3)] 253 temp = [registers.QuadRegister() for unused_i in range(4)] 254 255 emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 64)) 256 emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64)) 257 258 emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0]) 259 emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64)) 260 261 emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1]) 262 emitter.EmitVLoad(1, 8, lhs_load[2], emitter.DereferenceIncrement(lhs, 64)) 263 264 emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2]) 265 emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) 266 267 emitter.EmitVMull('u8', temp[3], lhs_load[1], rhs_load[0]) 268 emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64)) 269 270 emitter.EmitVPadal('u16', aggregators[0], temp[0]) 271 emitter.EmitVPadal('u16', aggregators[1], temp[1]) 272 emitter.EmitVPadal('u16', aggregators[2], temp[2]) 273 emitter.EmitVPadal('u16', aggregators[3], temp[3]) 274 275 emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1]) 276 emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2]) 277 278 registers.FreeRegisters([lhs_load[0], lhs_load[1]]) 279 temp.append(registers.QuadRegister()) 280 281 emitter.EmitVMull('u8', temp[2], lhs_load[2], rhs_load[0]) 282 emitter.EmitVMull('u8', temp[3], lhs_load[2], rhs_load[1]) 283 284 emitter.EmitNewline() 285 emitter.EmitComment('Subtract counter.') 286 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 287 emitter.EmitNewline() 288 289 emitter.EmitVMull('u8', temp[4], lhs_load[2], rhs_load[2]) 290 291 emitter.EmitVPadal('u16', aggregators[4], temp[0]) 292 emitter.EmitVPadal('u16', aggregators[5], temp[1]) 293 emitter.EmitVPadal('u16', aggregators[6], temp[2]) 294 emitter.EmitVPadal('u16', aggregators[7], temp[3]) 295 emitter.EmitVPadal('u16', aggregators[8], temp[4]) 296 297 emitter.EmitNewline() 298 emitter.EmitComment('Loop break.') 299 emitter.EmitBgtBack(1) 300 301 registers.FreeRegisters(temp + [lhs_load[2]] + rhs_load) 302 303 304def _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, 305 count): 306 """Emit inner loop for 2 rows x 4 cols multiplication.""" 307 emitter.EmitNewline() 308 emitter.EmitComment('2x4 lanes loop.') 309 emitter.EmitNumericalLabel(1) 310 emitter.EmitNewline() 311 312 lhs_load = [registers.DoubleRegister() for unused_i in range(2)] 313 rhs_load = [registers.DoubleRegister() for unused_i in range(4)] 314 temp = [registers.QuadRegister() for unused_i in range(5)] 315 316 emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 256)) 317 emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64)) 318 319 emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0]) 320 emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64)) 321 322 emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1]) 323 emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64)) 324 325 emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2]) 326 emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) 327 328 emitter.EmitVMull('u8', temp[3], lhs_load[0], rhs_load[3]) 329 emitter.EmitVMull('u8', temp[4], lhs_load[1], rhs_load[0]) 330 331 emitter.EmitVPadal('u16', aggregators[0], temp[0]) 332 emitter.EmitVPadal('u16', aggregators[1], temp[1]) 333 emitter.EmitVPadal('u16', aggregators[2], temp[2]) 334 335 emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1]) 336 emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2]) 337 emitter.EmitVMull('u8', temp[2], lhs_load[1], rhs_load[3]) 338 339 emitter.EmitNewline() 340 emitter.EmitComment('Subtract counter.') 341 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 342 343 emitter.EmitNewline() 344 emitter.EmitVPadal('u16', aggregators[3], temp[3]) 345 emitter.EmitVPadal('u16', aggregators[4], temp[4]) 346 emitter.EmitVPadal('u16', aggregators[5], temp[0]) 347 emitter.EmitVPadal('u16', aggregators[6], temp[1]) 348 emitter.EmitVPadal('u16', aggregators[7], temp[2]) 349 350 emitter.EmitNewline() 351 emitter.EmitComment('Loop break.') 352 emitter.EmitBgtBack(1) 353 354 registers.FreeRegisters(temp + lhs_load + rhs_load) 355 356 357def _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, 358 count): 359 """Emit inner loop for 1 rows x 8 cols multiplication.""" 360 emitter.EmitNewline() 361 emitter.EmitComment('1x8 lanes loop.') 362 emitter.EmitNumericalLabel(1) 363 emitter.EmitNewline() 364 365 lhs_load = registers.DoubleRegister() 366 rhs_load = [registers.DoubleRegister() for unused_i in range(4)] 367 temp = [registers.QuadRegister() for unused_i in range(5)] 368 369 emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256) 370 emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64) 371 372 emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[0]) 373 emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[1]) 374 emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[2]) 375 emitter.EmitVMull('u8', temp[3], lhs_load, rhs_load[3]) 376 377 emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256) 378 379 emitter.EmitVPadal('u16', aggregators[0], temp[0]) 380 emitter.EmitVPadal('u16', aggregators[1], temp[1]) 381 emitter.EmitVPadal('u16', aggregators[2], temp[2]) 382 emitter.EmitVPadal('u16', aggregators[3], temp[3]) 383 384 emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(256)) 385 386 emitter.EmitVMull('u8', temp[4], lhs_load, rhs_load[0]) 387 emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[1]) 388 emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[2]) 389 emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[3]) 390 391 emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(32)) 392 393 emitter.EmitNewline() 394 emitter.EmitComment('Subtract counter.') 395 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 396 397 emitter.EmitNewline() 398 emitter.EmitVPadal('u16', aggregators[4], temp[4]) 399 emitter.EmitVPadal('u16', aggregators[5], temp[0]) 400 emitter.EmitVPadal('u16', aggregators[6], temp[1]) 401 emitter.EmitVPadal('u16', aggregators[7], temp[2]) 402 403 emitter.EmitNewline() 404 emitter.EmitComment('Loop break.') 405 emitter.EmitBgtBack(1) 406 407 registers.FreeRegisters(temp + [lhs_load] + rhs_load) 408 409 410def _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n, 411 aggregators, lhs, rhs, count): 412 """Emit inner loop for N rows x M cols multiplication.""" 413 emitter.EmitNewline() 414 emitter.EmitComment('General NxM lanes loop.') 415 emitter.EmitNumericalLabel(1) 416 emitter.EmitNewline() 417 emitter.EmitComment('Subtract counter.') 418 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 419 emitter.EmitNewline() 420 421 lhs_load = [registers.DoubleRegister() for unused_i in range(kernel_m)] 422 rhs_load = [registers.DoubleRegister() for unused_i in range(kernel_n)] 423 424 emitter.EmitVLoadAE(8 * kernel_m, 8, lhs_load, lhs, 64) 425 emitter.EmitVLoadAE(8 * kernel_n, 8, rhs_load, rhs, 64) 426 427 emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) 428 emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64)) 429 430 results = [ 431 registers.QuadRegister() for unused_i in range(kernel_m * kernel_n) 432 ] 433 434 for row in range(kernel_m): 435 for col in range(kernel_n): 436 index = row * kernel_n + col 437 emitter.EmitVMull('u8', results[index], rhs_load[col], lhs_load[row]) 438 439 for i in range(kernel_m * kernel_n): 440 emitter.EmitVPadal('u16', aggregators[i], results[i]) 441 442 emitter.EmitNewline() 443 emitter.EmitComment('Loop break.') 444 emitter.EmitBgtBack(1) 445 446 registers.FreeRegisters(lhs_load + rhs_load + results) 447 448 449def _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators, 450 lhs, rhs, count): 451 """Emit inner loop for 1 row x M cols multiplication.""" 452 assert kernel_n in [5, 6, 7, 8] 453 emitter.EmitNewline() 454 emitter.EmitComment('General 1xM lanes loop.') 455 emitter.EmitNumericalLabel(1) 456 emitter.EmitNewline() 457 emitter.EmitComment('Subtract counter.') 458 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 459 emitter.EmitNewline() 460 461 leftover = kernel_n - 4 462 463 rhs_load = [registers.DoubleRegister() for unused_i in range(4)] 464 lhs_load = registers.DoubleRegister() 465 466 emitter.EmitVLoadAE(8 * 4, 8, rhs_load, rhs, 64) 467 emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64) 468 469 emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) 470 471 results = [registers.QuadRegister() for unused_i in range(4)] 472 473 for i in range(4): 474 emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load) 475 476 emitter.EmitVLoadAE(8 * leftover, 8, rhs_load, rhs, 64) 477 emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(128)) 478 479 for i in range(4): 480 emitter.EmitVPadal('u16', aggregators[i], results[i]) 481 482 for i in range(leftover): 483 emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load) 484 485 for i in range(leftover): 486 emitter.EmitVPadal('u16', aggregators[i + 4], results[i]) 487 488 emitter.EmitNewline() 489 emitter.EmitComment('Loop break.') 490 emitter.EmitBgtBack(1) 491 492 registers.FreeRegisters([lhs_load] + rhs_load + results) 493 494 495def _GenerateMultiplyKernel(emitter, registers, kernel_m, kernel_n, lhs, rhs): 496 """Main muliply loop. Pick best implementation for given kernel shape.""" 497 count = registers.MapParameter('count', 'params.kernel.count') 498 499 aggregators = _GenerateAndClearAggregators(emitter, registers, 500 kernel_m * kernel_n) 501 if kernel_m is 3 and kernel_n is 3: 502 _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, 503 count) 504 elif kernel_m is 2 and kernel_n is 4: 505 _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, 506 count) 507 elif kernel_m is 1 and kernel_n is 8: 508 _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, 509 count) 510 elif kernel_m is 1 and kernel_n > 4: 511 _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators, 512 lhs, rhs, count) 513 else: 514 _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n, 515 aggregators, lhs, rhs, count) 516 return aggregators 517 518 519def _ReduceAggregators(emitter, aggregators): 520 reduced_count = (len(aggregators) + 3) / 4 521 reduced = aggregators[:reduced_count] 522 emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators) 523 return reduced 524 525 526def _GenerateAggregatorReduce(emitter, aggregators, kernel_m, kernel_n): 527 emitter.EmitNewline() 528 emitter.EmitComment('Reduce aggregators.') 529 row_temps = [] 530 for i in range(kernel_m): 531 row_temps.append( 532 _ReduceAggregators(emitter, aggregators[i * kernel_n:(i + 1) * 533 kernel_n])) 534 return row_temps 535 536 537class QuantizedMulKernel(common.MulKernelGenerator): 538 """.""" 539 540 def __init__(self, cc_emitter, kernel_name, output_stream_name, asm_emitter, 541 fused_transformation, output_strategy): 542 common.MulKernelGenerator.__init__(self, cc_emitter, kernel_name, 543 output_stream_name) 544 self.asm_emitter = asm_emitter 545 self.fused_transformation = fused_transformation 546 self.output_strategy = output_strategy 547 548 def EmitMultiply(self, in_type, out_type, kernel_m, kernel_n, pack_size): 549 assert in_type is 'uint8_t' 550 assert pack_size is 8 551 assert kernel_m * kernel_n <= 9 552 553 registers = self.asm_emitter.CreateRegisters() 554 555 self.asm_emitter.PushIndent(self.emitter.indent) 556 self.asm_emitter.EmitAsmBegin() 557 558 lhs = registers.MapOutputParameter('lhs') 559 rhs = registers.MapOutputParameter('rhs') 560 self.asm_emitter.EmitPld(lhs) 561 self.asm_emitter.EmitPld(rhs) 562 563 aggregators = _GenerateMultiplyKernel(self.asm_emitter, registers, kernel_m, 564 kernel_n, lhs, rhs) 565 566 self.fused_transformation.Prepare(self.asm_emitter, registers, kernel_m, 567 kernel_n, lhs, rhs) 568 569 self.output_strategy.Prepare(self.asm_emitter, registers, kernel_m, 570 kernel_n, self.fused_transformation.Type()) 571 572 reduced = _GenerateAggregatorReduce(self.asm_emitter, aggregators, kernel_m, 573 kernel_n) 574 575 transformed = self.fused_transformation.Transform(self.asm_emitter, 576 registers, reduced, 577 kernel_m, kernel_n) 578 579 self.output_strategy.Output(self.asm_emitter, registers, transformed, 580 self.fused_transformation.Type(), kernel_m, 581 kernel_n) 582 583 self.asm_emitter.EmitAsmEnd(registers) 584 self.asm_emitter.PopIndent(len(self.emitter.indent)) 585 586 587class QuantizedMulStaticRowMajor(QuantizedMulKernel): 588 """.""" 589 590 def __init__(self, cc_emitter, asm_emitter): 591 QuantizedMulKernel.__init__(self, cc_emitter, 'QuantizedStaticPreprocessed', 592 'RowMajor', asm_emitter, 593 _StaticQuantizationUInt8Transformation(), 594 _RowMajorOutput()) 595 596 597class QuantizedMulStaticAsInt32RowMajor(QuantizedMulKernel): 598 """.""" 599 600 def __init__(self, cc_emitter, asm_emitter): 601 QuantizedMulKernel.__init__(self, cc_emitter, 602 'QuantizedStaticPreprocessedAsInt32', 603 'RowMajor', asm_emitter, 604 _StaticQuantizationInt32Transformation(), 605 _RowMajorOutput()) 606 607 608class QuantizedMulStaticAsFloatRowMajor(QuantizedMulKernel): 609 """.""" 610 611 def __init__(self, cc_emitter, asm_emitter): 612 QuantizedMulKernel.__init__(self, cc_emitter, 613 'QuantizedStaticPreprocessedAsFloat', 614 'RowMajor', asm_emitter, 615 _StaticQuantizationFloatTransformation(), 616 _RowMajorOutput()) 617 618 619def GenerateKernels(cc_emitter, asm_emitter, shapes): 620 """Generate the quantized multiplication kernels for uint8 operands.""" 621 quantized_mul_static_row_major = QuantizedMulStaticRowMajor(cc_emitter, 622 asm_emitter) 623 quantized_mul_static_int32_row_major = QuantizedMulStaticAsInt32RowMajor( 624 cc_emitter, asm_emitter) 625 626 quantized_mul_static_float_row_major = QuantizedMulStaticAsFloatRowMajor( 627 cc_emitter, asm_emitter) 628 629 for shape in shapes: 630 quantized_mul_static_row_major.SpecializeMulKernel('uint8_t', 'uint8_t', 631 shape[0], shape[1], 8) 632 for shape in shapes: 633 quantized_mul_static_int32_row_major.SpecializeMulKernel('uint8_t', 634 'int32_t', 635 shape[0], shape[1], 636 8) 637 638 for shape in shapes: 639 quantized_mul_static_float_row_major.SpecializeMulKernel('uint8_t', 'float', 640 shape[0], shape[1], 641 8) 642