• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#! /usr/bin/env python3
2
3# script to parse nvidia CL headers and generate inlines to be used in pushbuffer encoding.
4# probably needs python3.9
5
6import argparse
7import os.path
8import sys
9import re
10import subprocess
11
12from mako.template import Template
13
14METHOD_ARRAY_SIZES = {
15    'BIND_GROUP_CONSTANT_BUFFER'                            : 16,
16    'CALL_MME_DATA'                                         : 256,
17    'CALL_MME_MACRO'                                        : 256,
18    'LOAD_CONSTANT_BUFFER'                                  : 16,
19    'LOAD_INLINE_QMD_DATA'                                  : 64,
20    'SET_ANTI_ALIAS_SAMPLE_POSITIONS'                       : 4,
21    'SET_BLEND'                                             : 8,
22    'SET_BLEND_PER_TARGET_*'                                : 8,
23    'SET_COLOR_TARGET_*'                                    : 8,
24    'SET_COLOR_COMPRESSION'                                 : 8,
25    'SET_COLOR_CLEAR_VALUE'                                 : 4,
26    'SET_CT_WRITE'                                          : 8,
27    'SET_MME_SHADOW_SCRATCH'                                : 256,
28    'SET_MULTI_VIEW_RENDER_TARGET_ARRAY_INDEX_OFFSET'       : 4,
29    'SET_PIPELINE_*'                                        : 6,
30    'SET_ROOT_TABLE_VISIBILITY'                             : 8,
31    'SET_SCG_COMPUTE_SCHEDULING_PARAMETERS'                 : 16,
32    'SET_SCG_GRAPHICS_SCHEDULING_PARAMETERS'                : 16,
33    'SET_SCISSOR_*'                                         : 16,
34    'SET_SHADER_PERFORMANCE_SNAPSHOT_COUNTER_VALUE*'        : 8,
35    'SET_SHADER_PERFORMANCE_COUNTER_VALUE*'                 : 8,
36    'SET_SHADER_PERFORMANCE_COUNTER_EVENT'                  : 8,
37    'SET_SHADER_PERFORMANCE_COUNTER_CONTROL_A'              : 8,
38    'SET_SHADER_PERFORMANCE_COUNTER_CONTROL_B'              : 8,
39    'SET_SHADING_RATE_INDEX_SURFACE_*'                      : 1,
40    'SET_SPARE_MULTI_VIEW_RENDER_TARGET_ARRAY_INDEX_OFFSET' : 4,
41    'SET_STREAM_OUT_BUFFER_*'                               : 4,
42    'SET_STREAM_OUT_CONTROL_*'                              : 4,
43    'SET_VARIABLE_PIXEL_RATE_SAMPLE_ORDER'                  : 13,
44    'SET_VARIABLE_PIXEL_RATE_SHADING_CONTROL*'              : 16,
45    'SET_VARIABLE_PIXEL_RATE_SHADING_INDEX_TO_RATE*'        : 16,
46    'SET_VIEWPORT_*'                                        : 16,
47    'SET_VERTEX_ATTRIBUTE_*'                                : 16,
48    'SET_VERTEX_STREAM_*'                                   : 16,
49    'SET_WINDOW_CLIP_*'                                     : 8,
50    'SET_CLIP_ID_EXTENT_*'                                  : 4,
51}
52
53METHOD_IS_FLOAT = [
54    'SET_BLEND_CONST_*',
55    'SET_DEPTH_BIAS',
56    'SET_SLOPE_SCALE_DEPTH_BIAS',
57    'SET_DEPTH_BIAS_CLAMP',
58    'SET_DEPTH_BOUNDS_M*',
59    'SET_LINE_WIDTH_FLOAT',
60    'SET_ALIASED_LINE_WIDTH_FLOAT',
61    'SET_VIEWPORT_SCALE_*',
62    'SET_VIEWPORT_OFFSET_*',
63    'SET_VIEWPORT_CLIP_MIN_Z',
64    'SET_VIEWPORT_CLIP_MAX_Z',
65    'SET_Z_CLEAR_VALUE',
66]
67
68TEMPLATE_H = Template("""\
69/* parsed class ${nvcl} */
70
71#include "nvtypes.h"
72#include "${clheader}"
73
74#include <assert.h>
75#include <stdio.h>
76#include "util/u_math.h"
77
78%for mthd in mthddict:
79struct nv_${nvcl.lower()}_${mthd} {
80  %for field_name in mthddict[mthd].field_name_start:
81    uint32_t ${field_name.lower()};
82  %endfor
83};
84
85static inline void
86__${nvcl}_${mthd}(uint32_t *val_out, struct nv_${nvcl.lower()}_${mthd} st)
87{
88    uint32_t val = 0;
89  %for field_name in mthddict[mthd].field_name_start:
90    <%
91        field_start = int(mthddict[mthd].field_name_start[field_name])
92        field_end = int(mthddict[mthd].field_name_end[field_name])
93        field_width = field_end - field_start + 1
94    %>
95    %if field_width == 32:
96    val |= st.${field_name.lower()};
97    %else:
98    assert(st.${field_name.lower()} < (1ULL << ${field_width}));
99    val |= st.${field_name.lower()} << ${field_start};
100    %endif
101  %endfor
102    *val_out = val;
103}
104
105#define V_${nvcl}_${mthd}(val, args...) { ${bs}
106  %for field_name in mthddict[mthd].field_name_start:
107    %for d in mthddict[mthd].field_defs[field_name]:
108    UNUSED uint32_t ${field_name}_${d} = ${nvcl}_${mthd}_${field_name}_${d}; ${bs}
109    %endfor
110  %endfor
111  %if len(mthddict[mthd].field_name_start) > 1:
112    struct nv_${nvcl.lower()}_${mthd} __data = args; ${bs}
113  %else:
114<% field_name = next(iter(mthddict[mthd].field_name_start)).lower() %>\
115    struct nv_${nvcl.lower()}_${mthd} __data = { .${field_name} = (args) }; ${bs}
116  %endif
117    __${nvcl}_${mthd}(&val, __data); ${bs}
118}
119
120%if mthddict[mthd].is_array:
121#define VA_${nvcl}_${mthd}(i) V_${nvcl}_${mthd}
122%else:
123#define VA_${nvcl}_${mthd} V_${nvcl}_${mthd}
124%endif
125
126%if mthddict[mthd].is_array:
127#define P_${nvcl}_${mthd}(push, idx, args...) do { ${bs}
128%else:
129#define P_${nvcl}_${mthd}(push, args...) do { ${bs}
130%endif
131  %for field_name in mthddict[mthd].field_name_start:
132    %for d in mthddict[mthd].field_defs[field_name]:
133    UNUSED uint32_t ${field_name}_${d} = ${nvcl}_${mthd}_${field_name}_${d}; ${bs}
134    %endfor
135  %endfor
136    uint32_t nvk_p_ret; ${bs}
137    V_${nvcl}_${mthd}(nvk_p_ret, args); ${bs}
138    %if mthddict[mthd].is_array:
139    nv_push_val(push, ${nvcl}_${mthd}(idx), nvk_p_ret); ${bs}
140    %else:
141    nv_push_val(push, ${nvcl}_${mthd}, nvk_p_ret); ${bs}
142    %endif
143} while(0)
144
145%endfor
146
147const char *P_PARSE_${nvcl}_MTHD(uint16_t idx);
148void P_DUMP_${nvcl}_MTHD_DATA(FILE *fp, uint16_t idx, uint32_t data,
149                              const char *prefix);
150""")
151
152TEMPLATE_C = Template("""\
153#include "${header}"
154
155#include <stdio.h>
156
157const char*
158P_PARSE_${nvcl}_MTHD(uint16_t idx)
159{
160    switch (idx) {
161%for mthd in mthddict:
162  %if mthddict[mthd].is_array and mthddict[mthd].array_size == 0:
163    <% continue %>
164  %endif
165  %if mthddict[mthd].is_array:
166    %for i in range(mthddict[mthd].array_size):
167    case ${nvcl}_${mthd}(${i}):
168        return "${nvcl}_${mthd}(${i})";
169    %endfor
170  % else:
171    case ${nvcl}_${mthd}:
172        return "${nvcl}_${mthd}";
173  %endif
174%endfor
175    default:
176        return "unknown method";
177    }
178}
179
180void
181P_DUMP_${nvcl}_MTHD_DATA(FILE *fp, uint16_t idx, uint32_t data,
182                         const char *prefix)
183{
184    uint32_t parsed;
185    switch (idx) {
186%for mthd in mthddict:
187  %if mthddict[mthd].is_array and mthddict[mthd].array_size == 0:
188    <% continue %>
189  %endif
190  %if mthddict[mthd].is_array:
191    %for i in range(mthddict[mthd].array_size):
192    case ${nvcl}_${mthd}(${i}):
193    %endfor
194  % else:
195    case ${nvcl}_${mthd}:
196  %endif
197  %for field_name in mthddict[mthd].field_name_start:
198    <%
199        field_start = int(mthddict[mthd].field_name_start[field_name])
200        field_end = int(mthddict[mthd].field_name_end[field_name])
201        field_width = field_end - field_start + 1
202    %>
203    %if field_width == 32:
204        parsed = data;
205    %else:
206        parsed = (data >> ${field_start}) & ((1u << ${field_width}) - 1);
207    %endif
208        fprintf(fp, "%s.${field_name} = ", prefix);
209    %if len(mthddict[mthd].field_defs[field_name]):
210        switch (parsed) {
211      %for d in mthddict[mthd].field_defs[field_name]:
212        case ${nvcl}_${mthd}_${field_name}_${d}:
213            fprintf(fp, "${d}${bs}n");
214            break;
215      %endfor
216        default:
217            fprintf(fp, "0x%x${bs}n", parsed);
218            break;
219        }
220    %else:
221      %if mthddict[mthd].is_float:
222        fprintf(fp, "%ff (0x%x)${bs}n", uif(parsed), parsed);
223      %else:
224        fprintf(fp, "(0x%x)${bs}n", parsed);
225      %endif
226    %endif
227  %endfor
228        break;
229%endfor
230    default:
231        fprintf(fp, "%s.VALUE = 0x%x${bs}n", prefix, data);
232        break;
233    }
234}
235""")
236
237TEMPLATE_RS = Template("""\
238// parsed class ${nvcl}
239
240% if version is not None:
241pub const ${version[0]}: u16 = ${version[1]};
242% endif
243""")
244
245TEMPLATE_RS_MTHD = Template("""\
246
247// parsed class ${nvcl}
248
249## Write out the methods in Rust
250%for mthd_name, mthd in mthddict.items():
251## Identify the field type.
252<%
253for field_name, field_value in mthd.field_defs.items():
254    if field_name == 'V' and len(field_value) > 0:
255        mthd.field_rs_types[field_name] = to_camel(mthd_name) + 'V'
256        mthd.field_is_rs_enum[field_name] = True
257    elif len(field_value) > 0:
258        assert(field_name != "")
259        mthd.field_rs_types[field_name] = to_camel(mthd_name) + to_camel(field_name)
260        mthd.field_is_rs_enum[field_name] = True
261    elif mthd.is_float:
262        mthd.field_rs_types[field_name] = "f32"
263        mthd.field_is_rs_enum[field_name] = False
264    else:
265        mthd.field_rs_types[field_name] = "u32"
266        mthd.field_is_rs_enum[field_name] = False
267
268    # TRUE and FALSE are special cases.
269    if len(field_value) == 2:
270        for enumerant in field_value:
271            if enumerant.lower() == 'true' or enumerant.lower() == 'false':
272                mthd.field_rs_types[field_name] = "bool"
273                mthd.field_is_rs_enum[field_name] = False
274                break
275%>
276
277## If there are a range of values for a field, we define an enum.
278%for field_name in mthd.field_defs:
279    %if mthd.field_is_rs_enum[field_name]:
280#[repr(u16)]
281#[derive(Copy, Clone, Debug, PartialEq)]
282pub enum ${mthd.field_rs_types[field_name]} {
283    %for field_name, field_value in mthd.field_defs[field_name].items():
284    ${to_camel(rs_field_name(field_name))} = ${field_value.lower()},
285    %endfor
286}
287    %endif
288%endfor
289
290## We also define a struct with the fields for the mthd.
291#[derive(Copy, Clone, Debug, PartialEq)]
292pub struct ${to_camel(mthd_name)} {
293  %for field_name in mthddict[mthd_name].field_name_start:
294    pub ${rs_field_name(field_name.lower())}: ${mthd.field_rs_types[field_name]},
295  %endfor
296}
297
298## Notice that the "to_bits" implementation is identical, so the first brace is
299## not closed.
300% if not mthd.is_array:
301## This trait lays out how the conversion to u32 happens
302impl Mthd for ${to_camel(mthd_name)} {
303    const ADDR: u16 = ${mthd.addr.replace('(', '').replace(')', '')};
304    const CLASS: u16 = ${version[1].lower() if version is not None else nvcl.lower().replace("nv", "0x")};
305
306%else:
307impl ArrayMthd for ${to_camel(mthd_name)} {
308    const CLASS: u16 = ${version[1].lower() if version is not None else nvcl.lower().replace("nv", "0x")};
309
310    fn addr(i: usize) -> u16 {
311        <% assert not ('i' in mthd.addr and 'j' in mthd.addr) %>
312        (${mthd.addr.replace('j', 'i').replace('(', '').replace(')', '')}).try_into().unwrap()
313    }
314%endif
315
316    #[inline]
317    fn to_bits(self) -> u32 {
318        let mut val = 0;
319        %for field_name in mthddict[mthd_name].field_name_start:
320            <%
321                field_start = int(mthd.field_name_start[field_name])
322                field_end = int(mthd.field_name_end[field_name])
323                field_width = field_end - field_start + 1
324                field = rs_field_name(field_name.lower()) if mthd.field_rs_types[field_name] == "u32" else f"{rs_field_name(field_name)} as u32"
325            %>
326            %if field_width == 32:
327        val |= self.${field};
328            %else:
329                %if "as u32" in field:
330        assert!((self.${field}) < (1 << ${field_width}));
331        val |= (self.${field}) << ${field_start};
332                %else:
333        assert!(self.${field} < (1 << ${field_width}));
334        val |= self.${field} << ${field_start};
335                %endif
336            %endif
337        %endfor
338
339        val
340    }
341## Close the first brace.
342}
343%endfor
344""")
345
346## A mere convenience to convert snake_case to CamelCase. Numbers are prefixed
347## with "_".
348def to_camel(snake_str):
349    result = ''.join(word.title() for word in snake_str.split('_'))
350    return result if not result[0].isdigit() else '_' + result
351
352def rs_field_name(name):
353    name = name.lower()
354
355    # Fix up some Rust keywords
356    if name == 'type':
357        return 'type_'
358    elif name == 'override':
359        return 'override_'
360    elif name[0].isdigit():
361        return '_' + name
362    else:
363        return name
364
365def glob_match(glob, name):
366    if glob.endswith('*'):
367        return name.startswith(glob[:-1])
368    else:
369        assert '*' not in glob
370        return name == glob
371
372class method(object):
373    @property
374    def array_size(self):
375        for (glob, value) in METHOD_ARRAY_SIZES.items():
376            if glob_match(glob, self.name):
377                return value
378        return 0
379
380    @property
381    def is_float(self):
382        for glob in METHOD_IS_FLOAT:
383            if glob_match(glob, self.name):
384                assert len(self.field_defs) == 1
385                return True
386        return False
387
388def parse_header(nvcl, f):
389    # Simple state machine
390    # state 0 looking for a new method define
391    # state 1 looking for new fields in a method
392    # state 2 looking for enums for a fields in a method
393    # blank lines reset the state machine to 0
394
395    version = None
396    state = 0
397    mthddict = {}
398    curmthd = {}
399    for line in f:
400
401        if line.strip() == "":
402            state = 0
403            if (curmthd):
404                if not len(curmthd.field_name_start):
405                    del mthddict[curmthd.name]
406            curmthd = {}
407            continue
408
409        if line.startswith("#define"):
410            list = line.split();
411            if "_cl_" in list[1]:
412                continue
413
414            if not list[1].startswith(nvcl):
415                if len(list) > 2 and list[2].startswith("0x"):
416                    assert version is None
417                    version = (list[1], list[2])
418                continue
419
420            if list[1].endswith("TYPEDEF"):
421                continue
422
423            if state == 2:
424                teststr = nvcl + "_" + curmthd.name + "_" + curfield + "_"
425                if ":" in list[2]:
426                    state = 1
427                elif teststr in list[1]:
428                    curmthd.field_defs[curfield][list[1].removeprefix(teststr)] = list[2]
429                else:
430                    state = 1
431
432            if state == 1:
433                teststr = nvcl + "_" + curmthd.name + "_"
434                if teststr in list[1]:
435                    if ("0x" in list[2]):
436                        state = 1
437                    else:
438                        field = list[1].removeprefix(teststr)
439                        bitfield = list[2].split(":")
440                        curmthd.field_name_start[field] = bitfield[1]
441                        curmthd.field_name_end[field] = bitfield[0]
442                        curmthd.field_defs[field] = {}
443                        curfield = field
444                        state = 2
445                else:
446                    if not len(curmthd.field_name_start):
447                        del mthddict[curmthd.name]
448                        curmthd = {}
449                    state = 0
450
451            if state == 0:
452                if (curmthd):
453                    if not len(curmthd.field_name_start):
454                        del mthddict[curmthd.name]
455                teststr = nvcl + "_"
456                is_array = 0
457                if (':' in list[2]):
458                    continue
459                name = list[1].removeprefix(teststr)
460                if name.endswith("(i)"):
461                    is_array = 1
462                    name = name.removesuffix("(i)")
463                if name.endswith("(j)"):
464                    is_array = 1
465                    name = name.removesuffix("(j)")
466                x = method()
467                x.name = name
468                x.addr = list[2]
469                x.is_array = is_array
470                x.field_name_start = {}
471                x.field_name_end = {}
472                x.field_defs = {}
473                x.field_rs_types = {}
474                x.field_is_rs_enum = {}
475                mthddict[x.name] = x
476
477                curmthd = x
478                state = 1
479
480    return (version, mthddict)
481
482def convert_to_rust_constants(filename):
483    with open(filename, 'r') as file:
484        lines = file.readlines()
485
486    rust_items = []
487    processed_constants = {}
488    file_prefix = "NV" + os.path.splitext(os.path.basename(filename))[0].upper() + "_"
489    file_prefix = file_prefix.replace('CL', '')
490    for line in lines:
491        match = re.match(r'#define\s+(\w+)\((\w+)\)\s+(.+)', line.strip())
492        if match:
493            name, arg, expr = match.groups()
494            if name in processed_constants:
495                processed_constants[name] += 1
496                name += f"_{processed_constants[name]}"
497            else:
498                processed_constants[name] = 0
499            name = name.replace(file_prefix, '')
500            # convert to snake case
501            name =  re.sub(r'(?<=[a-z])(?=[A-Z])', '_', name).lower()
502            rust_items.append(f"#[inline]\npub fn {name}  ({arg}: u32) -> u32 {{ {expr.replace('(', '').replace(')', '')} }} ")
503        else:
504            match = re.match(r'#define\s+(\w+)\s+(?:MW\()?(\d+):(\d+)\)?', line.strip())
505            if match:
506                name, high, low = match.groups()
507                high = int(high) + 1  # Convert to exclusive range
508                if name in processed_constants:
509                    processed_constants[name] += 1
510                    name += f"_{processed_constants[name]}"
511                else:
512                    processed_constants[name] = 0
513                # name = name.replace('__', '_').replace(file_prefix, '')
514                name = name.replace(file_prefix, '')
515                rust_items.append(f"pub const {name}: Range<u32> = {low}..{high};")
516            else:
517                match = re.match(r'#define\s+(\w+)\s+\(?0x(\w+)\)?', line.strip())
518                if match:
519                    name, value = match.groups()
520                    if name in processed_constants:
521                        processed_constants[name] += 1
522                        name += f"_{processed_constants[name]}"
523                    else:
524                        processed_constants[name] = 0
525                    name = name.replace(file_prefix, '')
526                    rust_items.append(f"pub const {name}: u32 = 0x{value};")
527                else:
528                    match = re.match(r'#define\s+(\w+)\s+\(?(\d+)\)?', line.strip())
529                    if match:
530                        name, value = match.groups()
531                        if name in processed_constants:
532                            processed_constants[name] += 1
533                            name += f"_{processed_constants[name]}"
534                        else:
535                            processed_constants[name] = 0
536                        name = name.replace(file_prefix, '')
537                        rust_items.append(f"pub const {name}: u32 = {value};")
538
539    return '\n'.join(rust_items)
540
541def main():
542    parser = argparse.ArgumentParser()
543    parser.add_argument('--out-h', required=False, help='Output C header.')
544    parser.add_argument('--out-c', required=False, help='Output C file.')
545    parser.add_argument('--out-rs', required=False, help='Output Rust file.')
546    parser.add_argument('--out-rs-mthd', required=False,
547                        help='Output Rust file for methods.')
548    parser.add_argument('--in-h',
549                        help='Input class header file.',
550                        required=True)
551    args = parser.parse_args()
552
553    clheader = os.path.basename(args.in_h)
554    nvcl = clheader
555    nvcl = nvcl.removeprefix("cl")
556    nvcl = nvcl.removesuffix(".h")
557    nvcl = nvcl.upper()
558    nvcl = "NV" + nvcl
559
560    with open(args.in_h, 'r', encoding='utf-8') as f:
561        (version, mthddict) = parse_header(nvcl, f)
562
563    environment = {
564        'clheader': clheader,
565        'nvcl': nvcl,
566        'version': version,
567        'mthddict': mthddict,
568        'rs_field_name': rs_field_name,
569        'to_camel': to_camel,
570        'bs': '\\'
571    }
572
573    try:
574        if args.out_h is not None:
575            environment['header'] = os.path.basename(args.out_h)
576            with open(args.out_h, 'w', encoding='utf-8') as f:
577                f.write(TEMPLATE_H.render(**environment))
578        if args.out_c is not None:
579            with open(args.out_c, 'w', encoding='utf-8') as f:
580                f.write(TEMPLATE_C.render(**environment))
581        if args.out_rs is not None:
582            with open(args.out_rs, 'w', encoding='utf-8') as f:
583                f.write(TEMPLATE_RS.render(**environment))
584        if args.out_rs_mthd is not None:
585            with open(args.out_rs_mthd, 'w', encoding='utf-8') as f:
586                f.write("#![allow(non_camel_case_types)]\n")
587                f.write("#![allow(non_snake_case)]\n")
588                f.write("#![allow(non_upper_case_globals)]\n\n")
589                f.write("use std::ops::Range;\n")
590                f.write("use crate::Mthd;\n")
591                f.write("use crate::ArrayMthd;\n")
592                f.write("\n")
593                f.write(convert_to_rust_constants(args.in_h))
594                f.write("\n")
595                f.write(TEMPLATE_RS_MTHD.render(**environment))
596
597    except Exception:
598        # In the event there's an error, this imports some helpers from mako
599        # to print a useful stack trace and prints it, then exits with
600        # status 1, if python is run with debug; otherwise it just raises
601        # the exception
602        import sys
603        from mako import exceptions
604        print(exceptions.text_error_template().render(), file=sys.stderr)
605        sys.exit(1)
606
607if __name__ == '__main__':
608    main()
609