• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright © 2024 Imagination Technologies Ltd.
2# SPDX-License-Identifier: MIT
3
4from enum import Enum, auto
5
6_ = None
7VARIABLE = ~0
8
9prefix = 'pco'
10
11def val_fits_in_bits(val, num_bits):
12   return val < pow(2, num_bits)
13
14class BaseType(Enum):
15   bool = auto()
16   uint = auto()
17   enum = auto()
18
19class Type(object):
20   def __init__(self, name, tname, base_type, num_bits, dec_bits, check, encode, nzdefault, print_early, unset, enum):
21      self.name = name
22      self.tname = tname
23      self.base_type = base_type
24      self.num_bits = num_bits
25      self.dec_bits = dec_bits
26      self.check = check
27      self.encode = encode
28      self.nzdefault = nzdefault
29      self.print_early = print_early
30      self.unset = unset
31      self.enum = enum
32
33types = {}
34
35def type(name, base_type, num_bits=None, dec_bits=None, check=None, encode=None, nzdefault=None, print_early=False, unset=False, enum=None):
36   assert name not in types.keys(), f'Duplicate type "{name}".'
37
38   if base_type == BaseType.bool:
39      _name = 'bool'
40   elif base_type == BaseType.uint:
41      _name = 'unsigned'
42   elif base_type == BaseType.enum:
43      _name = f'enum {prefix}_{name}'
44   else:
45      assert False, f'Invalid base type for type {name}.'
46
47   t = Type(_name, name, base_type, num_bits, dec_bits, check, encode, nzdefault, print_early, unset, enum)
48   types[name] = t
49   return t
50
51# Enum types.
52class EnumType(object):
53   def __init__(self, name, ename, elems, valid, unique_count, is_bitset, parent):
54      self.name = name
55      self.ename = ename
56      self.elems = elems
57      self.valid = valid
58      self.unique_count = unique_count
59      self.is_bitset = is_bitset
60      self.parent = parent
61
62class EnumElem(object):
63   def __init__(self, cname, value, string):
64      self.cname = cname
65      self.value = value
66      self.string = string
67
68enums = {}
69def enum_type(name, elems, is_bitset=False, num_bits=None, *args, **kwargs):
70   assert name not in enums.keys(), f'Duplicate enum "{name}".'
71
72   _elems = {}
73   _valid_vals = set()
74   _valid_valmask = 0
75   next_value = 0
76   for e in elems:
77      if isinstance(e, str):
78         elem = e
79         value = (1 << next_value) if is_bitset else next_value
80         next_value += 1
81         string = elem
82      else:
83         assert isinstance(e, tuple) and len(e) > 1
84         elem = e[0]
85         if isinstance(e[1], str) and len(e) == 2:
86            value = (1 << next_value) if is_bitset else next_value
87            next_value += 1
88            string = e[1]
89         elif isinstance(e[1], int):
90            value = e[1]
91            string = e[2] if len(e) == 3 else elem
92         else:
93            assert False, f'Invalid defintion for element "{elem}" in elem "{name}".'
94
95      assert isinstance(elem, str) and isinstance(value, int) and isinstance(string, str)
96
97      assert not num_bits or val_fits_in_bits(value, num_bits), f'Element "{elem}" in elem "{name}" with value "{value}" does not fit into {num_bits} bits.'
98      # Collect valid values, ensure that elements with repeated values only have one string set.
99      if is_bitset:
100         if (_valid_valmask & value) != 0:
101            string = None
102         _valid_valmask |= value
103      else:
104         if value in _valid_vals:
105            string = None
106         _valid_vals.add(value)
107
108      assert elem not in _elems.keys(), f'Duplicate element "{elem}" in enum "".'
109      cname = f'{prefix}_{name}_{elem}'.upper()
110      _elems[elem] = EnumElem(cname, value, string)
111
112   _name = f'{prefix}_{name}'
113   _valid = _valid_valmask if is_bitset else _valid_vals
114   _unique_count = bin(_valid_valmask).count('1') if is_bitset else len(_valid_vals)
115   enum = EnumType(_name, name, _elems, _valid, _unique_count, is_bitset, parent=None)
116   enums[name] = enum
117
118   return type(name, BaseType.enum, num_bits, *args, **kwargs, enum=enum)
119
120def enum_subtype(name, parent, num_bits):
121   assert name not in enums.keys(), f'Duplicate enum "{name}".'
122
123   assert parent.enum is not None
124   assert not parent.enum.is_bitset
125   assert parent.num_bits is not None and parent.num_bits > num_bits
126
127   _name = f'{prefix}_{name}'
128   # Validation of subtype - values that will fit in the smaller bit size.
129   _valid = {val for val in parent.enum.valid if val_fits_in_bits(val, num_bits)}
130   enum = EnumType(_name, name, None, _valid, None, False, parent)
131   enums[name] = enum
132   return type(name, BaseType.enum, num_bits, enum=enum)
133
134# Type specializations.
135
136field_types = {}
137field_enum_types = {}
138def field_type(name, *args, **kwargs):
139   assert name not in field_types.keys(), f'Duplicate field type "{name}".'
140   t = type(name, *args, **kwargs)
141   field_types[name] = t
142   return t
143
144def field_enum_type(name, *args, **kwargs):
145   assert name not in field_types.keys() and name not in field_enum_types.keys(), f'Duplicate field enum type "{name}".'
146   t = enum_type(name, *args, **kwargs)
147   field_types[name] = t
148   field_enum_types[name] = enums[name]
149   return t
150
151def field_enum_subtype(name, *args, **kwargs):
152   assert name not in field_types.keys() and name not in field_enum_types.keys(), f'Duplicate field enum (sub)type "{name}".'
153   t = enum_subtype(name, *args, **kwargs)
154   field_types[name] = t
155   field_enum_types[name] = enums[name]
156   return t
157
158class OpMod(object):
159   def __init__(self, t, cname, ctype):
160      self.t = t
161      self.cname = cname
162      self.ctype = ctype
163
164op_mods = {}
165op_mod_enums = {}
166def op_mod(name, *args, **kwargs):
167   assert name not in op_mods.keys(), f'Duplicate op mod "{name}".'
168   t = type(name, *args, **kwargs)
169   cname = f'{prefix}_op_mod_{name}'.upper()
170   ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper()
171   om = op_mods[name] = OpMod(t, cname, ctype)
172   assert len(op_mods) <= 64, f'Too many op mods ({len(op_mods)})!'
173   return om
174
175def op_mod_enum(name, *args, **kwargs):
176   assert name not in op_mods.keys() and name not in op_mod_enums.keys(), f'Duplicate op mod enum "{name}".'
177   t = enum_type(name, *args, **kwargs)
178   cname = f'{prefix}_op_mod_{name}'.upper()
179   ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper()
180   om = op_mods[name] = OpMod(t, cname, ctype)
181   op_mod_enums[name] = enums[name]
182   assert len(op_mods) <= 64, f'Too many op mods ({len(op_mods)})!'
183   return om
184
185class RefMod(object):
186   def __init__(self, t, cname, ctype):
187      self.t = t
188      self.cname = cname
189      self.ctype = ctype
190
191ref_mods = {}
192ref_mod_enums = {}
193def ref_mod(name, *args, **kwargs):
194   assert name not in ref_mods.keys(), f'Duplicate ref mod "{name}".'
195   t = type(name, *args, **kwargs)
196   cname = f'{prefix}_ref_mod_{name}'.upper()
197   ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper()
198   rm = ref_mods[name] = RefMod(t, cname, ctype)
199   assert len(ref_mods) <= 64, f'Too many ref mods ({len(ref_mods)})!'
200   return rm
201
202def ref_mod_enum(name, *args, **kwargs):
203   assert name not in ref_mods.keys() and name not in ref_mod_enums.keys(), f'Duplicate ref mod enum "{name}".'
204   t = enum_type(name, *args, **kwargs)
205   cname = f'{prefix}_ref_mod_{name}'.upper()
206   ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper()
207   rm = ref_mods[name] = RefMod(t, cname, ctype)
208   ref_mod_enums[name] = enums[name]
209   assert len(ref_mods) <= 64, f'Too many ref mods ({len(ref_mods)})!'
210   return rm
211
212# Bit encoding definition helpers.
213
214class BitPiece(object):
215   def __init__(self, name, byte, hi_bit, lo_bit, num_bits):
216      self.name = name
217      self.byte = byte
218      self.hi_bit = hi_bit
219      self.lo_bit = lo_bit
220      self.num_bits = num_bits
221
222def bit_piece(name, byte, bit_range):
223   assert bit_range.count(':') <= 1, f'Invalid bit range specification in bit piece {name}.'
224   is_one_bit = not bit_range.count(':')
225
226   split_range = [bit_range, bit_range] if is_one_bit else bit_range.split(':', 1)
227   (hi_bit, lo_bit) = list(map(int, split_range))
228   assert hi_bit < 8 and hi_bit >= 0 and lo_bit < 8 and lo_bit >= 0 and hi_bit >= lo_bit
229
230   _num_bits = hi_bit - lo_bit + 1
231   return BitPiece(name, byte, hi_bit, lo_bit, _num_bits)
232
233class BitField(object):
234   def __init__(self, name, cname, field_type, pieces, reserved, validate, encoding, encoded_bits):
235      self.name = name
236      self.cname = cname
237      self.field_type = field_type
238      self.pieces = pieces
239      self.reserved = reserved
240      self.validate = validate
241      self.encoding = encoding
242      self.encoded_bits = encoded_bits
243
244class Encoding(object):
245   def __init__(self, clear, set):
246      self.clear = clear
247      self.set = set
248
249def bit_field(bit_set_name, name, bit_set_pieces, field_type, pieces, reserved=None):
250   _pieces = [bit_set_pieces[p] for p in pieces]
251
252   total_bits = sum([p.num_bits for p in _pieces])
253   assert total_bits == field_type.num_bits, f'Expected {field_type.num_bits}, got {total_bits} in bit field {name}.'
254
255   if reserved is not None:
256      assert val_fits_in_bits(reserved, total_bits), f'Reserved value for bit field {name} is too large.'
257
258   cname = f'{bit_set_name}_{name}'.upper()
259   if field_type.base_type == BaseType.enum:
260      validate = f'{prefix}_{field_type.enum.ename}_valid({{}})'
261   else:
262      validate = f'{{}} < (1ULL << {field_type.num_bits})'
263
264   encoding = []
265   bits_consumed = 0
266   for i, piece in enumerate(reversed(_pieces)):
267      enc_clear = f'{{}}[{piece.byte}] &= {hex((((1 << piece.num_bits) - 1) << piece.lo_bit) ^ 0xff)}'
268
269      enc_set = f'{{}}[{piece.byte}] |= ('
270      enc_set += f'({{}} >> {bits_consumed})' if bits_consumed > 0 else '{}'
271      enc_set += f' & {hex((1 << piece.num_bits) - 1)})'
272      enc_set += f' << {piece.lo_bit}' if piece.lo_bit > 0 else ''
273      encoding.append(Encoding(enc_clear, enc_set))
274
275      bits_consumed += piece.num_bits
276
277   return BitField(name, cname, field_type, _pieces, reserved, validate, encoding, bits_consumed)
278
279class BitSet(object):
280   def __init__(self, name, bsname, pieces, fields):
281      self.name = name
282      self.bsname = bsname
283      self.pieces = pieces
284      self.fields = fields
285      self.bit_structs = {}
286      self.variants = []
287
288bit_sets = {}
289
290def bit_set(name, pieces, fields):
291   assert name not in bit_sets.keys(), f'Duplicate bit set "{name}".'
292   _name = f'{prefix}_{name}'
293
294   _pieces = {}
295   for (piece, spec) in pieces:
296      assert piece not in _pieces.keys(), f'Duplicate bit piece "{piece}" in bit set "{name}".'
297      _pieces[piece] = bit_piece(piece, *spec)
298
299   _fields = {}
300   for (field, spec) in fields:
301      assert field not in _fields.keys(), f'Duplicate bit field "{field}" in bit set "{name}".'
302      _fields[field] = bit_field(_name, field, _pieces, *spec)
303
304   bs = BitSet(_name, name, _pieces, _fields)
305   bit_sets[name] = bs
306   return bs
307
308class BitStruct(object):
309   def __init__(self, name, bsname, struct_fields, encode_fields, num_bytes, data, bit_set):
310      self.name = name
311      self.bsname = bsname
312      self.struct_fields = struct_fields
313      self.encode_fields = encode_fields
314      self.num_bytes = num_bytes
315      self.data = data
316      self.bit_set = bit_set
317
318class StructField(object):
319   def __init__(self, type, field, bits):
320      self.type = type
321      self.field = field
322      self.bits = bits
323
324class EncodeField(object):
325   def __init__(self, name, value):
326      self.name = name
327      self.value = value
328
329class Variant(object):
330   def __init__(self, cname, bytes):
331      self.cname = cname
332      self.bytes = bytes
333
334def bit_struct(name, bit_set, field_mappings, data=None):
335   assert name not in bit_set.bit_structs.keys(), f'Duplicate bit struct "{name}" in bit set "{bit_set.name}".'
336
337   struct_fields = {}
338   encode_fields = []
339   all_pieces = []
340   total_bits = 0
341   for mapping in field_mappings:
342      if isinstance(mapping, str):
343         struct_field = mapping
344         _field = mapping
345         fixed_value = None
346      else:
347         assert isinstance(mapping, tuple)
348         struct_field, _field, *fixed_value = mapping
349         assert len(fixed_value) == 0 or len(fixed_value) == 1
350         fixed_value = None if len(fixed_value) == 0 else fixed_value[0]
351
352      assert struct_field not in struct_fields.keys(), f'Duplicate struct field "{struct_field}" in bit struct "{name}".'
353      assert _field in bit_set.fields.keys(), f'Field "{_field}" in mapping for struct field "{name}.{struct_field}" not defined in bit set "{bit_set.name}".'
354      field = bit_set.fields[_field]
355      field_type = field.field_type
356      is_enum = field_type.base_type == BaseType.enum
357
358      if fixed_value is not None:
359         assert field.reserved is None, f'Fixed value for field mapping "{struct_field}" using field "{_field}" cannot overwrite its reserved value.'
360
361         if is_enum and isinstance(fixed_value, str):
362            enum = field_type.enum
363            assert fixed_value in enum.elems.keys(), f'Fixed value for field mapping "{struct_field}" using field "{_field}" is not an element of enum {field_type.name}.'
364            fixed_value = enum.elems[fixed_value].cname.upper()
365         else:
366            if isinstance(fixed_value, bool):
367               fixed_value = int(fixed_value)
368
369            assert isinstance(fixed_value, int)
370            assert val_fits_in_bits(fixed_value, field_type.num_bits), f'Fixed value for field mapping "{struct_field}" using field "{_field}" is too large.'
371
372      all_pieces.extend([(piece.lo_bit + (8 * piece.byte), piece.hi_bit + (8 * piece.byte), piece.name) for piece in field.pieces])
373      total_bits += field_type.num_bits
374
375      # Describe how to encode the bit struct.
376      encode_field = f'{bit_set.name}_{_field}'.upper()
377      if fixed_value is not None:
378         encode_value = fixed_value
379      elif field.reserved is not None:
380         encode_value = field.reserved
381      else:
382         encode_value = f's.{struct_field}'
383      encode_fields.append(EncodeField(encode_field, encode_value))
384
385      # Describe settable fields.
386      if field.reserved is None and fixed_value is None:
387         # Use parent enum for struct fields.
388         if is_enum and field_type.enum.parent is not None:
389            field_type = field_type.enum.parent
390
391         struct_field_bits = field_type.dec_bits if field_type.dec_bits is not None else field_type.num_bits
392         struct_fields[struct_field] = StructField(field_type, struct_field, struct_field_bits)
393
394   # Check for overlapping pieces.
395   for p0 in all_pieces:
396      for p1 in all_pieces:
397         if p0 == p1:
398            continue
399         assert p0[1] < p1[0] or p0[0] > p1[1], f'Pieces "{p0[2]}" and "{p1[2]}" overlap in bit struct "{name}".'
400
401   # Check for byte-alignment.
402   assert (total_bits % 8) == 0, f'Bit struct "{name}" has a non-byte-aligned number of bits ({total_bits}).'
403
404   _name = f'{bit_set.name}_{name}'
405   total_bytes = total_bits // 8
406   bs = BitStruct(_name, name, struct_fields, encode_fields, total_bytes, data, bit_set)
407   bit_set.bit_structs[name] = bs
408   bit_set.variants.append(Variant(f'{bit_set.name}_{name}'.upper(), total_bytes))
409
410   return bs
411
412# Op definitions.
413class Op(object):
414   def __init__(self, name, cname, bname, op_type, op_mods, cop_mods, op_mod_map, num_dests, num_srcs, dest_mods, cdest_mods, src_mods, csrc_mods, has_target_cf_node, builder_params):
415      self.name = name
416      self.cname = cname
417      self.bname = bname
418      self.op_type = op_type
419      self.op_mods = op_mods
420      self.cop_mods = cop_mods
421      self.op_mod_map = op_mod_map
422      self.num_dests = num_dests
423      self.num_srcs = num_srcs
424      self.dest_mods = dest_mods
425      self.cdest_mods = cdest_mods
426      self.src_mods = src_mods
427      self.csrc_mods = csrc_mods
428      self.has_target_cf_node = has_target_cf_node
429      self.builder_params = builder_params
430
431ops = {}
432
433def op(name, op_type, op_mods, num_dests, num_srcs, dest_mods, src_mods, has_target_cf_node):
434   assert name not in ops.keys(), f'Duplicate op "{name}".'
435
436   _name = name.replace('.', '_')
437   cname = f'{prefix}_op_{_name}'
438   bname = f'{prefix}_{_name}'
439   cop_mods = 0 if not op_mods else ' | '.join([f'(1ULL << {op_mod.cname})' for op_mod in op_mods])
440   op_mod_map = {op_mod.cname: index + 1 for index, op_mod in enumerate(op_mods)}
441   cdest_mods = {i: 0 if not dest_mods else ' | '.join([f'(1ULL << {ref_mod.cname})' for ref_mod in destn_mods]) for i, destn_mods in enumerate(dest_mods)}
442   csrc_mods = {i: 0 if not src_mods else ' | '.join([f'(1ULL << {ref_mod.cname})' for ref_mod in srcn_mods]) for i, srcn_mods in enumerate(src_mods)}
443
444   builder_params = ['', '', '', '', '']
445
446   if op_type != 'hw_direct':
447      builder_params[0] = 'pco_builder *b'
448      builder_params[1] = 'b'
449      builder_params[4] = 'pco_cursor_func(b->cursor)'
450   else:
451      builder_params[0] = 'pco_func *func'
452      builder_params[1] = 'func'
453      builder_params[4] = 'func'
454
455   if has_target_cf_node:
456      builder_params[0] += ', pco_cf_node *target_cf_node'
457      builder_params[1] += ', target_cf_node'
458
459   if num_dests == VARIABLE:
460      builder_params[0] += f', unsigned num_dests, pco_ref *dest'
461      builder_params[1] += f', num_dests, dests'
462   else:
463      for d in range(num_dests):
464         builder_params[0] += f', pco_ref dest{d}'
465         builder_params[1] += f', dest{d}'
466
467   if num_srcs == VARIABLE:
468      builder_params[0] += f', unsigned num_srcs, pco_ref *src'
469      builder_params[1] += f', num_srcs, srcs'
470   else:
471      for s in range(num_srcs):
472         builder_params[0] += f', pco_ref src{s}'
473         builder_params[1] += f', src{s}'
474
475   if bool(op_mods):
476      builder_params[0] += f', struct {prefix}_{_name}_mods mods'
477      builder_params[2] = ', ...'
478      builder_params[3] = f', (struct {bname}_mods){{0, ##__VA_ARGS__}}'
479
480   op = Op(name, cname, bname, op_type, op_mods, cop_mods, op_mod_map, num_dests, num_srcs, dest_mods, cdest_mods, src_mods, csrc_mods, has_target_cf_node, builder_params)
481   ops[name] = op
482   return op
483
484def pseudo_op(name, op_mods=[], num_dests=0, num_srcs=0, dest_mods=[], src_mods=[], has_target_cf_node=False):
485   return op(name, 'pseudo', op_mods, num_dests, num_srcs, dest_mods, src_mods, has_target_cf_node)
486
487def hw_op(name, op_mods=[], num_dests=0, num_srcs=0, dest_mods=[], src_mods=[], has_target_cf_node=False):
488   return op(name, 'hw', op_mods, num_dests, num_srcs, dest_mods, src_mods, has_target_cf_node)
489
490def hw_direct_op(name, num_dests=0, num_srcs=0, has_target_cf_node=False):
491   return op(name, 'hw_direct', [], num_dests, num_srcs, [], [], has_target_cf_node)
492