1# Copyright 2016 The Gemmlowp Authors. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""".""" 15 16_HEADER_COPYRIGHT = ( 17 '''// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. 18// 19// Licensed under the Apache License, Version 2.0 (the "License"); 20// you may not use this file except in compliance with the License. 21// You may obtain a copy of the License at 22// 23// http://www.apache.org/licenses/LICENSE-2.0 24// 25// Unless required by applicable law or agreed to in writing, software 26// distributed under the License is distributed on an "AS IS" BASIS, 27// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 28// See the License for the specific language governing permissions and 29// limitations under the License. 30''') 31 32 33def GenerateHeader(cc, header_name, preprocessor_directive): 34 cc.EmitCodeNoSemicolon(_HEADER_COPYRIGHT) 35 cc.EmitHeaderBegin(header_name) 36 37 cc.EmitPreprocessor1('ifdef', preprocessor_directive) 38 cc.EmitNewline() 39 40 cc.EmitInclude('<cassert>') 41 cc.EmitInclude('<cstdint>') 42 cc.EmitNewline() 43 44 45def GenerateFooter(cc, message): 46 cc.EmitPreprocessor('else') 47 cc.EmitPreprocessor1('warning', '"%s"' % message) 48 cc.EmitPreprocessor('endif') 49 cc.EmitNewline() 50 cc.EmitHeaderEnd() 51 52 53def GenerateDebugLog(cc, message): 54 cc.EmitPreprocessor1('ifdef', 'DEBUG') 55 cc.EmitPreprocessor1('ifdef', 'DEBUG_METAGEMM_VERBOSE') 56 cc.EmitCode('std::cout << __FILE__ << \"(\" << __LINE__ << \") %s\" ' 57 '<< std::endl << std::flush' % message) 58 cc.EmitPreprocessor('endif') 59 cc.EmitPreprocessor('endif') 60 61 62def _TemplateName(base, params): 63 return '%s<%s>' % (base, ', '.join(map(str, params))) 64 65 66class StreamGenerator(object): 67 """.""" 68 69 def __init__(self, emitter, name): 70 self.name = name 71 self.emitter = emitter 72 73 def SpecializeStream(self, in_type, lanes_count, pack_size, leftovers): 74 if callable(getattr(self, 'EmitPack', None)): 75 template_params = [in_type, lanes_count, pack_size, leftovers, self.name] 76 self.emitter.EmitMemberFunctionBegin( 77 'Stream', [], template_params, 'Pack', 78 [['const %s*' % in_type, 'in'], ['const %s&' % self.name, 'params'], 79 ['%s*' % in_type, 'out']], 'inline void') 80 GenerateDebugLog(self.emitter, 81 '%s::Pack()' % _TemplateName(self.name, template_params)) 82 self.EmitPack(in_type, lanes_count, pack_size, leftovers) 83 self.emitter.EmitFunctionEnd() 84 85 86class MulKernelGenerator(object): 87 """.""" 88 89 def __init__(self, emitter, kernel_name, output_stream_name): 90 self.kernel_name = kernel_name 91 self.output_stream_name = output_stream_name 92 self.emitter = emitter 93 94 def SpecializeMulKernel(self, in_type, out_type, kernel_m, kernel_n, 95 pack_size): 96 """Generates the kernel wrapped in a MulKernel template specialization.""" 97 template_params = [ 98 in_type, out_type, self.kernel_name, self.output_stream_name, kernel_m, 99 kernel_n, pack_size 100 ] 101 self.emitter.EmitMemberFunctionBegin( 102 'MulKernel', [], template_params, 'Multiply', 103 [['const %s*' % in_type, 'lhs'], ['const %s*' % in_type, 'rhs'], [ 104 'const FusedKernelParams<%s, %s>&' % (self.kernel_name, 105 self.output_stream_name), 106 'params' 107 ], ['%s*' % out_type, 'result']], 'inline void') 108 GenerateDebugLog(self.emitter, '%s::Multiply()' % 109 _TemplateName(self.kernel_name + self.output_stream_name, 110 template_params)) 111 self.EmitMultiply(in_type, out_type, kernel_m, kernel_n, pack_size) 112 self.emitter.EmitFunctionEnd() 113 114 115class Transform1DKernelGenerator(object): 116 """.""" 117 118 def __init__(self, emitter, kernel_name): 119 self.kernel_name = kernel_name 120 self.emitter = emitter 121 122 def SpecializeTransform1DKernel(self, in_type, out_type, kernel_size, 123 leftovers): 124 """Generates the kernel wrapped in a Transform1DKernel specialization.""" 125 template_params = [ 126 in_type, out_type, self.kernel_name, kernel_size, leftovers 127 ] 128 self.emitter.EmitMemberFunctionBegin( 129 'Transform1DKernel', [], template_params, 'Transform', 130 [['const %s*' % in_type, 'input'], 131 ['const %s&' % self.kernel_name, 'params'], 132 ['%s*' % out_type, 'output']], 'inline void') 133 GenerateDebugLog(self.emitter, '%s::Transform()' % 134 _TemplateName(self.kernel_name, template_params)) 135 self.EmitTransform(in_type, out_type, kernel_size, leftovers) 136 self.emitter.EmitFunctionEnd() 137