• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# -----------------------------------------------------------------------------
4# Copyright 2022 Arm Limited
5#
6# Licensed under the Apache License, Version 2.0 (the "License"); you may not
7# use this file except in compliance with the License. You may obtain a copy
8# of the License at:
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15# License for the specific language governing permissions and limitations
16# under the License.
17# -----------------------------------------------------------------------------
18"""
19This script is a simple test result plotter for sweeps on multiple compressors.
20"""
21import csv
22import numpy as np
23import matplotlib.pyplot as plt
24import sys
25
26DATABASE = "competitive.csv"
27
28
29class Series:
30
31    def __init__(self, name, perf, qual):
32        self.name = name
33        self.perf = perf
34        self.qual = qual
35
36
37def get_series(database, compressor, quality, block_size):
38    title = f"{compressor} {quality} {block_size}"
39    in_section = False
40
41    perf = []
42    qual = []
43
44    with open(database) as csvfile:
45        reader = csv.reader(csvfile)
46        for row in reader:
47            if len(row) == 1:
48                in_section = row[0] == title
49                continue
50
51            if in_section:
52                perf.append(float(row[2]))
53                qual.append(float(row[3]))
54
55    return (perf, qual)
56
57
58def plot(block_size, series_set):
59
60    for series in series_set:
61        plt.scatter(series.perf, series.qual, s=2, label=series.name)
62
63    plt.xlabel("Speed (MT/s)")
64    plt.ylabel("PSNR dB")
65    plt.legend(loc='lower right', prop={'size': 6})
66
67    plt.tight_layout()
68    plt.savefig(f"ASTC_v_ISPC_{block_size}.png")
69    plt.clf()
70
71
72def plot_diff(series_a, series_b):
73
74    diff_perf = np.divide(series_a.perf, series_b.perf)
75    diff_qual = np.subtract(series_a.qual, series_b.qual)
76    label = f"{series_a.name} vs {series_b.name}"
77
78    plt.scatter(diff_perf, diff_qual, s=2, c="#0091BD", label=label)
79    plt.scatter(np.mean(diff_perf), np.mean(diff_qual), s=10, c="#FFA500", marker="*")
80
81    plt.axhline(y=0, color="r", linestyle="dotted", lw=0.5)
82    plt.axvline(x=1, color="r", linestyle="dotted", lw=0.5)
83
84    plt.xlabel("Relative speed")
85    plt.ylabel("PSNR diff (dB)")
86    plt.legend(loc='lower right', prop={'size': 6})
87
88    plt.tight_layout()
89    file_name = label.replace(" ", "_") + ".png"
90    plt.savefig(file_name)
91    plt.clf()
92
93
94def main():
95
96    block_sizes = ["4x4", "6x6", "8x8"]
97
98    for block_size in block_sizes:
99        series_set = []
100
101        perf, qual = get_series(DATABASE, "ISPC", "rgba", block_size)
102        series_set.append(Series(f"{block_size} IPSC Slow", perf, qual))
103
104        perf, qual = get_series(DATABASE, "ISPC", "rgb", block_size)
105        series_set.append(Series(f"{block_size} IPSC Fast", perf, qual))
106
107        perf, qual = get_series(DATABASE, "ASTC", "60", block_size)
108        series_set.append(Series(f"{block_size} ASTC 60", perf, qual))
109
110        perf, qual = get_series(DATABASE, "ASTC", "50", block_size)
111        series_set.append(Series(f"{block_size} ASTC 50", perf, qual))
112
113        perf, qual = get_series(DATABASE, "ASTC", "10", block_size)
114        series_set.append(Series(f"{block_size} ASTC 10", perf, qual))
115
116        perf, qual = get_series(DATABASE, "ASTC", "8", block_size)
117        series_set.append(Series(f"{block_size} ASTC 8", perf, qual))
118
119        plot(block_size, series_set)
120
121        plot_diff(series_set[3], series_set[0])
122
123    return 0
124
125
126if __name__ == "__main__":
127    sys.exit(main())
128