1"""Multiply primitive optimized for the gemv operation.""" 2 3import neon_emitter 4 5 6class Error(Exception): 7 """Module level error.""" 8 9 10class ConfigurationError(Error): 11 """Unsupported configuration.""" 12 13 14def GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators, 15 count, lhs, rhs_1, rhs_2): 16 """Emit inner loop for 1 row x M cols multiplication.""" 17 emitter.EmitComment('General 1xM lanes loop.') 18 emitter.EmitNumericalLabel(1) 19 emitter.EmitNewline() 20 emitter.EmitComment('Subtract counter.') 21 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 22 emitter.EmitNewline() 23 24 right_load = [registers.DoubleRegister() for unused_i in range(4)] 25 left_load = registers.DoubleRegister() 26 27 emitter.EmitVLoad('1.8', left_load, emitter.DereferenceIncrement(lhs, 64)) 28 emitter.EmitVLoadA('1.8', right_load, emitter.DereferenceIncrement(rhs_1, 64)) 29 30 emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) 31 emitter.EmitPldOffset(rhs_1, emitter.ImmediateConstant(128)) 32 33 multiply_results = [registers.QuadRegister() for unused_i in range(4)] 34 35 for i in range(4): 36 emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load) 37 38 emitter.EmitVLoadA('1.8', right_load[:lanes_count], 39 emitter.DereferenceIncrement(rhs_2, 64)) 40 emitter.EmitPldOffset(rhs_2, emitter.ImmediateConstant(lanes_count * 32)) 41 42 for i in range(4): 43 emitter.EmitVPadal('u16', aggregators[i], multiply_results[i]) 44 45 for i in range(lanes_count): 46 emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load) 47 48 for i in range(lanes_count): 49 emitter.EmitVPadal('u16', aggregators[i + 4], multiply_results[i]) 50 51 emitter.EmitNewline() 52 emitter.EmitComment('Loop break.') 53 emitter.EmitBneBack(1) 54 emitter.EmitNewline() 55 56 registers.FreeRegister(left_load) 57 registers.FreeRegisters(right_load) 58 registers.FreeRegisters(multiply_results) 59 60 61def ReadLeft(emitter, registers, lhs): 62 register = registers.QuadRegister() 63 emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)), 64 emitter.AllLanes(registers.High(register))], 65 emitter.Dereference(lhs, None)) 66 return register 67 68 69def ReadRight(emitter, registers, rhs, count): 70 if count == 1 or count == 2: 71 register = registers.DoubleRegister() 72 elif count == 3 or count == 4: 73 register = registers.QuadRegister() 74 else: 75 raise ConfigurationError('Unsupported elements no: %d' % count) 76 emitter.EmitVLoad('1.32', register, emitter.Dereference(rhs, 64)) 77 return register 78 79 80def DuplicateGeneralRegister(emitter, registers, general_register, 81 min_register): 82 duplicated = registers.QuadRegister(min_register) 83 emitter.EmitVDup('32', duplicated, general_register) 84 return duplicated 85 86 87def GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators, 88 result_type, lhs_add, rhs_add, lhs, rhs_1, 89 rhs_2, results): 90 """Generates assembly responsible for reducing the 4 way aggregators.""" 91 if lhs_add: 92 left_offset = ReadLeft(emitter, registers, lhs) 93 else: 94 left_offset = None 95 96 if rhs_add: 97 right_offset_1 = ReadRight(emitter, registers, rhs_1, 4) 98 right_offset_2 = ReadRight(emitter, registers, rhs_2, lanes_count) 99 else: 100 right_offset_1 = None 101 right_offset_2 = None 102 103 if result_type is 'float': 104 result_scale = DuplicateGeneralRegister( 105 emitter, registers, registers.MapParameter('result_scale'), 4) 106 else: 107 result_scale = None 108 109 emitter.EmitNewline() 110 emitter.EmitComment('Horizontal reduce aggregators.') 111 for aggregator in aggregators: 112 emitter.EmitVPadd('u32', registers.Low(aggregator), 113 registers.Low(aggregator), registers.High(aggregator)) 114 115 temp = aggregators[0] 116 emitter.EmitVPadd('u32', registers.Low(temp), registers.Low(aggregators[0]), 117 registers.Low(aggregators[1])) 118 emitter.EmitVPadd('u32', registers.High(temp), registers.Low(aggregators[2]), 119 registers.Low(aggregators[3])) 120 121 if lanes_count == 1: 122 temp_2 = registers.Low(aggregators[1]) 123 emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]), 124 registers.Low(aggregators[4])) 125 elif lanes_count == 2: 126 temp_2 = registers.Low(aggregators[1]) 127 emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]), 128 registers.Low(aggregators[5])) 129 elif lanes_count == 3: 130 temp_2 = aggregators[1] 131 emitter.EmitVPadd('u32', registers.Low(temp_2), 132 registers.Low(aggregators[4]), 133 registers.Low(aggregators[5])) 134 emitter.EmitVPadd('u32', registers.High(temp_2), 135 registers.Low(aggregators[6]), 136 registers.Low(aggregators[6])) 137 elif lanes_count == 4: 138 temp_2 = aggregators[1] 139 emitter.EmitVPadd('u32', registers.Low(temp_2), 140 registers.Low(aggregators[4]), 141 registers.Low(aggregators[5])) 142 emitter.EmitVPadd('u32', registers.High(temp_2), 143 registers.Low(aggregators[6]), 144 registers.Low(aggregators[7])) 145 else: 146 temp_2 = None 147 148 if lhs_add: 149 emitter.EmitNewline() 150 emitter.EmitComment('Add lhs offsets to aggregated rows.') 151 emitter.EmitVAdd('s32', temp, temp, left_offset) 152 if lanes_count == 1 or lanes_count == 2: 153 emitter.EmitVAdd('s32', temp_2, temp_2, registers.Low(left_offset)) 154 elif lanes_count == 3 or lanes_count == 4: 155 emitter.EmitVAdd('s32', temp_2, temp_2, left_offset) 156 157 if rhs_add: 158 emitter.EmitNewline() 159 emitter.EmitComment('Add rhs offset to aggregated rows.') 160 emitter.EmitVAdd('s32', temp, temp, right_offset_1) 161 emitter.EmitVAdd('s32', temp_2, temp_2, right_offset_2) 162 163 if result_type is 'float': 164 emitter.EmitNewline() 165 emitter.EmitComment('Convert to float and scale.') 166 emitter.EmitVCvt('f32', 's32', temp, temp) 167 emitter.EmitVCvt('f32', 's32', temp_2, temp_2) 168 emitter.EmitVMul('f32', temp, temp, result_scale) 169 if lanes_count == 1 or lanes_count == 2: 170 emitter.EmitVMul('f32', temp_2, temp_2, registers.Low(result_scale)) 171 elif lanes_count == 3 or lanes_count == 4: 172 emitter.EmitVMul('f32', temp_2, temp_2, result_scale) 173 174 emitter.EmitNewline() 175 emitter.EmitComment('Store results.') 176 if lanes_count == 1: 177 emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp)], 178 emitter.DereferenceIncrement(results, None)) 179 emitter.EmitVStore('1.32', emitter.Lane(temp_2, 0), 180 emitter.Dereference(results, None)) 181 elif lanes_count == 2: 182 emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp), 183 temp_2], emitter.Dereference(results, None)) 184 elif lanes_count == 3: 185 emitter.EmitVStoreA( 186 '1.32', 187 [registers.Low(temp), registers.High(temp), registers.Low(temp_2)], 188 emitter.DereferenceIncrement(results, None)) 189 emitter.EmitVStore('1.32', emitter.Lane( 190 registers.High(temp_2), 0), emitter.Dereference(results, None)) 191 elif lanes_count == 4: 192 emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp), 193 registers.Low(temp_2), registers.High(temp_2)], 194 emitter.Dereference(results, None)) 195 196 197def BuildName(result_type, lhs_add, rhs_add, lanes): 198 name = 'mul_1x8_%dx8_%s' % (lanes, result_type) 199 if lhs_add: 200 name += '_lhsadd' 201 if rhs_add: 202 name += '_rhsadd' 203 return name 204 205 206def CppResultType(result_type): 207 if result_type is 'int32': 208 return 'std::int32_t*' 209 elif result_type is 'float': 210 return 'float*' 211 else: 212 raise ConfigurationError('Unsupported result type: %s' % result_type) 213 214 215def GetParameters(result_type): 216 params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs_1'], 217 ['const std::uint8_t*', 'rhs_2'], ['std::int32_t', 'count'], 218 [CppResultType(result_type), 'result']] 219 if result_type is 'float': 220 params.append(['float', 'result_scale']) 221 return params 222 223 224def GenerateAndClearAggregators(emitter, registers, aggregator_count): 225 """Prepare aggregators and emit aggregator clear code.""" 226 emitter.EmitNewline() 227 emitter.EmitComment('Clear aggregators.') 228 aggregators = [] 229 for i in range(aggregator_count): 230 aggregator = registers.QuadRegister() 231 aggregators.append(aggregator) 232 if i < 3: 233 emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0)) 234 else: 235 emitter.EmitVMov('i32', aggregator, aggregators[i - 3]) 236 emitter.EmitNewline() 237 return aggregators 238 239 240def GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes_count): 241 """Generates the 1xN multiplication primitive.""" 242 if lanes_count < 1 or lanes_count > 4: 243 raise ConfigurationError('Lanes should be: 1, 2, 3 or 4.') 244 245 emitter.EmitFunctionBeginA( 246 BuildName(result_type, lhs_add, rhs_add, lanes_count + 4), 247 GetParameters(result_type), 'inline void') 248 249 emitter.EmitAssert('count % 8 == 0') 250 emitter.EmitAssert('count >= 8') 251 emitter.EmitAsmBegin() 252 253 registers = neon_emitter.NeonRegisters() 254 255 count = registers.MapParameter('count') 256 257 lhs = registers.MapParameter('lhs') 258 rhs_1 = registers.MapParameter('rhs_1') 259 rhs_2 = registers.MapParameter('rhs_2') 260 261 emitter.EmitPld(lhs) 262 emitter.EmitPld(rhs_1) 263 emitter.EmitPld(rhs_2) 264 265 aggregators = GenerateAndClearAggregators(emitter, registers, lanes_count + 4) 266 267 GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators, 268 count, lhs, rhs_1, rhs_2) 269 GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators, 270 result_type, lhs_add, rhs_add, lhs, rhs_1, 271 rhs_2, registers.MapParameter('result')) 272 273 emitter.EmitAsmEnd(registers.MappedParameters(), [], 274 registers.Clobbers() + ['cc', 'memory']) 275 emitter.EmitFunctionEnd() 276 277 278def GenerateFunctions(emitter, result_type, lhs_add, rhs_add): 279 for lanes in range(1, 5): 280 GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes) 281 emitter.EmitNewline() 282 283 284if __name__ == '__main__': 285 GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True) 286