1# Copyright 2016 The Gemmlowp Authors. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""".""" 15 16import common 17 18 19def _AlignForLanes(lanes_count): 20 if lanes_count is 8 or lanes_count is 4: 21 return 256 22 elif lanes_count is 6 or lanes_count is 2: 23 return 128 24 else: 25 return 64 26 27 28def _AlignForSums(lanes_count): 29 if lanes_count is 8: 30 return 256 31 elif lanes_count in [2, 4, 6]: 32 return 128 33 else: 34 return 64 35 36 37def _GenerateInputs(emitter, registers, lanes_count, input_address, stride): 38 """.""" 39 inputs = [] 40 last_address_register = input_address 41 for i in range(lanes_count): 42 if not i: 43 inputs.append(input_address) 44 else: 45 address_register = registers.GeneralRegister() 46 inputs.append(address_register) 47 emitter.EmitAdd(address_register, last_address_register, stride) 48 last_address_register = address_register 49 return inputs 50 51 52def _GenerateClear(emitter, clear_type, block): 53 for row in block: 54 emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0)) 55 56 57def _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count, 58 aggregators, inputs, output): 59 """Emit inner loop code for reading N lanes and interweaving them.""" 60 emitter.EmitNewline() 61 emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count, 62 elements_count)) 63 64 block = [registers.DoubleRegister() for unused_i in range(lanes_count)] 65 66 if elements_count is not 8: 67 _GenerateClear(emitter, 'i8', block) 68 69 for (row, input_address) in zip(block, inputs): 70 emitter.EmitVLoadE(8, elements_count, row, input_address, None) 71 72 for (aggregator, row) in zip(aggregators, block): 73 emitter.EmitVAddw('u8', aggregator, aggregator, row) 74 75 emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, 76 _AlignForLanes(lanes_count)) 77 78 registers.FreeRegisters(block) 79 80 81def _LoadMemoryParameter(emitter, registers, name, source): 82 register = registers.GeneralRegister() 83 emitter.EmitLdr(register, registers.MapMemoryParameter(name, source)) 84 return register 85 86 87def _GenerateAggregatorReductionLowRegisters(emitter, registers, 88 aggregators, output_address): 89 emitter.EmitNewline() 90 emitter.EmitComment('Aggregator Reduction.') 91 _GenerateAggregatorReduction( 92 emitter, registers, aggregators, output_address, 93 _LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset', 94 'params.multiplicative_sum_offset'), 95 _LoadMemoryParameter(emitter, registers, 'additive_sum_offset', 96 'params.additive_sum_offset')) 97 98 99def _GenerateAggregatorReductionHighRegisters(emitter, registers, 100 aggregators, output_address): 101 emitter.EmitNewline() 102 emitter.EmitComment('Aggregator Reduction.') 103 _GenerateAggregatorReduction( 104 emitter, registers, aggregators, output_address, 105 registers.MapParameter('multiplicative_sum_offset', 106 'params.multiplicative_sum_offset'), 107 registers.MapParameter('additive_sum_offset', 108 'params.additive_sum_offset')) 109 110 111def _GenerateAggregatorReduction(emitter, registers, aggregators, 112 output_address, multiplicative_sum_offset, 113 additive_sum_offset): 114 """Reduce 4 lane sum aggregators to 1 value and store the sums.""" 115 multiplier = registers.DoubleRegister() 116 emitter.EmitVMov('32', 117 emitter.Lane(32, multiplier, 0), multiplicative_sum_offset) 118 119 offset = registers.QuadRegister() 120 emitter.EmitVDup('32', offset, additive_sum_offset) 121 122 for aggregator in aggregators: 123 emitter.EmitVPaddl('u16', aggregator, aggregator) 124 125 reduced_count = (len(aggregators) + 3) / 4 126 reduced = aggregators[:reduced_count] 127 128 emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators) 129 130 for temp in reduced: 131 emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0)) 132 133 for temp in reduced: 134 emitter.EmitVAdd('i32', temp, temp, offset) 135 136 emitter.EmitVStoreA(1, 32, reduced, 137 emitter.Dereference(output_address, 138 _AlignForSums(len(aggregators)))) 139 140 141class RowMajorWithSumUInt8x8(common.StreamGenerator): 142 """.""" 143 144 def __init__(self, emitter, asm_emitter): 145 common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum') 146 self.asm_emitter = asm_emitter 147 148 def EmitPack(self, in_type, lanes_count, pack_size, leftovers): 149 assert pack_size is 8 150 assert in_type is 'uint8_t' 151 152 registers = self.asm_emitter.CreateRegisters() 153 154 self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') 155 156 self.asm_emitter.PushIndent(self.emitter.indent) 157 self.asm_emitter.EmitAsmBegin() 158 159 count = registers.MapOutputParameter('count', 'params_count_copy') 160 output = registers.MapOutputParameter('out') 161 inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count, 162 registers.MapOutputParameter('in'), 163 registers.MapParameter('stride', 'params.stride')) 164 aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] 165 166 _GenerateClear(self.asm_emitter, 'i16', aggregators) 167 168 if leftovers: 169 self.asm_emitter.EmitNewline() 170 self.asm_emitter.EmitComment('Reduce count by leftovers.') 171 self.asm_emitter.EmitSubs(count, count, 172 self.asm_emitter.ImmediateConstant(leftovers)) 173 self.asm_emitter.EmitBeqFront(2) 174 175 self.asm_emitter.EmitNewline() 176 self.asm_emitter.EmitNumericalLabel(1) 177 self.asm_emitter.EmitSubs(count, count, 178 self.asm_emitter.ImmediateConstant(8)) 179 180 _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, 181 aggregators, inputs, output) 182 183 self.asm_emitter.EmitNewline() 184 self.asm_emitter.EmitBneBack(1) 185 186 if leftovers: 187 self.asm_emitter.EmitNewline() 188 self.asm_emitter.EmitNumericalLabel(2) 189 _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 190 leftovers, aggregators, inputs, output) 191 192 registers.FreeRegisters(inputs) 193 194 if len(inputs) <= 6: 195 _GenerateAggregatorReductionHighRegisters( 196 self.asm_emitter, registers, aggregators, output) 197 else: 198 _GenerateAggregatorReductionLowRegisters( 199 self.asm_emitter, registers, aggregators, output) 200 201 self.asm_emitter.EmitAsmEnd(registers) 202 self.asm_emitter.PopIndent(len(self.emitter.indent)) 203 204 205def _GenerateColLoadAggregateStore(emitter, registers, lanes_count, 206 elements_count, aggregators, input_address, 207 stride, output): 208 """Emit inner loop code for reading N col lanes and interweaving them.""" 209 emitter.EmitNewline() 210 emitter.EmitComment('Load Aggregate Store - column major %dx%d' % 211 (lanes_count, elements_count)) 212 213 block = [registers.DoubleRegister() for unused_i in range(lanes_count)] 214 215 if elements_count is not 8: 216 _GenerateClear(emitter, 'i8', block) 217 218 block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count, 219 block, input_address, stride) 220 221 for (aggregator, row) in zip(aggregators, block): 222 emitter.EmitVAddw('u8', aggregator, aggregator, row) 223 224 emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, 225 _AlignForLanes(lanes_count)) 226 227 registers.FreeRegisters(block) 228 229 230class ColumnMajorWithSumUInt8x8(common.StreamGenerator): 231 """.""" 232 233 def __init__(self, emitter, asm_emitter): 234 common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum') 235 self.asm_emitter = asm_emitter 236 237 def EmitPack(self, in_type, lanes_count, pack_size, leftovers): 238 assert pack_size is 8 239 assert in_type is 'uint8_t' 240 241 registers = self.asm_emitter.CreateRegisters() 242 243 self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') 244 self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride') 245 246 self.asm_emitter.PushIndent(self.emitter.indent) 247 self.asm_emitter.EmitAsmBegin() 248 249 count = registers.MapOutputParameter('count', 'params_count_copy') 250 input_address = registers.MapOutputParameter('in') 251 output_address = registers.MapOutputParameter('out') 252 aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] 253 stride = registers.MapOutputParameter('stride', 'params_stride_copy') 254 255 self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride) 256 257 _GenerateClear(self.asm_emitter, 'i16', aggregators) 258 259 if leftovers: 260 self.asm_emitter.EmitNewline() 261 self.asm_emitter.EmitComment('Reduce count by leftovers.') 262 self.asm_emitter.EmitSubs(count, count, 263 self.asm_emitter.ImmediateConstant(leftovers)) 264 self.asm_emitter.EmitBeqFront(2) 265 266 self.asm_emitter.EmitNewline() 267 self.asm_emitter.EmitNumericalLabel(1) 268 self.asm_emitter.EmitSubs(count, count, 269 self.asm_emitter.ImmediateConstant(8)) 270 271 _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, 272 aggregators, input_address, stride, 273 output_address) 274 275 self.asm_emitter.EmitNewline() 276 self.asm_emitter.EmitBneBack(1) 277 278 if leftovers: 279 self.asm_emitter.EmitNewline() 280 self.asm_emitter.EmitNumericalLabel(2) 281 _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 282 leftovers, aggregators, input_address, 283 stride, output_address) 284 285 286 _GenerateAggregatorReductionHighRegisters( 287 self.asm_emitter, registers, aggregators, output_address) 288 289 self.asm_emitter.EmitAsmEnd(registers) 290 self.asm_emitter.PopIndent(len(self.emitter.indent)) 291 292 293def GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count): 294 row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter) 295 column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter) 296 297 for lanes_count in range(1, 1 + lanes_count): 298 for leftovers in range(8): 299 row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers) 300 301 for lanes_count in range(1, 1 + lanes_count): 302 for leftovers in range(8): 303 column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, 304 leftovers) 305