Lines Matching refs:emitter
20 def GenerateCommonTempsCountersAndConsts(emitter, rows): argument
21 emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3')
22 emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3')
23 emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
24 emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3')
25 emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
27 emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size',
29 emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs')
30 emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
31 emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
32 emitter.EmitDeclare(
36 emitter.EmitDeclare(
39 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs',
41 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs')
42 emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride',
44 emitter.EmitNewline()
47 def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows): argument
49 GenerateCommonTempsCountersAndConsts(emitter, rows)
50 emitter.EmitDeclare('const std::int32_t', 'const_offset',
52 emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
54 emitter.EmitDeclare('std::int32_t*', 'temp_result',
57 emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result')
58 emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
59 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
61 emitter.EmitNewline()
64 def GenerateFullTempsCountersAndConsts(emitter, result_type, rows): argument
66 GenerateCommonTempsCountersAndConsts(emitter, rows)
67 emitter.EmitDeclare('const std::int32_t', 'const_offset',
69 emitter.EmitDeclare(result_type, 'result_chunk', 'result')
70 emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
71 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
73 emitter.EmitNewline()
80 def GenerateZipRhs(emitter, aligned, cols, leftovers): argument
82 emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
83 emitter.EmitCall(
86 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
87 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
88 emitter.EmitCloseBracket()
91 emitter.EmitCall(
94 emitter.EmitNewline()
109 def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned, argument
112 emitter.EmitCall(
115 emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs')
116 emitter.EmitAssign('mul_result_chunk', result)
118 emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)')
120 emitter.EmitCall(
123 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
124 emitter.EmitAssignIncrement('mul_result_chunk', 3)
126 emitter.EmitCloseBracket()
129 emitter.EmitCall(
134 def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers): argument
136 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
137 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3,
139 emitter.EmitCall(
144 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
145 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
146 emitter.EmitCloseBracket()
147 emitter.EmitNewline()
150 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
152 emitter.EmitCall(
159 def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers): argument
160 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
161 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3,
163 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
164 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
165 emitter.EmitCloseBracket()
166 emitter.EmitNewline()
169 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned,
209 def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers): argument
211 emitter.EmitFunctionBeginA(
215 emitter.EmitAssert('m %% 3 == %d' % rows)
216 emitter.EmitAssert('n %% 3 == %d' % cols)
217 emitter.EmitAssert('k %% 8 == %d' % leftovers)
220 GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
221 GenerateZipRhs(emitter, aligned, cols, leftovers)
222 GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
224 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
225 GenerateZipRhs(emitter, aligned, cols, leftovers)
226 GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
228 GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
229 GenerateZipRhs(emitter, aligned, cols, leftovers)
230 GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
234 emitter.EmitFunctionEnd()
237 def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers): argument
238 emitter.EmitCall(
239 emitter.Scope('internal',
244 def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod): argument
246 emitter.EmitSwitch('k % 8')
249 emitter.EmitCase(leftovers)
250 emitter.PushIndent()
251 GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers)
252 emitter.EmitBreak()
253 emitter.PopIndent()
255 emitter.EmitSwitchEnd()
258 def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod): argument
260 emitter.EmitSwitch('n % 3')
263 emitter.EmitCase(n_mod)
264 emitter.PushIndent()
265 GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod)
266 emitter.EmitBreak()
267 emitter.PopIndent()
269 emitter.EmitSwitchEnd()
272 def GenerateGemmSwitch1(emitter, output_type, aligned): argument
274 emitter.EmitSwitch('m % 3')
277 emitter.EmitCase(m_mod)
278 emitter.PushIndent()
279 GenerateGemmSwitch2(emitter, output_type, aligned, m_mod)
280 emitter.EmitBreak()
281 emitter.PopIndent()
283 emitter.EmitSwitchEnd()
301 def GenerateMainGemmFunction(emitter, output_type): argument
303 emitter.EmitFunctionBeginA(
307 emitter.EmitDeclare('const bool', 'lhs_aligned',
309 emitter.EmitDeclare('const bool', 'rhs_aligned',
311 emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
314 emitter.EmitDeclare('const bool', 'result_aligned',
316 emitter.EmitDeclare('const bool', 'result_stride_aligned',
318 emitter.EmitDeclare('const bool', 'aligned',
322 emitter.EmitDeclare('const bool', 'aligned',
325 emitter.EmitIf('aligned')
326 GenerateGemmSwitch1(emitter, output_type, True)
327 emitter.EmitElse()
328 GenerateGemmSwitch1(emitter, output_type, False)
329 emitter.EmitEndif()
330 emitter.EmitFunctionEnd()
333 def GenerateWrapperGemmFunction(emitter, output_type): argument
334 emitter.EmitFunctionBeginA(
336 emitter.EmitCall(
339 emitter.EmitFunctionEnd()
342 def GenerateInternalFunctions(emitter): argument
349 GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)
350 emitter.EmitNewline()
353 def GeneratePublicFunctions(emitter): argument
355 GenerateMainGemmFunction(emitter, output_type)
356 emitter.EmitNewline()
358 GenerateWrapperGemmFunction(emitter, output_type)
359 emitter.EmitNewline()