1# Copyright © 2024 Imagination Technologies Ltd. 2# SPDX-License-Identifier: MIT 3 4from enum import Enum, auto 5 6_ = None 7VARIABLE = ~0 8 9prefix = 'pco' 10 11def val_fits_in_bits(val, num_bits): 12 return val < pow(2, num_bits) 13 14class BaseType(Enum): 15 bool = auto() 16 uint = auto() 17 enum = auto() 18 19class Type(object): 20 def __init__(self, name, tname, base_type, num_bits, dec_bits, check, encode, nzdefault, print_early, unset, enum): 21 self.name = name 22 self.tname = tname 23 self.base_type = base_type 24 self.num_bits = num_bits 25 self.dec_bits = dec_bits 26 self.check = check 27 self.encode = encode 28 self.nzdefault = nzdefault 29 self.print_early = print_early 30 self.unset = unset 31 self.enum = enum 32 33types = {} 34 35def type(name, base_type, num_bits=None, dec_bits=None, check=None, encode=None, nzdefault=None, print_early=False, unset=False, enum=None): 36 assert name not in types.keys(), f'Duplicate type "{name}".' 37 38 if base_type == BaseType.bool: 39 _name = 'bool' 40 elif base_type == BaseType.uint: 41 _name = 'unsigned' 42 elif base_type == BaseType.enum: 43 _name = f'enum {prefix}_{name}' 44 else: 45 assert False, f'Invalid base type for type {name}.' 46 47 t = Type(_name, name, base_type, num_bits, dec_bits, check, encode, nzdefault, print_early, unset, enum) 48 types[name] = t 49 return t 50 51# Enum types. 52class EnumType(object): 53 def __init__(self, name, ename, elems, valid, unique_count, is_bitset, parent): 54 self.name = name 55 self.ename = ename 56 self.elems = elems 57 self.valid = valid 58 self.unique_count = unique_count 59 self.is_bitset = is_bitset 60 self.parent = parent 61 62class EnumElem(object): 63 def __init__(self, cname, value, string): 64 self.cname = cname 65 self.value = value 66 self.string = string 67 68enums = {} 69def enum_type(name, elems, is_bitset=False, num_bits=None, *args, **kwargs): 70 assert name not in enums.keys(), f'Duplicate enum "{name}".' 71 72 _elems = {} 73 _valid_vals = set() 74 _valid_valmask = 0 75 next_value = 0 76 for e in elems: 77 if isinstance(e, str): 78 elem = e 79 value = (1 << next_value) if is_bitset else next_value 80 next_value += 1 81 string = elem 82 else: 83 assert isinstance(e, tuple) and len(e) > 1 84 elem = e[0] 85 if isinstance(e[1], str) and len(e) == 2: 86 value = (1 << next_value) if is_bitset else next_value 87 next_value += 1 88 string = e[1] 89 elif isinstance(e[1], int): 90 value = e[1] 91 string = e[2] if len(e) == 3 else elem 92 else: 93 assert False, f'Invalid defintion for element "{elem}" in elem "{name}".' 94 95 assert isinstance(elem, str) and isinstance(value, int) and isinstance(string, str) 96 97 assert not num_bits or val_fits_in_bits(value, num_bits), f'Element "{elem}" in elem "{name}" with value "{value}" does not fit into {num_bits} bits.' 98 # Collect valid values, ensure that elements with repeated values only have one string set. 99 if is_bitset: 100 if (_valid_valmask & value) != 0: 101 string = None 102 _valid_valmask |= value 103 else: 104 if value in _valid_vals: 105 string = None 106 _valid_vals.add(value) 107 108 assert elem not in _elems.keys(), f'Duplicate element "{elem}" in enum "".' 109 cname = f'{prefix}_{name}_{elem}'.upper() 110 _elems[elem] = EnumElem(cname, value, string) 111 112 _name = f'{prefix}_{name}' 113 _valid = _valid_valmask if is_bitset else _valid_vals 114 _unique_count = bin(_valid_valmask).count('1') if is_bitset else len(_valid_vals) 115 enum = EnumType(_name, name, _elems, _valid, _unique_count, is_bitset, parent=None) 116 enums[name] = enum 117 118 return type(name, BaseType.enum, num_bits, *args, **kwargs, enum=enum) 119 120def enum_subtype(name, parent, num_bits): 121 assert name not in enums.keys(), f'Duplicate enum "{name}".' 122 123 assert parent.enum is not None 124 assert not parent.enum.is_bitset 125 assert parent.num_bits is not None and parent.num_bits > num_bits 126 127 _name = f'{prefix}_{name}' 128 # Validation of subtype - values that will fit in the smaller bit size. 129 _valid = {val for val in parent.enum.valid if val_fits_in_bits(val, num_bits)} 130 enum = EnumType(_name, name, None, _valid, None, False, parent) 131 enums[name] = enum 132 return type(name, BaseType.enum, num_bits, enum=enum) 133 134# Type specializations. 135 136field_types = {} 137field_enum_types = {} 138def field_type(name, *args, **kwargs): 139 assert name not in field_types.keys(), f'Duplicate field type "{name}".' 140 t = type(name, *args, **kwargs) 141 field_types[name] = t 142 return t 143 144def field_enum_type(name, *args, **kwargs): 145 assert name not in field_types.keys() and name not in field_enum_types.keys(), f'Duplicate field enum type "{name}".' 146 t = enum_type(name, *args, **kwargs) 147 field_types[name] = t 148 field_enum_types[name] = enums[name] 149 return t 150 151def field_enum_subtype(name, *args, **kwargs): 152 assert name not in field_types.keys() and name not in field_enum_types.keys(), f'Duplicate field enum (sub)type "{name}".' 153 t = enum_subtype(name, *args, **kwargs) 154 field_types[name] = t 155 field_enum_types[name] = enums[name] 156 return t 157 158class OpMod(object): 159 def __init__(self, t, cname, ctype): 160 self.t = t 161 self.cname = cname 162 self.ctype = ctype 163 164op_mods = {} 165op_mod_enums = {} 166def op_mod(name, *args, **kwargs): 167 assert name not in op_mods.keys(), f'Duplicate op mod "{name}".' 168 t = type(name, *args, **kwargs) 169 cname = f'{prefix}_op_mod_{name}'.upper() 170 ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper() 171 om = op_mods[name] = OpMod(t, cname, ctype) 172 assert len(op_mods) <= 64, f'Too many op mods ({len(op_mods)})!' 173 return om 174 175def op_mod_enum(name, *args, **kwargs): 176 assert name not in op_mods.keys() and name not in op_mod_enums.keys(), f'Duplicate op mod enum "{name}".' 177 t = enum_type(name, *args, **kwargs) 178 cname = f'{prefix}_op_mod_{name}'.upper() 179 ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper() 180 om = op_mods[name] = OpMod(t, cname, ctype) 181 op_mod_enums[name] = enums[name] 182 assert len(op_mods) <= 64, f'Too many op mods ({len(op_mods)})!' 183 return om 184 185class RefMod(object): 186 def __init__(self, t, cname, ctype): 187 self.t = t 188 self.cname = cname 189 self.ctype = ctype 190 191ref_mods = {} 192ref_mod_enums = {} 193def ref_mod(name, *args, **kwargs): 194 assert name not in ref_mods.keys(), f'Duplicate ref mod "{name}".' 195 t = type(name, *args, **kwargs) 196 cname = f'{prefix}_ref_mod_{name}'.upper() 197 ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper() 198 rm = ref_mods[name] = RefMod(t, cname, ctype) 199 assert len(ref_mods) <= 64, f'Too many ref mods ({len(ref_mods)})!' 200 return rm 201 202def ref_mod_enum(name, *args, **kwargs): 203 assert name not in ref_mods.keys() and name not in ref_mod_enums.keys(), f'Duplicate ref mod enum "{name}".' 204 t = enum_type(name, *args, **kwargs) 205 cname = f'{prefix}_ref_mod_{name}'.upper() 206 ctype = f'{prefix}_mod_type_{t.base_type.name.upper()}'.upper() 207 rm = ref_mods[name] = RefMod(t, cname, ctype) 208 ref_mod_enums[name] = enums[name] 209 assert len(ref_mods) <= 64, f'Too many ref mods ({len(ref_mods)})!' 210 return rm 211 212# Bit encoding definition helpers. 213 214class BitPiece(object): 215 def __init__(self, name, byte, hi_bit, lo_bit, num_bits): 216 self.name = name 217 self.byte = byte 218 self.hi_bit = hi_bit 219 self.lo_bit = lo_bit 220 self.num_bits = num_bits 221 222def bit_piece(name, byte, bit_range): 223 assert bit_range.count(':') <= 1, f'Invalid bit range specification in bit piece {name}.' 224 is_one_bit = not bit_range.count(':') 225 226 split_range = [bit_range, bit_range] if is_one_bit else bit_range.split(':', 1) 227 (hi_bit, lo_bit) = list(map(int, split_range)) 228 assert hi_bit < 8 and hi_bit >= 0 and lo_bit < 8 and lo_bit >= 0 and hi_bit >= lo_bit 229 230 _num_bits = hi_bit - lo_bit + 1 231 return BitPiece(name, byte, hi_bit, lo_bit, _num_bits) 232 233class BitField(object): 234 def __init__(self, name, cname, field_type, pieces, reserved, validate, encoding, encoded_bits): 235 self.name = name 236 self.cname = cname 237 self.field_type = field_type 238 self.pieces = pieces 239 self.reserved = reserved 240 self.validate = validate 241 self.encoding = encoding 242 self.encoded_bits = encoded_bits 243 244class Encoding(object): 245 def __init__(self, clear, set): 246 self.clear = clear 247 self.set = set 248 249def bit_field(bit_set_name, name, bit_set_pieces, field_type, pieces, reserved=None): 250 _pieces = [bit_set_pieces[p] for p in pieces] 251 252 total_bits = sum([p.num_bits for p in _pieces]) 253 assert total_bits == field_type.num_bits, f'Expected {field_type.num_bits}, got {total_bits} in bit field {name}.' 254 255 if reserved is not None: 256 assert val_fits_in_bits(reserved, total_bits), f'Reserved value for bit field {name} is too large.' 257 258 cname = f'{bit_set_name}_{name}'.upper() 259 if field_type.base_type == BaseType.enum: 260 validate = f'{prefix}_{field_type.enum.ename}_valid({{}})' 261 else: 262 validate = f'{{}} < (1ULL << {field_type.num_bits})' 263 264 encoding = [] 265 bits_consumed = 0 266 for i, piece in enumerate(reversed(_pieces)): 267 enc_clear = f'{{}}[{piece.byte}] &= {hex((((1 << piece.num_bits) - 1) << piece.lo_bit) ^ 0xff)}' 268 269 enc_set = f'{{}}[{piece.byte}] |= (' 270 enc_set += f'({{}} >> {bits_consumed})' if bits_consumed > 0 else '{}' 271 enc_set += f' & {hex((1 << piece.num_bits) - 1)})' 272 enc_set += f' << {piece.lo_bit}' if piece.lo_bit > 0 else '' 273 encoding.append(Encoding(enc_clear, enc_set)) 274 275 bits_consumed += piece.num_bits 276 277 return BitField(name, cname, field_type, _pieces, reserved, validate, encoding, bits_consumed) 278 279class BitSet(object): 280 def __init__(self, name, bsname, pieces, fields): 281 self.name = name 282 self.bsname = bsname 283 self.pieces = pieces 284 self.fields = fields 285 self.bit_structs = {} 286 self.variants = [] 287 288bit_sets = {} 289 290def bit_set(name, pieces, fields): 291 assert name not in bit_sets.keys(), f'Duplicate bit set "{name}".' 292 _name = f'{prefix}_{name}' 293 294 _pieces = {} 295 for (piece, spec) in pieces: 296 assert piece not in _pieces.keys(), f'Duplicate bit piece "{piece}" in bit set "{name}".' 297 _pieces[piece] = bit_piece(piece, *spec) 298 299 _fields = {} 300 for (field, spec) in fields: 301 assert field not in _fields.keys(), f'Duplicate bit field "{field}" in bit set "{name}".' 302 _fields[field] = bit_field(_name, field, _pieces, *spec) 303 304 bs = BitSet(_name, name, _pieces, _fields) 305 bit_sets[name] = bs 306 return bs 307 308class BitStruct(object): 309 def __init__(self, name, bsname, struct_fields, encode_fields, num_bytes, data, bit_set): 310 self.name = name 311 self.bsname = bsname 312 self.struct_fields = struct_fields 313 self.encode_fields = encode_fields 314 self.num_bytes = num_bytes 315 self.data = data 316 self.bit_set = bit_set 317 318class StructField(object): 319 def __init__(self, type, field, bits): 320 self.type = type 321 self.field = field 322 self.bits = bits 323 324class EncodeField(object): 325 def __init__(self, name, value): 326 self.name = name 327 self.value = value 328 329class Variant(object): 330 def __init__(self, cname, bytes): 331 self.cname = cname 332 self.bytes = bytes 333 334def bit_struct(name, bit_set, field_mappings, data=None): 335 assert name not in bit_set.bit_structs.keys(), f'Duplicate bit struct "{name}" in bit set "{bit_set.name}".' 336 337 struct_fields = {} 338 encode_fields = [] 339 all_pieces = [] 340 total_bits = 0 341 for mapping in field_mappings: 342 if isinstance(mapping, str): 343 struct_field = mapping 344 _field = mapping 345 fixed_value = None 346 else: 347 assert isinstance(mapping, tuple) 348 struct_field, _field, *fixed_value = mapping 349 assert len(fixed_value) == 0 or len(fixed_value) == 1 350 fixed_value = None if len(fixed_value) == 0 else fixed_value[0] 351 352 assert struct_field not in struct_fields.keys(), f'Duplicate struct field "{struct_field}" in bit struct "{name}".' 353 assert _field in bit_set.fields.keys(), f'Field "{_field}" in mapping for struct field "{name}.{struct_field}" not defined in bit set "{bit_set.name}".' 354 field = bit_set.fields[_field] 355 field_type = field.field_type 356 is_enum = field_type.base_type == BaseType.enum 357 358 if fixed_value is not None: 359 assert field.reserved is None, f'Fixed value for field mapping "{struct_field}" using field "{_field}" cannot overwrite its reserved value.' 360 361 if is_enum and isinstance(fixed_value, str): 362 enum = field_type.enum 363 assert fixed_value in enum.elems.keys(), f'Fixed value for field mapping "{struct_field}" using field "{_field}" is not an element of enum {field_type.name}.' 364 fixed_value = enum.elems[fixed_value].cname.upper() 365 else: 366 if isinstance(fixed_value, bool): 367 fixed_value = int(fixed_value) 368 369 assert isinstance(fixed_value, int) 370 assert val_fits_in_bits(fixed_value, field_type.num_bits), f'Fixed value for field mapping "{struct_field}" using field "{_field}" is too large.' 371 372 all_pieces.extend([(piece.lo_bit + (8 * piece.byte), piece.hi_bit + (8 * piece.byte), piece.name) for piece in field.pieces]) 373 total_bits += field_type.num_bits 374 375 # Describe how to encode the bit struct. 376 encode_field = f'{bit_set.name}_{_field}'.upper() 377 if fixed_value is not None: 378 encode_value = fixed_value 379 elif field.reserved is not None: 380 encode_value = field.reserved 381 else: 382 encode_value = f's.{struct_field}' 383 encode_fields.append(EncodeField(encode_field, encode_value)) 384 385 # Describe settable fields. 386 if field.reserved is None and fixed_value is None: 387 # Use parent enum for struct fields. 388 if is_enum and field_type.enum.parent is not None: 389 field_type = field_type.enum.parent 390 391 struct_field_bits = field_type.dec_bits if field_type.dec_bits is not None else field_type.num_bits 392 struct_fields[struct_field] = StructField(field_type, struct_field, struct_field_bits) 393 394 # Check for overlapping pieces. 395 for p0 in all_pieces: 396 for p1 in all_pieces: 397 if p0 == p1: 398 continue 399 assert p0[1] < p1[0] or p0[0] > p1[1], f'Pieces "{p0[2]}" and "{p1[2]}" overlap in bit struct "{name}".' 400 401 # Check for byte-alignment. 402 assert (total_bits % 8) == 0, f'Bit struct "{name}" has a non-byte-aligned number of bits ({total_bits}).' 403 404 _name = f'{bit_set.name}_{name}' 405 total_bytes = total_bits // 8 406 bs = BitStruct(_name, name, struct_fields, encode_fields, total_bytes, data, bit_set) 407 bit_set.bit_structs[name] = bs 408 bit_set.variants.append(Variant(f'{bit_set.name}_{name}'.upper(), total_bytes)) 409 410 return bs 411 412# Op definitions. 413class Op(object): 414 def __init__(self, name, cname, bname, op_type, op_mods, cop_mods, op_mod_map, num_dests, num_srcs, dest_mods, cdest_mods, src_mods, csrc_mods, has_target_cf_node, builder_params): 415 self.name = name 416 self.cname = cname 417 self.bname = bname 418 self.op_type = op_type 419 self.op_mods = op_mods 420 self.cop_mods = cop_mods 421 self.op_mod_map = op_mod_map 422 self.num_dests = num_dests 423 self.num_srcs = num_srcs 424 self.dest_mods = dest_mods 425 self.cdest_mods = cdest_mods 426 self.src_mods = src_mods 427 self.csrc_mods = csrc_mods 428 self.has_target_cf_node = has_target_cf_node 429 self.builder_params = builder_params 430 431ops = {} 432 433def op(name, op_type, op_mods, num_dests, num_srcs, dest_mods, src_mods, has_target_cf_node): 434 assert name not in ops.keys(), f'Duplicate op "{name}".' 435 436 _name = name.replace('.', '_') 437 cname = f'{prefix}_op_{_name}' 438 bname = f'{prefix}_{_name}' 439 cop_mods = 0 if not op_mods else ' | '.join([f'(1ULL << {op_mod.cname})' for op_mod in op_mods]) 440 op_mod_map = {op_mod.cname: index + 1 for index, op_mod in enumerate(op_mods)} 441 cdest_mods = {i: 0 if not dest_mods else ' | '.join([f'(1ULL << {ref_mod.cname})' for ref_mod in destn_mods]) for i, destn_mods in enumerate(dest_mods)} 442 csrc_mods = {i: 0 if not src_mods else ' | '.join([f'(1ULL << {ref_mod.cname})' for ref_mod in srcn_mods]) for i, srcn_mods in enumerate(src_mods)} 443 444 builder_params = ['', '', '', '', ''] 445 446 if op_type != 'hw_direct': 447 builder_params[0] = 'pco_builder *b' 448 builder_params[1] = 'b' 449 builder_params[4] = 'pco_cursor_func(b->cursor)' 450 else: 451 builder_params[0] = 'pco_func *func' 452 builder_params[1] = 'func' 453 builder_params[4] = 'func' 454 455 if has_target_cf_node: 456 builder_params[0] += ', pco_cf_node *target_cf_node' 457 builder_params[1] += ', target_cf_node' 458 459 if num_dests == VARIABLE: 460 builder_params[0] += f', unsigned num_dests, pco_ref *dest' 461 builder_params[1] += f', num_dests, dests' 462 else: 463 for d in range(num_dests): 464 builder_params[0] += f', pco_ref dest{d}' 465 builder_params[1] += f', dest{d}' 466 467 if num_srcs == VARIABLE: 468 builder_params[0] += f', unsigned num_srcs, pco_ref *src' 469 builder_params[1] += f', num_srcs, srcs' 470 else: 471 for s in range(num_srcs): 472 builder_params[0] += f', pco_ref src{s}' 473 builder_params[1] += f', src{s}' 474 475 if bool(op_mods): 476 builder_params[0] += f', struct {prefix}_{_name}_mods mods' 477 builder_params[2] = ', ...' 478 builder_params[3] = f', (struct {bname}_mods){{0, ##__VA_ARGS__}}' 479 480 op = Op(name, cname, bname, op_type, op_mods, cop_mods, op_mod_map, num_dests, num_srcs, dest_mods, cdest_mods, src_mods, csrc_mods, has_target_cf_node, builder_params) 481 ops[name] = op 482 return op 483 484def pseudo_op(name, op_mods=[], num_dests=0, num_srcs=0, dest_mods=[], src_mods=[], has_target_cf_node=False): 485 return op(name, 'pseudo', op_mods, num_dests, num_srcs, dest_mods, src_mods, has_target_cf_node) 486 487def hw_op(name, op_mods=[], num_dests=0, num_srcs=0, dest_mods=[], src_mods=[], has_target_cf_node=False): 488 return op(name, 'hw', op_mods, num_dests, num_srcs, dest_mods, src_mods, has_target_cf_node) 489 490def hw_direct_op(name, num_dests=0, num_srcs=0, has_target_cf_node=False): 491 return op(name, 'hw_direct', [], num_dests, num_srcs, [], [], has_target_cf_node) 492