• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: Apache-2.0
2# -----------------------------------------------------------------------------
3# Copyright 2020 Arm Limited
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy
7# of the License at:
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations
15# under the License.
16# -----------------------------------------------------------------------------
17"""
18A ResultSet stores a set of results about the performance of a TestSet. Each
19set keeps result Records for each image and block size tested, that store the
20PSNR and coding time.
21
22ResultSets are often backed by a CSV file on disk, and a ResultSet can be
23compared against a set of reference results created by an earlier test run.
24"""
25
26
27import csv
28import enum
29import numpy as np
30import os
31
32
33@enum.unique
34class Result(enum.IntEnum):
35    """
36    An enumeration of test result status values.
37
38    Attributes:
39        NOTRUN: The test has not been run.
40        PASS: The test passed.
41        WARN: The test image quality was below the pass threshold but above
42            the fail threshold.
43        FAIL: The test image quality was below the fail threshold.
44    """
45    NOTRUN = 0
46    PASS = 1
47    WARN = 2
48    FAIL = 3
49
50
51class ResultSummary():
52    """
53    An result summary data container, storing number of results of each type.
54
55    Attributes:
56        notruns: The number of tests that did not run.
57        passes: The number of tests that passed.
58        warnings: The number of tests that produced a warning.
59        fails: The number of tests that failed.
60        tTimes: Total time speedup vs reference (<1 is slower, >1 is faster).
61        cTimes: Coding time speedup vs reference (<1 is slower, >1 is faster).
62        psnrs: Coding time quality vs reference (<0 is worse, >0 is better).
63    """
64
65    def __init__(self):
66        """
67        Create a new result summary.
68        """
69        # Pass fail metrics
70        self.notruns = 0
71        self.passes = 0
72        self.warnings = 0
73        self.fails = 0
74
75        # Relative results
76        self.tTimesRel = []
77        self.cTimesRel = []
78        self.psnrRel = []
79
80        # Absolute results
81        self.cTime = []
82        self.psnr = []
83
84    def add_record(self, record):
85        """
86        Add a record to this summary.
87
88        Args:
89            record (Record): The Record to add.
90        """
91        if record.status == Result.PASS:
92            self.passes += 1
93        elif record.status == Result.WARN:
94            self.warnings += 1
95        elif record.status == Result.FAIL:
96            self.fails += 1
97        else:
98            self.notruns += 1
99
100        if record.tTimeRel is not None:
101            self.tTimesRel.append(record.tTimeRel)
102            self.cTimesRel.append(record.cTimeRel)
103            self.psnrRel.append(record.psnrRel)
104
105            self.cTime.append(record.cTime)
106            self.psnr.append(record.psnr)
107
108    def get_worst_result(self):
109        """
110        Get the worst result in this set.
111
112        Returns:
113            Result: The worst test result.
114        """
115        if self.fails:
116            return Result.FAIL
117
118        if self.warnings:
119            return Result.WARN
120
121        if self.passes:
122            return Result.PASS
123
124        return Result.NOTRUN
125
126    def __str__(self):
127        # Overall pass/fail results
128        overall = self.get_worst_result().name
129        dat = (overall, self.passes, self.warnings, self.fails)
130        result = ["\nSet Status: %s (Pass: %u | Warn: %u | Fail: %u)" % dat]
131
132        if (self.tTimesRel):
133            # Performance summaries
134            dat = (np.mean(self.tTimesRel), np.std(self.tTimesRel))
135            result.append("\nTotal speed:   Mean:  %+0.3f x   Std: %0.2f x" % dat)
136
137            dat = (np.mean(self.cTimesRel), np.std(self.cTimesRel))
138            result.append("Coding speed:  Mean:  %+0.3f x   Std: %0.2f x" % dat)
139
140            dat = (np.mean(self.psnrRel), np.std(self.psnrRel))
141            result.append("Quality diff:  Mean:  %+0.3f dB  Std: %0.2f dB" % dat)
142
143            dat = (np.mean(self.cTime), np.std(self.cTime))
144            result.append("Coding time:   Mean:  %+0.3f s   Std: %0.2f s" % dat)
145
146            dat = (np.mean(self.psnr), np.std(self.psnr))
147            result.append("Quality:       Mean: %+0.3f dB  Std: %0.2f dB" % dat)
148
149        return "\n".join(result)
150
151
152class Record():
153    """
154    A single result record, sotring results for a singel image and block size.
155
156    Attributes:
157        blkSz: The block size.
158        name: The test image name.
159        psnr: The image quality (PSNR dB)
160        tTime: The total compression time.
161        cTime: The coding compression time.
162        cRate: The coding compression rate.
163        status: The test Result.
164    """
165
166    def __init__(self, blkSz, name, psnr, tTime, cTime, cRate):
167        """
168        Create a result record, initially in the NOTRUN status.
169
170        Args:
171            blkSz (str): The block size.
172            name (str): The test image name.
173            psnr (float): The image quality PSNR, in dB.
174            tTime (float): The total compression time, in seconds.
175            cTime (float): The coding compression time, in seconds.
176            cRate (float): The coding compression rate, in MPix/s.
177            tTimeRel (float): The relative total time speedup vs reference.
178            cTimeRel (float): The relative coding time speedup vs reference.
179            cRateRel (float): The relative rate speedup vs reference.
180            psnrRel (float): The relative PSNR dB vs reference.
181        """
182        self.blkSz = blkSz
183        self.name = name
184        self.psnr = psnr
185        self.tTime = tTime
186        self.cTime = cTime
187        self.cRate = cRate
188        self.status = Result.NOTRUN
189
190        self.tTimeRel = None
191        self.cTimeRel = None
192        self.cRateRel = None
193        self.psnrRel = None
194
195    def set_status(self, result):
196        """
197        Set the result status.
198
199        Args:
200            result (Result): The test result.
201        """
202        self.status = result
203
204    def __str__(self):
205        return "'%s' / '%s'" % (self.blkSz, self.name)
206
207
208class ResultSet():
209    """
210    A set of results for a TestSet, across one or more block sizes.
211
212    Attributes:
213        testSet: The test set these results are linked to.
214        records: The list of test results.
215    """
216
217    def __init__(self, testSet):
218        """
219        Create a new empty ResultSet.
220
221        Args:
222            testSet (TestSet): The test set these results are linked to.
223        """
224        self.testSet = testSet
225        self.records = []
226
227    def add_record(self, record):
228        """
229        Add a new test record to this result set.
230
231        Args:
232            record (Record): The test record to add.
233        """
234        self.records.append(record)
235
236    def get_record(self, testSet, blkSz, name):
237        """
238        Get a record matching the arguments.
239
240        Args:
241            testSet (TestSet): The test set to get results from.
242            blkSz (str): The block size.
243            name (str): The test name.
244
245        Returns:
246            Record: The test result, if present.
247
248        Raises:
249            KeyError: No match could be found.
250        """
251        if testSet != self.testSet:
252            raise KeyError()
253
254        for record in self.records:
255            if record.blkSz == blkSz and record.name == name:
256                return record
257
258        raise KeyError()
259
260    def get_matching_record(self, other):
261        """
262        Get a record matching the config of another record.
263
264        Args:
265            other (Record): The pattern result record to match.
266
267        Returns:
268            Record: The result, if present.
269
270        Raises:
271            KeyError: No match could be found.
272        """
273        for record in self.records:
274            if record.blkSz == other.blkSz and record.name == other.name:
275                return record
276
277        raise KeyError()
278
279    def get_results_summary(self):
280        """
281        Get a results summary of all the records in this result set.
282
283        Returns:
284            ResultSummary: The result summary.
285        """
286        summary = ResultSummary()
287        for record in self.records:
288            summary.add_record(record)
289
290        return summary
291
292    def save_to_file(self, filePath):
293        """
294        Save this result set to a CSV file.
295
296        Args:
297            filePath (str): The output file path.
298        """
299        dirName = os.path.dirname(filePath)
300        if not os.path.exists(dirName):
301            os.makedirs(dirName)
302
303        with open(filePath, "w") as csvfile:
304            writer = csv.writer(csvfile)
305            self._save_header(writer)
306            for record in self.records:
307                self._save_record(writer, record)
308
309    @staticmethod
310    def _save_header(writer):
311        """
312        Write the header to the CSV file.
313
314        Args:
315            writer (csv.writer): The CSV writer.
316        """
317        row = ["Image Set", "Block Size", "Name",
318               "PSNR", "Total Time", "Coding Time", "Coding Rate"]
319        writer.writerow(row)
320
321    def _save_record(self, writer, record):
322        """
323        Write a record to the CSV file.
324
325        Args:
326            writer (csv.writer): The CSV writer.
327            record (Record): The record to write.
328        """
329        row = [self.testSet,
330               record.blkSz,
331               record.name,
332               "%0.4f" % record.psnr,
333               "%0.4f" % record.tTime,
334               "%0.4f" % record.cTime,
335               "%0.4f" % record.cRate]
336        writer.writerow(row)
337
338    def load_from_file(self, filePath):
339        """
340        Load a reference result set from a CSV file on disk.
341
342        Args:
343            filePath (str): The input file path.
344        """
345        with open(filePath, "r") as csvfile:
346            reader = csv.reader(csvfile)
347            # Skip the header
348            next(reader)
349            for row in reader:
350                assert row[0] == self.testSet
351                record = Record(row[1], row[2],
352                                float(row[3]), float(row[4]),
353                                float(row[5]), float(row[6]))
354                self.add_record(record)
355