• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
2import csv
3import os
4import subprocess
5import sys
6
7max_log_n = 30
8
9
10def is_distinct(l):
11    return len(set(l)) == len(l)
12
13
14def float_avx_0(register, aux_registers, ident=""):
15    if not is_distinct(aux_registers):
16        raise Exception("auxiliary registers must be distinct")
17    if register in aux_registers:
18        raise Exception("the main register can't be one of the auxiliary ones")
19    if len(aux_registers) < 4:
20        raise Exception("float_avx_0 needs at least four auxiliary registers")
21    # given source ABCDEFGH, destination register gets AACCEEGG
22    res = ident + '"vpermilps $160, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0])
23    # given source ABCDEFGH, destination register gets BBDDFFHH
24    res += ident + '"vpermilps $245, %%%%%s, %%%%%s\\n"\n' % (
25        register,
26        aux_registers[1],
27    )
28    # aux2 <- 0
29    res += ident + '"vxorps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
30        aux_registers[2],
31        aux_registers[2],
32        aux_registers[2],
33    )
34    # aux3 <- -B -B -D -D -F -F -H -H
35    res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
36        aux_registers[1],
37        aux_registers[2],
38        aux_registers[3],
39    )
40    # reg <- (A+B)(A-B)(C+D)(C-D)(E+F)(E-F)(G+H)(G-H)
41    res += ident + '"vaddsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
42        aux_registers[3],
43        aux_registers[0],
44        register,
45    )
46    return res
47
48
49def float_avx_1(register, aux_registers, ident=""):
50    if not is_distinct(aux_registers):
51        raise Exception("auxiliary registers must be distinct")
52    if register in aux_registers:
53        raise Exception("the main register can't be one of the auxiliary ones")
54    if len(aux_registers) < 5:
55        raise Exception("float_avx_1 needs at least five auxiliary registers")
56    # Given source ABCDEFGH, r0 <- ABABEFEF
57    res = ident + '"vpermilps $68, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0])
58    # Given source ABCDEFGH, r1 <- CDCDGHGH
59    res += ident + '"vpermilps $238, %%%%%s, %%%%%s\\n"\n' % (
60        register,
61        aux_registers[1],
62    )
63    # r2 <- 0
64    res += ident + '"vxorps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
65        aux_registers[2],
66        aux_registers[2],
67        aux_registers[2],
68    )
69    # r3 <- -C -D -C -D -G -H -G -H
70    res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
71        aux_registers[1],
72        aux_registers[2],
73        aux_registers[3],
74    )
75    # r4 <- C D -C -D G H -G -H
76    res += ident + '"vblendps $204, %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
77        aux_registers[3],
78        aux_registers[1],
79        aux_registers[4],
80    )
81    # reg <- (A + C) (B + D) (A - C) (B - D) etc.
82    res += ident + '"vaddps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
83        aux_registers[0],
84        aux_registers[4],
85        register,
86    )
87    return res
88
89
90def float_avx_2(register, aux_registers, ident=""):
91    if not is_distinct(aux_registers):
92        raise Exception("auxiliary registers must be distinct")
93    if register in aux_registers:
94        raise Exception("the main register can't be one of the auxiliary ones")
95    if len(aux_registers) < 4:
96        raise Exception("float_avx_2 needs at least four auxiliary registers")
97    # r0 <- 0
98    res = ident + '"vxorps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
99        aux_registers[0],
100        aux_registers[0],
101        aux_registers[0],
102    )
103    # r1 <- -A -B -C -D -E -F -G -H
104    res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
105        register,
106        aux_registers[0],
107        aux_registers[1],
108    )
109    # r2 <- ABABEFEF
110    res += ident + '"vperm2f128 $0, %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
111        register,
112        register,
113        aux_registers[2],
114    )
115    # r3 <- C D -C -D G H -G -H
116    res += ident + '"vperm2f128 $49, %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
117        aux_registers[1],
118        register,
119        aux_registers[3],
120    )
121    # reg <- (A + C) (B + D)(A - C) (B - D) etc.
122    res += ident + '"vaddps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
123        aux_registers[2],
124        aux_registers[3],
125        register,
126    )
127    return res
128
129
130def float_avx_3_etc(
131    from_register_0, from_register_1, to_register_0, to_register_1, ident=""
132):
133    if not is_distinct(
134        [from_register_0, from_register_1, to_register_0, to_register_1]
135    ):
136        raise Exception("four registers must be distinct")
137    res = ident + '"vaddps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
138        from_register_1,
139        from_register_0,
140        to_register_0,
141    )
142    res += ident + '"vsubps %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
143        from_register_1,
144        from_register_0,
145        to_register_1,
146    )
147    return res
148
149
150def double_avx_0(register, aux_registers, ident=""):
151    if not is_distinct(aux_registers):
152        raise Exception("auxiliary registers must be distinct")
153    if register in aux_registers:
154        raise Exception("the main register can't be one of the auxiliary ones")
155    if len(aux_registers) < 4:
156        raise Exception("double_avx_0 needs at least four auxiliary registers")
157    # r0 <- AACC
158    res = ident + '"vpermilpd $0, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0])
159    # r1 <- BBDD
160    res += ident + '"vpermilpd $15, %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[1])
161    # r2 <- 0
162    res += ident + '"vxorpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
163        aux_registers[2],
164        aux_registers[2],
165        aux_registers[2],
166    )
167    # r3 <- -B -B -D -D
168    res += ident + '"vsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
169        aux_registers[1],
170        aux_registers[2],
171        aux_registers[3],
172    )
173    # reg <- (A + B)(A - B)(C + D)(C - D)
174    res += ident + '"vaddsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
175        aux_registers[3],
176        aux_registers[0],
177        register,
178    )
179    return res
180
181
182def double_avx_1(register, aux_registers, ident=""):
183    if not is_distinct(aux_registers):
184        raise Exception("auxiliary registers must be distinct")
185    if register in aux_registers:
186        raise Exception("the main register can't be one of the auxiliary ones")
187    if len(aux_registers) < 4:
188        raise Exception("double_avx_1 needs at least four auxiliary registers")
189    # r0 <- ABAB
190    res = ident + '"vperm2f128 $0, %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
191        register,
192        register,
193        aux_registers[0],
194    )
195    # r1 <- 0
196    res += ident + '"vxorpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
197        aux_registers[1],
198        aux_registers[1],
199        aux_registers[1],
200    )
201    # r2 <- -A -B -C -D
202    res += ident + '"vsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
203        register,
204        aux_registers[1],
205        aux_registers[2],
206    )
207    # r3 <- C D -C -D
208    res += ident + '"vperm2f128 $49, %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
209        aux_registers[2],
210        register,
211        aux_registers[3],
212    )
213    # reg <- (A + C)(B + D)(A - C)(B - D)
214    res += ident + '"vaddpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
215        aux_registers[3],
216        aux_registers[0],
217        register,
218    )
219    return res
220
221
222def double_avx_2_etc(
223    from_register_0, from_register_1, to_register_0, to_register_1, ident=""
224):
225    if not is_distinct(
226        [from_register_0, from_register_1, to_register_0, to_register_1]
227    ):
228        raise Exception("four registers must be distinct")
229    res = ident + '"vaddpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
230        from_register_1,
231        from_register_0,
232        to_register_0,
233    )
234    res += ident + '"vsubpd %%%%%s, %%%%%s, %%%%%s\\n"\n' % (
235        from_register_1,
236        from_register_0,
237        to_register_1,
238    )
239    return res
240
241
242def float_sse_0(register, aux_registers, ident=""):
243    if not is_distinct(aux_registers):
244        raise Exception("auxiliary registers must be distinct")
245    if register in aux_registers:
246        raise Exception("the main register can't be one of the auxiliary ones")
247    if len(aux_registers) < 2:
248        raise Exception("float_sse_0 needs at least two auxiliary registers")
249    res = ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0])
250    res += ident + '"shufps $160, %%%%%s, %%%%%s\\n"\n' % (
251        aux_registers[0],
252        aux_registers[0],
253    )
254    res += ident + '"shufps $245, %%%%%s, %%%%%s\\n"\n' % (register, register)
255    res += ident + '"xorps %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[1])
256    res += ident + '"subps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[1])
257    res += ident + '"addsubps %%%%%s, %%%%%s\\n"\n' % (
258        aux_registers[1],
259        aux_registers[0],
260    )
261    res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (aux_registers[0], register)
262    return res
263
264
265def float_sse_1(register, aux_registers, ident=""):
266    if not is_distinct(aux_registers):
267        raise Exception("auxiliary registers must be distinct")
268    if register in aux_registers:
269        raise Exception("the main register can't be one of the auxiliary ones")
270    if len(aux_registers) < 4:
271        raise Exception("float_sse_1 needs at least four auxiliary registers")
272    res = ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0])
273    res += ident + '"shufps $68, %%%%%s, %%%%%s\\n"\n' % (
274        aux_registers[0],
275        aux_registers[0],
276    )
277    res += ident + '"xorps %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[1])
278    res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[2])
279    res += ident + '"shufps $14, %%%%%s, %%%%%s\\n"\n' % (
280        aux_registers[1],
281        aux_registers[2],
282    )
283    res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[3])
284    res += ident + '"shufps $224, %%%%%s, %%%%%s\\n"\n' % (
285        aux_registers[3],
286        aux_registers[1],
287    )
288    res += ident + '"addps %%%%%s, %%%%%s\\n"\n' % (aux_registers[0], aux_registers[2])
289    res += ident + '"subps %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[2])
290    res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (aux_registers[2], register)
291    return res
292
293
294def float_sse_2_etc(
295    from_register_0, from_register_1, to_register_0, to_register_1, ident=""
296):
297    if not is_distinct(
298        [from_register_0, from_register_1, to_register_0, to_register_1]
299    ):
300        raise Exception("four registers must be distinct")
301    res = ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_0)
302    res += ident + '"movaps %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_1)
303    res += ident + '"addps %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_0)
304    res += ident + '"subps %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_1)
305    return res
306
307
308def double_sse_0(register, aux_registers, ident=""):
309    if not is_distinct(aux_registers):
310        raise Exception("auxiliary registers must be distinct")
311    if register in aux_registers:
312        raise Exception("the main register can't be one of the auxiliary ones")
313    if len(aux_registers) < 2:
314        raise Exception("double_sse_0 needs at least two auxiliary registers")
315    res = ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[0])
316    res += ident + '"haddpd %%%%%s, %%%%%s\\n"\n' % (aux_registers[0], aux_registers[0])
317    res += ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (register, aux_registers[1])
318    res += ident + '"hsubpd %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], aux_registers[1])
319    res += ident + '"blendpd $1, %%%%%s, %%%%%s\\n"\n' % (
320        aux_registers[0],
321        aux_registers[1],
322    )
323    res += ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (aux_registers[1], register)
324    return res
325
326
327def double_sse_1_etc(
328    from_register_0, from_register_1, to_register_0, to_register_1, ident=""
329):
330    if not is_distinct(
331        [from_register_0, from_register_1, to_register_0, to_register_1]
332    ):
333        raise Exception("four registers must be distinct")
334    res = ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_0)
335    res += ident + '"movapd %%%%%s, %%%%%s\\n"\n' % (from_register_0, to_register_1)
336    res += ident + '"addpd %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_0)
337    res += ident + '"subpd %%%%%s, %%%%%s\\n"\n' % (from_register_1, to_register_1)
338    return res
339
340
341# Given reg = ABCD, return (A+B)(A-B)(C+D)(C-D)
342def float_neon_0(register, aux_registers, ident=""):
343    if not is_distinct(aux_registers):
344        raise Exception("auxiliary registers must be distinct")
345    if register in aux_registers:
346        raise Exception("the main register can't be one of the auxiliary ones")
347    if len(aux_registers) < 2:
348        raise Exception("float_neon_0 needs at least two auxiliary registers")
349    # r0 <- AACC
350    res = f'{ident}"TRN1 {aux_registers[0]}.4S, {register}.4S, {register}.4S\\n"\n'
351    # r1 <- -A -B -C -D
352    res += f'{ident}"FNEG {aux_registers[1]}.4S, {register}.4S\\n"\n'
353    # r2 <- B (-B) D -D
354    res += f'{ident}"TRN2 {aux_registers[1]}.4S, {register}.4S, {aux_registers[1]}.4S\\n"\n'
355    # reg <- (A+B)(A-B)(C+D)(C-D)
356    res += f'{ident}"FADD {register}.4S, {aux_registers[0]}.4S, {aux_registers[1]}.4S\\n"\n'
357
358    return res
359
360
361# Given reg = ABCD, return (A + C)(B + D)(A - C)(B - D)
362def float_neon_1(register, aux_registers, ident=""):
363    if not is_distinct(aux_registers):
364        raise Exception("auxiliary registers must be distinct")
365    if register in aux_registers:
366        raise Exception("the main register can't be one of the auxiliary ones")
367    if len(aux_registers) < 2:
368        raise Exception("float_neon_1 needs at least two auxiliary registers")
369    # r0 <- ABAB
370    res = f'{ident}"DUP {aux_registers[0]}.2D, {register}.D[0]\\n"\n'
371    # r1 <- -A -B -C -D
372    res += f'{ident}"FNEG {aux_registers[1]}.4S, {register}.4S\\n"\n'
373    # r1 <- C D -C -D
374    res += f'{ident}"INS {aux_registers[1]}.D[0], {register}.D[1]\\n"\n'
375    # reg <- (A + C)(B + D)(A - C)(B - D)
376    res += f'{ident}"FADD {register}.4S, {aux_registers[0]}.4S, {aux_registers[1]}.4S\\n"\n'
377
378    return res
379
380
381def float_neon_2_etc(
382    from_register_0, from_register_1, to_register_0, to_register_1, ident=""
383):
384    if not is_distinct(
385        [from_register_0, from_register_1, to_register_0, to_register_1]
386    ):
387        raise Exception("four registers must be distinct")
388    res = f'{ident}"FADD {to_register_0}.4S, {from_register_0}.4S, {from_register_1}.4S\\n"\n'
389    res += f'{ident}"FSUB {to_register_1}.4S, {from_register_0}.4S, {from_register_1}.4S\\n"\n'
390    return res
391
392
393def plain_step(type_name, buf_name, log_n, it, ident=""):
394    if log_n <= 0:
395        raise Exception("log_n must be positive")
396    if it < 0:
397        raise Exception("it must be non-negative")
398    if it >= log_n:
399        raise Exception("it must be smaller than log_n")
400    n = 1 << log_n
401    res = ident + "for (int j = 0; j < %d; j += %d) {\n" % (n, 1 << (it + 1))
402    res += ident + "  for (int k = 0; k < %d; ++k) {\n" % (1 << it)
403    res += ident + "    %s u = %s[j + k];\n" % (type_name, buf_name)
404    res += ident + "    %s v = %s[j + k + %d];\n" % (type_name, buf_name, 1 << it)
405    res += ident + "    %s[j + k] = u + v;\n" % buf_name
406    res += ident + "    %s[j + k + %d] = u - v;\n" % (buf_name, 1 << it)
407    res += ident + "  }\n"
408    res += ident + "}\n"
409    return res
410
411
412MOVE_INSTRUCTION_USE_NEON = "NEON MOV"
413
414
415def composite_step(
416    buf_name,
417    log_n,
418    from_it,
419    to_it,
420    log_w,
421    registers,
422    move_instruction,
423    special_steps,
424    main_step,
425    ident="",
426):
427    # HACK: NEON needs different syntax for loads and stores.
428    use_neon_movs = move_instruction == MOVE_INSTRUCTION_USE_NEON
429    if log_n < log_w:
430        raise Exception("need at least %d elements" % (1 << log_w))
431    num_registers = len(registers)
432    if num_registers % 2 == 1:
433        raise Exception("odd number of registers: %d" % num_registers)
434    num_nontrivial_levels = 0
435    if to_it > log_w:
436        first_nontrivial = max(from_it, log_w)
437        num_nontrivial_levels = to_it - first_nontrivial
438        if 1 << num_nontrivial_levels > num_registers / 2:
439            raise Exception("not enough registers")
440    n = 1 << log_n
441    input_registers = []
442    output_registers = []
443    for i in range(num_registers):
444        if i < num_registers / 2:
445            input_registers.append(registers[i])
446        else:
447            output_registers.append(registers[i])
448    clobber = ", ".join(['"%%%s"' % x for x in registers])
449    if num_nontrivial_levels == 0:
450        res = ident + "for (int j = 0; j < %d; j += %d) {\n" % (n, 1 << log_w)
451        res += ident + "  __asm__ volatile (\n"
452        if use_neon_movs:
453            res += f'{ident}    "LD1 {{{input_registers[0]}.4S}}, [%0]\\n"\n'
454        else:
455            res += ident + '    "%s (%%0), %%%%%s\\n"\n' % (
456                move_instruction,
457                input_registers[0],
458            )
459        for it in range(from_it, to_it):
460            res += special_steps[it](
461                input_registers[0], output_registers, ident + "    "
462            )
463        if use_neon_movs:
464            res += f'{ident}    "ST1 {{{input_registers[0]}.4S}}, [%0]\\n"\n'
465        else:
466            res += ident + '    "%s %%%%%s, (%%0)\\n"\n' % (
467                move_instruction,
468                input_registers[0],
469            )
470        res += ident + '    :: "r"(%s + j) : %s, "memory"\n' % (buf_name, clobber)
471        res += ident + "  );\n"
472        res += ident + "}\n"
473        return res
474    res = ident + "for (int j = 0; j < %d; j += %d) {\n" % (n, 1 << to_it)
475    res += ident + "  for (int k = 0; k < %d; k += %d) {\n" % (
476        1 << (to_it - num_nontrivial_levels),
477        1 << log_w,
478    )
479    subcube = []
480    for l in range(1 << num_nontrivial_levels):
481        subcube.append("j + k + " + str(l * (1 << (to_it - num_nontrivial_levels))))
482    res += ident + "    __asm__ volatile (\n"
483    for l in range(1 << num_nontrivial_levels):
484        if use_neon_movs:
485            res += f'{ident}      "LD1 {{{input_registers[l]}.4S}}, [%{l}]\\n"\n'
486        else:
487            res += ident + '      "%s (%%%d), %%%%%s\\n"\n' % (
488                move_instruction,
489                l,
490                input_registers[l],
491            )
492    for it in range(from_it, log_w):
493        for ii in range(1 << num_nontrivial_levels):
494            res += special_steps[it](
495                input_registers[ii], output_registers, ident + "      "
496            )
497    for it in range(num_nontrivial_levels):
498        for ii in range(0, 1 << num_nontrivial_levels, 1 << (it + 1)):
499            for jj in range(1 << it):
500                res += main_step(
501                    input_registers[ii + jj],
502                    input_registers[ii + jj + (1 << it)],
503                    output_registers[ii + jj],
504                    output_registers[ii + jj + (1 << it)],
505                    ident + "      ",
506                )
507        tmp = input_registers
508        input_registers = output_registers
509        output_registers = tmp
510    for l in range(1 << num_nontrivial_levels):
511        if use_neon_movs:
512            res += f'{ident}      "ST1 {{{input_registers[l]}.4S}}, [%{l}]\\n"\n'
513        else:
514            res += ident + '      "%s %%%%%s, (%%%d)\\n"\n' % (
515                move_instruction,
516                input_registers[l],
517                l,
518            )
519    res += ident + '      :: %s : %s, "memory"\n' % (
520        ", ".join(['"r"(%s + %s)' % (buf_name, x) for x in subcube]),
521        clobber,
522    )
523    res += ident + "    );\n"
524    res += ident + "  }\n"
525    res += ident + "}\n"
526    return res
527
528
529def float_avx_composite_step(buf_name, log_n, from_it, to_it, ident=""):
530    return composite_step(
531        buf_name,
532        log_n,
533        from_it,
534        to_it,
535        3,
536        ["ymm%d" % x for x in range(16)],
537        "vmovups",
538        [float_avx_0, float_avx_1, float_avx_2],
539        float_avx_3_etc,
540        ident,
541    )
542
543
544def double_avx_composite_step(buf_name, log_n, from_it, to_it, ident=""):
545    return composite_step(
546        buf_name,
547        log_n,
548        from_it,
549        to_it,
550        2,
551        ["ymm%d" % x for x in range(16)],
552        "vmovupd",
553        [double_avx_0, double_avx_1],
554        double_avx_2_etc,
555        ident,
556    )
557
558
559def float_sse_composite_step(buf_name, log_n, from_it, to_it, ident=""):
560    return composite_step(
561        buf_name,
562        log_n,
563        from_it,
564        to_it,
565        2,
566        ["xmm%d" % x for x in range(16)],
567        "movups",
568        [float_sse_0, float_sse_1],
569        float_sse_2_etc,
570        ident,
571    )
572
573
574def double_sse_composite_step(buf_name, log_n, from_it, to_it, ident=""):
575    return composite_step(
576        buf_name,
577        log_n,
578        from_it,
579        to_it,
580        1,
581        ["xmm%d" % x for x in range(16)],
582        "movupd",
583        [double_sse_0],
584        double_sse_1_etc,
585        ident,
586    )
587
588
589NEON_VECTOR_REGS = [f"v{x}" for x in range(0, 32)]
590
591
592def float_neon_composite_step(buf_name, log_n, from_it, to_it, ident=""):
593    return composite_step(
594        buf_name,
595        log_n,
596        from_it,
597        to_it,
598        2,
599        NEON_VECTOR_REGS,
600        MOVE_INSTRUCTION_USE_NEON,
601        [float_neon_0, float_neon_1],
602        float_neon_2_etc,
603        ident,
604    )
605
606
607def plain_unmerged(type_name, log_n):
608    signature = "static inline void helper_%s_%d(%s *buf)" % (
609        type_name,
610        log_n,
611        type_name,
612    )
613    res = "%s;\n" % signature
614    res += "%s {\n" % signature
615    for i in range(log_n):
616        res += plain_step(type_name, "buf", log_n, i, "  ")
617    res += "}\n"
618    return res
619
620
621def greedy_merged(type_name, log_n, composite_step):
622    try:
623        composite_step("buf", log_n, 0, 0)
624    except Exception:
625        raise Exception("log_n is too small: %d" % log_n)
626    signature = "static inline void helper_%s_%d(%s *buf)" % (
627        type_name,
628        log_n,
629        type_name,
630    )
631    res = "%s;\n" % signature
632    res += "%s {\n" % signature
633    cur_it = 0
634    while cur_it < log_n:
635        cur_to_it = log_n
636        while True:
637            try:
638                composite_step("buf", log_n, cur_it, cur_to_it)
639                break
640            except Exception as e:
641                print(f"warning: {e}")
642                cur_to_it -= 1
643                continue
644        res += composite_step("buf", log_n, cur_it, cur_to_it, "  ")
645        cur_it = cur_to_it
646    res += "}\n"
647    return res
648
649
650def greedy_merged_recursive(type_name, log_n, threshold_step, composite_step):
651    if threshold_step > log_n:
652        raise Exception("threshold_step must be at most log_n")
653    try:
654        composite_step("buf", threshold_step, 0, 0)
655    except Exception:
656        raise Exception("threshold_step is too small: %d" % threshold_step)
657    signature = "void helper_%s_%d_recursive(%s *buf, int depth)" % (
658        type_name,
659        log_n,
660        type_name,
661    )
662    res = "%s;\n" % signature
663    res += "%s {\n" % signature
664    res += "  if (depth == %d) {\n" % threshold_step
665    if threshold_step == log_n:
666        cur_it = 0
667        while cur_it < threshold_step:
668            cur_to_it = threshold_step
669            while True:
670                try:
671                    composite_step("buf", threshold_step, cur_it, cur_to_it)
672                    break
673                except Exception:
674                    cur_to_it -= 1
675                    continue
676            res += composite_step("buf", threshold_step, cur_it, cur_to_it, "    ")
677            cur_it = cur_to_it
678    else:
679        res += "    helper_%s_%d(buf);\n" % (type_name, threshold_step)
680
681    res += "    return;\n"
682    res += "  }\n"
683    cur_it = threshold_step
684    while cur_it < log_n:
685        cur_to_it = log_n
686        while True:
687            try:
688                composite_step("buf", cur_to_it, cur_it, cur_to_it)
689                break
690            except Exception:
691                cur_to_it -= 1
692                continue
693        res += "  if (depth == %d) {\n" % cur_to_it
694        for i in range(1 << (cur_to_it - cur_it)):
695            res += "    helper_%s_%d_recursive(buf + %d, %d);\n" % (
696                type_name,
697                log_n,
698                i * (1 << cur_it),
699                cur_it,
700            )
701        if cur_to_it < log_n:
702            res += "    helper_%s_%d(buf);" % (type_name, cur_to_it)
703        else:
704            res += composite_step("buf", cur_to_it, cur_it, cur_to_it, "    ")
705        res += "    return;\n"
706        res += "  }\n"
707        cur_it = cur_to_it
708    res += "}\n"
709    signature = "void helper_%s_%d(%s *buf)" % (type_name, log_n, type_name)
710    res += "%s;\n" % signature
711    res += "%s {\n" % signature
712    res += "  helper_%s_%d_recursive(buf, %d);\n" % (type_name, log_n, log_n)
713    res += "}\n"
714    return res
715
716
717def extract_time(data):
718    cpu_time = float(data["cpu_time"])
719    time_unit = data["time_unit"]
720    if time_unit != "ns":
721        raise Exception("nanoseconds expected")
722    return cpu_time / 1e9
723
724
725def get_mean_stddev():
726    with open("measurements/output.csv", "r") as csvfile:
727        reader = csv.reader(csvfile)
728        first = True
729        for row in reader:
730            if first:
731                header = row
732                first = False
733            else:
734                data = {}
735                for x, y in zip(header, row):
736                    data[x] = y
737                if data["name"] == "benchmark_fht_mean":
738                    mean = extract_time(data)
739                elif data["name"] == "benchmark_fht_stddev":
740                    stddev = extract_time(data)
741    return mean
742
743
744def measure_time(code, log_n, type_name, method_name, num_it=3):
745    if num_it % 2 == 0:
746        raise Exception("even number of runs: %d" % num_it)
747    with open("measurements/to_run.h", "w") as output:
748        output.write(code)
749        output.write("const int log_n = %d;\n" % log_n)
750        signature = "void run(%s *buf)" % type_name
751        output.write("%s;\n" % signature)
752        output.write("%s {\n" % signature)
753        output.write("  %s(buf);\n" % method_name)
754        output.write("}\n")
755    with open("/dev/null", "wb") as devnull:
756        code = subprocess.call(
757            "cd measurements && make run_%s" % type_name, shell=True, stdout=devnull
758        )
759        if code != 0:
760            raise Exception("bad exit code")
761        code = subprocess.call(
762            "./measurements/run_%s --benchmark_repetitions=%d --benchmark_format=csv > ./measurements/output.csv"
763            % (type_name, num_it),
764            shell=True,
765            stderr=devnull,
766        )
767        if code != 0:
768            raise Exception("bad exit code")
769    return get_mean_stddev()
770
771
772# Configuration parameter; set to False if you want the absolute fastest code without regard to size.
773CARE_ABOUT_CODE_SIZE = True
774
775# When CARE_ABOUT_CODE_SIZE, accept the smallest code that is not slower than
776# MAX_PERFORMANCE_PENALTY_FOR_REDUCED_SIZE * the fastest time.
777MAX_PERFORMANCE_PENALTY_FOR_REDUCED_SIZE = 1.1
778
779
780if __name__ == "__main__":
781    final_code = '// @generated\n#include "fht.h"\n'
782    code_so_far = ""
783    hall_of_fame = []
784    for type_name, composite_step_generator in [("float", float_neon_composite_step)]:
785        for log_n in range(1, max_log_n + 1):
786            sys.stdout.write("log_n = %d\n" % log_n)
787            times = []
788            try:
789                (res, desc) = (
790                    greedy_merged(type_name, log_n, composite_step_generator),
791                    "greedy_merged",
792                )
793            except Exception:
794                (res, desc) = (plain_unmerged(type_name, log_n), "plain_unmerged")
795            time = measure_time(
796                code_so_far + res, log_n, type_name, "helper_%s_%d" % (type_name, log_n)
797            )
798            code_size = res.count("\n")
799            times.append((time, res, code_size, desc))
800            sys.stdout.write(
801                "log_n = %d; iterative; code_size = %d; time = %.10e\n"
802                % (log_n, code_size, time)
803            )
804            for threshold_step in range(1, log_n + 1):
805                try:
806                    res = greedy_merged_recursive(
807                        type_name, log_n, threshold_step, composite_step_generator
808                    )
809                    time = measure_time(
810                        code_so_far + res,
811                        log_n,
812                        type_name,
813                        "helper_%s_%d" % (type_name, log_n),
814                    )
815                    code_size = res.count("\n")
816                    times.append(
817                        (
818                            time,
819                            res,
820                            code_size,
821                            "greedy_merged_recursive %d" % threshold_step,
822                        )
823                    )
824                    sys.stdout.write(
825                        "log_n = %d; threshold_step = %d; code_size = %d; time = %.10e\n"
826                        % (log_n, threshold_step, code_size, time)
827                    )
828                except Exception as e:
829                    sys.stdout.write(f"FAIL: {threshold_step} ({e})\n")
830            if CARE_ABOUT_CODE_SIZE:
831                fastest_time = min(times)[0]
832                times_by_size = sorted(times, key=lambda x: x[2])
833                for x in times_by_size:
834                    if x[0] <= fastest_time * MAX_PERFORMANCE_PENALTY_FOR_REDUCED_SIZE:
835                        smallest_acceptable = x
836                        break
837                (best_time, best_code, best_code_size, best_desc) = smallest_acceptable
838            else:
839                (best_time, best_code, best_code_size, best_desc) = min(times)
840            hall_of_fame.append((type_name, log_n, best_time, best_desc))
841            final_code += best_code
842            code_so_far += best_code
843            sys.stdout.write(
844                "log_n = %d; best_time = %.10e; %s\n" % (log_n, best_time, best_desc)
845            )
846        final_code += "int fht_%s(%s *buf, int log_n) {\n" % (type_name, type_name)
847        final_code += "  if (log_n == 0) {\n"
848        final_code += "    return 0;\n"
849        final_code += "  }\n"
850        for i in range(1, max_log_n + 1):
851            final_code += "  if (log_n == %d) {\n" % i
852            final_code += "    helper_%s_%d(buf);\n" % (type_name, i)
853            final_code += "    return 0;\n"
854            final_code += "  }\n"
855        final_code += "  return 1;\n"
856        final_code += "}\n"
857    with open("fht_neon.c", "w") as output:
858        output.write(final_code)
859    sys.stdout.write("hall of fame\n")
860    with open("hall_of_fame_neon.txt", "w") as hof:
861        for type_name, log_n, best_time, best_desc in hall_of_fame:
862            s = "type_name = %s; log_n = %d; best_time = %.10e; best_desc = %s\n" % (
863                type_name,
864                log_n,
865                best_time,
866                best_desc,
867            )
868            sys.stdout.write(s)
869            hof.write(s)
870