• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""."""
15import collections
16
17_HEADER_COPYRIGHT = (
18    '''// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
19//
20// Licensed under the Apache License, Version 2.0 (the "License");
21// you may not use this file except in compliance with the License.
22// You may obtain a copy of the License at
23//
24//     http://www.apache.org/licenses/LICENSE-2.0
25//
26// Unless required by applicable law or agreed to in writing, software
27// distributed under the License is distributed on an "AS IS" BASIS,
28// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29// See the License for the specific language governing permissions and
30// limitations under the License.
31''')
32
33
34def GenerateHeader(cc, header_name, preprocessor_directive):
35  cc.EmitCodeNoSemicolon(_HEADER_COPYRIGHT)
36  cc.EmitHeaderBegin(header_name)
37
38  cc.EmitPreprocessor1('ifdef', preprocessor_directive)
39  cc.EmitNewline()
40
41  cc.EmitInclude('<cassert>')
42  cc.EmitInclude('<cstdint>')
43  cc.EmitNewline()
44
45
46def GenerateFooter(cc, message):
47  cc.EmitPreprocessor('else')
48  cc.EmitPreprocessor1('warning', '"%s"' % message)
49  cc.EmitPreprocessor('endif')
50  cc.EmitNewline()
51  cc.EmitHeaderEnd()
52
53
54def GenerateDebugLog(cc, message):
55  cc.EmitPreprocessor1('ifdef', 'DEBUG')
56  cc.EmitPreprocessor1('ifdef', 'DEBUG_METAGEMM_VERBOSE')
57  cc.EmitCode('std::cout << __FILE__ << \"(\" << __LINE__ << \") %s\" '
58              '<< std::endl << std::flush' % message)
59  cc.EmitPreprocessor('endif')
60  cc.EmitPreprocessor('endif')
61
62
63def _TemplateName(base, params):
64  return '%s<%s>' % (base, ', '.join(map(str, params)))
65
66
67class StreamGenerator(object):
68  """."""
69
70  def __init__(self, emitter, name):
71    self.name = name
72    self.emitter = emitter
73
74  def SpecializeStream(self, in_type, lanes_count, pack_size, leftovers):
75    if isinstance(getattr(self, 'EmitPack', None), collections.Callable):
76      template_params = [in_type, lanes_count, pack_size, leftovers, self.name]
77      self.emitter.EmitMemberFunctionBegin(
78          'Stream', [], template_params, 'Pack',
79          [['const %s*' % in_type, 'in'], ['const %s&' % self.name, 'params'],
80           ['%s*' % in_type, 'out']], 'inline void')
81      GenerateDebugLog(self.emitter,
82                       '%s::Pack()' % _TemplateName(self.name, template_params))
83      self.EmitPack(in_type, lanes_count, pack_size, leftovers)
84      self.emitter.EmitFunctionEnd()
85
86
87class MulKernelGenerator(object):
88  """."""
89
90  def __init__(self, emitter, kernel_name, output_stream_name):
91    self.kernel_name = kernel_name
92    self.output_stream_name = output_stream_name
93    self.emitter = emitter
94
95  def SpecializeMulKernel(self, in_type, out_type, kernel_m, kernel_n,
96                          pack_size):
97    """Generates the kernel wrapped in a MulKernel template specialization."""
98    template_params = [
99        in_type, out_type, self.kernel_name, self.output_stream_name, kernel_m,
100        kernel_n, pack_size
101    ]
102    self.emitter.EmitMemberFunctionBegin(
103        'MulKernel', [], template_params, 'Multiply',
104        [['const %s*' % in_type, 'lhs'], ['const %s*' % in_type, 'rhs'], [
105            'const FusedKernelParams<%s, %s>&' % (self.kernel_name,
106                                                  self.output_stream_name),
107            'params'
108        ], ['%s*' % out_type, 'result']], 'inline void')
109    GenerateDebugLog(self.emitter, '%s::Multiply()' %
110                     _TemplateName(self.kernel_name + self.output_stream_name,
111                                   template_params))
112    self.EmitMultiply(in_type, out_type, kernel_m, kernel_n, pack_size)
113    self.emitter.EmitFunctionEnd()
114
115
116class Transform1DKernelGenerator(object):
117  """."""
118
119  def __init__(self, emitter, kernel_name):
120    self.kernel_name = kernel_name
121    self.emitter = emitter
122
123  def SpecializeTransform1DKernel(self, in_type, out_type, kernel_size,
124                                  leftovers):
125    """Generates the kernel wrapped in a Transform1DKernel specialization."""
126    template_params = [
127        in_type, out_type, self.kernel_name, kernel_size, leftovers
128    ]
129    self.emitter.EmitMemberFunctionBegin(
130        'Transform1DKernel', [], template_params, 'Transform',
131        [['const %s*' % in_type, 'input'],
132         ['const %s&' % self.kernel_name, 'params'],
133         ['%s*' % out_type, 'output']], 'inline void')
134    GenerateDebugLog(self.emitter, '%s::Transform()' %
135                     _TemplateName(self.kernel_name, template_params))
136    self.EmitTransform(in_type, out_type, kernel_size, leftovers)
137    self.emitter.EmitFunctionEnd()
138