"""Mul primitive used by the GEMM function. The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs matrix multiplication on those resulting in a small 1x1 to 3x3 block of results. """ import neon_emitter class Error(Exception): """Module level error.""" class ConfigurationError(Error): """Unsupported configuration.""" class MulLanes(object): def __init__(self, input_address): self.input_address = input_address self.lanes = [] def AddLane(self, lane): self.lanes.append(lane) def FreeRegisters(self, registers): for i in range(0, len(self.lanes)): registers.FreeRegister(self.lanes[i]) self.lanes[i] = None def GenerateMulLanes(registers, lane_count, address): lanes = MulLanes(address) for unused_i in range(0, lane_count): lanes.AddLane(registers.DoubleRegister()) return lanes def Generate3MulLanes(quad_register, registers, address): lanes = MulLanes(address) lanes.AddLane(registers.Low(quad_register)) lanes.AddLane(registers.High(quad_register)) lanes.AddLane(registers.DoubleRegister()) return lanes def GenerateAndClearAggregators(emitter, registers, aggregator_count): """Prepare aggregators and emit aggregator clear code.""" emitter.EmitComment('Clear aggregators.') aggregators = [] for i in range(0, aggregator_count): aggregator = registers.QuadRegister() aggregators.append(aggregator) if i < 3: emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0)) else: emitter.EmitVMov('i32', aggregator, aggregators[i - 3]) emitter.EmitNewline() return aggregators def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, right_lanes, aggregators, count): """Emit inner loop for N rows x M cols multiplication.""" emitter.EmitComment('General NxM lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() emitter.EmitVLoadA('1.8', left_lanes.lanes, emitter.DereferenceIncrement(left_lanes.input_address, 64)) emitter.EmitVLoadA( '1.8', right_lanes.lanes, emitter.DereferenceIncrement(right_lanes.input_address, 64)) emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) emitter.EmitPldOffset(right_lanes.input_address, emitter.ImmediateConstant(64)) rows = len(left_lanes.lanes) cols = len(right_lanes.lanes) multiply_results = [] for i in range(0, rows * cols): multiply_results.append(registers.QuadRegister()) for row in range(0, rows): for col in range(0, cols): index = row * cols + col emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col], left_lanes.lanes[row]) for i in range(0, rows * cols): emitter.EmitVPadal('u16', aggregators[i], multiply_results[i]) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBneBack(1) emitter.EmitNewline() for register in multiply_results: registers.FreeRegister(register) def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, right_lanes, aggregators, count, backup_register): """Emit inner loop for 3 rows x 3 cols multiplication (register trick).""" emitter.EmitComment('3x3 lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() emitter.EmitVLoadA('1.8', left_lanes.lanes, emitter.DereferenceIncrement(left_lanes.input_address, 64)) emitter.EmitVLoadA( '1.8', right_lanes.lanes, emitter.DereferenceIncrement(right_lanes.input_address, 64)) emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) emitter.EmitPldOffset(right_lanes.input_address, emitter.ImmediateConstant(64)) temp = [] for unused_i in range(0, 4): temp.append(registers.QuadRegister()) emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0]) emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1]) emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2]) emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0]) emitter.EmitVPadal('u16', aggregators[0], temp[0]) emitter.EmitVPadal('u16', aggregators[1], temp[1]) emitter.EmitVPadal('u16', aggregators[2], temp[2]) emitter.EmitVPadal('u16', aggregators[3], temp[3]) emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1]) emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2]) emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0]) emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1]) emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2], right_lanes.lanes[2]) emitter.EmitVPadal('u16', aggregators[4], temp[0]) emitter.EmitVPadal('u16', aggregators[5], temp[1]) emitter.EmitVPadal('u16', aggregators[6], temp[2]) emitter.EmitVPadal('u16', aggregators[7], temp[3]) emitter.EmitVPadal('u16', aggregators[8], backup_register) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBneBack(1) emitter.EmitNewline() for register in temp: registers.FreeRegister(register) def ReadParams(emitter, registers, input_address, elements, min_reg): if elements == 1 or elements == 2: register = registers.DoubleRegister(min_reg * 2) emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) return register elif elements == 3: register = registers.QuadRegister(min_reg) emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) return register else: raise ConfigurationError('Unsupported elements no: %d' % elements) def Duplicate(emitter, registers, rows, cols, min_register, values): """Populate a grid of registers duplicating provided values.""" duplicated = [] if cols == 1 or cols == 2: for unused_i in range(0, rows): duplicated.append(registers.DoubleRegister(min_register)) elif cols == 3: for unused_i in range(0, rows): duplicated.append(registers.QuadRegister(min_register)) else: raise ConfigurationError('Unsupported duplicate amount: %d' % cols) if rows == 1: emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) elif rows == 2: emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1)) elif rows == 3: emitter.EmitVDup('32', duplicated[0], emitter.Lane( registers.Low(values), 0)) emitter.EmitVDup('32', duplicated[1], emitter.Lane( registers.Low(values), 1)) emitter.EmitVDup('32', duplicated[2], emitter.Lane( registers.High(values), 0)) return duplicated def DuplicateGeneralRegister(emitter, registers, cols, general_register, min_register): if cols == 1 or cols == 2: duplicated = registers.DoubleRegister(min_register) elif cols == 3: duplicated = registers.QuadRegister(min_register) else: raise ConfigurationError('Unsupported duplicate amount: %d' % cols) emitter.EmitVDup('32', duplicated, general_register) return duplicated def ReduceAggregator(emitter, registers, aggregators, row, cols): if cols == 1: register = registers.Low(aggregators[row]) emitter.EmitVPadd('u32', register, register, register) return register elif cols == 2: register = registers.Low(aggregators[row * 2]) emitter.EmitVPadd('u32', register, register, registers.Low(aggregators[row * 2 + 1])) return register elif cols == 3: register = aggregators[row * 3] emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register), registers.Low(aggregators[row * 3 + 1])) emitter.EmitVPadd('u32', registers.High(register), registers.Low(aggregators[row * 3 + 2]), registers.Low(aggregators[row * 3 + 2])) return register else: raise ConfigurationError('Unsupported columns no: %d' % cols) def StoreAggregator(emitter, registers, aggregator, cols, result_address, result_stride): if cols == 1: emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0), emitter.Dereference(result_address, None), result_stride) elif cols == 2: emitter.EmitVStoreOffset('1.32', aggregator, emitter.Dereference(result_address, None), result_stride) elif cols == 3: emitter.EmitVStore('1.32', registers.Low(aggregator), emitter.DereferenceIncrement(result_address, None)) emitter.EmitVStoreOffset('1.32', emitter.Lane( registers.High(aggregator), 0), emitter.Dereference(result_address, None), result_stride) emitter.EmitNewline() else: raise ConfigurationError('Unsupported columns no: %d' % cols) def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, lhs_add, rhs_add, left_lanes, right_lanes, results, results_stride): """Emit code that reduces 4 lane aggregators to 1 value, and stores them.""" rows = len(left_lanes.lanes) cols = len(right_lanes.lanes) if lhs_add: left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows, 4) left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset) else: left_offsets = None if rhs_add: right_offset = ReadParams(emitter, registers, right_lanes.input_address, cols, 4) else: right_offset = None if result_type is 'float': result_scale = DuplicateGeneralRegister( emitter, registers, cols, registers.MapParameter('result_scale'), 4) else: result_scale = None if cols == 3: emitter.EmitNewline() emitter.EmitComment('Change stride because storing in two ops.') emitter.EmitSub(results_stride, results_stride, emitter.ImmediateConstant(8)) emitter.EmitNewline() emitter.EmitComment('Horizontal reduce aggregators.') for aggregator in aggregators: emitter.EmitVPadd('u32', registers.Low(aggregator), registers.Low(aggregator), registers.High(aggregator)) emitter.EmitNewline() emitter.EmitComment('Reduce rows.') row_temps = [] for i in range(0, rows): row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols)) if lhs_add: emitter.EmitNewline() emitter.EmitComment('Add lhs offsets to aggregated rows.') for (row_temp, left_offset) in zip(row_temps, left_offsets): emitter.EmitVAdd('s32', row_temp, row_temp, left_offset) if rhs_add: emitter.EmitNewline() emitter.EmitComment('Add rhs offset to aggregated rows.') for row_temp in row_temps: emitter.EmitVAdd('s32', row_temp, row_temp, right_offset) if result_type is 'float': emitter.EmitNewline() emitter.EmitComment('Convert to float. Multiply by result scale.') for row_temp in row_temps: emitter.EmitVCvt('f32', 's32', row_temp, row_temp) for row_temp in row_temps: emitter.EmitVMul('f32', row_temp, row_temp, result_scale) emitter.EmitNewline() emitter.EmitComment('Store reduced rows.') for row_temp in row_temps: StoreAggregator(emitter, registers, row_temp, cols, results, results_stride) def BuildName(result_type, lhs_add, rhs_add, left, right): name = 'mul_%dx8_%dx8_%s' % (left, right, result_type) if lhs_add: name += '_lhsadd' if rhs_add: name += '_rhsadd' return name def CppResultType(result_type): if result_type is 'int32': return 'std::int32_t*' elif result_type is 'float': return 'float*' else: raise ConfigurationError('Unsupported result type: %s' % result_type) def GetParameters(result_type): params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'count'], [CppResultType(result_type), 'result'], ['std::int32_t', 'result_stride']] if result_type is 'float': params.append(['float', 'result_scale']) return params def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count, right_lanes_count): """Emit the multiply code for given rows and cols counts.""" if left_lanes_count < 1 or left_lanes_count > 3: raise ConfigurationError('Left_lanes should be: 1, 2 or 3.') if right_lanes_count < 1 or right_lanes_count > 3: raise ConfigurationError('Right_lanes should be: 1, 2 or 3.') emitter.EmitFunctionBeginA( BuildName(result_type, lhs_add, rhs_add, left_lanes_count, right_lanes_count), GetParameters(result_type), 'inline void') emitter.EmitAssert('count % 8 == 0') emitter.EmitAssert('count >= 8') emitter.EmitAsmBegin() registers = neon_emitter.NeonRegisters() count = registers.MapParameter('count') size = left_lanes_count * right_lanes_count if size < 9: aggregators = GenerateAndClearAggregators(emitter, registers, size) left_lanes = GenerateMulLanes(registers, left_lanes_count, registers.MapParameter('lhs')) right_lanes = GenerateMulLanes(registers, right_lanes_count, registers.MapParameter('rhs')) emitter.EmitPld(left_lanes.input_address) emitter.EmitPld(right_lanes.input_address) GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, right_lanes, aggregators, count) else: # left == 3 and right == 3 aggregators = GenerateAndClearAggregators(emitter, registers, size) backup_register = registers.QuadRegister() left_lanes = Generate3MulLanes(backup_register, registers, registers.MapParameter('lhs')) right_lanes = GenerateMulLanes(registers, right_lanes_count, registers.MapParameter('rhs')) emitter.EmitPld(left_lanes.input_address) emitter.EmitPld(right_lanes.input_address) Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, right_lanes, aggregators, count, backup_register) left_lanes.FreeRegisters(registers) right_lanes.FreeRegisters(registers) GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, lhs_add, rhs_add, left_lanes, right_lanes, registers.MapParameter('result'), registers.MapParameter('result_stride')) emitter.EmitAsmEnd(registers.MappedParameters(), [], registers.Clobbers() + ['cc', 'memory']) emitter.EmitFunctionEnd() def GenerateFunctions(emitter, result_type, lhs_add, rhs_add): for left_lanes in range(1, 4): for right_lanes in range(1, 4): GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes, right_lanes) emitter.EmitNewline()