1#!/usr/bin/env python3 2# SPDX-License-Identifier: Apache-2.0 3# ----------------------------------------------------------------------------- 4# Copyright 2021 Arm Limited 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); you may not 7# use this file except in compliance with the License. You may obtain a copy 8# of the License at: 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 14# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 15# License for the specific language governing permissions and limitations 16# under the License. 17# ----------------------------------------------------------------------------- 18""" 19The ``astc_trace_analysis`` utility provides a tool to analyze trace files. 20 21WARNING: Trace files are an engineering tool, and not part of the standard 22product, so traces and their associated tools are volatile and may change 23significantly without notice. 24""" 25 26import argparse 27from collections import defaultdict as ddict 28import json 29import numpy as np 30import sys 31 32QUANT_TABLE = { 33 0: 2, 34 1: 3, 35 2: 4, 36 3: 5, 37 4: 6, 38 5: 8, 39 6: 10, 40 7: 12, 41 8: 16, 42 9: 20, 43 10: 24, 44 11: 32 45} 46 47CHANNEL_TABLE = { 48 0: "R", 49 1: "G", 50 2: "B", 51 3: "A" 52} 53 54class Trace: 55 56 def __init__(self, block_x, block_y, block_z): 57 self.block_x = block_x 58 self.block_y = block_y 59 self.block_z = block_z 60 self.blocks = [] 61 62 def add_block(self, block): 63 self.blocks.append(block) 64 65 def __getitem__(self, i): 66 return self.blocks[i] 67 68 def __delitem__(self, i): 69 del self.blocks[i] 70 71 def __len__(self): 72 return len(self.blocks) 73 74class Block: 75 76 def __init__(self, pos_x, pos_y, pos_z, error_target): 77 self.pos_x = pos_x 78 self.pos_y = pos_y 79 self.pos_z = pos_z 80 81 self.raw_min = None 82 self.raw_max = None 83 84 self.ldr_min = None 85 self.ldr_max = None 86 87 self.error_target = error_target 88 self.passes = [] 89 self.qualityHit = None 90 91 def add_minimums(self, r, g, b, a): 92 self.raw_min = (r, g, b, a) 93 94 def ldr(x): 95 cmax = 65535.0 96 return int((r / cmax) * 255.0) 97 98 self.ldr_min = (ldr(r), ldr(g), ldr(b), ldr(a)) 99 100 def add_maximums(self, r, g, b, a): 101 self.raw_max = (r, g, b, a) 102 103 def ldr(x): 104 cmax = 65535.0 105 return int((r / cmax) * 255.0) 106 107 self.ldr_max = (ldr(r), ldr(g), ldr(b), ldr(a)) 108 109 def add_pass(self, pas): 110 self.passes.append(pas) 111 112 def __getitem__(self, i): 113 return self.passes[i] 114 115 def __delitem__(self, i): 116 del self.passes[i] 117 118 def __len__(self): 119 return len(self.passes) 120 121 122class Pass: 123 124 def __init__(self, partitions, partition, planes, target_hit, mode, component): 125 self.partitions = partitions 126 self.partition_index = 0 if partition is None else partition 127 self.planes = planes 128 self.plane2_component = component 129 self.target_hit = target_hit 130 self.search_mode = mode 131 self.candidates = [] 132 133 def add_candidate(self, candidate): 134 self.candidates.append(candidate) 135 136 def __getitem__(self, i): 137 return self.candidates[i] 138 139 def __delitem__(self, i): 140 del self.candidates[i] 141 142 def __len__(self): 143 return len(self.candidates) 144 145 146class Candidate: 147 148 def __init__(self, weight_x, weight_y, weight_z, weight_quant): 149 self.weight_x = weight_x 150 self.weight_y = weight_y 151 self.weight_z = weight_z 152 self.weight_quant = weight_quant 153 self.refinement_errors = [] 154 155 def add_refinement(self, errorval): 156 self.refinement_errors.append(errorval) 157 158 159def get_attrib(data, name, multiple=False, hard_fail=True): 160 results = [] 161 for attrib in data: 162 if len(attrib) == 2 and attrib[0] == name: 163 results.append(attrib[1]) 164 165 if not results: 166 if hard_fail: 167 print(json.dumps(data, indent=2)) 168 assert False, "Attribute %s not found" % name 169 if multiple: 170 return list() 171 return None 172 173 if not multiple: 174 if len(results) > 1: 175 print(json.dumps(data, indent=2)) 176 assert False, "Attribute %s found %u times" % (name, len(results)) 177 return results[0] 178 179 return results 180 181 182def rev_enumerate(seq): 183 return zip(reversed(range(len(seq))), reversed(seq)) 184 185def foreach_block(data): 186 187 for block in data: 188 yield block 189 190def foreach_pass(data): 191 192 for block in data: 193 for pas in block: 194 yield (block, pas) 195 196def foreach_candidate(data): 197 198 for block in data: 199 for pas in block: 200 # Special case - None candidates for 0 partition 201 if not len(pas): 202 yield (block, pas, None) 203 204 for candidate in pas: 205 yield (block, pas, candidate) 206 207def get_node(data, name, multiple=False, hard_fail=True): 208 results = [] 209 for attrib in data: 210 if len(attrib) == 3 and attrib[0] == "node" and attrib[1] == name: 211 results.append(attrib[2]) 212 213 if not results: 214 if hard_fail: 215 print(json.dumps(data, indent=2)) 216 assert False, "Node %s not found" % name 217 return None 218 219 if not multiple: 220 if len(results) > 1: 221 print(json.dumps(data, indent=2)) 222 assert False, "Node %s found %u times" % (name, len(results)) 223 return results[0] 224 225 return results 226 227 228def find_best_pass_and_candidate(block): 229 explicit_pass = None 230 231 best_error = 1e30 232 best_pass = None 233 best_candidate = None 234 235 for pas in block: 236 # Special case for constant color blocks - no trial candidates 237 if pas.target_hit and pas.partitions == 0: 238 return (pas, None) 239 240 for candidate in pas: 241 errorval = candidate.refinement_errors[-1] 242 if errorval <= best_error: 243 best_error = errorval 244 best_pass = pas 245 best_candidate = candidate 246 247 # Every other return type must have both best pass and best candidate 248 assert (best_pass and best_candidate) 249 return (best_pass, best_candidate) 250 251 252def generate_database(data): 253 # Skip header 254 assert(data[0] == "node") 255 assert(data[1] == "root") 256 data = data[2] 257 258 bx = get_attrib(data, "block_x") 259 by = get_attrib(data, "block_y") 260 bz = get_attrib(data, "block_z") 261 dbStruct = Trace(bx, by, bz) 262 263 for block in get_node(data, "block", True): 264 px = get_attrib(block, "pos_x") 265 py = get_attrib(block, "pos_y") 266 pz = get_attrib(block, "pos_z") 267 268 minr = get_attrib(block, "min_r") 269 ming = get_attrib(block, "min_g") 270 minb = get_attrib(block, "min_b") 271 mina = get_attrib(block, "min_a") 272 273 maxr = get_attrib(block, "max_r") 274 maxg = get_attrib(block, "max_g") 275 maxb = get_attrib(block, "max_b") 276 maxa = get_attrib(block, "max_a") 277 278 et = get_attrib(block, "tune_error_threshold") 279 280 blockStruct = Block(px, py, pz, et) 281 blockStruct.add_minimums(minr, ming, minb, mina) 282 blockStruct.add_maximums(maxr, maxg, maxb, maxa) 283 dbStruct.add_block(blockStruct) 284 285 for pas in get_node(block, "pass", True): 286 # Don't copy across passes we skipped due to heuristics 287 skipped = get_attrib(pas, "skip", False, False) 288 if skipped: 289 continue 290 291 prts = get_attrib(pas, "partition_count") 292 prti = get_attrib(pas, "partition_index", False, False) 293 plns = get_attrib(pas, "plane_count") 294 chan = get_attrib(pas, "plane_component", False, plns > 2) 295 mode = get_attrib(pas, "search_mode", False, False) 296 ehit = get_attrib(pas, "exit", False, False) == "quality hit" 297 298 passStruct = Pass(prts, prti, plns, ehit, mode, chan) 299 blockStruct.add_pass(passStruct) 300 301 # Constant color blocks don't have any candidates 302 if prts == 0: 303 continue 304 305 for candidate in get_node(pas, "candidate", True): 306 # Don't copy across candidates we couldn't encode 307 failed = get_attrib(candidate, "failed", False, False) 308 if failed: 309 continue 310 311 wx = get_attrib(candidate, "weight_x") 312 wy = get_attrib(candidate, "weight_y") 313 wz = get_attrib(candidate, "weight_z") 314 wq = QUANT_TABLE[get_attrib(candidate, "weight_quant")] 315 epre = get_attrib(candidate, "error_prerealign", True, False) 316 epst = get_attrib(candidate, "error_postrealign", True, False) 317 318 candStruct = Candidate(wx, wy, wz, wq) 319 passStruct.add_candidate(candStruct) 320 for value in epre: 321 candStruct.add_refinement(value) 322 for value in epst: 323 candStruct.add_refinement(value) 324 325 return dbStruct 326 327 328def filter_database(data): 329 330 for block in data: 331 best_pass, best_candidate = find_best_pass_and_candidate(block) 332 333 for i, pas in rev_enumerate(block): 334 if pas != best_pass: 335 del block[i] 336 continue 337 338 if best_candidate is None: 339 continue 340 341 for j, candidate in rev_enumerate(pas): 342 if candidate != best_candidate: 343 del pas[j] 344 345 346def generate_pass_statistics(data): 347 pass 348 349 350def generate_feature_statistics(data): 351 # ------------------------------------------------------------------------- 352 # Config 353 print("Compressor Config") 354 print("=================") 355 356 if data.block_z > 1: 357 dat = (data.block_x, data.block_y, data.block_z) 358 print(" - Block size: %ux%ux%u" % dat) 359 else: 360 dat = (data.block_x, data.block_y) 361 print(" - Block size: %ux%u" % dat) 362 363 print("") 364 365 # ------------------------------------------------------------------------- 366 # Block metrics 367 result = ddict(int) 368 369 RANGE_QUANT = 16 370 371 for block in foreach_block(data): 372 ranges = [] 373 for i in range(0, 4): 374 ranges.append(block.ldr_max[i] - block.ldr_min[i]) 375 376 max_range = max(ranges) 377 max_range = int(max_range / RANGE_QUANT) * RANGE_QUANT 378 379 result[max_range] += 1 380 381 print("Channel Range") 382 print("=============") 383 keys = sorted(result.keys()) 384 for key in keys: 385 dat = (key, key + RANGE_QUANT - 1, result[key]) 386 print(" - %3u-%3u dynamic range = %6u blocks" % dat) 387 388 print("") 389 390 # ------------------------------------------------------------------------- 391 # Partition usage 392 result_totals = ddict(int) 393 results = ddict(lambda: ddict(int)) 394 395 for _, pas in foreach_pass(data): 396 result_totals[pas.partitions] += 1 397 results[pas.partitions][pas.partition_index] += 1 398 399 print("Partition Count") 400 print("===============") 401 keys = sorted(result_totals.keys()) 402 for key in keys: 403 dat = (key, result_totals[key], len(results[key])) 404 print(" - %u partition(s) = %6u blocks / %4u indicies" % dat) 405 406 print("") 407 408 # ------------------------------------------------------------------------- 409 # Plane usage 410 result_count = ddict(lambda: ddict(int)) 411 result_channel = ddict(lambda: ddict(int)) 412 413 for _, pas in foreach_pass(data): 414 result_count[pas.partitions][pas.planes] += 1 415 if (pas.planes > 1): 416 result_channel[pas.partitions][pas.plane2_component] += 1 417 418 print("Plane Usage") 419 print("===========") 420 keys = sorted(result_count.keys()) 421 for key in keys: 422 keys2 = sorted(result_count[key]) 423 for key2 in keys2: 424 val2 = result_count[key][key2] 425 dat = (key, key2, val2) 426 print(" - %u partition(s) %u plane(s) = %6u blocks" % dat) 427 if key2 == 2: 428 keys3 = sorted(result_channel[key]) 429 for key3 in keys3: 430 dat = (CHANNEL_TABLE[key3], result_channel[key][key3]) 431 print(" - %s plane = %6u blocks" % dat) 432 433 print("") 434 435 # ------------------------------------------------------------------------- 436 # Decimation usage 437 decim_count = ddict(lambda: ddict(int)) 438 quant_count = ddict(lambda: ddict(lambda: ddict(int))) 439 440 441 MERGE_ROTATIONS = True 442 443 for _, pas, can in foreach_candidate(data): 444 # Skip constant color blocks 445 if can is None: 446 continue 447 448 wx = can.weight_x 449 wy = can.weight_y 450 451 if MERGE_ROTATIONS and wx < wy: 452 wx, wy = wy, wx 453 454 decim_count[wx][wy] += 1 455 quant_count[wx][wy][can.weight_quant] += 1 456 457 print("Decimation Usage") 458 print("================") 459 460 if MERGE_ROTATIONS: 461 print(" - Note: data merging grid rotations") 462 463 x_keys = sorted(decim_count.keys()) 464 for x_key in x_keys: 465 y_keys = sorted(decim_count[x_key]) 466 467 for y_key in y_keys: 468 count = decim_count[x_key][y_key] 469 dat = (x_key, y_key, count) 470 print(" - %ux%u weights = %6u blocks" % dat) 471 472 q_keys = sorted(quant_count[x_key][y_key]) 473 for q_key in q_keys: 474 count = quant_count[x_key][y_key][q_key] 475 dat = (q_key, count) 476 print(" - %2u quant range = %6u blocks" % dat) 477 478 print("") 479 480 # ------------------------------------------------------------------------- 481 # Refinement usage 482 483 total_count = 0 484 better_count = 0 485 could_have_count = 0 486 success_count = 0 487 488 refinement_step = [] 489 490 for block, pas, candidate in foreach_candidate(data): 491 # Ignore zero partition blocks - they don't use refinement 492 if not candidate: 493 continue 494 495 target_error = block.error_target 496 start_error = candidate.refinement_errors[0] 497 end_error = candidate.refinement_errors[-1] 498 499 rpf = float(start_error - end_error) / float(len(candidate.refinement_errors)) 500 rpf = abs(rpf) 501 refinement_step.append(rpf / start_error) 502 503 total_count += 1 504 if end_error <= start_error: 505 better_count += 1 506 507 if end_error <= target_error: 508 success_count += 1 509 else: 510 for refinement in candidate.refinement_errors: 511 if refinement <= target_error: 512 could_have_count += 1 513 break 514 515 516 print("Refinement Usage") 517 print("================") 518 print(" - %u refinements(s)" % total_count) 519 print(" - %u refinements(s) improved" % better_count) 520 print(" - %u refinements(s) worsened" % (total_count - better_count)) 521 print(" - %u refinements(s) could hit target, but didn't" % could_have_count) 522 print(" - %u refinements(s) hit target" % success_count) 523 print(" - %f mean step improvement" % np.mean(refinement_step)) 524 525 526def parse_command_line(): 527 """ 528 Parse the command line. 529 530 Returns: 531 Namespace: The parsed command line container. 532 """ 533 parser = argparse.ArgumentParser() 534 535 parser.add_argument("trace", type=argparse.FileType("r"), 536 help="The trace file to analyze") 537 538 return parser.parse_args() 539 540 541def main(): 542 """ 543 The main function. 544 545 Returns: 546 int: The process return code. 547 """ 548 args = parse_command_line() 549 550 data = json.load(args.trace) 551 db = generate_database(data) 552 filter_database(db) 553 554 generate_feature_statistics(db) 555 556 return 0 557 558 559if __name__ == "__main__": 560 sys.exit(main()) 561