• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Generates the whole gemm header.
2
3"""
4
5import cc_emitter
6import mul_Nx8_Mx8_neon
7import neon_emitter
8import qnt_Nx8_neon
9import zip_Nx8_neon
10
11_HEADER_COPYRIGHT = """// Copyright 2015 Google Inc. All Rights Reserved.
12//
13// Licensed under the Apache License, Version 2.0 (the "License");
14// you may not use this file except in compliance with the License.
15// You may obtain a copy of the License at
16//
17//     http://www.apache.org/licenses/LICENSE-2.0
18//
19// Unless required by applicable law or agreed to in writing, software
20// distributed under the License is distributed on an "AS IS" BASIS,
21// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22// See the License for the specific language governing permissions and
23// limitations under the License.
24//
25// single_thread_gemm.h: programatically generated GEMM library header.
26"""
27
28_QUANTIZED_8BIT = 'quantized_8bit'
29_FULL_32BIT = 'full_32bit'
30_FULL_FLOAT = 'full_float'
31
32
33class Error(Exception):
34  """Module level error."""
35
36
37class ConfigurationError(Error):
38  """Runtime configuration error."""
39
40
41def GenerateCommonTempsCountersAndConsts(emitter, rows):
42  emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3')
43  emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3')
44  emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
45  emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3')
46  emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
47                      '(padded_k + 16) * 3')
48  emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size',
49                      '(padded_k + 16) * n')
50  emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs')
51  emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
52  emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
53  emitter.EmitDeclare(
54      'std::int32_t*', 'zipped_lhs_3_offsets',
55      'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3)')
56  if rows is not 0:
57    emitter.EmitDeclare(
58        'std::int32_t*', 'zipped_lhs_%d_offsets' % rows,
59        'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows)
60  emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs',
61                      'scratch + zipped_chunk_size')
62  emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs')
63  emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride',
64                      'result_stride * 3')
65  emitter.EmitNewline()
66
67
68def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows):
69  """Generates all the boilerplate variables for the q8 gemm function."""
70  GenerateCommonTempsCountersAndConsts(emitter, rows)
71  emitter.EmitDeclare('const std::int32_t', 'const_offset',
72                      'lhs_offset * rhs_offset * k + result_offset')
73  emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
74                      '(1 << (shift - 1))')
75  emitter.EmitDeclare('std::int32_t*', 'temp_result',
76                      'reinterpret_cast<std::int32_t*>('
77                      'scratch + zipped_chunk_size + zipped_rhs_size)')
78  emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result')
79  emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
80  emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
81                      '((n * 4 + 7) / 8) * 8')
82  emitter.EmitNewline()
83
84
85def GenerateFullTempsCountersAndConsts(emitter, result_type, rows):
86  """Generates all the boilerplate variables for the int32 and float gemms."""
87  GenerateCommonTempsCountersAndConsts(emitter, rows)
88  emitter.EmitDeclare('const std::int32_t', 'const_offset',
89                      'lhs_offset * rhs_offset * k')
90  emitter.EmitDeclare(result_type, 'result_chunk', 'result')
91  emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
92  emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
93                      'result_stride * 4')
94  emitter.EmitNewline()
95
96
97def ZipName(rows, leftovers, aligned):
98  return zip_Nx8_neon.BuildName(rows, leftovers, aligned)
99
100
101def GenerateZipRhs(emitter, aligned, cols, leftovers):
102  """Emits the code responsible for zipping the rhs matrix."""
103  emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
104  emitter.EmitCall(
105      ZipName(3, leftovers, aligned),
106      ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
107  emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
108  emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
109  emitter.EmitCloseBracket()
110
111  if cols is not 0:
112    emitter.EmitCall(
113        ZipName(cols, leftovers, aligned),
114        ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
115  emitter.EmitNewline()
116
117
118def MulName(result_type, lhs_add, rhs_add, rows, cols):
119  return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols)
120
121
122def GetMulParams(result_type):
123  params = ['zipped_lhs', 'zipped_rhs_chunk', 'padded_k', 'mul_result_chunk',
124            'mul_result_chunk_stride_bytes']
125  if result_type is 'float':
126    params.append('result_scale')
127  return params
128
129
130def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned,
131                    rows, cols, leftovers):
132  """Emits code responsible for multiplication of one horizontal lhs strip."""
133  emitter.EmitCall(
134      ZipName(rows, leftovers, aligned),
135      ['lhs_chunk', 'k', 'k', 'zipped_lhs', 'rhs_offset', 'const_offset'])
136  emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs')
137  emitter.EmitAssign('mul_result_chunk', result)
138
139  emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)')
140
141  emitter.EmitCall(
142      MulName(result_type, lhs_add, rhs_add, rows, 3),
143      GetMulParams(result_type))
144  emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
145  emitter.EmitAssignIncrement('mul_result_chunk', 3)
146
147  emitter.EmitCloseBracket()
148
149  if cols is not 0:
150    emitter.EmitCall(
151        MulName(result_type, lhs_add, rhs_add, rows, cols),
152        GetMulParams(result_type))
153
154
155def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers):
156  """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
157  emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
158  GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3,
159                  cols, leftovers)
160  emitter.EmitCall(
161      BuildMultiQuantizeName(aligned, 3),
162      ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
163       'zipped_lhs_3_offsets', 'result_chunk', 'result_stride',
164       'multiplicative_offset', 'rounding_offset', '-shift'])
165  emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
166  emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
167  emitter.EmitCloseBracket()
168  emitter.EmitNewline()
169
170  if rows is not 0:
171    GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
172                    cols, leftovers)
173    emitter.EmitCall(
174        BuildMultiQuantizeName(aligned, rows),
175        ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
176         'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride',
177         'multiplicative_offset', 'rounding_offset', '-shift'])
178
179
180def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers):
181  emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
182  GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3,
183                  cols, leftovers)
184  emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
185  emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
186  emitter.EmitCloseBracket()
187  emitter.EmitNewline()
188
189  if rows is not 0:
190    GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned,
191                    rows, cols, leftovers)
192
193
194def BuildName(output_type, aligned, rows, cols, leftover):
195  name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover)
196  if aligned:
197    name += '_aligned'
198  return name
199
200
201def GetCommonGemmParameters():
202  return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'],
203          ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'm'],
204          ['std::int32_t', 'n'], ['std::int32_t', 'k'],
205          ['std::int32_t', 'lhs_offset'], ['std::int32_t', 'rhs_offset']]
206
207
208def GetGemmParameters(output_type, extra_params=None):
209  """Prepares a (type, parameter) array for the gemm functions."""
210  if extra_params is None:
211    extra_params = []
212  params = GetCommonGemmParameters()
213  if output_type is _QUANTIZED_8BIT:
214    params += [['std::int32_t', 'result_offset'],
215               ['std::int32_t', 'multiplicative_offset'],
216               ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']]
217  elif output_type is _FULL_32BIT:
218    params += [['std::int32_t*', 'result']]
219  elif output_type is _FULL_FLOAT:
220    params += [['float', 'result_scale'], ['float*', 'result']]
221  else:
222    raise ConfigurationError('Unsupported output type: %s' % output_type)
223  return params + extra_params
224
225
226def GetStridedGemmParameters(output_type):
227  return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']])
228
229
230def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers):
231  """Build one gemm function for given row, col, and depth leftovers."""
232  emitter.EmitFunctionBeginA(
233      BuildName(output_type, aligned, rows, cols, leftovers),
234      GetStridedGemmParameters(output_type), 'void')
235
236  emitter.EmitAssert('m %% 3 == %d' % rows)
237  emitter.EmitAssert('n %% 3 == %d' % cols)
238  emitter.EmitAssert('k %% 8 == %d' % leftovers)
239
240  if output_type is _QUANTIZED_8BIT:
241    GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
242    GenerateZipRhs(emitter, aligned, cols, leftovers)
243    GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
244  elif output_type is _FULL_32BIT:
245    GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
246    GenerateZipRhs(emitter, aligned, cols, leftovers)
247    GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
248  elif output_type is _FULL_FLOAT:
249    GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
250    GenerateZipRhs(emitter, aligned, cols, leftovers)
251    GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
252  else:
253    raise ConfigurationError('Unknown output type: %s' % output_type)
254
255  emitter.EmitFunctionEnd()
256
257
258def BuildMultiQuantizeName(aligned, rows):
259  name = 'multi_qnt_%dx8' % rows
260  if aligned:
261    name = '%s_aligned' % name
262  return name
263
264
265def GenerateMultiQuantize(emitter, aligned, rows):
266  """Emit main quantization code that switches between optimized versions."""
267  name = BuildMultiQuantizeName(aligned, rows)
268  emitter.EmitFunctionBeginA(
269      name,
270      [['const std::int32_t*', 'source'], ['std::int32_t', 'count'],
271       ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'],
272       ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'],
273       ['std::int32_t', 'multiplicative_offset'],
274       ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void')
275  emitter.EmitSwitch('count % 8')
276
277  for leftovers in range(0, 8):
278    emitter.EmitCase(leftovers)
279    emitter.PushIndent()
280    emitter.EmitCall(
281        qnt_Nx8_neon.BuildName(rows, leftovers, aligned),
282        ['source', 'count', 'stride', 'offsets', 'destination',
283         'destination_stride', 'multiplicative_offset', 'rounding_offset',
284         'shift'])
285    emitter.EmitBreak()
286    emitter.PopIndent()
287
288  emitter.EmitSwitchEnd()
289  emitter.EmitFunctionEnd()
290
291
292def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers):
293  emitter.EmitCall(
294      emitter.Scope('internal',
295                    BuildName(output_type, aligned, m_mod, n_mod, leftovers)),
296      [p for (unused_t, p) in GetStridedGemmParameters(output_type)])
297
298
299def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod):
300  """Third level of main switch, choose optimized version on depth leftover."""
301  emitter.EmitSwitch('k % 8')
302
303  for leftovers in range(0, 8):
304    emitter.EmitCase(leftovers)
305    emitter.PushIndent()
306    GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers)
307    emitter.EmitBreak()
308    emitter.PopIndent()
309
310  emitter.EmitSwitchEnd()
311
312
313def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod):
314  """Second level of main switch, choose optimized version on cols leftover."""
315  emitter.EmitSwitch('n % 3')
316
317  for n_mod in range(0, 3):
318    emitter.EmitCase(n_mod)
319    emitter.PushIndent()
320    GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod)
321    emitter.EmitBreak()
322    emitter.PopIndent()
323
324  emitter.EmitSwitchEnd()
325
326
327def GenerateGemmSwitch1(emitter, output_type, aligned):
328  """First level of main switch, choose optimized version on rows leftover."""
329  emitter.EmitSwitch('m % 3')
330
331  for m_mod in range(0, 3):
332    emitter.EmitCase(m_mod)
333    emitter.PushIndent()
334    GenerateGemmSwitch2(emitter, output_type, aligned, m_mod)
335    emitter.EmitBreak()
336    emitter.PopIndent()
337
338  emitter.EmitSwitchEnd()
339
340
341def BuildMainGemmName(output_type):
342  if output_type is _QUANTIZED_8BIT:
343    return 'gemm_q8'
344  elif output_type is _FULL_32BIT:
345    return 'gemm_i32'
346  elif output_type is _FULL_FLOAT:
347    return 'gemm_f'
348  else:
349    raise ConfigurationError('Unsupported output type: %s' % output_type)
350
351
352def BuildStridedMainGemmName(output_type):
353  return BuildMainGemmName(output_type) + '_strided'
354
355
356def GenerateMainGemmFunction(emitter, output_type):
357  """Emit high level gemm function that switches between optimized versions."""
358  emitter.EmitFunctionBeginA(
359      BuildStridedMainGemmName(output_type),
360      GetStridedGemmParameters(output_type), 'void')
361
362  emitter.EmitDeclare('const bool', 'lhs_aligned',
363                      '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)')
364  emitter.EmitDeclare('const bool', 'rhs_aligned',
365                      '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)')
366  emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
367
368  if output_type is _QUANTIZED_8BIT:
369    emitter.EmitDeclare('const bool', 'result_aligned',
370                        '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)')
371    emitter.EmitDeclare('const bool', 'result_stride_aligned',
372                        '((result_stride % 8) == 0)')
373    emitter.EmitDeclare('const bool', 'aligned',
374                        'lhs_aligned && rhs_aligned && result_aligned '
375                        '&& k_aligned && result_stride_aligned')
376  else:
377    emitter.EmitDeclare('const bool', 'aligned',
378                        'lhs_aligned && rhs_aligned && k_aligned')
379
380  emitter.EmitIf('aligned')
381  GenerateGemmSwitch1(emitter, output_type, True)
382  emitter.EmitElse()
383  GenerateGemmSwitch1(emitter, output_type, False)
384  emitter.EmitEndif()
385  emitter.EmitFunctionEnd()
386
387
388def GenerateWrapperGemmFunction(emitter, output_type):
389  emitter.EmitFunctionBeginA(
390      BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void')
391  emitter.EmitCall(
392      BuildStridedMainGemmName(output_type),
393      [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n'])
394  emitter.EmitFunctionEnd()
395
396
397def GenerateInternalFunctions(emitter):
398  """Generate all the functions hidden in the internal namespace."""
399  zip_Nx8_neon.GenerateFunctions(neon_emitter.NeonEmitter())
400  emitter.EmitNewline()
401
402  mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', False,
403                                     True)
404  emitter.EmitNewline()
405
406  mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True,
407                                     True)
408  emitter.EmitNewline()
409
410  mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'float', True,
411                                     True)
412  emitter.EmitNewline()
413
414  qnt_Nx8_neon.GenerateFunctions(neon_emitter.NeonEmitter())
415  emitter.EmitNewline()
416
417  for aligned in [True, False]:
418    for rows in range(1, 4):
419      GenerateMultiQuantize(emitter, aligned, rows)
420      emitter.EmitNewline()
421
422  for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
423    for aligned in [True, False]:
424      for rows in range(0, 3):
425        for cols in range(0, 3):
426          for leftover in range(0, 8):
427            GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)
428            emitter.EmitNewline()
429
430
431def Main():
432  """Generate the single threaded meta gemm library."""
433  emitter = cc_emitter.CCEmitter()
434
435  emitter.EmitCodeNoSemicolon(_HEADER_COPYRIGHT)
436  emitter.EmitHeaderBegin('gemmlowp_meta_single_thread_gemm')
437
438  emitter.EmitPreprocessor1('ifdef', 'GEMMLOWP_NEON_32')
439  emitter.EmitNewline()
440
441  emitter.EmitInclude('<cassert>')
442  emitter.EmitNewline()
443
444  emitter.EmitNamespaceBegin('gemmlowp')
445  emitter.EmitNamespaceBegin('meta')
446  emitter.EmitNamespaceBegin('internal')
447  emitter.EmitNewline()
448
449  GenerateInternalFunctions(emitter)
450
451  emitter.EmitNamespaceEnd()
452  emitter.EmitNewline()
453
454  GenerateMainGemmFunction(emitter, _QUANTIZED_8BIT)
455  emitter.EmitNewline()
456  GenerateMainGemmFunction(emitter, _FULL_32BIT)
457  emitter.EmitNewline()
458  GenerateMainGemmFunction(emitter, _FULL_FLOAT)
459  emitter.EmitNewline()
460  GenerateWrapperGemmFunction(emitter, _QUANTIZED_8BIT)
461  emitter.EmitNewline()
462  GenerateWrapperGemmFunction(emitter, _FULL_32BIT)
463  emitter.EmitNewline()
464  GenerateWrapperGemmFunction(emitter, _FULL_FLOAT)
465  emitter.EmitNewline()
466
467  emitter.EmitNamespaceEnd()
468  emitter.EmitNamespaceEnd()
469  emitter.EmitNewline()
470
471  emitter.EmitPreprocessor('else')
472  emitter.EmitPreprocessor1('warning',
473                            '"Meta gemm fast-path requires GEMMLOWP_NEON_32!"')
474  emitter.EmitPreprocessor('endif')
475  emitter.EmitNewline()
476
477  emitter.EmitHeaderEnd()
478
479
480if __name__ == '__main__':
481  Main()
482