1# Portions Copyright (c) Meta Platforms, Inc. and affiliates. 2import csv 3import os 4import subprocess 5import sys 6 7max_log_n = 30 8 9 10def is_distinct(l): 11 return len(set(l)) == len(l) 12 13 14def float_avx_0(register, aux_registers, ident=""): 15 if not is_distinct(aux_registers): 16 raise Exception("auxiliary registers must be distinct") 17 if register in aux_registers: 18 raise Exception("the main register can't be one of the auxiliary ones") 19 if len(aux_registers) < 4: 20 raise Exception("float_avx_0 needs at least four auxiliary registers") 21 # given source ABCDEFGH, destination register gets AACCEEGG 22 res = ident + '"vpermilps $160, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0]) 23 # given source ABCDEFGH, destination register gets BBDDFFHH 24 res += ident + '"vpermilps $245, %%%%%s, %%%%%s\\n"\n' % ( 25 register, 26 aux_registers[1], 27 ) 28 # aux2 <- 0 29 res += ident + '"vxorps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 30 aux_registers[2], 31 aux_registers[2], 32 aux_registers[2], 33 ) 34 # aux3 <- -B -B -D -D -F -F -H -H 35 res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 36 aux_registers[1], 37 aux_registers[2], 38 aux_registers[3], 39 ) 40 # reg <- (A+B)(A-B)(C+D)(C-D)(E+F)(E-F)(G+H)(G-H) 41 res += ident + '"vaddsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 42 aux_registers[3], 43 aux_registers[0], 44 register, 45 ) 46 return res 47 48 49def float_avx_1(register, aux_registers, ident=""): 50 if not is_distinct(aux_registers): 51 raise Exception("auxiliary registers must be distinct") 52 if register in aux_registers: 53 raise Exception("the main register can't be one of the auxiliary ones") 54 if len(aux_registers) < 5: 55 raise Exception("float_avx_1 needs at least five auxiliary registers") 56 # Given source ABCDEFGH, r0 <- ABABEFEF 57 res = ident + '"vpermilps $68, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0]) 58 # Given source ABCDEFGH, r1 <- CDCDGHGH 59 res += ident + '"vpermilps $238, %%%%%s, %%%%%s\\n"\n' % ( 60 register, 61 aux_registers[1], 62 ) 63 # r2 <- 0 64 res += ident + '"vxorps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 65 aux_registers[2], 66 aux_registers[2], 67 aux_registers[2], 68 ) 69 # r3 <- -C -D -C -D -G -H -G -H 70 res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 71 aux_registers[1], 72 aux_registers[2], 73 aux_registers[3], 74 ) 75 # r4 <- C D -C -D G H -G -H 76 res += ident + '"vblendps $204, %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 77 aux_registers[3], 78 aux_registers[1], 79 aux_registers[4], 80 ) 81 # reg <- (A + C) (B + D) (A - C) (B - D) etc. 82 res += ident + '"vaddps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 83 aux_registers[0], 84 aux_registers[4], 85 register, 86 ) 87 return res 88 89 90def float_avx_2(register, aux_registers, ident=""): 91 if not is_distinct(aux_registers): 92 raise Exception("auxiliary registers must be distinct") 93 if register in aux_registers: 94 raise Exception("the main register can't be one of the auxiliary ones") 95 if len(aux_registers) < 4: 96 raise Exception("float_avx_2 needs at least four auxiliary registers") 97 # r0 <- 0 98 res = ident + '"vxorps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 99 aux_registers[0], 100 aux_registers[0], 101 aux_registers[0], 102 ) 103 # r1 <- -A -B -C -D -E -F -G -H 104 res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 105 register, 106 aux_registers[0], 107 aux_registers[1], 108 ) 109 # r2 <- ABABEFEF 110 res += ident + '"vperm2f128 $0, %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 111 register, 112 register, 113 aux_registers[2], 114 ) 115 # r3 <- C D -C -D G H -G -H 116 res += ident + '"vperm2f128 $49, %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 117 aux_registers[1], 118 register, 119 aux_registers[3], 120 ) 121 # reg <- (A + C) (B + D)(A - C) (B - D) etc. 122 res += ident + '"vaddps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 123 aux_registers[2], 124 aux_registers[3], 125 register, 126 ) 127 return res 128 129 130def float_avx_3_etc( 131 from_register_0, from_register_1, to_register_0, to_register_1, ident="" 132): 133 if not is_distinct( 134 [from_register_0, from_register_1, to_register_0, to_register_1] 135 ): 136 raise Exception("four registers must be distinct") 137 res = ident + '"vaddps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 138 from_register_1, 139 from_register_0, 140 to_register_0, 141 ) 142 res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 143 from_register_1, 144 from_register_0, 145 to_register_1, 146 ) 147 return res 148 149 150def double_avx_0(register, aux_registers, ident=""): 151 if not is_distinct(aux_registers): 152 raise Exception("auxiliary registers must be distinct") 153 if register in aux_registers: 154 raise Exception("the main register can't be one of the auxiliary ones") 155 if len(aux_registers) < 4: 156 raise Exception("double_avx_0 needs at least four auxiliary registers") 157 # r0 <- AACC 158 res = ident + '"vpermilpd $0, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0]) 159 # r1 <- BBDD 160 res += ident + '"vpermilpd $15, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[1]) 161 # r2 <- 0 162 res += ident + '"vxorpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 163 aux_registers[2], 164 aux_registers[2], 165 aux_registers[2], 166 ) 167 # r3 <- -B -B -D -D 168 res += ident + '"vsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 169 aux_registers[1], 170 aux_registers[2], 171 aux_registers[3], 172 ) 173 # reg <- (A + B)(A - B)(C + D)(C - D) 174 res += ident + '"vaddsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 175 aux_registers[3], 176 aux_registers[0], 177 register, 178 ) 179 return res 180 181 182def double_avx_1(register, aux_registers, ident=""): 183 if not is_distinct(aux_registers): 184 raise Exception("auxiliary registers must be distinct") 185 if register in aux_registers: 186 raise Exception("the main register can't be one of the auxiliary ones") 187 if len(aux_registers) < 4: 188 raise Exception("double_avx_1 needs at least four auxiliary registers") 189 # r0 <- ABAB 190 res = ident + '"vperm2f128 $0, %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 191 register, 192 register, 193 aux_registers[0], 194 ) 195 # r1 <- 0 196 res += ident + '"vxorpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 197 aux_registers[1], 198 aux_registers[1], 199 aux_registers[1], 200 ) 201 # r2 <- -A -B -C -D 202 res += ident + '"vsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 203 register, 204 aux_registers[1], 205 aux_registers[2], 206 ) 207 # r3 <- C D -C -D 208 res += ident + '"vperm2f128 $49, %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 209 aux_registers[2], 210 register, 211 aux_registers[3], 212 ) 213 # reg <- (A + C)(B + D)(A - C)(B - D) 214 res += ident + '"vaddpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 215 aux_registers[3], 216 aux_registers[0], 217 register, 218 ) 219 return res 220 221 222def double_avx_2_etc( 223 from_register_0, from_register_1, to_register_0, to_register_1, ident="" 224): 225 if not is_distinct( 226 [from_register_0, from_register_1, to_register_0, to_register_1] 227 ): 228 raise Exception("four registers must be distinct") 229 res = ident + '"vaddpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 230 from_register_1, 231 from_register_0, 232 to_register_0, 233 ) 234 res += ident + '"vsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % ( 235 from_register_1, 236 from_register_0, 237 to_register_1, 238 ) 239 return res 240 241 242def float_sse_0(register, aux_registers, ident=""): 243 if not is_distinct(aux_registers): 244 raise Exception("auxiliary registers must be distinct") 245 if register in aux_registers: 246 raise Exception("the main register can't be one of the auxiliary ones") 247 if len(aux_registers) < 2: 248 raise Exception("float_sse_0 needs at least two auxiliary registers") 249 res = ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0]) 250 res += ident + '"shufps $160, %%%%%s, %%%%%s\\n"\n' % ( 251 aux_registers[0], 252 aux_registers[0], 253 ) 254 res += ident + '"shufps $245, %%%%%s, %%%%%s\\n"\n' % (register, register) 255 res += ident + '"xorps %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[1]) 256 res += ident + '"subps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[1]) 257 res += ident + '"addsubps %%%%%s, %%%%%s\\n"\n' % ( 258 aux_registers[1], 259 aux_registers[0], 260 ) 261 res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (aux_registers[0], register) 262 return res 263 264 265def float_sse_1(register, aux_registers, ident=""): 266 if not is_distinct(aux_registers): 267 raise Exception("auxiliary registers must be distinct") 268 if register in aux_registers: 269 raise Exception("the main register can't be one of the auxiliary ones") 270 if len(aux_registers) < 4: 271 raise Exception("float_sse_1 needs at least four auxiliary registers") 272 res = ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0]) 273 res += ident + '"shufps $68, %%%%%s, %%%%%s\\n"\n' % ( 274 aux_registers[0], 275 aux_registers[0], 276 ) 277 res += ident + '"xorps %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[1]) 278 res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[2]) 279 res += ident + '"shufps $14, %%%%%s, %%%%%s\\n"\n' % ( 280 aux_registers[1], 281 aux_registers[2], 282 ) 283 res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[3]) 284 res += ident + '"shufps $224, %%%%%s, %%%%%s\\n"\n' % ( 285 aux_registers[3], 286 aux_registers[1], 287 ) 288 res += ident + '"addps %%%%%s, %%%%%s\\n"\n' % (aux_registers[0], aux_registers[2]) 289 res += ident + '"subps %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[2]) 290 res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (aux_registers[2], register) 291 return res 292 293 294def float_sse_2_etc( 295 from_register_0, from_register_1, to_register_0, to_register_1, ident="" 296): 297 if not is_distinct( 298 [from_register_0, from_register_1, to_register_0, to_register_1] 299 ): 300 raise Exception("four registers must be distinct") 301 res = ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_0) 302 res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_1) 303 res += ident + '"addps %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_0) 304 res += ident + '"subps %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_1) 305 return res 306 307 308def double_sse_0(register, aux_registers, ident=""): 309 if not is_distinct(aux_registers): 310 raise Exception("auxiliary registers must be distinct") 311 if register in aux_registers: 312 raise Exception("the main register can't be one of the auxiliary ones") 313 if len(aux_registers) < 2: 314 raise Exception("double_sse_0 needs at least two auxiliary registers") 315 res = ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0]) 316 res += ident + '"haddpd %%%%%s, %%%%%s\\n"\n' % (aux_registers[0], aux_registers[0]) 317 res += ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[1]) 318 res += ident + '"hsubpd %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[1]) 319 res += ident + '"blendpd $1, %%%%%s, %%%%%s\\n"\n' % ( 320 aux_registers[0], 321 aux_registers[1], 322 ) 323 res += ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], register) 324 return res 325 326 327def double_sse_1_etc( 328 from_register_0, from_register_1, to_register_0, to_register_1, ident="" 329): 330 if not is_distinct( 331 [from_register_0, from_register_1, to_register_0, to_register_1] 332 ): 333 raise Exception("four registers must be distinct") 334 res = ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_0) 335 res += ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_1) 336 res += ident + '"addpd %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_0) 337 res += ident + '"subpd %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_1) 338 return res 339 340 341# Given reg = ABCD, return (A+B)(A-B)(C+D)(C-D) 342def float_neon_0(register, aux_registers, ident=""): 343 if not is_distinct(aux_registers): 344 raise Exception("auxiliary registers must be distinct") 345 if register in aux_registers: 346 raise Exception("the main register can't be one of the auxiliary ones") 347 if len(aux_registers) < 2: 348 raise Exception("float_neon_0 needs at least two auxiliary registers") 349 # r0 <- AACC 350 res = f'{ident}"TRN1 {aux_registers[0]}.4S, {register}.4S, {register}.4S\\n"\n' 351 # r1 <- -A -B -C -D 352 res += f'{ident}"FNEG {aux_registers[1]}.4S, {register}.4S\\n"\n' 353 # r2 <- B (-B) D -D 354 res += f'{ident}"TRN2 {aux_registers[1]}.4S, {register}.4S, {aux_registers[1]}.4S\\n"\n' 355 # reg <- (A+B)(A-B)(C+D)(C-D) 356 res += f'{ident}"FADD {register}.4S, {aux_registers[0]}.4S, {aux_registers[1]}.4S\\n"\n' 357 358 return res 359 360 361# Given reg = ABCD, return (A + C)(B + D)(A - C)(B - D) 362def float_neon_1(register, aux_registers, ident=""): 363 if not is_distinct(aux_registers): 364 raise Exception("auxiliary registers must be distinct") 365 if register in aux_registers: 366 raise Exception("the main register can't be one of the auxiliary ones") 367 if len(aux_registers) < 2: 368 raise Exception("float_neon_1 needs at least two auxiliary registers") 369 # r0 <- ABAB 370 res = f'{ident}"DUP {aux_registers[0]}.2D, {register}.D[0]\\n"\n' 371 # r1 <- -A -B -C -D 372 res += f'{ident}"FNEG {aux_registers[1]}.4S, {register}.4S\\n"\n' 373 # r1 <- C D -C -D 374 res += f'{ident}"INS {aux_registers[1]}.D[0], {register}.D[1]\\n"\n' 375 # reg <- (A + C)(B + D)(A - C)(B - D) 376 res += f'{ident}"FADD {register}.4S, {aux_registers[0]}.4S, {aux_registers[1]}.4S\\n"\n' 377 378 return res 379 380 381def float_neon_2_etc( 382 from_register_0, from_register_1, to_register_0, to_register_1, ident="" 383): 384 if not is_distinct( 385 [from_register_0, from_register_1, to_register_0, to_register_1] 386 ): 387 raise Exception("four registers must be distinct") 388 res = f'{ident}"FADD {to_register_0}.4S, {from_register_0}.4S, {from_register_1}.4S\\n"\n' 389 res += f'{ident}"FSUB {to_register_1}.4S, {from_register_0}.4S, {from_register_1}.4S\\n"\n' 390 return res 391 392 393def plain_step(type_name, buf_name, log_n, it, ident=""): 394 if log_n <= 0: 395 raise Exception("log_n must be positive") 396 if it < 0: 397 raise Exception("it must be non-negative") 398 if it >= log_n: 399 raise Exception("it must be smaller than log_n") 400 n = 1 << log_n 401 res = ident + "for (int j = 0; j < %d; j += %d) {\n" % (n, 1 << (it + 1)) 402 res += ident + " for (int k = 0; k < %d; ++k) {\n" % (1 << it) 403 res += ident + " %s u = %s[j + k];\n" % (type_name, buf_name) 404 res += ident + " %s v = %s[j + k + %d];\n" % (type_name, buf_name, 1 << it) 405 res += ident + " %s[j + k] = u + v;\n" % buf_name 406 res += ident + " %s[j + k + %d] = u - v;\n" % (buf_name, 1 << it) 407 res += ident + " }\n" 408 res += ident + "}\n" 409 return res 410 411 412MOVE_INSTRUCTION_USE_NEON = "NEON MOV" 413 414 415def composite_step( 416 buf_name, 417 log_n, 418 from_it, 419 to_it, 420 log_w, 421 registers, 422 move_instruction, 423 special_steps, 424 main_step, 425 ident="", 426): 427 # HACK: NEON needs different syntax for loads and stores. 428 use_neon_movs = move_instruction == MOVE_INSTRUCTION_USE_NEON 429 if log_n < log_w: 430 raise Exception("need at least %d elements" % (1 << log_w)) 431 num_registers = len(registers) 432 if num_registers % 2 == 1: 433 raise Exception("odd number of registers: %d" % num_registers) 434 num_nontrivial_levels = 0 435 if to_it > log_w: 436 first_nontrivial = max(from_it, log_w) 437 num_nontrivial_levels = to_it - first_nontrivial 438 if 1 << num_nontrivial_levels > num_registers / 2: 439 raise Exception("not enough registers") 440 n = 1 << log_n 441 input_registers = [] 442 output_registers = [] 443 for i in range(num_registers): 444 if i < num_registers / 2: 445 input_registers.append(registers[i]) 446 else: 447 output_registers.append(registers[i]) 448 clobber = ", ".join(['"%%%s"' % x for x in registers]) 449 if num_nontrivial_levels == 0: 450 res = ident + "for (int j = 0; j < %d; j += %d) {\n" % (n, 1 << log_w) 451 res += ident + " __asm__ volatile (\n" 452 if use_neon_movs: 453 res += f'{ident} "LD1 {{{input_registers[0]}.4S}}, [%0]\\n"\n' 454 else: 455 res += ident + ' "%s (%%0), %%%%%s\\n"\n' % ( 456 move_instruction, 457 input_registers[0], 458 ) 459 for it in range(from_it, to_it): 460 res += special_steps[it]( 461 input_registers[0], output_registers, ident + " " 462 ) 463 if use_neon_movs: 464 res += f'{ident} "ST1 {{{input_registers[0]}.4S}}, [%0]\\n"\n' 465 else: 466 res += ident + ' "%s %%%%%s, (%%0)\\n"\n' % ( 467 move_instruction, 468 input_registers[0], 469 ) 470 res += ident + ' :: "r"(%s + j) : %s, "memory"\n' % (buf_name, clobber) 471 res += ident + " );\n" 472 res += ident + "}\n" 473 return res 474 res = ident + "for (int j = 0; j < %d; j += %d) {\n" % (n, 1 << to_it) 475 res += ident + " for (int k = 0; k < %d; k += %d) {\n" % ( 476 1 << (to_it - num_nontrivial_levels), 477 1 << log_w, 478 ) 479 subcube = [] 480 for l in range(1 << num_nontrivial_levels): 481 subcube.append("j + k + " + str(l * (1 << (to_it - num_nontrivial_levels)))) 482 res += ident + " __asm__ volatile (\n" 483 for l in range(1 << num_nontrivial_levels): 484 if use_neon_movs: 485 res += f'{ident} "LD1 {{{input_registers[l]}.4S}}, [%{l}]\\n"\n' 486 else: 487 res += ident + ' "%s (%%%d), %%%%%s\\n"\n' % ( 488 move_instruction, 489 l, 490 input_registers[l], 491 ) 492 for it in range(from_it, log_w): 493 for ii in range(1 << num_nontrivial_levels): 494 res += special_steps[it]( 495 input_registers[ii], output_registers, ident + " " 496 ) 497 for it in range(num_nontrivial_levels): 498 for ii in range(0, 1 << num_nontrivial_levels, 1 << (it + 1)): 499 for jj in range(1 << it): 500 res += main_step( 501 input_registers[ii + jj], 502 input_registers[ii + jj + (1 << it)], 503 output_registers[ii + jj], 504 output_registers[ii + jj + (1 << it)], 505 ident + " ", 506 ) 507 tmp = input_registers 508 input_registers = output_registers 509 output_registers = tmp 510 for l in range(1 << num_nontrivial_levels): 511 if use_neon_movs: 512 res += f'{ident} "ST1 {{{input_registers[l]}.4S}}, [%{l}]\\n"\n' 513 else: 514 res += ident + ' "%s %%%%%s, (%%%d)\\n"\n' % ( 515 move_instruction, 516 input_registers[l], 517 l, 518 ) 519 res += ident + ' :: %s : %s, "memory"\n' % ( 520 ", ".join(['"r"(%s + %s)' % (buf_name, x) for x in subcube]), 521 clobber, 522 ) 523 res += ident + " );\n" 524 res += ident + " }\n" 525 res += ident + "}\n" 526 return res 527 528 529def float_avx_composite_step(buf_name, log_n, from_it, to_it, ident=""): 530 return composite_step( 531 buf_name, 532 log_n, 533 from_it, 534 to_it, 535 3, 536 ["ymm%d" % x for x in range(16)], 537 "vmovups", 538 [float_avx_0, float_avx_1, float_avx_2], 539 float_avx_3_etc, 540 ident, 541 ) 542 543 544def double_avx_composite_step(buf_name, log_n, from_it, to_it, ident=""): 545 return composite_step( 546 buf_name, 547 log_n, 548 from_it, 549 to_it, 550 2, 551 ["ymm%d" % x for x in range(16)], 552 "vmovupd", 553 [double_avx_0, double_avx_1], 554 double_avx_2_etc, 555 ident, 556 ) 557 558 559def float_sse_composite_step(buf_name, log_n, from_it, to_it, ident=""): 560 return composite_step( 561 buf_name, 562 log_n, 563 from_it, 564 to_it, 565 2, 566 ["xmm%d" % x for x in range(16)], 567 "movups", 568 [float_sse_0, float_sse_1], 569 float_sse_2_etc, 570 ident, 571 ) 572 573 574def double_sse_composite_step(buf_name, log_n, from_it, to_it, ident=""): 575 return composite_step( 576 buf_name, 577 log_n, 578 from_it, 579 to_it, 580 1, 581 ["xmm%d" % x for x in range(16)], 582 "movupd", 583 [double_sse_0], 584 double_sse_1_etc, 585 ident, 586 ) 587 588 589NEON_VECTOR_REGS = [f"v{x}" for x in range(0, 32)] 590 591 592def float_neon_composite_step(buf_name, log_n, from_it, to_it, ident=""): 593 return composite_step( 594 buf_name, 595 log_n, 596 from_it, 597 to_it, 598 2, 599 NEON_VECTOR_REGS, 600 MOVE_INSTRUCTION_USE_NEON, 601 [float_neon_0, float_neon_1], 602 float_neon_2_etc, 603 ident, 604 ) 605 606 607def plain_unmerged(type_name, log_n): 608 signature = "static inline void helper_%s_%d(%s *buf)" % ( 609 type_name, 610 log_n, 611 type_name, 612 ) 613 res = "%s;\n" % signature 614 res += "%s {\n" % signature 615 for i in range(log_n): 616 res += plain_step(type_name, "buf", log_n, i, " ") 617 res += "}\n" 618 return res 619 620 621def greedy_merged(type_name, log_n, composite_step): 622 try: 623 composite_step("buf", log_n, 0, 0) 624 except Exception: 625 raise Exception("log_n is too small: %d" % log_n) 626 signature = "static inline void helper_%s_%d(%s *buf)" % ( 627 type_name, 628 log_n, 629 type_name, 630 ) 631 res = "%s;\n" % signature 632 res += "%s {\n" % signature 633 cur_it = 0 634 while cur_it < log_n: 635 cur_to_it = log_n 636 while True: 637 try: 638 composite_step("buf", log_n, cur_it, cur_to_it) 639 break 640 except Exception as e: 641 print(f"warning: {e}") 642 cur_to_it -= 1 643 continue 644 res += composite_step("buf", log_n, cur_it, cur_to_it, " ") 645 cur_it = cur_to_it 646 res += "}\n" 647 return res 648 649 650def greedy_merged_recursive(type_name, log_n, threshold_step, composite_step): 651 if threshold_step > log_n: 652 raise Exception("threshold_step must be at most log_n") 653 try: 654 composite_step("buf", threshold_step, 0, 0) 655 except Exception: 656 raise Exception("threshold_step is too small: %d" % threshold_step) 657 signature = "void helper_%s_%d_recursive(%s *buf, int depth)" % ( 658 type_name, 659 log_n, 660 type_name, 661 ) 662 res = "%s;\n" % signature 663 res += "%s {\n" % signature 664 res += " if (depth == %d) {\n" % threshold_step 665 if threshold_step == log_n: 666 cur_it = 0 667 while cur_it < threshold_step: 668 cur_to_it = threshold_step 669 while True: 670 try: 671 composite_step("buf", threshold_step, cur_it, cur_to_it) 672 break 673 except Exception: 674 cur_to_it -= 1 675 continue 676 res += composite_step("buf", threshold_step, cur_it, cur_to_it, " ") 677 cur_it = cur_to_it 678 else: 679 res += " helper_%s_%d(buf);\n" % (type_name, threshold_step) 680 681 res += " return;\n" 682 res += " }\n" 683 cur_it = threshold_step 684 while cur_it < log_n: 685 cur_to_it = log_n 686 while True: 687 try: 688 composite_step("buf", cur_to_it, cur_it, cur_to_it) 689 break 690 except Exception: 691 cur_to_it -= 1 692 continue 693 res += " if (depth == %d) {\n" % cur_to_it 694 for i in range(1 << (cur_to_it - cur_it)): 695 res += " helper_%s_%d_recursive(buf + %d, %d);\n" % ( 696 type_name, 697 log_n, 698 i * (1 << cur_it), 699 cur_it, 700 ) 701 if cur_to_it < log_n: 702 res += " helper_%s_%d(buf);" % (type_name, cur_to_it) 703 else: 704 res += composite_step("buf", cur_to_it, cur_it, cur_to_it, " ") 705 res += " return;\n" 706 res += " }\n" 707 cur_it = cur_to_it 708 res += "}\n" 709 signature = "void helper_%s_%d(%s *buf)" % (type_name, log_n, type_name) 710 res += "%s;\n" % signature 711 res += "%s {\n" % signature 712 res += " helper_%s_%d_recursive(buf, %d);\n" % (type_name, log_n, log_n) 713 res += "}\n" 714 return res 715 716 717def extract_time(data): 718 cpu_time = float(data["cpu_time"]) 719 time_unit = data["time_unit"] 720 if time_unit != "ns": 721 raise Exception("nanoseconds expected") 722 return cpu_time / 1e9 723 724 725def get_mean_stddev(): 726 with open("measurements/output.csv", "r") as csvfile: 727 reader = csv.reader(csvfile) 728 first = True 729 for row in reader: 730 if first: 731 header = row 732 first = False 733 else: 734 data = {} 735 for x, y in zip(header, row): 736 data[x] = y 737 if data["name"] == "benchmark_fht_mean": 738 mean = extract_time(data) 739 elif data["name"] == "benchmark_fht_stddev": 740 stddev = extract_time(data) 741 return mean 742 743 744def measure_time(code, log_n, type_name, method_name, num_it=3): 745 if num_it % 2 == 0: 746 raise Exception("even number of runs: %d" % num_it) 747 with open("measurements/to_run.h", "w") as output: 748 output.write(code) 749 output.write("const int log_n = %d;\n" % log_n) 750 signature = "void run(%s *buf)" % type_name 751 output.write("%s;\n" % signature) 752 output.write("%s {\n" % signature) 753 output.write(" %s(buf);\n" % method_name) 754 output.write("}\n") 755 with open("/dev/null", "wb") as devnull: 756 code = subprocess.call( 757 "cd measurements && make run_%s" % type_name, shell=True, stdout=devnull 758 ) 759 if code != 0: 760 raise Exception("bad exit code") 761 code = subprocess.call( 762 "./measurements/run_%s --benchmark_repetitions=%d --benchmark_format=csv > ./measurements/output.csv" 763 % (type_name, num_it), 764 shell=True, 765 stderr=devnull, 766 ) 767 if code != 0: 768 raise Exception("bad exit code") 769 return get_mean_stddev() 770 771 772# Configuration parameter; set to False if you want the absolute fastest code without regard to size. 773CARE_ABOUT_CODE_SIZE = True 774 775# When CARE_ABOUT_CODE_SIZE, accept the smallest code that is not slower than 776# MAX_PERFORMANCE_PENALTY_FOR_REDUCED_SIZE * the fastest time. 777MAX_PERFORMANCE_PENALTY_FOR_REDUCED_SIZE = 1.1 778 779 780if __name__ == "__main__": 781 final_code = '// @generated\n#include "fht.h"\n' 782 code_so_far = "" 783 hall_of_fame = [] 784 for type_name, composite_step_generator in [("float", float_neon_composite_step)]: 785 for log_n in range(1, max_log_n + 1): 786 sys.stdout.write("log_n = %d\n" % log_n) 787 times = [] 788 try: 789 (res, desc) = ( 790 greedy_merged(type_name, log_n, composite_step_generator), 791 "greedy_merged", 792 ) 793 except Exception: 794 (res, desc) = (plain_unmerged(type_name, log_n), "plain_unmerged") 795 time = measure_time( 796 code_so_far + res, log_n, type_name, "helper_%s_%d" % (type_name, log_n) 797 ) 798 code_size = res.count("\n") 799 times.append((time, res, code_size, desc)) 800 sys.stdout.write( 801 "log_n = %d; iterative; code_size = %d; time = %.10e\n" 802 % (log_n, code_size, time) 803 ) 804 for threshold_step in range(1, log_n + 1): 805 try: 806 res = greedy_merged_recursive( 807 type_name, log_n, threshold_step, composite_step_generator 808 ) 809 time = measure_time( 810 code_so_far + res, 811 log_n, 812 type_name, 813 "helper_%s_%d" % (type_name, log_n), 814 ) 815 code_size = res.count("\n") 816 times.append( 817 ( 818 time, 819 res, 820 code_size, 821 "greedy_merged_recursive %d" % threshold_step, 822 ) 823 ) 824 sys.stdout.write( 825 "log_n = %d; threshold_step = %d; code_size = %d; time = %.10e\n" 826 % (log_n, threshold_step, code_size, time) 827 ) 828 except Exception as e: 829 sys.stdout.write(f"FAIL: {threshold_step} ({e})\n") 830 if CARE_ABOUT_CODE_SIZE: 831 fastest_time = min(times)[0] 832 times_by_size = sorted(times, key=lambda x: x[2]) 833 for x in times_by_size: 834 if x[0] <= fastest_time * MAX_PERFORMANCE_PENALTY_FOR_REDUCED_SIZE: 835 smallest_acceptable = x 836 break 837 (best_time, best_code, best_code_size, best_desc) = smallest_acceptable 838 else: 839 (best_time, best_code, best_code_size, best_desc) = min(times) 840 hall_of_fame.append((type_name, log_n, best_time, best_desc)) 841 final_code += best_code 842 code_so_far += best_code 843 sys.stdout.write( 844 "log_n = %d; best_time = %.10e; %s\n" % (log_n, best_time, best_desc) 845 ) 846 final_code += "int fht_%s(%s *buf, int log_n) {\n" % (type_name, type_name) 847 final_code += " if (log_n == 0) {\n" 848 final_code += " return 0;\n" 849 final_code += " }\n" 850 for i in range(1, max_log_n + 1): 851 final_code += " if (log_n == %d) {\n" % i 852 final_code += " helper_%s_%d(buf);\n" % (type_name, i) 853 final_code += " return 0;\n" 854 final_code += " }\n" 855 final_code += " return 1;\n" 856 final_code += "}\n" 857 with open("fht_neon.c", "w") as output: 858 output.write(final_code) 859 sys.stdout.write("hall of fame\n") 860 with open("hall_of_fame_neon.txt", "w") as hof: 861 for type_name, log_n, best_time, best_desc in hall_of_fame: 862 s = "type_name = %s; log_n = %d; best_time = %.10e; best_desc = %s\n" % ( 863 type_name, 864 log_n, 865 best_time, 866 best_desc, 867 ) 868 sys.stdout.write(s) 869 hof.write(s) 870