• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright © 2024 Tomeu Vizoso <tomeu.vizoso@tomeuvizoso.net>
2# SPDX-License-Identifier: MIT
3
4"""Parse the coefficients data blob and print something more readable."""
5
6import math
7import os
8import sys
9
10DEBUG = False
11
12VERSION = os.environ.get("VSIMULATOR_CONFIG", "VIPPICO_V3_PID0X99")
13
14weights_width = 2
15weights_height = weights_width
16input_channels = 8
17output_channels = 0
18
19def assertEqual(actual, expected):
20    assert abs(actual - expected) <= 1, "Expecting 0x%02x, got 0x%02x instead" % (expected, actual)
21    if actual != expected:
22        print("Warning: expecting 0x%02x, got 0x%02x instead" % (expected, actual))
23
24def dbg(*args, **kwargs):
25    if DEBUG:
26        print(" ".join(map(str, args)), **kwargs)
27
28class BitStream():
29    def __init__(self, bytes):
30      self._bytes = bytes
31      self._buffer = 0
32      self._bits_in_buffer = 0
33
34    def read(self, bits):
35        if bits == 0:
36            dbg("read %d bits: %d" % (bits, 0))
37            return 0
38
39        while bits > self._bits_in_buffer:
40            self._buffer |= self._bytes.pop(0) << self._bits_in_buffer
41            self._bits_in_buffer += 8
42
43        temp = 0
44        for i in range(0, bits):
45            temp |= self._buffer & (1 << i)
46            self._bits_in_buffer -= 1
47        self._buffer >>= bits
48
49        dbg("read %d bits: %d" % (bits, temp))
50        return temp
51
52    def read32(self):
53        val = self.read(32)
54        return val
55
56    def read16(self):
57        val = self.read(16)
58        return val
59
60    def reset(self):
61        self._buffer = 0
62        self._bits_in_buffer = 0
63
64    def read_bytes(self, length):
65        assert(self._bits_in_buffer == 0)
66
67        new = BitStream(self._bytes[:length])
68        self._bytes = self._bytes[length:]
69        return new
70
71bytes = []
72content = sys.stdin.read().strip()
73content = content.replace("\n", "")
74bytes = [int(content[i:i + 2], 16) for i in range(0, len(content), 2)]
75
76bytes = BitStream(bytes)
77
78precode = bytes.read(1)
79bit16 = bytes.read(1)
80fp16 = bytes.read(1)
81reserved1 = bytes.read(1)
82version = bytes.read(4)
83
84run_length_table_size = bytes.read(8)
85
86run_length_table = []
87for i in range(0, 18):
88    run_length_table.append(bytes.read(8))
89
90symbol_map = []
91for i in range(0, 8):
92    symbol_map.append(bytes.read(4))
93
94avg_bias = bytes.read(16)
95reserved2 = bytes.read(16)
96
97stream_sizes = []
98for i in range(0, 8):
99    stream_sizes.append(bytes.read32())
100
101padding = bytes.read32()
102
103print("Precode: %d" % precode)
104print("Bit16: %d" % bit16)
105print("FP16: %d" % fp16)
106print("Reserved 1: %d" % reserved1)
107print("Version: %d" % version)
108print("Run length table size: %d" % run_length_table_size)
109print("Run length table: %r" % run_length_table)
110print("Symbol map: %r" % symbol_map)
111print("Avg bias: %d" % avg_bias)
112print("Reserved 2: %d" % reserved2)
113print("Stream sizes: %r" % stream_sizes)
114print("Padding: %d" % padding)
115
116def get_symbol(part0, part1):
117    dbg("get_symbol part0 %d part1 %d" % (part0, part1))
118    if part0 == 0:
119        return 0, part0 >> 2
120    elif part0 == 1:
121        return 1, part0 >> 2
122    elif part0 == 2:
123        if part1 == 1 or part1 == 3:
124            return 5, part1 >> 1
125        elif part1 == 0:
126            return 7, -1
127        elif part1 == 2:
128            return 6, -1
129        else:
130            assert False
131    elif part0 == 3:
132        return 3, -1
133    elif part0 == 4:
134        return 0, part0 >> 2
135    elif part0 == 5:
136        return 1, part0 >> 2
137    elif part0 == 6:
138        return 4, -1
139    elif part0 == 7:
140        return 2, -1
141    else:
142        assert False
143
144class Code:
145    def __init__(self):
146        self.reset()
147
148    def reset(self):
149        self.part0 = 0
150        self.part1 = 0
151        self.part2 = 0
152        self.part1_len = 0
153        self.part2_len = 0
154        self.bit_len = 0
155
156RING_BUFFER_SIZE = 6
157ring_buffer = []
158for i in range(0, RING_BUFFER_SIZE):
159    ring_buffer.append(Code())
160
161weights = []
162total_read = 1
163
164def decode_one_char(part2, bit_length, unk1, zero):
165    if (unk1 == -1):
166        unk1 = part2 & 1
167        part2 = part2 >> 1
168
169    if (bit_length != 0):
170        part2 = part2 | 1 << (bit_length - 1 & 0x1f)
171
172    if (unk1 != 0):
173        part2 = (part2 ^ 0xff) + zero
174
175    return part2
176
177def uint8(val):
178    if val > 255:
179        return val - 256
180    else:
181        return val
182
183def read_pair(bytes):
184    global total_read
185
186    if avg_bias > 0:
187        zero_point = avg_bias
188    else:
189        zero_point = 0x80
190
191    dbg(">>>>>>>>>> Stage 1: total_read %d" % total_read)
192    for i in range(1, -1, -1):
193        code = ring_buffer[(total_read - i) % RING_BUFFER_SIZE]
194        code.reset()
195        code.part0 = bytes.read(3)
196        code.part1_len = 2 if code.part0 == 2 else 0
197        dbg("code at %d has part0 %d part1_len %d" % ((total_read - i) % RING_BUFFER_SIZE, code.part0, code.part1_len))
198
199    if total_read >= 2:
200        dbg(">>>>>>>>>> Stage 2")
201        for i in range(3, 1, -1):
202            code = ring_buffer[(total_read - i) % RING_BUFFER_SIZE]
203            code.part1 = bytes.read(code.part1_len)
204            symbol, code.unk1 = get_symbol(code.part0, code.part1)
205            #dbg("symbol %d code.unk1 %d" % (symbol, code.unk1))
206            code.bit_len = symbol_map[symbol]
207            if run_length_table_size == 0:
208                code.part2_len = max(code.bit_len, 1)
209            else:
210                if run_length_table_size <= 4:
211                   code.part2_len = 1
212                elif run_length_table_size <= 6:
213                   code.part2_len = 2
214                elif run_length_table_size <= 10:
215                   code.part2_len = 3
216                else:
217                   code.part2_len = 4
218
219            if code.unk1 != -1:
220                code.part2_len -= 1
221
222            dbg("part1 %d bit_len %d part2_len %d" % (code.part1, code.bit_len, code.part2_len))
223
224    if total_read >= 4:
225        dbg(">>>>>>>>>> Stage 3")
226        for i in range(5, 3, -1):
227            code = ring_buffer[(total_read - i) % RING_BUFFER_SIZE]
228            code.part2 = bytes.read(code.part2_len)
229
230            if run_length_table_size == 0:
231                char = decode_one_char(code.part2, code.bit_len, code.unk1, 0)
232                weights.append(uint8(char + avg_bias))
233                dbg("run_length_table_size == 0: uint8(char + avg_bias) %d" % uint8(char + avg_bias))
234            else:
235                if code.bit_len == 7:
236                    if code.unk1 == -1:
237                        char = code.part2
238                    else:
239                        char = code.unk1 + code.part2 * 2
240                    weights.append(uint8(char + avg_bias))
241                    dbg("7: char %d uint8(char + avg_bias) %d" % (char, uint8(char + avg_bias)))
242                elif code.bit_len == 8:
243                    if code.unk1 == -1:
244                        index = code.part2 + 2
245                    else:
246                        index = code.part2 * 2 + code.unk1
247                    char = run_length_table[index] + 1
248                    weights.extend([0x0 + avg_bias] * char)
249                    dbg("8: [0x0 + avg_bias] * char %r" % [0x0 + avg_bias] * char)
250                elif code.bit_len == 0:
251                    if code.unk1 == -1:
252                        symbol = code.part2
253                    else:
254                        symbol = code.unk1
255                    char = run_length_table[symbol] + 1
256                    if VERSION == "VIPPICO_V3_PID0X99":
257                        weights.extend([zero_point] * char)
258                    else:
259                        weights.extend([0x0] * char)
260                    dbg("0: [zero_point] * char %r" % [zero_point] * char)
261                else:
262                    char = decode_one_char(code.part2, code.bit_len, code.unk1, 0)
263                    weights.append(uint8(char + avg_bias))
264                    dbg("else: uint8(char + avg_bias) %d" % uint8(char + avg_bias))
265
266            dbg("bit_len %d part2_len %d part2 %d char %02x" % (code.bit_len, code.part2_len, code.part2, char))
267
268    dbg()
269    total_read += 2
270
271def align(num, alignment):
272    if num % alignment == 0:
273        return num
274    return num + alignment - (num % alignment)
275
276def pop_int32(weights):
277    val1 = weights.pop(0)
278
279    val2 = weights.pop(0)
280    if val2 > 0:
281        val2 += 1
282
283    val3 = weights.pop(0)
284    if val3 > 0:
285        val3 += 1
286
287    val4 = weights.pop(0)
288    if val4 > 0:
289        val4 += 1
290
291    return val1 | (val2 << 8) | (val3 << 16) | (val4 << 24)
292
293core = 0
294for stream_size in stream_sizes:
295    if stream_size == 0:
296        break
297    aligned_size = int(align(math.ceil(stream_size / 8.0), 64))
298    core_bytes = bytes.read_bytes(aligned_size)
299    while len(core_bytes._bytes) > aligned_size - math.ceil(stream_size / 8.0):
300        read_pair(core_bytes)
301
302    print()
303    print("Raw data for core %d: %r" % (core, weights))
304
305    vz1 = weights.pop(0)
306    vz2 = weights.pop(0)
307    kernels_per_core = vz1 | (vz2 << 8)
308    output_channels += kernels_per_core
309    print("Kernels per core: %r" % (kernels_per_core))
310
311    for ic in range(0, input_channels):
312        for kernel in range(0, kernels_per_core):
313            if VERSION == "VIPPICO_V3_PID0X99" and ic == 0:
314                bias = pop_int32(weights)
315                print("Bias: 0x%x" % (bias))
316
317            kernel_size = weights_width * weights_height
318            channel_weights = weights[:kernel_size]
319            weights = weights[kernel_size:]
320
321            if VERSION == "VIP8000NANOSI_PLUS_PID0X9F":
322                converted = []
323                for weight in channel_weights:
324                    unsigned = weight + 0x80
325                    if unsigned > 255:
326                        unsigned = unsigned & 0xFF
327                    converted.append(unsigned)
328                channel_weights = converted
329
330            print("Weights: %r" % channel_weights)
331
332            if ic == input_channels - 1:
333                if VERSION == "VIPPICO_V3_PID0X99":
334                    out_offset = pop_int32(weights)
335                    print("Output offset: %r" % out_offset)
336
337    weights.clear()
338
339    total_read = 0
340
341    core += 1
342
343for oc in range(0, output_channels):
344   print("%x" % bytes.read32())
345