• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import statistics
2
3import pandas as pd
4from tabulate import tabulate
5
6
7class ProcessedMetricsPrinter:
8    def print_data_frame(self, name, processed_metrics):
9        print(f"metrics for {name}")
10        data_frame = self.get_data_frame(processed_metrics)
11        print(
12            tabulate(
13                data_frame, showindex=False, headers=data_frame.columns, tablefmt="grid"
14            )
15        )
16
17    def combine_processed_metrics(self, processed_metrics_list):
18        r"""
19        A method that merges the value arrays of the keys in the dictionary
20        of processed metrics.
21
22        Args:
23            processed_metrics_list (list): a list containing dictionaries with
24                recorded metrics as keys, and the values are lists of elapsed times.
25
26        Returns::
27            A merged dictionary that is created from the list of dictionaries passed
28                into the method.
29
30        Examples::
31            >>> instance = ProcessedMetricsPrinter()
32            >>> dict_1 = trainer1.get_processed_metrics()
33            >>> dict_2 = trainer2.get_processed_metrics()
34            >>> print(dict_1)
35            {
36                "forward_metric_type,forward_pass" : [.0429, .0888]
37            }
38            >>> print(dict_2)
39            {
40                "forward_metric_type,forward_pass" : [.0111, .0222]
41            }
42            >>> processed_metrics_list = [dict_1, dict_2]
43            >>> result = instance.combine_processed_metrics(processed_metrics_list)
44            >>> print(result)
45            {
46                "forward_metric_type,forward_pass" : [.0429, .0888, .0111, .0222]
47            }
48        """
49        processed_metric_totals = {}
50        for processed_metrics in processed_metrics_list:
51            for metric_name, values in processed_metrics.items():
52                if metric_name not in processed_metric_totals:
53                    processed_metric_totals[metric_name] = []
54                processed_metric_totals[metric_name] += values
55        return processed_metric_totals
56
57    def get_data_frame(self, processed_metrics):
58        df = pd.DataFrame(columns=["name", "min", "max", "mean", "variance", "stdev"])
59        for metric_name in sorted(processed_metrics.keys()):
60            values = processed_metrics[metric_name]
61            row = {
62                "name": metric_name,
63                "min": min(values),
64                "max": max(values),
65                "mean": statistics.mean(values),
66                "variance": statistics.variance(values),
67                "stdev": statistics.stdev(values),
68            }
69            df = df.append(row, ignore_index=True)
70        return df
71
72    def print_metrics(self, name, rank_metrics_list):
73        if rank_metrics_list:
74            metrics_list = []
75            for rank, metric in rank_metrics_list:
76                self.print_data_frame(f"{name}={rank}", metric)
77                metrics_list.append(metric)
78            combined_metrics = self.combine_processed_metrics(metrics_list)
79            self.print_data_frame(f"all {name}", combined_metrics)
80
81    def save_to_file(self, data_frame, file_name):
82        file_name = f"data_frames/{file_name}.csv"
83        data_frame.to_csv(file_name, encoding="utf-8", index=False)
84