• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# -----------------------------------------------------------------------------
4# Copyright 2021 Arm Limited
5#
6# Licensed under the Apache License, Version 2.0 (the "License"); you may not
7# use this file except in compliance with the License. You may obtain a copy
8# of the License at:
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15# License for the specific language governing permissions and limitations
16# under the License.
17# -----------------------------------------------------------------------------
18"""
19The ``astc_trace_analysis`` utility provides a tool to analyze trace files.
20
21WARNING: Trace files are an engineering tool, and not part of the standard
22product, so traces and their associated tools are volatile and may change
23significantly without notice.
24"""
25
26import argparse
27from collections import defaultdict as ddict
28import json
29import numpy as np
30import sys
31
32QUANT_TABLE = {
33	 0:   2,
34	 1:   3,
35	 2:   4,
36	 3:   5,
37	 4:   6,
38	 5:   8,
39	 6:  10,
40	 7:  12,
41	 8:  16,
42	 9:  20,
43	10:  24,
44	11:  32
45}
46
47CHANNEL_TABLE = {
48	 0: "R",
49	 1: "G",
50	 2: "B",
51	 3: "A"
52}
53
54class Trace:
55
56    def __init__(self, block_x, block_y, block_z):
57        self.block_x = block_x
58        self.block_y = block_y
59        self.block_z = block_z
60        self.blocks = []
61
62    def add_block(self, block):
63        self.blocks.append(block)
64
65    def __getitem__(self, i):
66       return self.blocks[i]
67
68    def __delitem__(self, i):
69       del self.blocks[i]
70
71    def __len__(self):
72       return len(self.blocks)
73
74class Block:
75
76    def __init__(self, pos_x, pos_y, pos_z, error_target):
77        self.pos_x = pos_x
78        self.pos_y = pos_y
79        self.pos_z = pos_z
80
81        self.raw_min = None
82        self.raw_max = None
83
84        self.ldr_min = None
85        self.ldr_max = None
86
87        self.error_target = error_target
88        self.passes = []
89        self.qualityHit = None
90
91    def add_minimums(self, r, g, b, a):
92        self.raw_min = (r, g, b, a)
93
94        def ldr(x):
95            cmax = 65535.0
96            return int((r / cmax) * 255.0)
97
98        self.ldr_min = (ldr(r), ldr(g), ldr(b), ldr(a))
99
100    def add_maximums(self, r, g, b, a):
101        self.raw_max = (r, g, b, a)
102
103        def ldr(x):
104            cmax = 65535.0
105            return int((r / cmax) * 255.0)
106
107        self.ldr_max = (ldr(r), ldr(g), ldr(b), ldr(a))
108
109    def add_pass(self, pas):
110        self.passes.append(pas)
111
112    def __getitem__(self, i):
113       return self.passes[i]
114
115    def __delitem__(self, i):
116       del self.passes[i]
117
118    def __len__(self):
119       return len(self.passes)
120
121
122class Pass:
123
124    def __init__(self, partitions, partition, planes, target_hit, mode, component):
125        self.partitions = partitions
126        self.partition_index = 0 if partition is None else partition
127        self.planes = planes
128        self.plane2_component = component
129        self.target_hit = target_hit
130        self.search_mode = mode
131        self.candidates = []
132
133    def add_candidate(self, candidate):
134        self.candidates.append(candidate)
135
136    def __getitem__(self, i):
137       return self.candidates[i]
138
139    def __delitem__(self, i):
140       del self.candidates[i]
141
142    def __len__(self):
143       return len(self.candidates)
144
145
146class Candidate:
147
148    def __init__(self, weight_x, weight_y, weight_z, weight_quant):
149        self.weight_x = weight_x
150        self.weight_y = weight_y
151        self.weight_z = weight_z
152        self.weight_quant = weight_quant
153        self.refinement_errors = []
154
155    def add_refinement(self, errorval):
156        self.refinement_errors.append(errorval)
157
158
159def get_attrib(data, name, multiple=False, hard_fail=True):
160    results = []
161    for attrib in data:
162        if len(attrib) == 2 and attrib[0] == name:
163            results.append(attrib[1])
164
165    if not results:
166        if hard_fail:
167            print(json.dumps(data, indent=2))
168            assert False, "Attribute %s not found" % name
169        if multiple:
170            return list()
171        return None
172
173    if not multiple:
174        if len(results) > 1:
175            print(json.dumps(data, indent=2))
176            assert False, "Attribute %s found %u times" % (name, len(results))
177        return results[0]
178
179    return results
180
181
182def rev_enumerate(seq):
183    return zip(reversed(range(len(seq))), reversed(seq))
184
185def foreach_block(data):
186
187    for block in data:
188        yield block
189
190def foreach_pass(data):
191
192    for block in data:
193        for pas in block:
194            yield (block, pas)
195
196def foreach_candidate(data):
197
198    for block in data:
199        for pas in block:
200            # Special case - None candidates for 0 partition
201            if not len(pas):
202                yield (block, pas, None)
203
204            for candidate in pas:
205                yield (block, pas, candidate)
206
207def get_node(data, name, multiple=False, hard_fail=True):
208    results = []
209    for attrib in data:
210        if len(attrib) == 3 and attrib[0] == "node" and attrib[1] == name:
211            results.append(attrib[2])
212
213    if not results:
214        if hard_fail:
215            print(json.dumps(data, indent=2))
216            assert False, "Node %s not found" % name
217        return None
218
219    if not multiple:
220        if len(results) > 1:
221            print(json.dumps(data, indent=2))
222            assert False, "Node %s found %u times" % (name, len(results))
223        return results[0]
224
225    return results
226
227
228def find_best_pass_and_candidate(block):
229    explicit_pass = None
230
231    best_error = 1e30
232    best_pass = None
233    best_candidate = None
234
235    for pas in block:
236        # Special case for constant color blocks - no trial candidates
237        if pas.target_hit and pas.partitions == 0:
238            return (pas, None)
239
240        for candidate in pas:
241            errorval = candidate.refinement_errors[-1]
242            if errorval <= best_error:
243                best_error = errorval
244                best_pass = pas
245                best_candidate = candidate
246
247    # Every other return type must have both best pass and best candidate
248    assert (best_pass and best_candidate)
249    return (best_pass, best_candidate)
250
251
252def generate_database(data):
253    # Skip header
254    assert(data[0] == "node")
255    assert(data[1] == "root")
256    data = data[2]
257
258    bx = get_attrib(data, "block_x")
259    by = get_attrib(data, "block_y")
260    bz = get_attrib(data, "block_z")
261    dbStruct = Trace(bx, by, bz)
262
263    for block in get_node(data, "block", True):
264        px = get_attrib(block, "pos_x")
265        py = get_attrib(block, "pos_y")
266        pz = get_attrib(block, "pos_z")
267
268        minr = get_attrib(block, "min_r")
269        ming = get_attrib(block, "min_g")
270        minb = get_attrib(block, "min_b")
271        mina = get_attrib(block, "min_a")
272
273        maxr = get_attrib(block, "max_r")
274        maxg = get_attrib(block, "max_g")
275        maxb = get_attrib(block, "max_b")
276        maxa = get_attrib(block, "max_a")
277
278        et = get_attrib(block, "tune_error_threshold")
279
280        blockStruct = Block(px, py, pz, et)
281        blockStruct.add_minimums(minr, ming, minb, mina)
282        blockStruct.add_maximums(maxr, maxg, maxb, maxa)
283        dbStruct.add_block(blockStruct)
284
285        for pas in get_node(block, "pass", True):
286            # Don't copy across passes we skipped due to heuristics
287            skipped = get_attrib(pas, "skip", False, False)
288            if skipped:
289                continue
290
291            prts = get_attrib(pas, "partition_count")
292            prti = get_attrib(pas, "partition_index", False, False)
293            plns = get_attrib(pas, "plane_count")
294            chan = get_attrib(pas, "plane_component", False, plns > 2)
295            mode = get_attrib(pas, "search_mode", False, False)
296            ehit = get_attrib(pas, "exit", False, False) == "quality hit"
297
298            passStruct = Pass(prts, prti, plns, ehit, mode, chan)
299            blockStruct.add_pass(passStruct)
300
301            # Constant color blocks don't have any candidates
302            if prts == 0:
303                continue
304
305            for candidate in get_node(pas, "candidate", True):
306                # Don't copy across candidates we couldn't encode
307                failed = get_attrib(candidate, "failed", False, False)
308                if failed:
309                    continue
310
311                wx = get_attrib(candidate, "weight_x")
312                wy = get_attrib(candidate, "weight_y")
313                wz = get_attrib(candidate, "weight_z")
314                wq = QUANT_TABLE[get_attrib(candidate, "weight_quant")]
315                epre = get_attrib(candidate, "error_prerealign", True, False)
316                epst = get_attrib(candidate, "error_postrealign", True, False)
317
318                candStruct = Candidate(wx, wy, wz, wq)
319                passStruct.add_candidate(candStruct)
320                for value in epre:
321                    candStruct.add_refinement(value)
322                for value in epst:
323                    candStruct.add_refinement(value)
324
325    return dbStruct
326
327
328def filter_database(data):
329
330    for block in data:
331        best_pass, best_candidate = find_best_pass_and_candidate(block)
332
333        for i, pas in rev_enumerate(block):
334            if pas != best_pass:
335                del block[i]
336                continue
337
338            if best_candidate is None:
339                continue
340
341            for j, candidate in rev_enumerate(pas):
342                if candidate != best_candidate:
343                    del pas[j]
344
345
346def generate_pass_statistics(data):
347    pass
348
349
350def generate_feature_statistics(data):
351    # -------------------------------------------------------------------------
352    # Config
353    print("Compressor Config")
354    print("=================")
355
356    if data.block_z > 1:
357        dat = (data.block_x, data.block_y, data.block_z)
358        print("  - Block size: %ux%ux%u" % dat)
359    else:
360        dat = (data.block_x, data.block_y)
361        print("  - Block size: %ux%u" % dat)
362
363    print("")
364
365    # -------------------------------------------------------------------------
366    # Block metrics
367    result = ddict(int)
368
369    RANGE_QUANT = 16
370
371    for block in foreach_block(data):
372        ranges = []
373        for i in range(0, 4):
374            ranges.append(block.ldr_max[i] - block.ldr_min[i])
375
376        max_range = max(ranges)
377        max_range = int(max_range / RANGE_QUANT) * RANGE_QUANT
378
379        result[max_range] += 1
380
381    print("Channel Range")
382    print("=============")
383    keys = sorted(result.keys())
384    for key in keys:
385        dat = (key, key + RANGE_QUANT - 1, result[key])
386        print("  - %3u-%3u dynamic range = %6u blocks" % dat)
387
388    print("")
389
390    # -------------------------------------------------------------------------
391    # Partition usage
392    result_totals = ddict(int)
393    results = ddict(lambda: ddict(int))
394
395    for _, pas in foreach_pass(data):
396        result_totals[pas.partitions] += 1
397        results[pas.partitions][pas.partition_index] += 1
398
399    print("Partition Count")
400    print("===============")
401    keys = sorted(result_totals.keys())
402    for key in keys:
403        dat = (key, result_totals[key], len(results[key]))
404        print("  - %u partition(s) = %6u blocks / %4u indicies" % dat)
405
406    print("")
407
408    # -------------------------------------------------------------------------
409    # Plane usage
410    result_count = ddict(lambda: ddict(int))
411    result_channel = ddict(lambda: ddict(int))
412
413    for _, pas in foreach_pass(data):
414        result_count[pas.partitions][pas.planes] += 1
415        if (pas.planes > 1):
416            result_channel[pas.partitions][pas.plane2_component] += 1
417
418    print("Plane Usage")
419    print("===========")
420    keys = sorted(result_count.keys())
421    for key in keys:
422        keys2 = sorted(result_count[key])
423        for key2 in keys2:
424            val2 = result_count[key][key2]
425            dat = (key, key2, val2)
426            print("  - %u partition(s) %u plane(s) = %6u blocks" % dat)
427            if key2 == 2:
428                keys3 = sorted(result_channel[key])
429                for key3 in keys3:
430                    dat = (CHANNEL_TABLE[key3], result_channel[key][key3])
431                    print("    - %s plane                 = %6u blocks" % dat)
432
433    print("")
434
435    # -------------------------------------------------------------------------
436    # Decimation usage
437    decim_count = ddict(lambda: ddict(int))
438    quant_count = ddict(lambda: ddict(lambda: ddict(int)))
439
440
441    MERGE_ROTATIONS = True
442
443    for _, pas, can in foreach_candidate(data):
444        # Skip constant color blocks
445        if can is None:
446            continue
447
448        wx = can.weight_x
449        wy = can.weight_y
450
451        if MERGE_ROTATIONS and wx < wy:
452            wx, wy = wy, wx
453
454        decim_count[wx][wy] += 1
455        quant_count[wx][wy][can.weight_quant] += 1
456
457    print("Decimation Usage")
458    print("================")
459
460    if MERGE_ROTATIONS:
461        print("  - Note: data merging grid rotations")
462
463    x_keys = sorted(decim_count.keys())
464    for x_key in x_keys:
465        y_keys = sorted(decim_count[x_key])
466
467        for y_key in y_keys:
468            count = decim_count[x_key][y_key]
469            dat = (x_key, y_key, count)
470            print("  - %ux%u weights      = %6u blocks" % dat)
471
472            q_keys = sorted(quant_count[x_key][y_key])
473            for q_key in q_keys:
474                count = quant_count[x_key][y_key][q_key]
475                dat = (q_key, count)
476                print("    - %2u quant range = %6u blocks" % dat)
477
478    print("")
479
480    # -------------------------------------------------------------------------
481    # Refinement usage
482
483    total_count = 0
484    better_count = 0
485    could_have_count = 0
486    success_count = 0
487
488    refinement_step = []
489
490    for block, pas, candidate in foreach_candidate(data):
491        # Ignore zero partition blocks - they don't use refinement
492        if not candidate:
493            continue
494
495        target_error = block.error_target
496        start_error = candidate.refinement_errors[0]
497        end_error = candidate.refinement_errors[-1]
498
499        rpf = float(start_error - end_error) / float(len(candidate.refinement_errors))
500        rpf = abs(rpf)
501        refinement_step.append(rpf / start_error)
502
503        total_count += 1
504        if end_error <= start_error:
505            better_count += 1
506
507        if end_error <= target_error:
508            success_count += 1
509        else:
510            for refinement in candidate.refinement_errors:
511                if refinement <= target_error:
512                    could_have_count += 1
513                    break
514
515
516    print("Refinement Usage")
517    print("================")
518    print("  - %u refinements(s)" % total_count)
519    print("  - %u refinements(s) improved" % better_count)
520    print("  - %u refinements(s) worsened" % (total_count - better_count))
521    print("  - %u refinements(s) could hit target, but didn't" % could_have_count)
522    print("  - %u refinements(s) hit target" % success_count)
523    print("  - %f mean step improvement" % np.mean(refinement_step))
524
525
526def parse_command_line():
527    """
528    Parse the command line.
529
530    Returns:
531        Namespace: The parsed command line container.
532    """
533    parser = argparse.ArgumentParser()
534
535    parser.add_argument("trace", type=argparse.FileType("r"),
536                        help="The trace file to analyze")
537
538    return parser.parse_args()
539
540
541def main():
542    """
543    The main function.
544
545    Returns:
546        int: The process return code.
547    """
548    args = parse_command_line()
549
550    data = json.load(args.trace)
551    db = generate_database(data)
552    filter_database(db)
553
554    generate_feature_statistics(db)
555
556    return 0
557
558
559if __name__ == "__main__":
560    sys.exit(main())
561