• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""."""
15
16import common
17
18
19def _AlignForLanes(lanes_count):
20  if lanes_count is 8 or lanes_count is 4:
21    return 256
22  elif lanes_count is 6 or lanes_count is 2:
23    return 128
24  else:
25    return 64
26
27
28def _AlignForSums(lanes_count):
29  if lanes_count is 8:
30    return 256
31  elif lanes_count in [2, 4, 6]:
32    return 128
33  else:
34    return 64
35
36
37def _GenerateInputs(emitter, registers, lanes_count, input_address, stride):
38  """."""
39  inputs = []
40  last_address_register = input_address
41  for i in range(lanes_count):
42    if not i:
43      inputs.append(input_address)
44    else:
45      address_register = registers.GeneralRegister()
46      inputs.append(address_register)
47      emitter.EmitAdd(address_register, last_address_register, stride)
48      last_address_register = address_register
49  return inputs
50
51
52def _GenerateClear(emitter, clear_type, block):
53  for row in block:
54    emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0))
55
56
57def _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count,
58                                aggregators, inputs, output):
59  """Emit inner loop code for reading N lanes and interweaving them."""
60  emitter.EmitNewline()
61  emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count,
62                                                        elements_count))
63
64  block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
65
66  if elements_count is not 8:
67    _GenerateClear(emitter, 'i8', block)
68
69  for (row, input_address) in zip(block, inputs):
70    emitter.EmitVLoadE(8, elements_count, row, input_address, None)
71
72  for (aggregator, row) in zip(aggregators, block):
73    emitter.EmitVAddw('u8', aggregator, aggregator, row)
74
75  emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
76                       _AlignForLanes(lanes_count))
77
78  registers.FreeRegisters(block)
79
80
81def _LoadMemoryParameter(emitter, registers, name, source):
82  register = registers.GeneralRegister()
83  emitter.EmitLdr(register, registers.MapMemoryParameter(name, source))
84  return register
85
86
87def _GenerateAggregatorReductionLowRegisters(emitter, registers,
88                                             aggregators, output_address):
89  emitter.EmitNewline()
90  emitter.EmitComment('Aggregator Reduction.')
91  _GenerateAggregatorReduction(
92      emitter, registers, aggregators, output_address,
93      _LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset',
94                           'params.multiplicative_sum_offset'),
95      _LoadMemoryParameter(emitter, registers, 'additive_sum_offset',
96                           'params.additive_sum_offset'))
97
98
99def _GenerateAggregatorReductionHighRegisters(emitter, registers,
100                                              aggregators, output_address):
101  emitter.EmitNewline()
102  emitter.EmitComment('Aggregator Reduction.')
103  _GenerateAggregatorReduction(
104      emitter, registers, aggregators, output_address,
105      registers.MapParameter('multiplicative_sum_offset',
106                             'params.multiplicative_sum_offset'),
107      registers.MapParameter('additive_sum_offset',
108                             'params.additive_sum_offset'))
109
110
111def _GenerateAggregatorReduction(emitter, registers, aggregators,
112                                 output_address, multiplicative_sum_offset,
113                                 additive_sum_offset):
114  """Reduce 4 lane sum aggregators to 1 value and store the sums."""
115  multiplier = registers.DoubleRegister()
116  emitter.EmitVMov('32',
117                   emitter.Lane(32, multiplier, 0), multiplicative_sum_offset)
118
119  offset = registers.QuadRegister()
120  emitter.EmitVDup('32', offset, additive_sum_offset)
121
122  for aggregator in aggregators:
123    emitter.EmitVPaddl('u16', aggregator, aggregator)
124
125  reduced_count = (len(aggregators) + 3) / 4
126  reduced = aggregators[:reduced_count]
127
128  emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators)
129
130  for temp in reduced:
131    emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0))
132
133  for temp in reduced:
134    emitter.EmitVAdd('i32', temp, temp, offset)
135
136  emitter.EmitVStoreA(1, 32, reduced,
137                      emitter.Dereference(output_address,
138                                          _AlignForSums(len(aggregators))))
139
140
141class RowMajorWithSumUInt8x8(common.StreamGenerator):
142  """."""
143
144  def __init__(self, emitter, asm_emitter):
145    common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum')
146    self.asm_emitter = asm_emitter
147
148  def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
149    assert pack_size is 8
150    assert in_type is 'uint8_t'
151
152    registers = self.asm_emitter.CreateRegisters()
153
154    self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
155
156    self.asm_emitter.PushIndent(self.emitter.indent)
157    self.asm_emitter.EmitAsmBegin()
158
159    count = registers.MapOutputParameter('count', 'params_count_copy')
160    output = registers.MapOutputParameter('out')
161    inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count,
162                             registers.MapOutputParameter('in'),
163                             registers.MapParameter('stride', 'params.stride'))
164    aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
165
166    _GenerateClear(self.asm_emitter, 'i16', aggregators)
167
168    if leftovers:
169      self.asm_emitter.EmitNewline()
170      self.asm_emitter.EmitComment('Reduce count by leftovers.')
171      self.asm_emitter.EmitSubs(count, count,
172                                self.asm_emitter.ImmediateConstant(leftovers))
173      self.asm_emitter.EmitBeqFront(2)
174
175    self.asm_emitter.EmitNewline()
176    self.asm_emitter.EmitNumericalLabel(1)
177    self.asm_emitter.EmitSubs(count, count,
178                              self.asm_emitter.ImmediateConstant(8))
179
180    _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
181                                aggregators, inputs, output)
182
183    self.asm_emitter.EmitNewline()
184    self.asm_emitter.EmitBneBack(1)
185
186    if leftovers:
187      self.asm_emitter.EmitNewline()
188      self.asm_emitter.EmitNumericalLabel(2)
189      _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count,
190                                  leftovers, aggregators, inputs, output)
191
192    registers.FreeRegisters(inputs)
193
194    if len(inputs) <= 6:
195      _GenerateAggregatorReductionHighRegisters(
196          self.asm_emitter, registers, aggregators, output)
197    else:
198      _GenerateAggregatorReductionLowRegisters(
199          self.asm_emitter, registers, aggregators, output)
200
201    self.asm_emitter.EmitAsmEnd(registers)
202    self.asm_emitter.PopIndent(len(self.emitter.indent))
203
204
205def _GenerateColLoadAggregateStore(emitter, registers, lanes_count,
206                                   elements_count, aggregators, input_address,
207                                   stride, output):
208  """Emit inner loop code for reading N col lanes and interweaving them."""
209  emitter.EmitNewline()
210  emitter.EmitComment('Load Aggregate Store - column major %dx%d' %
211                      (lanes_count, elements_count))
212
213  block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
214
215  if elements_count is not 8:
216    _GenerateClear(emitter, 'i8', block)
217
218  block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count,
219                                   block, input_address, stride)
220
221  for (aggregator, row) in zip(aggregators, block):
222    emitter.EmitVAddw('u8', aggregator, aggregator, row)
223
224  emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
225                       _AlignForLanes(lanes_count))
226
227  registers.FreeRegisters(block)
228
229
230class ColumnMajorWithSumUInt8x8(common.StreamGenerator):
231  """."""
232
233  def __init__(self, emitter, asm_emitter):
234    common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum')
235    self.asm_emitter = asm_emitter
236
237  def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
238    assert pack_size is 8
239    assert in_type is 'uint8_t'
240
241    registers = self.asm_emitter.CreateRegisters()
242
243    self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
244    self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride')
245
246    self.asm_emitter.PushIndent(self.emitter.indent)
247    self.asm_emitter.EmitAsmBegin()
248
249    count = registers.MapOutputParameter('count', 'params_count_copy')
250    input_address = registers.MapOutputParameter('in')
251    output_address = registers.MapOutputParameter('out')
252    aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
253    stride = registers.MapOutputParameter('stride', 'params_stride_copy')
254
255    self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride)
256
257    _GenerateClear(self.asm_emitter, 'i16', aggregators)
258
259    if leftovers:
260      self.asm_emitter.EmitNewline()
261      self.asm_emitter.EmitComment('Reduce count by leftovers.')
262      self.asm_emitter.EmitSubs(count, count,
263                                self.asm_emitter.ImmediateConstant(leftovers))
264      self.asm_emitter.EmitBeqFront(2)
265
266    self.asm_emitter.EmitNewline()
267    self.asm_emitter.EmitNumericalLabel(1)
268    self.asm_emitter.EmitSubs(count, count,
269                              self.asm_emitter.ImmediateConstant(8))
270
271    _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
272                                   aggregators, input_address, stride,
273                                   output_address)
274
275    self.asm_emitter.EmitNewline()
276    self.asm_emitter.EmitBneBack(1)
277
278    if leftovers:
279      self.asm_emitter.EmitNewline()
280      self.asm_emitter.EmitNumericalLabel(2)
281      _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count,
282                                     leftovers, aggregators, input_address,
283                                     stride, output_address)
284
285
286    _GenerateAggregatorReductionHighRegisters(
287        self.asm_emitter, registers, aggregators, output_address)
288
289    self.asm_emitter.EmitAsmEnd(registers)
290    self.asm_emitter.PopIndent(len(self.emitter.indent))
291
292
293def GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count):
294  row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter)
295  column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter)
296
297  for lanes_count in range(1, 1 + lanes_count):
298    for leftovers in range(8):
299      row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers)
300
301  for lanes_count in range(1, 1 + lanes_count):
302    for leftovers in range(8):
303      column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8,
304                                             leftovers)
305