• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import argparse
31import os
32from uuid import UUID
33from collections import OrderedDict
34import pickle
35
36
37import torch
38import numpy as np
39
40import utils
41
42
43
44parser = argparse.ArgumentParser()
45parser.add_argument("input", type=str, help="input folder containing multi-run output")
46parser.add_argument("tag", type=str, help="tag for multi-run experiment")
47parser.add_argument("csv", type=str, help="name for output csv")
48
49
50def is_uuid(val):
51    try:
52        UUID(val)
53        return True
54    except:
55        return False
56
57
58def collect_results(folder):
59
60    training_folder = os.path.join(folder, 'training')
61    testing_folder  = os.path.join(folder, 'testing')
62
63    # validation loss
64    checkpoint = torch.load(os.path.join(training_folder, 'checkpoints', 'checkpoint_finalize_epoch_1.pth'), map_location='cpu')
65    validation_loss = checkpoint['validation_loss']
66
67    # eval_warpq
68    eval_warpq = utils.data.parse_warpq_scores(os.path.join(training_folder, 'out_finalize.txt'))[-1]
69
70    # testing results
71    testing_results = utils.data.collect_test_stats(os.path.join(testing_folder, 'final'))
72
73    results = OrderedDict()
74    results['eval_loss']          = validation_loss
75    results['eval_warpq']         = eval_warpq
76    results['pesq_mean']          = testing_results['pesq'][0]
77    results['warpq_mean']         = testing_results['warpq'][0]
78    results['pitch_error_mean']   = testing_results['pitch_error'][0]
79    results['voicing_error_mean'] = testing_results['voicing_error'][0]
80
81    return results
82
83def print_csv(path, results, tag, ranks=None, header=True):
84
85    metrics = next(iter(results.values())).keys()
86    if ranks is not None:
87        rank_keys = next(iter(ranks.values())).keys()
88    else:
89        rank_keys = []
90
91    with open(path, 'w') as f:
92        if header:
93            f.write("uuid, tag")
94
95            for metric in metrics:
96                f.write(f", {metric}")
97
98            for rank in rank_keys:
99                f.write(f", {rank}")
100
101            f.write("\n")
102
103
104        for uuid, values in results.items():
105            f.write(f"{uuid}, {tag}")
106
107            for val in values.values():
108                f.write(f", {val:10.8f}")
109
110            for rank in rank_keys:
111                f.write(f", {ranks[uuid][rank]:4d}")
112
113            f.write("\n")
114
115def get_ranks(results):
116
117    metrics = list(next(iter(results.values())).keys())
118
119    positive = {'pesq_mean', 'mix'}
120
121    ranks = OrderedDict()
122    for key in results.keys():
123        ranks[key] = OrderedDict()
124
125    for metric in metrics:
126        sign = -1 if metric in positive else 1
127
128        x = sorted([(key, value[metric]) for key, value in results.items()], key=lambda x: sign * x[1])
129        x = [y[0] for y in x]
130
131        for key in results.keys():
132            ranks[key]['rank_' + metric] = x.index(key) + 1
133
134    return ranks
135
136def analyse_metrics(results):
137    metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
138
139    x = []
140    for metric in metrics:
141        x.append([val[metric] for val in results.values()])
142
143    x = np.array(x)
144
145    print(x)
146
147def add_mix_metric(results):
148    metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
149
150    x = []
151    for metric in metrics:
152        x.append([val[metric] for val in results.values()])
153
154    x = np.array(x).transpose() * np.array([-1, 1, -1, -1, -1])
155
156    z = (x - np.mean(x, axis=0)) / np.std(x, axis=0)
157
158    print(f"covariance matrix for normalized scores of {metrics}:")
159    print(np.cov(z.transpose()))
160
161    score = np.mean(z, axis=1)
162
163    for i, key in enumerate(results.keys()):
164        results[key]['mix'] = score[i].item()
165
166if __name__ == "__main__":
167    args = parser.parse_args()
168
169    uuids = sorted([x for x in os.listdir(args.input) if os.path.isdir(os.path.join(args.input, x)) and is_uuid(x)])
170
171
172    results = OrderedDict()
173
174    for uuid in uuids:
175        results[uuid] = collect_results(os.path.join(args.input, uuid))
176
177
178    add_mix_metric(results)
179
180    ranks = get_ranks(results)
181
182
183
184    csv = args.csv if args.csv.endswith('.csv') else args.csv + '.csv'
185
186    print_csv(args.csv, results, args.tag, ranks=ranks)
187
188
189    with open(csv[:-4] + '.pickle', 'wb') as f:
190        pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)