• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""."""
15
16import common
17
18
19def _ReadParams(emitter, registers, input_address, elements, min_register):
20  registers_count = (elements + 3) / 4
21  registers = [
22      registers.QuadRegister(min_register)
23      for unused_i in range(registers_count)
24  ]
25  emitter.EmitVLoadAE(registers_count * 4, 32, registers, input_address, 64)
26  return registers
27
28
29def _Duplicate(emitter, registers, rows, values):
30  """Populate a grid of registers duplicating provided values."""
31  duplicated = []
32  for i in range(rows):
33    if i is rows - 1:
34      duplicated.append(values[0])
35    else:
36      duplicated.append(registers.QuadRegister())
37
38    emitter.EmitVDup('32', duplicated[i],
39                     emitter.Lane(32, values[i / 4], i % 4))
40
41  return duplicated
42
43
44def _DuplicateGeneralRegister(emitter, registers, value, min_register):
45  register = registers.QuadRegister(min_register)
46  emitter.EmitVDup('32', register, value)
47  return register
48
49
50class _StaticQuantizationUInt8Transformation(object):
51  """Calculate quantized values and cast back to uint8."""
52
53  def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
54    """Load parameters and prepare duplicated registers."""
55    emitter.EmitNewline()
56    emitter.EmitComment('StaticQuantization::Prepare')
57
58    lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
59    self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
60    self.multiplicative_offset = _DuplicateGeneralRegister(
61        emitter, registers,
62        registers.MapParameter('multiplicative_offset',
63                               'params.kernel.multiplicative_offset'), 4)
64    self.rounding_offset = _DuplicateGeneralRegister(
65        emitter, registers,
66        registers.MapParameter('rounding_offset',
67                               'params.kernel.rounding_offset'), 4)
68    self.shift = _DuplicateGeneralRegister(
69        emitter, registers,
70        registers.MapParameter('shift', 'params.kernel.shift'), 4)
71    self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
72
73  def Transform(self, emitter, registers, data, unused_kernel_m,
74                unused_kernel_n):
75    """Quantize the data."""
76    emitter.EmitNewline()
77    emitter.EmitComment('StaticQuantization::Transform')
78
79    for (row, lhs_offset) in zip(data, self.lhs_offsets):
80      for row_register in row:
81        emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
82
83    for row in data:
84      for (row_register, rhs_offset_register) in zip(row, self.rhs_offsets):
85        emitter.EmitVAdd('s32', row_register, row_register, rhs_offset_register)
86
87    for row in data:
88      for row_register in row:
89        emitter.EmitVMul('i32', row_register, row_register,
90                         self.multiplicative_offset)
91
92    for row in data:
93      for row_register in row:
94        emitter.EmitVAdd('i32', row_register, row_register,
95                         self.rounding_offset)
96
97    for row in data:
98      for row_register in row:
99        emitter.EmitVShl('s32', row_register, row_register, self.shift)
100
101    if len(data[0]) is 1:
102      for row in data:
103        emitter.EmitVQmovn('s32', row[0], row[0])
104
105      for row in data:
106        emitter.EmitVQmovun('s16', row[0], row[0])
107
108      return data
109    elif len(data[0]) is 2:
110      results = []
111      for row in data:
112        emitter.EmitVQmovn2('s32', row[0], row[0], row[1])
113        registers.FreeRegister(row[1])
114        results.append([row[0]])
115
116      for row in results:
117        emitter.EmitVQmovun('s16', row[0], row[0])
118
119      return results
120    else:
121      assert False
122
123  def Type(self):
124    return 8
125
126
127class _StaticQuantizationInt32Transformation(object):
128  """."""
129
130  def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
131    emitter.EmitNewline()
132    emitter.EmitComment('StaticQuantizationInt32::Prepare')
133
134    lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
135    self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
136    self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
137
138  def Transform(self, emitter, unused_registers, data, unused_kernel_m,
139                unused_kernel_n):
140    """Quantize data and output as int32."""
141    emitter.EmitNewline()
142    emitter.EmitComment('StaticQuantizationInt32::Transform')
143
144    for (row, lhs_offset) in zip(data, self.lhs_offsets):
145      for row_register in row:
146        emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
147
148    for row in data:
149      for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets):
150        emitter.EmitVAdd('s32', row_register, row_register,
151                         rhs_offsets_register)
152
153    return data
154
155  def Type(self):
156    return 32
157
158
159class _StaticQuantizationFloatTransformation(object):
160  """."""
161
162  def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
163    emitter.EmitNewline()
164    emitter.EmitComment('StaticQuantizationFloat::Prepare')
165
166    lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
167    self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
168    self.scale = _DuplicateGeneralRegister(
169        emitter, registers,
170        registers.MapParameter('scale', 'params.kernel.scale'), 4)
171    self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
172
173  def Transform(self, emitter, unused_registers, data, unused_kernel_m,
174                unused_kernel_n):
175    """Quantize data and output as float."""
176    emitter.EmitNewline()
177    emitter.EmitComment('StaticQuantizationFloat::Transform')
178
179    for (row, lhs_offset) in zip(data, self.lhs_offsets):
180      for row_register in row:
181        emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
182
183    for row in data:
184      for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets):
185        emitter.EmitVAdd('s32', row_register, row_register,
186                         rhs_offsets_register)
187
188    for row in data:
189      for row_register in row:
190        emitter.EmitVCvt('f32', 's32', row_register, row_register)
191
192    for row in data:
193      for row_register in row:
194        emitter.EmitVMul('f32', row_register, row_register, self.scale)
195
196    return data
197
198  def Type(self):
199    return 32
200
201
202class _RowMajorOutput(object):
203  """Output data in row major layout."""
204
205  def Prepare(self, emitter, registers, kernel_m, unused_kernel_n,
206              unused_data_type):
207    """Prepare strided load addresses."""
208    emitter.EmitNewline()
209    emitter.EmitComment('RowMajorOutput::Prepare')
210
211    stride = registers.MapParameter('stride', 'params.output_stream.stride')
212
213    self.outputs = []
214    self.outputs.append(registers.MapOutputParameter('result'))
215
216    for unused_i in range(kernel_m - 1):
217      register = registers.GeneralRegister()
218      emitter.EmitAdd(register, self.outputs[-1], stride)
219      self.outputs.append(register)
220
221  def Output(self, emitter, unused_registers, data, data_type, unused_kernel_m,
222             kernel_n):
223    emitter.EmitNewline()
224    emitter.EmitComment('RowMajorOutput::Output')
225
226    for (datum, output) in zip(data, self.outputs):
227      emitter.EmitVStoreAE(data_type, kernel_n, datum, output, None)
228
229
230def _GenerateAndClearAggregators(emitter, registers, count):
231  """Prepare aggregators and emit aggregator clear code."""
232  emitter.EmitNewline()
233  emitter.EmitComment('Clear aggregators.')
234  aggregators = [registers.QuadRegister() for unused_i in range(count)]
235  for i in range(count):
236    if i < 3:
237      emitter.EmitVMov('i32', aggregators[i], emitter.ImmediateConstant(0))
238    else:
239      emitter.EmitVMov('i32', aggregators[i], aggregators[i - 3])
240  return aggregators
241
242
243def _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
244                                      count):
245  """Emit inner loop for 3 rows x 3 cols multiplication."""
246  emitter.EmitNewline()
247  emitter.EmitComment('3x3 lanes loop.')
248  emitter.EmitNumericalLabel(1)
249  emitter.EmitNewline()
250
251  lhs_load = [registers.DoubleRegister() for unused_i in range(3)]
252  rhs_load = [registers.DoubleRegister() for unused_i in range(3)]
253  temp = [registers.QuadRegister() for unused_i in range(4)]
254
255  emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 64))
256  emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64))
257
258  emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0])
259  emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64))
260
261  emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1])
262  emitter.EmitVLoad(1, 8, lhs_load[2], emitter.DereferenceIncrement(lhs, 64))
263
264  emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2])
265  emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
266
267  emitter.EmitVMull('u8', temp[3], lhs_load[1], rhs_load[0])
268  emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
269
270  emitter.EmitVPadal('u16', aggregators[0], temp[0])
271  emitter.EmitVPadal('u16', aggregators[1], temp[1])
272  emitter.EmitVPadal('u16', aggregators[2], temp[2])
273  emitter.EmitVPadal('u16', aggregators[3], temp[3])
274
275  emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1])
276  emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2])
277
278  registers.FreeRegisters([lhs_load[0], lhs_load[1]])
279  temp.append(registers.QuadRegister())
280
281  emitter.EmitVMull('u8', temp[2], lhs_load[2], rhs_load[0])
282  emitter.EmitVMull('u8', temp[3], lhs_load[2], rhs_load[1])
283
284  emitter.EmitNewline()
285  emitter.EmitComment('Subtract counter.')
286  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
287  emitter.EmitNewline()
288
289  emitter.EmitVMull('u8', temp[4], lhs_load[2], rhs_load[2])
290
291  emitter.EmitVPadal('u16', aggregators[4], temp[0])
292  emitter.EmitVPadal('u16', aggregators[5], temp[1])
293  emitter.EmitVPadal('u16', aggregators[6], temp[2])
294  emitter.EmitVPadal('u16', aggregators[7], temp[3])
295  emitter.EmitVPadal('u16', aggregators[8], temp[4])
296
297  emitter.EmitNewline()
298  emitter.EmitComment('Loop break.')
299  emitter.EmitBgtBack(1)
300
301  registers.FreeRegisters(temp + [lhs_load[2]] + rhs_load)
302
303
304def _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
305                                      count):
306  """Emit inner loop for 2 rows x 4 cols multiplication."""
307  emitter.EmitNewline()
308  emitter.EmitComment('2x4 lanes loop.')
309  emitter.EmitNumericalLabel(1)
310  emitter.EmitNewline()
311
312  lhs_load = [registers.DoubleRegister() for unused_i in range(2)]
313  rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
314  temp = [registers.QuadRegister() for unused_i in range(5)]
315
316  emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 256))
317  emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64))
318
319  emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0])
320  emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64))
321
322  emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1])
323  emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
324
325  emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2])
326  emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
327
328  emitter.EmitVMull('u8', temp[3], lhs_load[0], rhs_load[3])
329  emitter.EmitVMull('u8', temp[4], lhs_load[1], rhs_load[0])
330
331  emitter.EmitVPadal('u16', aggregators[0], temp[0])
332  emitter.EmitVPadal('u16', aggregators[1], temp[1])
333  emitter.EmitVPadal('u16', aggregators[2], temp[2])
334
335  emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1])
336  emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2])
337  emitter.EmitVMull('u8', temp[2], lhs_load[1], rhs_load[3])
338
339  emitter.EmitNewline()
340  emitter.EmitComment('Subtract counter.')
341  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
342
343  emitter.EmitNewline()
344  emitter.EmitVPadal('u16', aggregators[3], temp[3])
345  emitter.EmitVPadal('u16', aggregators[4], temp[4])
346  emitter.EmitVPadal('u16', aggregators[5], temp[0])
347  emitter.EmitVPadal('u16', aggregators[6], temp[1])
348  emitter.EmitVPadal('u16', aggregators[7], temp[2])
349
350  emitter.EmitNewline()
351  emitter.EmitComment('Loop break.')
352  emitter.EmitBgtBack(1)
353
354  registers.FreeRegisters(temp + lhs_load + rhs_load)
355
356
357def _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
358                                      count):
359  """Emit inner loop for 1 rows x 8 cols multiplication."""
360  emitter.EmitNewline()
361  emitter.EmitComment('1x8 lanes loop.')
362  emitter.EmitNumericalLabel(1)
363  emitter.EmitNewline()
364
365  lhs_load = registers.DoubleRegister()
366  rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
367  temp = [registers.QuadRegister() for unused_i in range(5)]
368
369  emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256)
370  emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64)
371
372  emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[0])
373  emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[1])
374  emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[2])
375  emitter.EmitVMull('u8', temp[3], lhs_load, rhs_load[3])
376
377  emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256)
378
379  emitter.EmitVPadal('u16', aggregators[0], temp[0])
380  emitter.EmitVPadal('u16', aggregators[1], temp[1])
381  emitter.EmitVPadal('u16', aggregators[2], temp[2])
382  emitter.EmitVPadal('u16', aggregators[3], temp[3])
383
384  emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(256))
385
386  emitter.EmitVMull('u8', temp[4], lhs_load, rhs_load[0])
387  emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[1])
388  emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[2])
389  emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[3])
390
391  emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(32))
392
393  emitter.EmitNewline()
394  emitter.EmitComment('Subtract counter.')
395  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
396
397  emitter.EmitNewline()
398  emitter.EmitVPadal('u16', aggregators[4], temp[4])
399  emitter.EmitVPadal('u16', aggregators[5], temp[0])
400  emitter.EmitVPadal('u16', aggregators[6], temp[1])
401  emitter.EmitVPadal('u16', aggregators[7], temp[2])
402
403  emitter.EmitNewline()
404  emitter.EmitComment('Loop break.')
405  emitter.EmitBgtBack(1)
406
407  registers.FreeRegisters(temp + [lhs_load] + rhs_load)
408
409
410def _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n,
411                                      aggregators, lhs, rhs, count):
412  """Emit inner loop for N rows x M cols multiplication."""
413  emitter.EmitNewline()
414  emitter.EmitComment('General NxM lanes loop.')
415  emitter.EmitNumericalLabel(1)
416  emitter.EmitNewline()
417  emitter.EmitComment('Subtract counter.')
418  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
419  emitter.EmitNewline()
420
421  lhs_load = [registers.DoubleRegister() for unused_i in range(kernel_m)]
422  rhs_load = [registers.DoubleRegister() for unused_i in range(kernel_n)]
423
424  emitter.EmitVLoadAE(8 * kernel_m, 8, lhs_load, lhs, 64)
425  emitter.EmitVLoadAE(8 * kernel_n, 8, rhs_load, rhs, 64)
426
427  emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
428  emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
429
430  results = [
431      registers.QuadRegister() for unused_i in range(kernel_m * kernel_n)
432  ]
433
434  for row in range(kernel_m):
435    for col in range(kernel_n):
436      index = row * kernel_n + col
437      emitter.EmitVMull('u8', results[index], rhs_load[col], lhs_load[row])
438
439  for i in range(kernel_m * kernel_n):
440    emitter.EmitVPadal('u16', aggregators[i], results[i])
441
442  emitter.EmitNewline()
443  emitter.EmitComment('Loop break.')
444  emitter.EmitBgtBack(1)
445
446  registers.FreeRegisters(lhs_load + rhs_load + results)
447
448
449def _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators,
450                                      lhs, rhs, count):
451  """Emit inner loop for 1 row x M cols multiplication."""
452  assert kernel_n in [5, 6, 7, 8]
453  emitter.EmitNewline()
454  emitter.EmitComment('General 1xM lanes loop.')
455  emitter.EmitNumericalLabel(1)
456  emitter.EmitNewline()
457  emitter.EmitComment('Subtract counter.')
458  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
459  emitter.EmitNewline()
460
461  leftover = kernel_n - 4
462
463  rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
464  lhs_load = registers.DoubleRegister()
465
466  emitter.EmitVLoadAE(8 * 4, 8, rhs_load, rhs, 64)
467  emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64)
468
469  emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
470
471  results = [registers.QuadRegister() for unused_i in range(4)]
472
473  for i in range(4):
474    emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load)
475
476  emitter.EmitVLoadAE(8 * leftover, 8, rhs_load, rhs, 64)
477  emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(128))
478
479  for i in range(4):
480    emitter.EmitVPadal('u16', aggregators[i], results[i])
481
482  for i in range(leftover):
483    emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load)
484
485  for i in range(leftover):
486    emitter.EmitVPadal('u16', aggregators[i + 4], results[i])
487
488  emitter.EmitNewline()
489  emitter.EmitComment('Loop break.')
490  emitter.EmitBgtBack(1)
491
492  registers.FreeRegisters([lhs_load] + rhs_load + results)
493
494
495def _GenerateMultiplyKernel(emitter, registers, kernel_m, kernel_n, lhs, rhs):
496  """Main muliply loop. Pick best implementation for given kernel shape."""
497  count = registers.MapParameter('count', 'params.kernel.count')
498
499  aggregators = _GenerateAndClearAggregators(emitter, registers,
500                                             kernel_m * kernel_n)
501  if kernel_m is 3 and kernel_n is 3:
502    _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
503                                      count)
504  elif kernel_m is 2 and kernel_n is 4:
505    _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
506                                      count)
507  elif kernel_m is 1 and kernel_n is 8:
508    _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
509                                      count)
510  elif kernel_m is 1 and kernel_n > 4:
511    _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators,
512                                      lhs, rhs, count)
513  else:
514    _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n,
515                                      aggregators, lhs, rhs, count)
516  return aggregators
517
518
519def _ReduceAggregators(emitter, aggregators):
520  reduced_count = (len(aggregators) + 3) / 4
521  reduced = aggregators[:reduced_count]
522  emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators)
523  return reduced
524
525
526def _GenerateAggregatorReduce(emitter, aggregators, kernel_m, kernel_n):
527  emitter.EmitNewline()
528  emitter.EmitComment('Reduce aggregators.')
529  row_temps = []
530  for i in range(kernel_m):
531    row_temps.append(
532        _ReduceAggregators(emitter, aggregators[i * kernel_n:(i + 1) *
533                                                kernel_n]))
534  return row_temps
535
536
537class QuantizedMulKernel(common.MulKernelGenerator):
538  """."""
539
540  def __init__(self, cc_emitter, kernel_name, output_stream_name, asm_emitter,
541               fused_transformation, output_strategy):
542    common.MulKernelGenerator.__init__(self, cc_emitter, kernel_name,
543                                       output_stream_name)
544    self.asm_emitter = asm_emitter
545    self.fused_transformation = fused_transformation
546    self.output_strategy = output_strategy
547
548  def EmitMultiply(self, in_type, out_type, kernel_m, kernel_n, pack_size):
549    assert in_type is 'uint8_t'
550    assert pack_size is 8
551    assert kernel_m * kernel_n <= 9
552
553    registers = self.asm_emitter.CreateRegisters()
554
555    self.asm_emitter.PushIndent(self.emitter.indent)
556    self.asm_emitter.EmitAsmBegin()
557
558    lhs = registers.MapOutputParameter('lhs')
559    rhs = registers.MapOutputParameter('rhs')
560    self.asm_emitter.EmitPld(lhs)
561    self.asm_emitter.EmitPld(rhs)
562
563    aggregators = _GenerateMultiplyKernel(self.asm_emitter, registers, kernel_m,
564                                          kernel_n, lhs, rhs)
565
566    self.fused_transformation.Prepare(self.asm_emitter, registers, kernel_m,
567                                      kernel_n, lhs, rhs)
568
569    self.output_strategy.Prepare(self.asm_emitter, registers, kernel_m,
570                                 kernel_n, self.fused_transformation.Type())
571
572    reduced = _GenerateAggregatorReduce(self.asm_emitter, aggregators, kernel_m,
573                                        kernel_n)
574
575    transformed = self.fused_transformation.Transform(self.asm_emitter,
576                                                      registers, reduced,
577                                                      kernel_m, kernel_n)
578
579    self.output_strategy.Output(self.asm_emitter, registers, transformed,
580                                self.fused_transformation.Type(), kernel_m,
581                                kernel_n)
582
583    self.asm_emitter.EmitAsmEnd(registers)
584    self.asm_emitter.PopIndent(len(self.emitter.indent))
585
586
587class QuantizedMulStaticRowMajor(QuantizedMulKernel):
588  """."""
589
590  def __init__(self, cc_emitter, asm_emitter):
591    QuantizedMulKernel.__init__(self, cc_emitter, 'QuantizedStaticPreprocessed',
592                                'RowMajor', asm_emitter,
593                                _StaticQuantizationUInt8Transformation(),
594                                _RowMajorOutput())
595
596
597class QuantizedMulStaticAsInt32RowMajor(QuantizedMulKernel):
598  """."""
599
600  def __init__(self, cc_emitter, asm_emitter):
601    QuantizedMulKernel.__init__(self, cc_emitter,
602                                'QuantizedStaticPreprocessedAsInt32',
603                                'RowMajor', asm_emitter,
604                                _StaticQuantizationInt32Transformation(),
605                                _RowMajorOutput())
606
607
608class QuantizedMulStaticAsFloatRowMajor(QuantizedMulKernel):
609  """."""
610
611  def __init__(self, cc_emitter, asm_emitter):
612    QuantizedMulKernel.__init__(self, cc_emitter,
613                                'QuantizedStaticPreprocessedAsFloat',
614                                'RowMajor', asm_emitter,
615                                _StaticQuantizationFloatTransformation(),
616                                _RowMajorOutput())
617
618
619def GenerateKernels(cc_emitter, asm_emitter, shapes):
620  """Generate the quantized multiplication kernels for uint8 operands."""
621  quantized_mul_static_row_major = QuantizedMulStaticRowMajor(cc_emitter,
622                                                              asm_emitter)
623  quantized_mul_static_int32_row_major = QuantizedMulStaticAsInt32RowMajor(
624      cc_emitter, asm_emitter)
625
626  quantized_mul_static_float_row_major = QuantizedMulStaticAsFloatRowMajor(
627      cc_emitter, asm_emitter)
628
629  for shape in shapes:
630    quantized_mul_static_row_major.SpecializeMulKernel('uint8_t', 'uint8_t',
631                                                       shape[0], shape[1], 8)
632  for shape in shapes:
633    quantized_mul_static_int32_row_major.SpecializeMulKernel('uint8_t',
634                                                             'int32_t',
635                                                             shape[0], shape[1],
636                                                             8)
637
638  for shape in shapes:
639    quantized_mul_static_float_row_major.SpecializeMulKernel('uint8_t', 'float',
640                                                             shape[0], shape[1],
641                                                             8)
642