# Copyright 2016 The Gemmlowp Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """32bit ARM/NEON assembly emitter. Used by code generators to produce ARM assembly with NEON simd code. Provides tools for easier register management: named register variable allocation/deallocation, and offers a more procedural/structured approach to generating assembly. TODO: right now neon emitter prints out assembly instructions immediately, it might be beneficial to keep the whole structure and emit the assembly after applying some optimizations like: instruction reordering or register reuse. TODO: NeonRegister object assigns explicit registers at allocation time. Similarily to emiting code, register mapping and reuse can be performed and optimized lazily. """ class Error(Exception): """Module level error.""" class RegisterAllocationError(Error): """Cannot alocate registers.""" class LaneError(Error): """Wrong lane number.""" class ArgumentError(Error): """Wrong argument.""" def _Low(register): assert register[0] == 'q' num = int(register[1:]) return 'd%d' % (num * 2) def _High(register): assert register[0] == 'q' num = int(register[1:]) return 'd%d' % (num * 2 + 1) def _ExpandQuads(registers): doubles = [] for register in registers: if register[0] == 'q': doubles.append(_Low(register)) doubles.append(_High(register)) else: doubles.append(register) return doubles def _MakeCompatible(op1, op2, op3): if op1[0] == 'd' or op2[0] == 'd' or op3[0] == 'd': if op1[0] == 'q': op1 = _Low(op1) if op2[0] == 'q': op2 = _Low(op2) if op3[0] == 'q': op3 = _Low(op3) return (op1, op2, op3) class _NeonRegisters32Bit(object): """Utility that keeps track of used 32bit ARM/NEON registers.""" def __init__(self): self.double = set() self.double_ever = set() self.general = set() self.general_ever = set() self.parameters = dict() self.output_parameters = dict() def MapParameter(self, parameter, parameter_value=None): if not parameter_value: parameter_value = parameter self.parameters[parameter] = (parameter_value, 'r') return '%%[%s]' % parameter def MapMemoryParameter(self, parameter, parameter_value=None): if not parameter_value: parameter_value = parameter self.parameters[parameter] = (parameter_value, 'm') return '%%[%s]' % parameter def MapOutputParameter(self, parameter, parameter_value=None): if not parameter_value: parameter_value = parameter self.output_parameters[parameter] = (parameter_value, '+r') return '%%[%s]' % parameter def DoubleRegister(self, min_val=0): for i in range(min_val, 32): if i not in self.double: self.double.add(i) self.double_ever.add(i) return 'd%d' % i raise RegisterAllocationError('Not enough double registers.') def QuadRegister(self, min_val=0): for i in range(min_val, 16): if ((i * 2) not in self.double) and ((i * 2 + 1) not in self.double): self.double.add(i * 2) self.double.add(i * 2 + 1) self.double_ever.add(i * 2) self.double_ever.add(i * 2 + 1) return 'q%d' % i raise RegisterAllocationError('Not enough quad registers.') def GeneralRegister(self): for i in range(0, 16): if i not in self.general: self.general.add(i) self.general_ever.add(i) return 'r%d' % i raise RegisterAllocationError('Not enough general registers.') def MappedParameters(self): return [(k, v) for (k, v) in self.parameters.items()] def MappedOutputParameters(self): return [(k, v) for (k, v) in self.output_parameters.items()] def Clobbers(self): return (['r%d' % i for i in self.general_ever] + ['d%d' % i for i in self.DoubleClobbers()]) def DoubleClobbers(self): return sorted(self.double_ever) def FreeRegister(self, register): assert len(register) > 1 if register[0] not in ['r', 'd', 'q']: return num = int(register[1:]) if register[0] == 'r': assert num in self.general self.general.remove(num) elif register[0] == 'd': assert num in self.double self.double.remove(num) elif register[0] == 'q': assert num * 2 in self.double assert num * 2 + 1 in self.double self.double.remove(num * 2) self.double.remove(num * 2 + 1) else: raise RegisterDeallocationError('Register not allocated: %s' % register) def FreeRegisters(self, registers): for register in registers: self.FreeRegister(register) class NeonEmitter(object): """Emits ARM/NEON assembly opcodes.""" def __init__(self, debug=False): self.ops = {} self.indent = '' self.debug = debug def PushIndent(self, delta=' '): self.indent += delta def PopIndent(self, delta=2): self.indent = self.indent[:-delta] def EmitIndented(self, what): print self.indent + what def PushOp(self, op): if op in self.ops.keys(): self.ops[op] += 1 else: self.ops[op] = 1 def ClearCounters(self): self.ops.clear() def EmitNewline(self): print '' def EmitPreprocessor1(self, op, param): print '#%s %s' % (op, param) def EmitPreprocessor(self, op): print '#%s' % op def EmitInclude(self, include): self.EmitPreprocessor1('include', include) def EmitCall1(self, function, param): self.EmitIndented('%s(%s);' % (function, param)) def EmitAssert(self, assert_expression): if self.debug: self.EmitCall1('assert', assert_expression) def EmitHeaderBegin(self, header_name, includes): self.EmitPreprocessor1('ifndef', (header_name + '_H_').upper()) self.EmitPreprocessor1('define', (header_name + '_H_').upper()) self.EmitNewline() if includes: for include in includes: self.EmitInclude(include) self.EmitNewline() def EmitHeaderEnd(self): self.EmitPreprocessor('endif') def EmitCode(self, code): self.EmitIndented('%s;' % code) def EmitFunctionBeginA(self, function_name, params, return_type): self.EmitIndented('%s %s(%s) {' % (return_type, function_name, ', '.join(['%s %s' % (t, n) for (t, n) in params]))) self.PushIndent() def EmitFunctionEnd(self): self.PopIndent() self.EmitIndented('}') def EmitAsmBegin(self): self.EmitIndented('asm volatile(') self.PushIndent() def EmitAsmMapping(self, elements): if elements: self.EmitIndented(': ' + ', '.join( ['[%s] "%s"(%s)' % (d, v[1], v[0]) for (d, v) in elements])) else: self.EmitIndented(':') def EmitClobbers(self, elements): if elements: self.EmitIndented(': ' + ', '.join(['"%s"' % c for c in elements])) else: self.EmitIndented(':') def EmitAsmEnd(self, registers): self.EmitAsmMapping(registers.MappedOutputParameters()) self.EmitAsmMapping(registers.MappedParameters()) self.EmitClobbers(registers.Clobbers() + ['cc', 'memory']) self.PopIndent() self.EmitIndented(');') def EmitComment(self, comment): self.EmitIndented('// ' + comment) def EmitNumericalLabel(self, label): self.EmitIndented('"%d:"' % label) def EmitOp1(self, op, param1): self.PushOp(op) self.EmitIndented('"%s %s\\n"' % (op, param1)) def EmitOp2(self, op, param1, param2): self.PushOp(op) self.EmitIndented('"%s %s, %s\\n"' % (op, param1, param2)) def EmitOp3(self, op, param1, param2, param3): self.PushOp(op) self.EmitIndented('"%s %s, %s, %s\\n"' % (op, param1, param2, param3)) def EmitAdd(self, destination, source, param): self.EmitOp3('add', destination, source, param) def EmitSubs(self, destination, source, param): self.EmitOp3('subs', destination, source, param) def EmitSub(self, destination, source, param): self.EmitOp3('sub', destination, source, param) def EmitMul(self, destination, source, param): self.EmitOp3('mul', destination, source, param) def EmitMov(self, param1, param2): self.EmitOp2('mov', param1, param2) def EmitBeqBack(self, label): self.EmitOp1('beq', '%db' % label) def EmitBeqFront(self, label): self.EmitOp1('beq', '%df' % label) def EmitBgtBack(self, label): self.EmitOp1('bgt', '%db' % label) def EmitBgtFront(self, label): self.EmitOp1('bgt', '%df' % label) def EmitBleBack(self, label): self.EmitOp1('ble', '%db' % label) def EmitBleFront(self, label): self.EmitOp1('ble', '%df' % label) def EmitBneBack(self, label): self.EmitOp1('bne', '%db' % label) def EmitBneFront(self, label): self.EmitOp1('bne', '%df' % label) def EmitVAdd(self, add_type, destination, source_1, source_2): destination, source_1, source_2 = _MakeCompatible(destination, source_1, source_2) self.EmitOp3('vadd.%s' % add_type, destination, source_1, source_2) def EmitVAddw(self, add_type, destination, source_1, source_2): self.EmitOp3('vaddw.%s' % add_type, destination, source_1, source_2) def EmitVSub(self, sub_type, destination, source_1, source_2): destination, source_1, source_2 = _MakeCompatible(destination, source_1, source_2) self.EmitOp3('vsub.%s' % sub_type, destination, source_1, source_2) def EmitVCvt(self, cvt_to, cvt_from, destination, source): self.EmitOp2('vcvt.%s.%s' % (cvt_to, cvt_from), destination, source) def EmitVDup(self, dup_type, destination, source): self.EmitOp2('vdup.%s' % dup_type, destination, source) def EmitVMax(self, size, destination, source_1, source_2): self.EmitOp3('vmax.%s' % size, destination, source_1, source_2) def EmitVMin(self, size, destination, source_1, source_2): self.EmitOp3('vmin.%s' % size, destination, source_1, source_2) def EmitVMov(self, mov_type, destination, source): self.EmitOp2('vmov.%s' % mov_type, destination, source) def EmitVMovl(self, mov_type, destination, source): if source[0] == 'q': source = _Low(source) self.EmitOp2('vmovl.%s' % mov_type, destination, source) def EmitVMovl2(self, mov_type, destination_1, destination_2, source): self.EmitVMovl(mov_type, destination_2, _High(source)) self.EmitVMovl(mov_type, destination_1, _Low(source)) def EmitVQmovn(self, mov_type, destination, source): if destination[0] == 'q': destination = _Low(destination) self.EmitOp2('vqmovn.%s' % mov_type, destination, source) def EmitVQmovn2(self, mov_type, destination, source_1, source_2): self.EmitVQmovn(mov_type, _Low(destination), source_1) self.EmitVQmovn(mov_type, _High(destination), source_2) def EmitVQmovun(self, mov_type, destination, source): if destination[0] == 'q': destination = _Low(destination) self.EmitOp2('vqmovun.%s' % mov_type, destination, source) def EmitVQmovun2(self, mov_type, destination, source_1, source_2): self.EmitVQmovun(mov_type, _Low(destination), source_1) self.EmitVQmovun(mov_type, _High(destination), source_2) def EmitVMul(self, mul_type, destination, source_1, source_2): destination, source_1, source_2 = _MakeCompatible(destination, source_1, source_2) self.EmitOp3('vmul.%s' % mul_type, destination, source_1, source_2) def EmitVMulScalar(self, mul_type, destination, source_1, source_2): self.EmitOp3('vmul.%s' % mul_type, destination, source_1, source_2) def EmitVMull(self, mul_type, destination, source_1, source_2): self.EmitOp3('vmull.%s' % mul_type, destination, source_1, source_2) def EmitVPadd(self, add_type, destination, source_1, source_2): self.EmitOp3('vpadd.%s' % add_type, destination, source_1, source_2) def EmitVPaddl(self, add_type, destination, source): self.EmitOp2('vpaddl.%s' % add_type, destination, source) def EmitVPadal(self, add_type, destination, source): self.EmitOp2('vpadal.%s' % add_type, destination, source) def EmitLdr(self, register, value): self.EmitOp2('ldr', register, value) def EmitVLoad(self, load_no, load_type, destination, source): self.EmitVLoadA(load_no, load_type, [destination], source) def EmitVLoadA(self, load_no, load_type, destinations, source): self.EmitOp2('vld%d.%d' % (load_no, load_type), '{%s}' % ', '.join(_ExpandQuads(destinations)), source) def EmitVLoadAE(self, load_type, elem_count, destinations, source, alignment=None): bits_to_load = load_type * elem_count destinations = _ExpandQuads(destinations) if len(destinations) * 64 < bits_to_load: raise ArgumentError('To few destinations: %d to load %d bits.' % (len(destinations), bits_to_load)) while bits_to_load > 0: if bits_to_load >= 256: self.EmitVLoadA(1, 32, destinations[:4], self.DereferenceIncrement(source, alignment)) bits_to_load -= 256 destinations = destinations[4:] elif bits_to_load >= 192: self.EmitVLoadA(1, 32, destinations[:3], self.DereferenceIncrement(source, alignment)) bits_to_load -= 192 destinations = destinations[3:] elif bits_to_load >= 128: self.EmitVLoadA(1, 32, destinations[:2], self.DereferenceIncrement(source, alignment)) bits_to_load -= 128 destinations = destinations[2:] elif bits_to_load >= 64: self.EmitVLoad(1, 32, destinations[0], self.DereferenceIncrement(source, alignment)) bits_to_load -= 64 destinations = destinations[1:] else: destination = destinations[0] if bits_to_load == 56: self.EmitVLoad(1, 32, self.Lane(32, destination, 0), self.DereferenceIncrement(source)) self.EmitVLoad(1, 16, self.Lane(16, destination, 2), self.DereferenceIncrement(source)) self.EmitVLoad(1, 8, self.Lane(8, destination, 6), self.DereferenceIncrement(source)) elif bits_to_load == 48: self.EmitVLoad(1, 32, self.Lane(32, destination, 0), self.DereferenceIncrement(source)) self.EmitVLoad(1, 16, self.Lane(16, destination, 2), self.DereferenceIncrement(source)) elif bits_to_load == 40: self.EmitVLoad(1, 32, self.Lane(32, destination, 0), self.DereferenceIncrement(source)) self.EmitVLoad(1, 8, self.Lane(8, destination, 4), self.DereferenceIncrement(source)) elif bits_to_load == 32: self.EmitVLoad(1, 32, self.Lane(32, destination, 0), self.DereferenceIncrement(source)) elif bits_to_load == 24: self.EmitVLoad(1, 16, self.Lane(16, destination, 0), self.DereferenceIncrement(source)) self.EmitVLoad(1, 8, self.Lane(8, destination, 2), self.DereferenceIncrement(source)) elif bits_to_load == 16: self.EmitVLoad(1, 16, self.Lane(16, destination, 0), self.DereferenceIncrement(source)) elif bits_to_load == 8: self.EmitVLoad(1, 8, self.Lane(8, destination, 0), self.DereferenceIncrement(source)) else: raise ArgumentError('Wrong leftover: %d' % bits_to_load) return def EmitVLoadE(self, load_type, count, destination, source, alignment=None): self.EmitVLoadAE(load_type, count, [destination], source, alignment) def EmitVLoadAllLanes(self, load_no, load_type, destination, source): destinations = [] if destination[0] == 'q': destinations.append(self.AllLanes(_Low(destination))) destinations.append(self.AllLanes(_High(destination))) else: destinations.append(self.AllLanes(destination)) self.EmitVLoadA(load_no, load_type, destinations, source) def EmitVLoadOffset(self, load_no, load_type, destination, source, offset): self.EmitVLoadOffsetA(load_no, load_type, [destination], source, offset) def EmitVLoadOffsetA(self, load_no, load_type, destinations, source, offset): assert len(destinations) <= 4 self.EmitOp3('vld%d.%d' % (load_no, load_type), '{%s}' % ', '.join(_ExpandQuads(destinations)), source, offset) def EmitPld(self, load_address_register): self.EmitOp1('pld', '[%s]' % load_address_register) def EmitPldw(self, store_address_register): self.EmitOp1('pldw', '[%s]' % store_address_register) def EmitPldOffset(self, load_address_register, offset): self.EmitOp1('pld', '[%s, %s]' % (load_address_register, offset)) def EmitPldwOffset(self, store_address_register, offset): self.EmitOp1('pldw', '[%s, %s]' % (store_address_register, offset)) def EmitVShl(self, shift_type, destination, source, shift): self.EmitOp3('vshl.%s' % shift_type, destination, source, shift) def EmitVStore(self, store_no, store_type, source, destination): self.EmitVStoreA(store_no, store_type, [source], destination) def EmitVStoreA(self, store_no, store_type, sources, destination): self.EmitOp2('vst%d.%d' % (store_no, store_type), '{%s}' % ', '.join(_ExpandQuads(sources)), destination) def EmitVStoreAE(self, store_type, elem_count, sources, destination, alignment=None): bits_to_store = store_type * elem_count sources = _ExpandQuads(sources) if len(sources) * 64 < bits_to_store: raise ArgumentError('To few sources: %d to store %d bits.' % (len(sources), bits_to_store)) while bits_to_store > 0: if bits_to_store >= 256: self.EmitVStoreA(1, 32, sources[:4], self.DereferenceIncrement(destination, alignment)) bits_to_store -= 256 sources = sources[4:] elif bits_to_store >= 192: self.EmitVStoreA(1, 32, sources[:3], self.DereferenceIncrement(destination, alignment)) bits_to_store -= 192 sources = sources[3:] elif bits_to_store >= 128: self.EmitVStoreA(1, 32, sources[:2], self.DereferenceIncrement(destination, alignment)) bits_to_store -= 128 sources = sources[2:] elif bits_to_store >= 64: self.EmitVStore(1, 32, sources[0], self.DereferenceIncrement(destination, alignment)) bits_to_store -= 64 sources = sources[1:] else: source = sources[0] if bits_to_store == 56: self.EmitVStore(1, 32, self.Lane(32, source, 0), self.DereferenceIncrement(destination)) self.EmitVStore(1, 16, self.Lane(16, source, 2), self.DereferenceIncrement(destination)) self.EmitVStore(1, 8, self.Lane(8, source, 6), self.DereferenceIncrement(destination)) elif bits_to_store == 48: self.EmitVStore(1, 32, self.Lane(32, source, 0), self.DereferenceIncrement(destination)) self.EmitVStore(1, 16, self.Lane(16, source, 2), self.DereferenceIncrement(destination)) elif bits_to_store == 40: self.EmitVStore(1, 32, self.Lane(32, source, 0), self.DereferenceIncrement(destination)) self.EmitVStore(1, 8, self.Lane(8, source, 4), self.DereferenceIncrement(destination)) elif bits_to_store == 32: self.EmitVStore(1, 32, self.Lane(32, source, 0), self.DereferenceIncrement(destination)) elif bits_to_store == 24: self.EmitVStore(1, 16, self.Lane(16, source, 0), self.DereferenceIncrement(destination)) self.EmitVStore(1, 8, self.Lane(8, source, 2), self.DereferenceIncrement(destination)) elif bits_to_store == 16: self.EmitVStore(1, 16, self.Lane(16, source, 0), self.DereferenceIncrement(destination)) elif bits_to_store == 8: self.EmitVStore(1, 8, self.Lane(8, source, 0), self.DereferenceIncrement(destination)) else: raise ArgumentError('Wrong leftover: %d' % bits_to_store) return def EmitVStoreE(self, store_type, count, source, destination, alignment=None): self.EmitVStoreAE(store_type, count, [source], destination, alignment) def EmitVStoreOffset(self, store_no, store_type, source, destination, offset): self.EmitVStoreOffsetA(store_no, store_type, [source], destination, offset) def EmitVStoreOffsetA(self, store_no, store_type, sources, destination, offset): self.EmitOp3('vst%d.%d' % (store_no, store_type), '{%s}' % ', '.join(_ExpandQuads(sources)), destination, offset) def EmitVStoreOffsetE(self, store_type, count, source, destination, offset): """Emit assembly to store a number elements from the source registers.""" if store_type is not 32: raise ArgumentError('Unsupported store_type: %d' % store_type) sources = [] if source[0] == 'q': sources.append(_Low(source)) sources.append(_High(source)) if count * store_type > 128: raise ArgumentError('To many %dbit elements in a q register: %d' % (store_type, count)) else: sources.append(source) if count * store_type > 64: raise ArgumentError('To many %dbit elements in a d register: %d' % (store_type, count)) if count == 1: self.EmitVStoreOffset(1, store_type, self.Lane(store_type, sources[0], 0), self.Dereference(destination, None), offset) elif count == 2: self.EmitVStoreOffset(1, store_type, sources[0], self.Dereference(destination, None), offset) elif count == 3: self.EmitVStore(1, store_type, sources[0], self.DereferenceIncrement(destination, None)) self.EmitVStoreOffset(1, store_type, self.Lane(store_type, sources[1], 0), self.Dereference(destination, None), offset) self.EmitSub(destination, destination, self.ImmediateConstant(8)) elif count == 4: self.EmitVStoreOffsetA(1, store_type, sources, self.Dereference(destination, None), offset) else: raise ArgumentError('To many elements: %d' % count) def EmitVSumReduce(self, reduce_type, elem_count, reduce_count, destinations, sources): """Emit assembly for n-fold horizontal sum reduction.""" if reduce_type is not 'u32': raise ArgumentError('Unsupported reduce: %s' % reduce_type) sources = _ExpandQuads(sources) destinations = _ExpandQuads(destinations) if len(destinations) * 2 < elem_count: raise ArgumentError('Not enough space in destination: %d vs %d' % (len(destinations) * 2, elem_count)) if len(sources) * 2 != elem_count * reduce_count: raise ArgumentError('Wrong number of sources: %d vs %d' % (len(sources) * 2, elem_count * reduce_count)) if reduce_count <= 1: raise ArgumentError('Unsupported reduce_count: %d' % reduce_count) while reduce_count > 1: if len(sources) % 2 == 1: sources.append(sources[-1]) if reduce_count == 2: for i in range(len(sources) / 2): self.EmitVPadd(reduce_type, destinations[i], sources[2 * i], sources[2 * i + 1]) return else: sources_2 = [] for i in range(len(sources) / 2): self.EmitVPadd(reduce_type, sources[2 * i], sources[2 * i], sources[2 * i + 1]) sources_2.append(sources[2 * i]) reduce_count /= 2 sources = sources_2 def EmitVUzp(self, uzp_type, operand_1, operand_2): self.EmitOp2('vuzp.%d' % uzp_type, operand_1, operand_2) def EmitVTrn(self, trn_type, operand_1, operand_2): self.EmitOp2('vtrn.%d' % trn_type, operand_1, operand_2) def EmitColBlockStride(self, cols, stride, new_stride): assert cols in [1, 2, 3, 4, 5, 6, 7, 8] if cols in [5, 6, 7]: self.EmitSub(new_stride, stride, self.ImmediateConstant(4)) def EmitLoadColBlock(self, unused_registers, load_type, cols, elements, block, input_address, stride): """Load a block of column major data.""" assert cols is len(block) assert load_type is 8 input_deref = self.Dereference(input_address, None) input_deref_increment = self.DereferenceIncrement(input_address, None) if cols is 1: for i in range(elements): self.EmitVLoadOffset(1, 8, self.Lane(8, block[0], i), input_deref, stride) self.EmitPld(input_address) elif cols is 2: for i in range(elements): self.EmitVLoadOffset(1, 16, self.Lane(16, block[i / 4], i % 4), input_deref, stride) self.EmitPld(input_address) self.EmitVUzp(8, block[0], block[1]) elif cols is 3: for i in range(elements): self.EmitVLoadOffsetA(3, 8, [self.Lane(8, row, i) for row in block], input_deref, stride) elif cols is 4: for i in range(elements): self.EmitVLoadOffset(1, 32, self.Lane(32, block[i % 4], i / 4), input_deref, stride) self.EmitPld(input_address) self.EmitVTrn(16, block[0], block[2]) self.EmitVTrn(16, block[1], block[3]) self.EmitVTrn(8, block[0], block[1]) self.EmitVTrn(8, block[2], block[3]) elif cols is 5: for i in range(elements): self.EmitVLoad(1, 32, self.Lane(32, block[i % 4], i / 4), input_deref_increment) self.EmitVLoadOffset(1, 8, self.Lane(8, block[4], i), input_deref, stride) self.EmitPld(input_address) self.EmitVTrn(16, block[0], block[2]) self.EmitVTrn(16, block[1], block[3]) self.EmitVTrn(8, block[0], block[1]) self.EmitVTrn(8, block[2], block[3]) elif cols is 6: for i in range(elements): self.EmitVLoad(1, 32, self.Lane(32, block[i % 4], i / 4), input_deref_increment) self.EmitVLoadOffset(1, 16, self.Lane(16, block[4 + i / 4], i % 4), input_deref, stride) self.EmitPld(input_address) self.EmitVTrn(16, block[0], block[2]) self.EmitVTrn(16, block[1], block[3]) self.EmitVUzp(8, block[4], block[5]) self.EmitVTrn(8, block[0], block[1]) self.EmitVTrn(8, block[2], block[3]) elif cols is 7: for i in range(elements): self.EmitVLoad(1, 32, self.Lane(32, block[i % 4], i / 4), input_deref_increment) self.EmitVLoadOffsetA(3, 8, [self.Lane(8, row, i) for row in block[4:]], input_deref, stride) self.EmitPld(input_address) self.EmitVTrn(16, block[0], block[2]) self.EmitVTrn(16, block[1], block[3]) self.EmitVTrn(8, block[0], block[1]) self.EmitVTrn(8, block[2], block[3]) elif cols is 8: for i in range(elements): self.EmitVLoadOffset(1, 32, block[i], input_deref, stride) self.EmitPld(input_address) self.EmitVTrn(8, block[0], block[1]) self.EmitVTrn(8, block[2], block[3]) self.EmitVTrn(8, block[4], block[5]) self.EmitVTrn(8, block[6], block[7]) self.EmitVTrn(16, block[0], block[2]) self.EmitVTrn(16, block[1], block[3]) self.EmitVTrn(16, block[4], block[6]) self.EmitVTrn(16, block[5], block[7]) self.EmitVTrn(32, block[0], block[4]) self.EmitVTrn(32, block[1], block[5]) self.EmitVTrn(32, block[2], block[6]) self.EmitVTrn(32, block[3], block[7]) else: assert False return block def Dereference(self, value, alignment=None): if alignment: return '[%s:%d]' % (value, alignment) else: return '[%s]' % value def DereferenceIncrement(self, value, alignment=None): return '%s!' % self.Dereference(value, alignment) def ImmediateConstant(self, value): return '#%d' % value def AllLanes(self, value): return '%s[]' % value def Lane(self, bits, value, lane): """Get the proper n-bit lane from the given register.""" registers = [] if value[0] == 'q': registers.append(_Low(value)) registers.append(_High(value)) else: registers.append(value) elems_per_register = 64 / bits register = lane / elems_per_register lane %= elems_per_register return '%s[%d]' % (registers[register], lane) def CreateRegisters(self): return _NeonRegisters32Bit()