1# Copyright © 2024 Tomeu Vizoso <tomeu.vizoso@tomeuvizoso.net> 2# SPDX-License-Identifier: MIT 3 4"""Parse the coefficients data blob and print something more readable.""" 5 6import math 7import os 8import sys 9 10DEBUG = False 11 12VERSION = os.environ.get("VSIMULATOR_CONFIG", "VIPPICO_V3_PID0X99") 13 14weights_width = 2 15weights_height = weights_width 16input_channels = 8 17output_channels = 0 18 19def assertEqual(actual, expected): 20 assert abs(actual - expected) <= 1, "Expecting 0x%02x, got 0x%02x instead" % (expected, actual) 21 if actual != expected: 22 print("Warning: expecting 0x%02x, got 0x%02x instead" % (expected, actual)) 23 24def dbg(*args, **kwargs): 25 if DEBUG: 26 print(" ".join(map(str, args)), **kwargs) 27 28class BitStream(): 29 def __init__(self, bytes): 30 self._bytes = bytes 31 self._buffer = 0 32 self._bits_in_buffer = 0 33 34 def read(self, bits): 35 if bits == 0: 36 dbg("read %d bits: %d" % (bits, 0)) 37 return 0 38 39 while bits > self._bits_in_buffer: 40 self._buffer |= self._bytes.pop(0) << self._bits_in_buffer 41 self._bits_in_buffer += 8 42 43 temp = 0 44 for i in range(0, bits): 45 temp |= self._buffer & (1 << i) 46 self._bits_in_buffer -= 1 47 self._buffer >>= bits 48 49 dbg("read %d bits: %d" % (bits, temp)) 50 return temp 51 52 def read32(self): 53 val = self.read(32) 54 return val 55 56 def read16(self): 57 val = self.read(16) 58 return val 59 60 def reset(self): 61 self._buffer = 0 62 self._bits_in_buffer = 0 63 64 def read_bytes(self, length): 65 assert(self._bits_in_buffer == 0) 66 67 new = BitStream(self._bytes[:length]) 68 self._bytes = self._bytes[length:] 69 return new 70 71bytes = [] 72content = sys.stdin.read().strip() 73content = content.replace("\n", "") 74bytes = [int(content[i:i + 2], 16) for i in range(0, len(content), 2)] 75 76bytes = BitStream(bytes) 77 78precode = bytes.read(1) 79bit16 = bytes.read(1) 80fp16 = bytes.read(1) 81reserved1 = bytes.read(1) 82version = bytes.read(4) 83 84run_length_table_size = bytes.read(8) 85 86run_length_table = [] 87for i in range(0, 18): 88 run_length_table.append(bytes.read(8)) 89 90symbol_map = [] 91for i in range(0, 8): 92 symbol_map.append(bytes.read(4)) 93 94avg_bias = bytes.read(16) 95reserved2 = bytes.read(16) 96 97stream_sizes = [] 98for i in range(0, 8): 99 stream_sizes.append(bytes.read32()) 100 101padding = bytes.read32() 102 103print("Precode: %d" % precode) 104print("Bit16: %d" % bit16) 105print("FP16: %d" % fp16) 106print("Reserved 1: %d" % reserved1) 107print("Version: %d" % version) 108print("Run length table size: %d" % run_length_table_size) 109print("Run length table: %r" % run_length_table) 110print("Symbol map: %r" % symbol_map) 111print("Avg bias: %d" % avg_bias) 112print("Reserved 2: %d" % reserved2) 113print("Stream sizes: %r" % stream_sizes) 114print("Padding: %d" % padding) 115 116def get_symbol(part0, part1): 117 dbg("get_symbol part0 %d part1 %d" % (part0, part1)) 118 if part0 == 0: 119 return 0, part0 >> 2 120 elif part0 == 1: 121 return 1, part0 >> 2 122 elif part0 == 2: 123 if part1 == 1 or part1 == 3: 124 return 5, part1 >> 1 125 elif part1 == 0: 126 return 7, -1 127 elif part1 == 2: 128 return 6, -1 129 else: 130 assert False 131 elif part0 == 3: 132 return 3, -1 133 elif part0 == 4: 134 return 0, part0 >> 2 135 elif part0 == 5: 136 return 1, part0 >> 2 137 elif part0 == 6: 138 return 4, -1 139 elif part0 == 7: 140 return 2, -1 141 else: 142 assert False 143 144class Code: 145 def __init__(self): 146 self.reset() 147 148 def reset(self): 149 self.part0 = 0 150 self.part1 = 0 151 self.part2 = 0 152 self.part1_len = 0 153 self.part2_len = 0 154 self.bit_len = 0 155 156RING_BUFFER_SIZE = 6 157ring_buffer = [] 158for i in range(0, RING_BUFFER_SIZE): 159 ring_buffer.append(Code()) 160 161weights = [] 162total_read = 1 163 164def decode_one_char(part2, bit_length, unk1, zero): 165 if (unk1 == -1): 166 unk1 = part2 & 1 167 part2 = part2 >> 1 168 169 if (bit_length != 0): 170 part2 = part2 | 1 << (bit_length - 1 & 0x1f) 171 172 if (unk1 != 0): 173 part2 = (part2 ^ 0xff) + zero 174 175 return part2 176 177def uint8(val): 178 if val > 255: 179 return val - 256 180 else: 181 return val 182 183def read_pair(bytes): 184 global total_read 185 186 if avg_bias > 0: 187 zero_point = avg_bias 188 else: 189 zero_point = 0x80 190 191 dbg(">>>>>>>>>> Stage 1: total_read %d" % total_read) 192 for i in range(1, -1, -1): 193 code = ring_buffer[(total_read - i) % RING_BUFFER_SIZE] 194 code.reset() 195 code.part0 = bytes.read(3) 196 code.part1_len = 2 if code.part0 == 2 else 0 197 dbg("code at %d has part0 %d part1_len %d" % ((total_read - i) % RING_BUFFER_SIZE, code.part0, code.part1_len)) 198 199 if total_read >= 2: 200 dbg(">>>>>>>>>> Stage 2") 201 for i in range(3, 1, -1): 202 code = ring_buffer[(total_read - i) % RING_BUFFER_SIZE] 203 code.part1 = bytes.read(code.part1_len) 204 symbol, code.unk1 = get_symbol(code.part0, code.part1) 205 #dbg("symbol %d code.unk1 %d" % (symbol, code.unk1)) 206 code.bit_len = symbol_map[symbol] 207 if run_length_table_size == 0: 208 code.part2_len = max(code.bit_len, 1) 209 else: 210 if run_length_table_size <= 4: 211 code.part2_len = 1 212 elif run_length_table_size <= 6: 213 code.part2_len = 2 214 elif run_length_table_size <= 10: 215 code.part2_len = 3 216 else: 217 code.part2_len = 4 218 219 if code.unk1 != -1: 220 code.part2_len -= 1 221 222 dbg("part1 %d bit_len %d part2_len %d" % (code.part1, code.bit_len, code.part2_len)) 223 224 if total_read >= 4: 225 dbg(">>>>>>>>>> Stage 3") 226 for i in range(5, 3, -1): 227 code = ring_buffer[(total_read - i) % RING_BUFFER_SIZE] 228 code.part2 = bytes.read(code.part2_len) 229 230 if run_length_table_size == 0: 231 char = decode_one_char(code.part2, code.bit_len, code.unk1, 0) 232 weights.append(uint8(char + avg_bias)) 233 dbg("run_length_table_size == 0: uint8(char + avg_bias) %d" % uint8(char + avg_bias)) 234 else: 235 if code.bit_len == 7: 236 if code.unk1 == -1: 237 char = code.part2 238 else: 239 char = code.unk1 + code.part2 * 2 240 weights.append(uint8(char + avg_bias)) 241 dbg("7: char %d uint8(char + avg_bias) %d" % (char, uint8(char + avg_bias))) 242 elif code.bit_len == 8: 243 if code.unk1 == -1: 244 index = code.part2 + 2 245 else: 246 index = code.part2 * 2 + code.unk1 247 char = run_length_table[index] + 1 248 weights.extend([0x0 + avg_bias] * char) 249 dbg("8: [0x0 + avg_bias] * char %r" % [0x0 + avg_bias] * char) 250 elif code.bit_len == 0: 251 if code.unk1 == -1: 252 symbol = code.part2 253 else: 254 symbol = code.unk1 255 char = run_length_table[symbol] + 1 256 if VERSION == "VIPPICO_V3_PID0X99": 257 weights.extend([zero_point] * char) 258 else: 259 weights.extend([0x0] * char) 260 dbg("0: [zero_point] * char %r" % [zero_point] * char) 261 else: 262 char = decode_one_char(code.part2, code.bit_len, code.unk1, 0) 263 weights.append(uint8(char + avg_bias)) 264 dbg("else: uint8(char + avg_bias) %d" % uint8(char + avg_bias)) 265 266 dbg("bit_len %d part2_len %d part2 %d char %02x" % (code.bit_len, code.part2_len, code.part2, char)) 267 268 dbg() 269 total_read += 2 270 271def align(num, alignment): 272 if num % alignment == 0: 273 return num 274 return num + alignment - (num % alignment) 275 276def pop_int32(weights): 277 val1 = weights.pop(0) 278 279 val2 = weights.pop(0) 280 if val2 > 0: 281 val2 += 1 282 283 val3 = weights.pop(0) 284 if val3 > 0: 285 val3 += 1 286 287 val4 = weights.pop(0) 288 if val4 > 0: 289 val4 += 1 290 291 return val1 | (val2 << 8) | (val3 << 16) | (val4 << 24) 292 293core = 0 294for stream_size in stream_sizes: 295 if stream_size == 0: 296 break 297 aligned_size = int(align(math.ceil(stream_size / 8.0), 64)) 298 core_bytes = bytes.read_bytes(aligned_size) 299 while len(core_bytes._bytes) > aligned_size - math.ceil(stream_size / 8.0): 300 read_pair(core_bytes) 301 302 print() 303 print("Raw data for core %d: %r" % (core, weights)) 304 305 vz1 = weights.pop(0) 306 vz2 = weights.pop(0) 307 kernels_per_core = vz1 | (vz2 << 8) 308 output_channels += kernels_per_core 309 print("Kernels per core: %r" % (kernels_per_core)) 310 311 for ic in range(0, input_channels): 312 for kernel in range(0, kernels_per_core): 313 if VERSION == "VIPPICO_V3_PID0X99" and ic == 0: 314 bias = pop_int32(weights) 315 print("Bias: 0x%x" % (bias)) 316 317 kernel_size = weights_width * weights_height 318 channel_weights = weights[:kernel_size] 319 weights = weights[kernel_size:] 320 321 if VERSION == "VIP8000NANOSI_PLUS_PID0X9F": 322 converted = [] 323 for weight in channel_weights: 324 unsigned = weight + 0x80 325 if unsigned > 255: 326 unsigned = unsigned & 0xFF 327 converted.append(unsigned) 328 channel_weights = converted 329 330 print("Weights: %r" % channel_weights) 331 332 if ic == input_channels - 1: 333 if VERSION == "VIPPICO_V3_PID0X99": 334 out_offset = pop_int32(weights) 335 print("Output offset: %r" % out_offset) 336 337 weights.clear() 338 339 total_read = 0 340 341 core += 1 342 343for oc in range(0, output_channels): 344 print("%x" % bytes.read32()) 345