1"""Qnt primitive used by the GEMM function. 2 3""" 4 5import neon_emitter 6 7 8class Error(Exception): 9 """Module level error.""" 10 11 12class ConfigurationError(Error): 13 """Unsupported configuration.""" 14 15 16class QntLane(object): 17 18 def __init__(self, source, output, offset, load_1, load_2): 19 self.source = source 20 self.output = output 21 self.offset = offset 22 self.load_1 = load_1 23 self.load_2 = load_2 24 25 26def BuildName(lanes, leftovers, aligned): 27 name = 'qnt_%dx8' % lanes 28 if leftovers: 29 name += '_%d' % leftovers 30 if aligned: 31 name += '_aligned' 32 return name 33 34 35def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets): 36 if lanes == 1 or lanes == 2 or lanes == 3: 37 offset_registers = [] 38 for unused_i in range(0, lanes): 39 register = registers.QuadRegister() 40 emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)), 41 emitter.AllLanes(registers.High(register))], 42 emitter.DereferenceIncrement(offsets, 32)) 43 offset_registers.append(register) 44 return offset_registers 45 else: 46 raise ConfigurationError('Unsupported number of lanes: %d' % lanes) 47 48 49def GenerateQntLanes(emitter, registers, qnt_lanes, source, stride, destination, 50 destination_stride, offsets): 51 """Prepare lanes for reading unquantized multiplication results.""" 52 offset_registers = LoadAndDuplicateOffsets(emitter, registers, qnt_lanes, 53 offsets) 54 55 lanes = [] 56 last_input_register = source 57 last_output_register = destination 58 for i in range(0, qnt_lanes): 59 if not i: 60 lanes.append(QntLane(source, 61 destination, 62 offset_registers[i], 63 registers.QuadRegister(), # load 1 64 registers.QuadRegister())) # load 2 65 else: 66 input_register = registers.GeneralRegister() 67 output_register = registers.GeneralRegister() 68 lanes.append(QntLane(input_register, 69 output_register, 70 offset_registers[i], 71 registers.QuadRegister(), # load 1 72 registers.QuadRegister())) # load 2 73 emitter.EmitAdd(input_register, last_input_register, stride) 74 emitter.EmitAdd(output_register, last_output_register, destination_stride) 75 last_input_register = input_register 76 last_output_register = output_register 77 return lanes 78 79 80def DuplicateRegister(emitter, registers, value): 81 register = registers.QuadRegister() 82 emitter.EmitVDup('32', register, value) 83 return register 84 85 86def GenerateQuantize(emitter, registers, lanes, lane_temps, 87 multiplicative_offset, rounding_offset, shift): 88 """Inner loop for quantization: add offsets, multiply, round, shift.""" 89 for lane in lanes: 90 emitter.EmitVAdd('i32', lane[0], lane[0], lane[1]) 91 92 for lane in lanes: 93 emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset) 94 95 for lane in lanes: 96 emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset) 97 98 for lane in lanes: 99 emitter.EmitVShl('s32', lane[0], lane[0], shift) 100 101 for lane in lanes: 102 emitter.EmitVQmovn('s32', lane[2], lane[0]) 103 104 for lane_temp in lane_temps: 105 emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp) 106 107 108def GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset, 109 rounding_offset, shift, alignment): 110 """Load unquantized data from lanes, quantize, store final result.""" 111 lane_temps = [] 112 for lane in lanes: 113 lane_temps.append(registers.QuadRegister()) 114 115 for lane in lanes: 116 emitter.EmitVLoadA( 117 '1.32', [registers.Low(lane.load_1), registers.High(lane.load_1), 118 registers.Low(lane.load_2), registers.High(lane.load_2)], 119 emitter.DereferenceIncrement(lane.source, 64)) 120 121 for lane in lanes: 122 emitter.EmitPld(lane.source) 123 124 quantize_setup = [] 125 for (lane_temp, lane) in zip(lane_temps, lanes): 126 quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) 127 quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)]) 128 129 GenerateQuantize(emitter, registers, quantize_setup, lane_temps, 130 multiplicative_offset, rounding_offset, shift) 131 132 for (lane_temp, lane) in zip(lane_temps, lanes): 133 emitter.EmitVStore('1.8', registers.Low(lane_temp), 134 emitter.DereferenceIncrement(lane.output, alignment)) 135 136 for lane_temp in lane_temps: 137 registers.FreeRegister(lane_temp) 138 139 140def GenerateLoadLeftovers(emitter, registers, leftovers, lanes): 141 """Handle non multiply of 8 leftover loading.""" 142 if leftovers == 1: 143 for lane in lanes: 144 emitter.EmitVLoad('1.32', emitter.Lane( 145 registers.Low(lane.load_1), 0), 146 emitter.Dereference(lane.source, None)) 147 elif leftovers == 2: 148 for lane in lanes: 149 emitter.EmitVLoad('1.32', registers.Low(lane.load_1), 150 emitter.Dereference(lane.source, 64)) 151 elif leftovers == 3: 152 for lane in lanes: 153 emitter.EmitVLoad('1.32', registers.Low(lane.load_1), 154 emitter.DereferenceIncrement(lane.source, 64)) 155 for lane in lanes: 156 emitter.EmitVLoad('1.32', emitter.Lane( 157 registers.High(lane.load_1), 0), 158 emitter.Dereference(lane.source, None)) 159 elif leftovers == 4: 160 for lane in lanes: 161 emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), 162 registers.High(lane.load_1)], 163 emitter.Dereference(lane.source, 64)) 164 elif leftovers == 5: 165 for lane in lanes: 166 emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), 167 registers.High(lane.load_1)], 168 emitter.DereferenceIncrement(lane.source, 64)) 169 for lane in lanes: 170 emitter.EmitVLoad('1.32', emitter.Lane( 171 registers.Low(lane.load_2), 0), 172 emitter.Dereference(lane.source, None)) 173 elif leftovers == 6: 174 for lane in lanes: 175 emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), 176 registers.High(lane.load_1), 177 registers.Low(lane.load_2)], 178 emitter.Dereference(lane.source, 64)) 179 elif leftovers == 7: 180 for lane in lanes: 181 emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), 182 registers.High(lane.load_1), 183 registers.Low(lane.load_2)], 184 emitter.DereferenceIncrement(lane.source, 64)) 185 for lane in lanes: 186 emitter.EmitVLoad('1.32', emitter.Lane( 187 registers.High(lane.load_2), 0), 188 emitter.Dereference(lane.source, None)) 189 else: 190 raise ConfigurationError('Unsuported leftover count: %d' % leftovers) 191 192 193def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes): 194 """Handle non multiply of 8 leftover storing.""" 195 setup = [] 196 for (temp, lane) in zip(lane_temps, lanes): 197 setup.append([registers.Low(temp), lane.output]) 198 199 if leftovers == 1: 200 for lane in setup: 201 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0), 202 emitter.Dereference(lane[1], None)) 203 elif leftovers == 2: 204 for lane in setup: 205 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), 206 emitter.Dereference(lane[1], None)) 207 elif leftovers == 3: 208 for lane in setup: 209 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), 210 emitter.DereferenceIncrement(lane[1], None)) 211 for lane in setup: 212 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2), 213 emitter.Dereference(lane[1], None)) 214 elif leftovers == 4: 215 for lane in setup: 216 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 217 emitter.Dereference(lane[1], None)) 218 elif leftovers == 5: 219 for lane in setup: 220 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 221 emitter.DereferenceIncrement(lane[1], None)) 222 for lane in setup: 223 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4), 224 emitter.Dereference(lane[1], None)) 225 elif leftovers == 6: 226 for lane in setup: 227 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 228 emitter.DereferenceIncrement(lane[1], None)) 229 for lane in setup: 230 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), 231 emitter.Dereference(lane[1], None)) 232 elif leftovers == 7: 233 for lane in setup: 234 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 235 emitter.DereferenceIncrement(lane[1], None)) 236 for lane in setup: 237 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), 238 emitter.DereferenceIncrement(lane[1], None)) 239 for lane in setup: 240 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6), 241 emitter.DereferenceIncrement(lane[1], None)) 242 else: 243 raise ConfigurationError('Unsupported leftovers count: %d' % leftovers) 244 245 246def GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes, 247 multiplicative_offset, rounding_offset, 248 shift): 249 """Handle leftovers if row size not a multiply of 8.""" 250 lane_temps = [] 251 for lane in lanes: 252 lane_temps.append(registers.QuadRegister()) 253 254 GenerateLoadLeftovers(emitter, registers, leftovers, lanes) 255 256 quantize_setup = [] 257 for (lane_temp, lane) in zip(lane_temps, lanes): 258 quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) 259 if leftovers > 4: 260 quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp) 261 ]) 262 263 GenerateQuantize(emitter, registers, quantize_setup, lane_temps, 264 multiplicative_offset, rounding_offset, shift) 265 266 GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes) 267 268 269def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned): 270 """Emits optimized quantization code for given lanes and row size.""" 271 if leftovers < 0 or leftovers > 7: 272 raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') 273 if qnt_lanes < 1 or qnt_lanes > 3: 274 raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.') 275 276 name = BuildName(qnt_lanes, leftovers, aligned) 277 278 emitter.EmitFunctionBeginA( 279 name, 280 [['const std::int32_t*', 'source'], ['std::int32_t', 'count'], 281 ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'], 282 ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'], 283 ['std::int32_t', 'multiplicative_offset'], 284 ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void') 285 emitter.EmitAssert('count %% 8 == %d' % leftovers) 286 emitter.EmitAssert('count >= 8') 287 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0') 288 if aligned: 289 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0') 290 if qnt_lanes > 1: 291 emitter.EmitAssert('destination_stride % 8 == 0') 292 emitter.EmitAsmBegin() 293 294 registers = neon_emitter.NeonRegisters() 295 296 count = registers.MapParameter('count') 297 298 multiplicative_offset = DuplicateRegister( 299 emitter, registers, registers.MapParameter('multiplicative_offset')) 300 rounding_offset = DuplicateRegister(emitter, registers, 301 registers.MapParameter('rounding_offset')) 302 shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift')) 303 304 lanes = GenerateQntLanes( 305 emitter, registers, qnt_lanes, registers.MapParameter('source'), 306 registers.MapParameter('stride'), registers.MapParameter('destination'), 307 registers.MapParameter('destination_stride'), 308 registers.MapParameter('offsets')) 309 310 if leftovers: 311 emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers)) 312 emitter.EmitBeqFront(2) 313 314 emitter.EmitNewline() 315 emitter.EmitNumericalLabel(1) 316 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 317 318 GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset, 319 rounding_offset, shift, 64 if aligned else None) 320 321 emitter.EmitNewline() 322 emitter.EmitBneBack(1) 323 324 if leftovers: 325 emitter.EmitNumericalLabel(2) 326 GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes, 327 multiplicative_offset, rounding_offset, 328 shift) 329 330 emitter.EmitAsmEnd(registers.MappedParameters(), [], 331 registers.Clobbers() + ['cc', 'memory']) 332 emitter.EmitFunctionEnd() 333 334 335def BuildMultiQuantizeName(aligned, rows): 336 name = 'multi_qnt_%dx8' % rows 337 if aligned: 338 name = '%s_aligned' % name 339 return name 340 341 342def GenerateMultiQuantize(emitter, aligned, rows): 343 """Emit main quantization code that switches between optimized versions.""" 344 name = BuildMultiQuantizeName(aligned, rows) 345 emitter.EmitFunctionBeginA( 346 name, 347 [['const std::int32_t*', 'source'], ['std::int32_t', 'count'], 348 ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'], 349 ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'], 350 ['std::int32_t', 'multiplicative_offset'], 351 ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void') 352 emitter.EmitSwitch('count % 8') 353 354 for leftovers in range(0, 8): 355 emitter.EmitCase(leftovers) 356 emitter.PushIndent() 357 emitter.EmitCall( 358 BuildName(rows, leftovers, aligned), 359 ['source', 'count', 'stride', 'offsets', 'destination', 360 'destination_stride', 'multiplicative_offset', 'rounding_offset', 361 'shift']) 362 emitter.EmitBreak() 363 emitter.PopIndent() 364 365 emitter.EmitSwitchEnd() 366 emitter.EmitFunctionEnd() 367 368 369def GenerateFunctions(neon, cc): 370 for aligned in [True, False]: 371 for lanes in range(1, 4): 372 for leftovers in range(0, 8): 373 GenerateQntNx8(neon, lanes, leftovers, aligned) 374 neon.EmitNewline() 375 376 for aligned in [True, False]: 377 for rows in range(1, 4): 378 GenerateMultiQuantize(cc, aligned, rows) 379 cc.EmitNewline() 380