• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#encoding=utf-8
2
3# Copyright 2016 Intel Corporation
4# Copyright 2016 Broadcom
5# Copyright 2020 Collabora, Ltd.
6# SPDX-License-Identifier: MIT
7
8import xml.parsers.expat
9import sys
10import operator
11import math
12import platform
13from functools import reduce
14
15global_prefix = "agx"
16
17def to_alphanum(name):
18    substitutions = {
19        ' ': '_',
20        '/': '_',
21        '[': '',
22        ']': '',
23        '(': '',
24        ')': '',
25        '-': '_',
26        ':': '',
27        '.': '',
28        ',': '',
29        '=': '',
30        '>': '',
31        '#': '',
32        '&': '',
33        '*': '',
34        '"': '',
35        '+': '',
36        '\'': '',
37        '?': '',
38    }
39
40    for i, j in substitutions.items():
41        name = name.replace(i, j)
42
43    return name
44
45def safe_name(name):
46    name = to_alphanum(name)
47    if not name[0].isalpha():
48        name = '_' + name
49
50    return name
51
52def prefixed_upper_name(prefix, name):
53    if prefix:
54        name = prefix + "_" + name
55    return safe_name(name).upper()
56
57def enum_name(name):
58    return f"{global_prefix}_{safe_name(name)}".lower()
59
60MODIFIERS = ["shr", "minus", "align", "log2", "groups"]
61
62def parse_modifier(modifier):
63    if modifier is None:
64        return None
65
66    for mod in MODIFIERS:
67        if modifier[0:len(mod)] == mod:
68            if mod == "log2":
69                assert(len(mod) == len(modifier))
70                return [mod]
71
72            if modifier[len(mod)] == '(' and modifier[-1] == ')':
73                ret = [mod, int(modifier[(len(mod) + 1):-1])]
74                if ret[0] == 'align':
75                    align = ret[1]
76                    # Make sure the alignment is a power of 2
77                    assert(align > 0 and not(align & (align - 1)));
78
79                return ret
80
81    print("Invalid modifier")
82    assert(False)
83
84class Field(object):
85    def __init__(self, parser, attrs):
86        self.parser = parser
87        if "name" in attrs:
88            self.name = safe_name(attrs["name"]).lower()
89            self.human_name = attrs["name"]
90
91        if ":" in str(attrs["start"]):
92            (word, bit) = attrs["start"].split(":")
93            self.start = (int(word) * 32) + int(bit)
94        else:
95            self.start = int(attrs["start"])
96
97        self.end = self.start + int(attrs["size"]) - 1
98        self.type = attrs["type"]
99
100        if self.type == 'bool' and self.start != self.end:
101            print(f"#error Field {self.name} has bool type but more than one bit of size");
102
103        if "prefix" in attrs:
104            self.prefix = safe_name(attrs["prefix"]).upper()
105        else:
106            self.prefix = None
107
108        self.modifier = parse_modifier(attrs.get("modifier"))
109        self.exact = attrs.get("exact")
110        self.default = None
111
112        if self.exact is not None:
113            self.default = self.exact
114        elif self.modifier is not None:
115            # Set the default value to encode to zero
116            mod = self.modifier
117            if mod[0] == 'log2':
118                self.default = 1
119            elif mod[0] == 'minus':
120                self.default = mod[1]
121            elif mod[0] == 'groups':
122                # The zero encoding means "all"
123                self.default = (1 << int(attrs["size"])) * mod[1]
124            elif mod[0] in ['shr', 'align']:
125                # Zero encodes to zero
126                pass
127            else:
128                assert(0)
129
130        # Map enum values
131        if self.type in self.parser.enums and self.default is not None:
132            self.default = safe_name(f'{global_prefix}_{self.type}_{self.default}').upper()
133
134
135    def emit_template_struct(self, dim):
136        if self.type == 'address':
137            type = 'uint64_t'
138        elif self.type == 'bool':
139            type = 'bool'
140        elif self.type in ['float', 'half', 'lod']:
141            type = 'float'
142        elif self.type in ['uint', 'hex'] and self.end - self.start > 32:
143            type = 'uint64_t'
144        elif self.type == 'int':
145            type = 'int32_t'
146        elif self.type in ['uint', 'hex']:
147            type = 'uint32_t'
148        elif self.type in self.parser.structs:
149            type = 'struct ' + self.parser.gen_prefix(safe_name(self.type.upper()))
150        elif self.type in self.parser.enums:
151            type = 'enum ' + enum_name(self.type)
152        else:
153            print(f"#error unhandled type: {self.type}")
154            type = "uint32_t"
155
156        print("   %-36s %s%s;" % (type, self.name, dim))
157
158        for value in self.values:
159            name = prefixed_upper_name(self.prefix, value.name)
160            print("#define %-40s %d" % (name, value.value))
161
162    def overlaps(self, field):
163        return self != field and max(self.start, field.start) <= min(self.end, field.end)
164
165class Group(object):
166    def __init__(self, parser, parent, start, count, label):
167        self.parser = parser
168        self.parent = parent
169        self.start = start
170        self.count = count
171        self.label = label
172        self.size = 0
173        self.length = 0
174        self.fields = []
175
176    def get_length(self):
177        # Determine number of bytes in this group.
178        calculated = max(field.end // 8 for field in self.fields) + 1 if len(self.fields) > 0 else 0
179        if self.length > 0:
180            assert(self.length >= calculated)
181        else:
182            self.length = calculated
183        return self.length
184
185
186    def emit_template_struct(self, dim):
187        if self.count == 0:
188            print("   /* variable length fields follow */")
189        else:
190            if self.count > 1:
191                dim = "%s[%d]" % (dim, self.count)
192
193            any_fields = False
194            for field in self.fields:
195                if not field.exact:
196                    field.emit_template_struct(dim)
197                    any_fields = True
198
199            if not any_fields:
200                print("   int dummy;")
201
202    class Word:
203        def __init__(self):
204            self.size = 32
205            self.contributors = []
206
207    class FieldRef:
208        def __init__(self, field, path, start, end):
209            self.field = field
210            self.path = path
211            self.start = start
212            self.end = end
213
214    def collect_fields(self, fields, offset, path, all_fields):
215        for field in fields:
216            field_path = f'{path}{field.name}'
217            field_offset = offset + field.start
218
219            if field.type in self.parser.structs:
220                sub_struct = self.parser.structs[field.type]
221                self.collect_fields(sub_struct.fields, field_offset, field_path + '.', all_fields)
222                continue
223
224            start = field_offset
225            end = offset + field.end
226            all_fields.append(self.FieldRef(field, field_path, start, end))
227
228    def collect_words(self, fields, offset, path, words):
229        for field in fields:
230            field_path = f'{path}{field.name}'
231            start = offset + field.start
232
233            if field.type in self.parser.structs:
234                sub_fields = self.parser.structs[field.type].fields
235                self.collect_words(sub_fields, start, field_path + '.', words)
236                continue
237
238            end = offset + field.end
239            contributor = self.FieldRef(field, field_path, start, end)
240            first_word = contributor.start // 32
241            last_word = contributor.end // 32
242            for b in range(first_word, last_word + 1):
243                if not b in words:
244                    words[b] = self.Word()
245                words[b].contributors.append(contributor)
246
247    def emit_pack_function(self):
248        self.get_length()
249
250        words = {}
251        self.collect_words(self.fields, 0, '', words)
252
253        # Validate the modifier is lossless
254        for field in self.fields:
255            if field.modifier is None:
256                continue
257
258            if field.modifier[0] == "shr":
259                shift = field.modifier[1]
260                mask = hex((1 << shift) - 1)
261                print(f"   assert((values->{field.name} & {mask}) == 0);")
262            elif field.modifier[0] == "minus":
263                print(f"   assert(values->{field.name} >= {field.modifier[1]});")
264            elif field.modifier[0] == "log2":
265                print(f"   assert(IS_POT_NONZERO(values->{field.name}));")
266
267        for index in range(math.ceil(self.length / 4)):
268            # Handle MBZ words
269            if not index in words:
270                print("   cl[%2d] = 0;" % index)
271                continue
272
273            word = words[index]
274
275            word_start = index * 32
276
277            v = None
278            prefix = "   cl[%2d] =" % index
279
280            lines = []
281
282            for contributor in word.contributors:
283                field = contributor.field
284                name = field.name
285                start = contributor.start
286                end = contributor.end
287                contrib_word_start = (start // 32) * 32
288                start -= contrib_word_start
289                end -= contrib_word_start
290
291                value = f"values->{contributor.path}"
292                if field.exact:
293                    value = field.default
294
295                # These types all use util_bitpack_uint
296                pack_as_uint = field.type in ["uint", "hex", "address", "bool"]
297                pack_as_uint |= field.type in self.parser.enums
298                start_adjusted = start
299                value_unshifted = None
300
301                if field.modifier is not None:
302                    if field.modifier[0] == "shr":
303                        if pack_as_uint and start >= field.modifier[1]:
304                            # For uint, we fast path.  If we do `(a >> 2) << 2`,
305                            # clang will generate a mask in release builds, even
306                            # though we know we're aligned. So don't generate
307                            # that to avoid the masking.
308                            start_adjusted = start - field.modifier[1]
309                        else:
310                            value = f"{value} >> {field.modifier[1]}"
311                    elif field.modifier[0] == "minus":
312                        value = f"{value} - {field.modifier[1]}"
313                    elif field.modifier[0] == "align":
314                        value = f"ALIGN_POT({value}, {field.modifier[1]})"
315                    elif field.modifier[0] == "log2":
316                        value = f"util_logbase2({value})"
317                    elif field.modifier[0] == "groups":
318                        value = "__gen_to_groups({}, {}, {})".format(value,
319                                field.modifier[1], end - start + 1)
320
321                if pack_as_uint:
322                    bits = (end - start_adjusted + 1)
323                    if bits < 64 and not field.exact:
324                        # Add some nicer error checking
325                        label = f"{self.label}::{name}"
326                        bound = hex(1 << bits)
327                        print(f"   agx_genxml_validate_bounds(\"{label}\", {value}, {bound}ull);")
328
329                    s = f"util_bitpack_uint({value}, {start_adjusted}, {end})"
330                elif field.type == "int":
331                    s = "util_bitpack_sint(%s, %d, %d)" % \
332                        (value, start, end)
333                elif field.type == "float":
334                    assert(start == 0 and end == 31)
335                    s = f"util_bitpack_float({value})"
336                elif field.type == "half":
337                    assert(start == 0 and end == 15)
338                    s = f"_mesa_float_to_half({value})"
339                elif field.type == "lod":
340                    assert(end - start + 1 == 10)
341                    s = "__gen_pack_lod(%s, %d, %d)" % (value, start, end)
342                else:
343                    s = f"#error unhandled field {contributor.path}, type {field.type}"
344
345                if not s == None:
346                    shift = word_start - contrib_word_start
347                    if shift:
348                        s = "%s >> %d" % (s, shift)
349
350                    if contributor == word.contributors[-1]:
351                        lines.append(f"{prefix} {s};")
352                    else:
353                        lines.append(f"{prefix} {s} |")
354                    prefix = "           "
355
356            for ln in lines:
357                print(ln)
358
359            continue
360
361    # Given a field (start, end) contained in word `index`, generate the 32-bit
362    # mask of present bits relative to the word
363    def mask_for_word(self, index, start, end):
364        field_word_start = index * 32
365        start -= field_word_start
366        end -= field_word_start
367        # Cap multiword at one word
368        start = max(start, 0)
369        end = min(end, 32 - 1)
370        count = (end - start + 1)
371        return (((1 << count) - 1) << start)
372
373    def emit_unpack_function(self):
374        # First, verify there is no garbage in unused bits
375        words = {}
376        self.collect_words(self.fields, 0, '', words)
377        validation = []
378
379        for index in range(self.length // 4):
380            base = index * 32
381            word = words.get(index, self.Word())
382            masks = [self.mask_for_word(index, c.start, c.end) for c in word.contributors]
383            mask = reduce(lambda x,y: x | y, masks, 0)
384
385            ALL_ONES = 0xffffffff
386
387            if mask != ALL_ONES:
388                bad_mask = hex(mask ^ ALL_ONES)
389                validation.append(f'agx_genxml_validate_mask(fp, \"{self.label}\", cl, {index}, {bad_mask})')
390
391        fieldrefs = []
392        self.collect_fields(self.fields, 0, '', fieldrefs)
393        for fieldref in fieldrefs:
394            field = fieldref.field
395            convert = None
396
397            args = []
398            args.append('(CONST uint32_t *) cl')
399            args.append(str(fieldref.start))
400            args.append(str(fieldref.end))
401
402            if field.type in set(["uint", "address", "hex"]) | self.parser.enums:
403                convert = "__gen_unpack_uint"
404            elif field.type == "int":
405                convert = "__gen_unpack_sint"
406            elif field.type == "bool":
407                convert = "__gen_unpack_uint"
408            elif field.type == "float":
409                convert = "__gen_unpack_float"
410            elif field.type == "half":
411                convert = "__gen_unpack_half"
412            elif field.type == "lod":
413                convert = "__gen_unpack_lod"
414            else:
415                s = f"/* unhandled field {field.name}, type {field.type} */\n"
416
417            suffix = ""
418            prefix = ""
419            if field.modifier:
420                if field.modifier[0] == "minus":
421                    suffix = f" + {field.modifier[1]}"
422                elif field.modifier[0] == "shr":
423                    suffix = f" << {field.modifier[1]}"
424                if field.modifier[0] == "log2":
425                    prefix = "1 << "
426                elif field.modifier[0] == "groups":
427                    prefix = "__gen_from_groups("
428                    suffix = ", {}, {})".format(field.modifier[1],
429                                                fieldref.end - fieldref.start + 1)
430
431            if field.type in self.parser.enums and not field.exact:
432                prefix = f"(enum {enum_name(field.type)}) {prefix}"
433
434            decoded = f"{prefix}{convert}({', '.join(args)}){suffix}"
435
436            if field.exact:
437                name = self.label
438                validation.append(f'agx_genxml_validate_exact(fp, \"{name}\", {decoded}, {field.default})')
439            else:
440                print(f'   values->{fieldref.path} = {decoded};')
441
442            if field.modifier and field.modifier[0] == "align":
443                assert(not field.exact)
444                mask = hex(field.modifier[1] - 1)
445                print(f'   assert(!(values->{fieldref.path} & {mask}));')
446
447        if len(validation) > 1:
448            print('   bool valid = true;')
449            for v in validation:
450                print(f'   valid &= {v};')
451            print("   return valid;")
452        elif len(validation) == 1:
453            print(f"   return {validation[0]};")
454        else:
455            print("   return true;")
456
457    def emit_print_function(self):
458        for field in self.fields:
459            convert = None
460            name, val = field.human_name, f'values->{field.name}'
461
462            if field.exact:
463                continue
464
465            if field.type in self.parser.structs:
466                pack_name = self.parser.gen_prefix(safe_name(field.type)).upper()
467                print(f'   fprintf(fp, "%*s{field.human_name}:\\n", indent, "");')
468                print(f"   {pack_name}_print(fp, &values->{field.name}, indent + 2);")
469            elif field.type == "address":
470                # TODO resolve to name
471                print(f'   fprintf(fp, "%*s{name}: 0x%" PRIx64 "\\n", indent, "", {val});')
472            elif field.type in self.parser.enums:
473                print(f'   if ({enum_name(field.type)}_as_str({val}))')
474                print(f'     fprintf(fp, "%*s{name}: %s\\n", indent, "", {enum_name(field.type)}_as_str({val}));')
475                print(f'   else')
476                print(f'     fprintf(fp, "%*s{name}: unknown %X (XXX)\\n", indent, "", {val});')
477            elif field.type == "int":
478                print(f'   fprintf(fp, "%*s{name}: %d\\n", indent, "", {val});')
479            elif field.type == "bool":
480                print(f'   fprintf(fp, "%*s{name}: %s\\n", indent, "", {val} ? "true" : "false");')
481            elif field.type in ["float", "lod", "half"]:
482                print(f'   fprintf(fp, "%*s{name}: %f\\n", indent, "", {val});')
483            elif field.type in ["uint", "hex"] and (field.end - field.start) >= 32:
484                print(f'   fprintf(fp, "%*s{name}: 0x%" PRIx64 "\\n", indent, "", {val});')
485            elif field.type == "hex":
486                print(f'   fprintf(fp, "%*s{name}: 0x%" PRIx32 "\\n", indent, "", {val});')
487            else:
488                print(f'   fprintf(fp, "%*s{name}: %u\\n", indent, "", {val});')
489
490class Value(object):
491    def __init__(self, attrs):
492        self.name = attrs["name"]
493        self.value = int(attrs["value"], 0)
494
495class Parser(object):
496    def __init__(self):
497        self.parser = xml.parsers.expat.ParserCreate()
498        self.parser.StartElementHandler = self.start_element
499        self.parser.EndElementHandler = self.end_element
500        self.os = platform.system().lower()
501
502        self.struct = None
503        self.structs = {}
504        # Set of enum names we've seen.
505        self.enums = set()
506
507    def gen_prefix(self, name):
508        return f'{global_prefix.upper()}_{name}'
509
510    def start_element(self, name, attrs):
511        if "os" in attrs and attrs["os"] != self.os:
512            return
513
514        if name == "genxml":
515            print(pack_header)
516        elif name == "struct":
517            name = attrs["name"]
518            object_name = self.gen_prefix(safe_name(name.upper()))
519            self.struct = object_name
520
521            self.group = Group(self, None, 0, 1, name)
522            if "size" in attrs:
523                self.group.length = int(attrs["size"])
524            self.group.align = int(attrs["align"]) if "align" in attrs else None
525            self.structs[attrs["name"]] = self.group
526        elif name == "field" and self.group is not None:
527            self.group.fields.append(Field(self, attrs))
528            self.values = []
529        elif name == "enum":
530            self.values = []
531            self.enum = safe_name(attrs["name"])
532            self.enums.add(attrs["name"])
533            if "prefix" in attrs:
534                self.prefix = attrs["prefix"]
535            else:
536                self.prefix= None
537        elif name == "value":
538            self.values.append(Value(attrs))
539
540    def end_element(self, name):
541        if name == "struct":
542            if self.struct is not None:
543                self.emit_struct()
544                self.struct = None
545
546            self.group = None
547        elif name  == "field" and self.group is not None:
548            self.group.fields[-1].values = self.values
549        elif name  == "enum":
550            self.emit_enum()
551            self.enum = None
552
553    def emit_header(self, name):
554        default_fields = []
555        for field in self.group.fields:
556            if not type(field) is Field or field.exact:
557                continue
558            if field.default is not None:
559                default_fields.append(f"   .{field.name} = {field.default}")
560            elif field.type in self.structs:
561                default_fields.append(f"   .{field.name} = {{ {self.gen_prefix(safe_name(field.type.upper()))}_header }}")
562
563        if default_fields:
564            print('#define %-40s\\' % (name + '_header'))
565            print(",  \\\n".join(default_fields))
566        else:
567            print(f'#define {name}_header 0')
568        print('')
569
570    def emit_template_struct(self, name, group):
571        print("struct %s {" % name)
572        group.emit_template_struct("")
573        print("};\n")
574
575    def emit_pack_function(self, name, group):
576        print("static inline void\n%s_pack(GLOBAL uint32_t * restrict cl,\n%sconst struct %s * restrict values)\n{" %
577              (name, ' ' * (len(name) + 6), name))
578
579        group.emit_pack_function()
580
581        print("}\n")
582
583        print(f"#define {name + '_LENGTH'} {self.group.length}")
584        if self.group.align != None:
585            print(f"#define {name + '_ALIGN'} {self.group.align}")
586
587        # round up to handle 6 half-word USC structures
588        words = (self.group.length + 4 - 1) // 4
589        print(f'struct {name.lower()}_packed {{ uint32_t opaque[{words}];}};')
590
591    def emit_unpack_function(self, name, group):
592        print("static inline bool")
593        print("%s_unpack(FILE *fp, CONST uint8_t * restrict cl,\n%sstruct %s * restrict values)\n{" %
594              (name.upper(), ' ' * (len(name) + 8), name))
595
596        group.emit_unpack_function()
597
598        print("}\n")
599
600    def emit_print_function(self, name, group):
601        print("#ifndef __OPENCL_VERSION__")
602        print("static inline void")
603        print(f"{name.upper()}_print(FILE *fp, const struct {name} * values, unsigned indent)\n{{")
604
605        group.emit_print_function()
606
607        print("}")
608        print("#endif\n")
609
610    def emit_struct(self):
611        name = self.struct
612
613        self.emit_template_struct(self.struct, self.group)
614        self.emit_header(name)
615        self.emit_pack_function(self.struct, self.group)
616        self.emit_unpack_function(self.struct, self.group)
617        self.emit_print_function(self.struct, self.group)
618
619    def enum_prefix(self, name):
620        return
621
622    def emit_enum(self):
623        e_name = enum_name(self.enum)
624        prefix = e_name if self.enum != 'Format' else global_prefix
625        print(f'enum {e_name} {{')
626
627        for value in self.values:
628            name = f'{prefix}_{value.name}'
629            name = safe_name(name).upper()
630            print(f'   {name} = {value.value},')
631        print('};\n')
632
633        print("#ifndef __OPENCL_VERSION__")
634        print("static inline const char *")
635        print(f"{e_name.lower()}_as_str(enum {e_name} imm)\n{{")
636        print("    switch (imm) {")
637        for value in self.values:
638            name = f'{prefix}_{value.name}'
639            name = safe_name(name).upper()
640            print(f'    case {name}: return "{value.name}";')
641        print('    default: return NULL;')
642        print("    }")
643        print("}")
644        print("#endif\n")
645
646    def parse(self, filename):
647        file = open(filename, "rb")
648        self.parser.ParseFile(file)
649        file.close()
650
651if len(sys.argv) < 3:
652    print("Missing input files file specified")
653    sys.exit(1)
654
655input_file = sys.argv[1]
656pack_header = open(sys.argv[2]).read()
657
658p = Parser()
659p.parse(input_file)
660