• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""."""
15
16_HEADER_COPYRIGHT = (
17    '''// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
18//
19// Licensed under the Apache License, Version 2.0 (the "License");
20// you may not use this file except in compliance with the License.
21// You may obtain a copy of the License at
22//
23//     http://www.apache.org/licenses/LICENSE-2.0
24//
25// Unless required by applicable law or agreed to in writing, software
26// distributed under the License is distributed on an "AS IS" BASIS,
27// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28// See the License for the specific language governing permissions and
29// limitations under the License.
30''')
31
32
33def GenerateHeader(cc, header_name, preprocessor_directive):
34  cc.EmitCodeNoSemicolon(_HEADER_COPYRIGHT)
35  cc.EmitHeaderBegin(header_name)
36
37  cc.EmitPreprocessor1('ifdef', preprocessor_directive)
38  cc.EmitNewline()
39
40  cc.EmitInclude('<cassert>')
41  cc.EmitInclude('<cstdint>')
42  cc.EmitNewline()
43
44
45def GenerateFooter(cc, message):
46  cc.EmitPreprocessor('else')
47  cc.EmitPreprocessor1('warning', '"%s"' % message)
48  cc.EmitPreprocessor('endif')
49  cc.EmitNewline()
50  cc.EmitHeaderEnd()
51
52
53def GenerateDebugLog(cc, message):
54  cc.EmitPreprocessor1('ifdef', 'DEBUG')
55  cc.EmitPreprocessor1('ifdef', 'DEBUG_METAGEMM_VERBOSE')
56  cc.EmitCode('std::cout << __FILE__ << \"(\" << __LINE__ << \") %s\" '
57              '<< std::endl << std::flush' % message)
58  cc.EmitPreprocessor('endif')
59  cc.EmitPreprocessor('endif')
60
61
62def _TemplateName(base, params):
63  return '%s<%s>' % (base, ', '.join(map(str, params)))
64
65
66class StreamGenerator(object):
67  """."""
68
69  def __init__(self, emitter, name):
70    self.name = name
71    self.emitter = emitter
72
73  def SpecializeStream(self, in_type, lanes_count, pack_size, leftovers):
74    if callable(getattr(self, 'EmitPack', None)):
75      template_params = [in_type, lanes_count, pack_size, leftovers, self.name]
76      self.emitter.EmitMemberFunctionBegin(
77          'Stream', [], template_params, 'Pack',
78          [['const %s*' % in_type, 'in'], ['const %s&' % self.name, 'params'],
79           ['%s*' % in_type, 'out']], 'inline void')
80      GenerateDebugLog(self.emitter,
81                       '%s::Pack()' % _TemplateName(self.name, template_params))
82      self.EmitPack(in_type, lanes_count, pack_size, leftovers)
83      self.emitter.EmitFunctionEnd()
84
85
86class MulKernelGenerator(object):
87  """."""
88
89  def __init__(self, emitter, kernel_name, output_stream_name):
90    self.kernel_name = kernel_name
91    self.output_stream_name = output_stream_name
92    self.emitter = emitter
93
94  def SpecializeMulKernel(self, in_type, out_type, kernel_m, kernel_n,
95                          pack_size):
96    """Generates the kernel wrapped in a MulKernel template specialization."""
97    template_params = [
98        in_type, out_type, self.kernel_name, self.output_stream_name, kernel_m,
99        kernel_n, pack_size
100    ]
101    self.emitter.EmitMemberFunctionBegin(
102        'MulKernel', [], template_params, 'Multiply',
103        [['const %s*' % in_type, 'lhs'], ['const %s*' % in_type, 'rhs'], [
104            'const FusedKernelParams<%s, %s>&' % (self.kernel_name,
105                                                  self.output_stream_name),
106            'params'
107        ], ['%s*' % out_type, 'result']], 'inline void')
108    GenerateDebugLog(self.emitter, '%s::Multiply()' %
109                     _TemplateName(self.kernel_name + self.output_stream_name,
110                                   template_params))
111    self.EmitMultiply(in_type, out_type, kernel_m, kernel_n, pack_size)
112    self.emitter.EmitFunctionEnd()
113
114
115class Transform1DKernelGenerator(object):
116  """."""
117
118  def __init__(self, emitter, kernel_name):
119    self.kernel_name = kernel_name
120    self.emitter = emitter
121
122  def SpecializeTransform1DKernel(self, in_type, out_type, kernel_size,
123                                  leftovers):
124    """Generates the kernel wrapped in a Transform1DKernel specialization."""
125    template_params = [
126        in_type, out_type, self.kernel_name, kernel_size, leftovers
127    ]
128    self.emitter.EmitMemberFunctionBegin(
129        'Transform1DKernel', [], template_params, 'Transform',
130        [['const %s*' % in_type, 'input'],
131         ['const %s&' % self.kernel_name, 'params'],
132         ['%s*' % out_type, 'output']], 'inline void')
133    GenerateDebugLog(self.emitter, '%s::Transform()' %
134                     _TemplateName(self.kernel_name, template_params))
135    self.EmitTransform(in_type, out_type, kernel_size, leftovers)
136    self.emitter.EmitFunctionEnd()
137