"""Multiply primitive optimized for the gemv operation.""" import neon_emitter class Error(Exception): """Module level error.""" class ConfigurationError(Error): """Unsupported configuration.""" def GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators, count, lhs, rhs_1, rhs_2): """Emit inner loop for 1 row x M cols multiplication.""" emitter.EmitComment('General 1xM lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() right_load = [registers.DoubleRegister() for unused_i in range(4)] left_load = registers.DoubleRegister() emitter.EmitVLoad('1.8', left_load, emitter.DereferenceIncrement(lhs, 64)) emitter.EmitVLoadA('1.8', right_load, emitter.DereferenceIncrement(rhs_1, 64)) emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) emitter.EmitPldOffset(rhs_1, emitter.ImmediateConstant(128)) multiply_results = [registers.QuadRegister() for unused_i in range(4)] for i in range(4): emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load) emitter.EmitVLoadA('1.8', right_load[:lanes_count], emitter.DereferenceIncrement(rhs_2, 64)) emitter.EmitPldOffset(rhs_2, emitter.ImmediateConstant(lanes_count * 32)) for i in range(4): emitter.EmitVPadal('u16', aggregators[i], multiply_results[i]) for i in range(lanes_count): emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load) for i in range(lanes_count): emitter.EmitVPadal('u16', aggregators[i + 4], multiply_results[i]) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBneBack(1) emitter.EmitNewline() registers.FreeRegister(left_load) registers.FreeRegisters(right_load) registers.FreeRegisters(multiply_results) def ReadLeft(emitter, registers, lhs): register = registers.QuadRegister() emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)), emitter.AllLanes(registers.High(register))], emitter.Dereference(lhs, None)) return register def ReadRight(emitter, registers, rhs, count): if count == 1 or count == 2: register = registers.DoubleRegister() elif count == 3 or count == 4: register = registers.QuadRegister() else: raise ConfigurationError('Unsupported elements no: %d' % count) emitter.EmitVLoad('1.32', register, emitter.Dereference(rhs, 64)) return register def DuplicateGeneralRegister(emitter, registers, general_register, min_register): duplicated = registers.QuadRegister(min_register) emitter.EmitVDup('32', duplicated, general_register) return duplicated def GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators, result_type, lhs_add, rhs_add, lhs, rhs_1, rhs_2, results): """Generates assembly responsible for reducing the 4 way aggregators.""" if lhs_add: left_offset = ReadLeft(emitter, registers, lhs) else: left_offset = None if rhs_add: right_offset_1 = ReadRight(emitter, registers, rhs_1, 4) right_offset_2 = ReadRight(emitter, registers, rhs_2, lanes_count) else: right_offset_1 = None right_offset_2 = None if result_type is 'float': result_scale = DuplicateGeneralRegister( emitter, registers, registers.MapParameter('result_scale'), 4) else: result_scale = None emitter.EmitNewline() emitter.EmitComment('Horizontal reduce aggregators.') for aggregator in aggregators: emitter.EmitVPadd('u32', registers.Low(aggregator), registers.Low(aggregator), registers.High(aggregator)) temp = aggregators[0] emitter.EmitVPadd('u32', registers.Low(temp), registers.Low(aggregators[0]), registers.Low(aggregators[1])) emitter.EmitVPadd('u32', registers.High(temp), registers.Low(aggregators[2]), registers.Low(aggregators[3])) if lanes_count == 1: temp_2 = registers.Low(aggregators[1]) emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]), registers.Low(aggregators[4])) elif lanes_count == 2: temp_2 = registers.Low(aggregators[1]) emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]), registers.Low(aggregators[5])) elif lanes_count == 3: temp_2 = aggregators[1] emitter.EmitVPadd('u32', registers.Low(temp_2), registers.Low(aggregators[4]), registers.Low(aggregators[5])) emitter.EmitVPadd('u32', registers.High(temp_2), registers.Low(aggregators[6]), registers.Low(aggregators[6])) elif lanes_count == 4: temp_2 = aggregators[1] emitter.EmitVPadd('u32', registers.Low(temp_2), registers.Low(aggregators[4]), registers.Low(aggregators[5])) emitter.EmitVPadd('u32', registers.High(temp_2), registers.Low(aggregators[6]), registers.Low(aggregators[7])) else: temp_2 = None if lhs_add: emitter.EmitNewline() emitter.EmitComment('Add lhs offsets to aggregated rows.') emitter.EmitVAdd('s32', temp, temp, left_offset) if lanes_count == 1 or lanes_count == 2: emitter.EmitVAdd('s32', temp_2, temp_2, registers.Low(left_offset)) elif lanes_count == 3 or lanes_count == 4: emitter.EmitVAdd('s32', temp_2, temp_2, left_offset) if rhs_add: emitter.EmitNewline() emitter.EmitComment('Add rhs offset to aggregated rows.') emitter.EmitVAdd('s32', temp, temp, right_offset_1) emitter.EmitVAdd('s32', temp_2, temp_2, right_offset_2) if result_type is 'float': emitter.EmitNewline() emitter.EmitComment('Convert to float and scale.') emitter.EmitVCvt('f32', 's32', temp, temp) emitter.EmitVCvt('f32', 's32', temp_2, temp_2) emitter.EmitVMul('f32', temp, temp, result_scale) if lanes_count == 1 or lanes_count == 2: emitter.EmitVMul('f32', temp_2, temp_2, registers.Low(result_scale)) elif lanes_count == 3 or lanes_count == 4: emitter.EmitVMul('f32', temp_2, temp_2, result_scale) emitter.EmitNewline() emitter.EmitComment('Store results.') if lanes_count == 1: emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp)], emitter.DereferenceIncrement(results, None)) emitter.EmitVStore('1.32', emitter.Lane(temp_2, 0), emitter.Dereference(results, None)) elif lanes_count == 2: emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp), temp_2], emitter.Dereference(results, None)) elif lanes_count == 3: emitter.EmitVStoreA( '1.32', [registers.Low(temp), registers.High(temp), registers.Low(temp_2)], emitter.DereferenceIncrement(results, None)) emitter.EmitVStore('1.32', emitter.Lane( registers.High(temp_2), 0), emitter.Dereference(results, None)) elif lanes_count == 4: emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp), registers.Low(temp_2), registers.High(temp_2)], emitter.Dereference(results, None)) def BuildName(result_type, lhs_add, rhs_add, lanes): name = 'mul_1x8_%dx8_%s' % (lanes, result_type) if lhs_add: name += '_lhsadd' if rhs_add: name += '_rhsadd' return name def CppResultType(result_type): if result_type is 'int32': return 'std::int32_t*' elif result_type is 'float': return 'float*' else: raise ConfigurationError('Unsupported result type: %s' % result_type) def GetParameters(result_type): params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs_1'], ['const std::uint8_t*', 'rhs_2'], ['std::int32_t', 'count'], [CppResultType(result_type), 'result']] if result_type is 'float': params.append(['float', 'result_scale']) return params def GenerateAndClearAggregators(emitter, registers, aggregator_count): """Prepare aggregators and emit aggregator clear code.""" emitter.EmitNewline() emitter.EmitComment('Clear aggregators.') aggregators = [] for i in range(aggregator_count): aggregator = registers.QuadRegister() aggregators.append(aggregator) if i < 3: emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0)) else: emitter.EmitVMov('i32', aggregator, aggregators[i - 3]) emitter.EmitNewline() return aggregators def GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes_count): """Generates the 1xN multiplication primitive.""" if lanes_count < 1 or lanes_count > 4: raise ConfigurationError('Lanes should be: 1, 2, 3 or 4.') emitter.EmitFunctionBeginA( BuildName(result_type, lhs_add, rhs_add, lanes_count + 4), GetParameters(result_type), 'inline void') emitter.EmitAssert('count % 8 == 0') emitter.EmitAssert('count >= 8') emitter.EmitAsmBegin() registers = neon_emitter.NeonRegisters() count = registers.MapParameter('count') lhs = registers.MapParameter('lhs') rhs_1 = registers.MapParameter('rhs_1') rhs_2 = registers.MapParameter('rhs_2') emitter.EmitPld(lhs) emitter.EmitPld(rhs_1) emitter.EmitPld(rhs_2) aggregators = GenerateAndClearAggregators(emitter, registers, lanes_count + 4) GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators, count, lhs, rhs_1, rhs_2) GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators, result_type, lhs_add, rhs_add, lhs, rhs_1, rhs_2, registers.MapParameter('result')) emitter.EmitAsmEnd(registers.MappedParameters(), [], registers.Clobbers() + ['cc', 'memory']) emitter.EmitFunctionEnd() def GenerateFunctions(emitter, result_type, lhs_add, rhs_add): for lanes in range(1, 5): GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes) emitter.EmitNewline() if __name__ == '__main__': GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)