# Copyright 2016 The Gemmlowp Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """.""" import common def _AlignForLanes(lanes_count): if lanes_count is 8 or lanes_count is 4: return 256 elif lanes_count is 6 or lanes_count is 2: return 128 else: return 64 def _AlignForSums(lanes_count): if lanes_count is 8: return 256 elif lanes_count in [2, 4, 6]: return 128 else: return 64 def _GenerateInputs(emitter, registers, lanes_count, input_address, stride): """.""" inputs = [] last_address_register = input_address for i in range(lanes_count): if not i: inputs.append(input_address) else: address_register = registers.GeneralRegister() inputs.append(address_register) emitter.EmitAdd(address_register, last_address_register, stride) last_address_register = address_register return inputs def _GenerateClear(emitter, clear_type, block): for row in block: emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0)) def _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count, aggregators, inputs, output): """Emit inner loop code for reading N lanes and interweaving them.""" emitter.EmitNewline() emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count, elements_count)) block = [registers.DoubleRegister() for unused_i in range(lanes_count)] if elements_count is not 8: _GenerateClear(emitter, 'i8', block) for (row, input_address) in zip(block, inputs): emitter.EmitVLoadE(8, elements_count, row, input_address, None) for (aggregator, row) in zip(aggregators, block): emitter.EmitVAddw('u8', aggregator, aggregator, row) emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, _AlignForLanes(lanes_count)) registers.FreeRegisters(block) def _LoadMemoryParameter(emitter, registers, name, source): register = registers.GeneralRegister() emitter.EmitLdr(register, registers.MapMemoryParameter(name, source)) return register def _GenerateAggregatorReductionLowRegisters(emitter, registers, aggregators, output_address): emitter.EmitNewline() emitter.EmitComment('Aggregator Reduction.') _GenerateAggregatorReduction( emitter, registers, aggregators, output_address, _LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset', 'params.multiplicative_sum_offset'), _LoadMemoryParameter(emitter, registers, 'additive_sum_offset', 'params.additive_sum_offset')) def _GenerateAggregatorReductionHighRegisters(emitter, registers, aggregators, output_address): emitter.EmitNewline() emitter.EmitComment('Aggregator Reduction.') _GenerateAggregatorReduction( emitter, registers, aggregators, output_address, registers.MapParameter('multiplicative_sum_offset', 'params.multiplicative_sum_offset'), registers.MapParameter('additive_sum_offset', 'params.additive_sum_offset')) def _GenerateAggregatorReduction(emitter, registers, aggregators, output_address, multiplicative_sum_offset, additive_sum_offset): """Reduce 4 lane sum aggregators to 1 value and store the sums.""" multiplier = registers.DoubleRegister() emitter.EmitVMov('32', emitter.Lane(32, multiplier, 0), multiplicative_sum_offset) offset = registers.QuadRegister() emitter.EmitVDup('32', offset, additive_sum_offset) for aggregator in aggregators: emitter.EmitVPaddl('u16', aggregator, aggregator) reduced_count = (len(aggregators) + 3) / 4 reduced = aggregators[:reduced_count] emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators) for temp in reduced: emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0)) for temp in reduced: emitter.EmitVAdd('i32', temp, temp, offset) emitter.EmitVStoreA(1, 32, reduced, emitter.Dereference(output_address, _AlignForSums(len(aggregators)))) class RowMajorWithSumUInt8x8(common.StreamGenerator): """.""" def __init__(self, emitter, asm_emitter): common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum') self.asm_emitter = asm_emitter def EmitPack(self, in_type, lanes_count, pack_size, leftovers): assert pack_size is 8 assert in_type is 'uint8_t' registers = self.asm_emitter.CreateRegisters() self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') self.asm_emitter.PushIndent(self.emitter.indent) self.asm_emitter.EmitAsmBegin() count = registers.MapOutputParameter('count', 'params_count_copy') output = registers.MapOutputParameter('out') inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count, registers.MapOutputParameter('in'), registers.MapParameter('stride', 'params.stride')) aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] _GenerateClear(self.asm_emitter, 'i16', aggregators) if leftovers: self.asm_emitter.EmitNewline() self.asm_emitter.EmitComment('Reduce count by leftovers.') self.asm_emitter.EmitSubs(count, count, self.asm_emitter.ImmediateConstant(leftovers)) self.asm_emitter.EmitBeqFront(2) self.asm_emitter.EmitNewline() self.asm_emitter.EmitNumericalLabel(1) self.asm_emitter.EmitSubs(count, count, self.asm_emitter.ImmediateConstant(8)) _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, aggregators, inputs, output) self.asm_emitter.EmitNewline() self.asm_emitter.EmitBneBack(1) if leftovers: self.asm_emitter.EmitNewline() self.asm_emitter.EmitNumericalLabel(2) _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, leftovers, aggregators, inputs, output) registers.FreeRegisters(inputs) if len(inputs) <= 6: _GenerateAggregatorReductionHighRegisters( self.asm_emitter, registers, aggregators, output) else: _GenerateAggregatorReductionLowRegisters( self.asm_emitter, registers, aggregators, output) self.asm_emitter.EmitAsmEnd(registers) self.asm_emitter.PopIndent(len(self.emitter.indent)) def _GenerateColLoadAggregateStore(emitter, registers, lanes_count, elements_count, aggregators, input_address, stride, output): """Emit inner loop code for reading N col lanes and interweaving them.""" emitter.EmitNewline() emitter.EmitComment('Load Aggregate Store - column major %dx%d' % (lanes_count, elements_count)) block = [registers.DoubleRegister() for unused_i in range(lanes_count)] if elements_count is not 8: _GenerateClear(emitter, 'i8', block) block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count, block, input_address, stride) for (aggregator, row) in zip(aggregators, block): emitter.EmitVAddw('u8', aggregator, aggregator, row) emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, _AlignForLanes(lanes_count)) registers.FreeRegisters(block) class ColumnMajorWithSumUInt8x8(common.StreamGenerator): """.""" def __init__(self, emitter, asm_emitter): common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum') self.asm_emitter = asm_emitter def EmitPack(self, in_type, lanes_count, pack_size, leftovers): assert pack_size is 8 assert in_type is 'uint8_t' registers = self.asm_emitter.CreateRegisters() self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride') self.asm_emitter.PushIndent(self.emitter.indent) self.asm_emitter.EmitAsmBegin() count = registers.MapOutputParameter('count', 'params_count_copy') input_address = registers.MapOutputParameter('in') output_address = registers.MapOutputParameter('out') aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] stride = registers.MapOutputParameter('stride', 'params_stride_copy') self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride) _GenerateClear(self.asm_emitter, 'i16', aggregators) if leftovers: self.asm_emitter.EmitNewline() self.asm_emitter.EmitComment('Reduce count by leftovers.') self.asm_emitter.EmitSubs(count, count, self.asm_emitter.ImmediateConstant(leftovers)) self.asm_emitter.EmitBeqFront(2) self.asm_emitter.EmitNewline() self.asm_emitter.EmitNumericalLabel(1) self.asm_emitter.EmitSubs(count, count, self.asm_emitter.ImmediateConstant(8)) _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, aggregators, input_address, stride, output_address) self.asm_emitter.EmitNewline() self.asm_emitter.EmitBneBack(1) if leftovers: self.asm_emitter.EmitNewline() self.asm_emitter.EmitNumericalLabel(2) _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, leftovers, aggregators, input_address, stride, output_address) _GenerateAggregatorReductionHighRegisters( self.asm_emitter, registers, aggregators, output_address) self.asm_emitter.EmitAsmEnd(registers) self.asm_emitter.PopIndent(len(self.emitter.indent)) def GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count): row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter) column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter) for lanes_count in range(1, 1 + lanes_count): for leftovers in range(8): row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers) for lanes_count in range(1, 1 + lanes_count): for leftovers in range(8): column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers)