• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import subprocess
2import enum
3
4from typing import Dict
5
6from model_info import ModelInfo, Fmk
7from context import Context
8
9
10class ExecRet(enum.Enum):
11    idle = 0
12    success = 1
13    failed = 2
14
15    def __str__(self):
16        if self.value == ExecRet.idle.value:
17            return "idle"
18        if self.value == ExecRet.success.value:
19            return "pass"
20        if self.value == ExecRet.failed.value:
21            return "failed"
22        return ""
23
24    def __repr__(self):
25        return self.__str__()
26
27
28class TestInfo:
29    def __init__(self):
30        self.model: ModelInfo = None
31
32        self.convert_shapes = ""
33        self.convert_output_name = ""
34
35        self.need_performance = False
36        self.benchmark_shapes = ""
37        self.acc_threshold = 0.5
38        self.warmup_loop = 3
39        self.bench_loop = 10
40        self.num_threads = 2
41
42        self.cmd_envs = Context.instance().export_lib_paths
43        self.cmd_envs["ENABLE_MULTI_BACKEND_RUNTIME"] = "on"
44
45        self.convert_ret: ExecRet = ExecRet.idle
46        self.bench_acc_ret: ExecRet = ExecRet.idle
47        self.bench_perf_ret: ExecRet = ExecRet.idle
48
49    def __str__(self):
50        return f"{{model: {self.model}, convert_shapes: {self.convert_shapes}, " \
51               f"convert_output_name: {self.convert_output_name}, need_performance: {self.need_performance}, " \
52               f"benchmark_shapes: {self.benchmark_shapes}, acc_threshold: {self.acc_threshold}, " \
53               f"warmup_loop: {self.warmup_loop}, bench_loop: {self.bench_loop}, num_threads: {self.num_threads}}}"
54
55    def __repr__(self):
56        return self.__str__()
57
58    @classmethod
59    def create(cls, model_fmk: Fmk, model_name, info: dict):
60        config = cls()
61        fmk = model_fmk
62        network_suffix = info.get("network_suffix", "")
63        weight_suffix = info.get("weight_suffix", "")
64        config.model = ModelInfo(model_name, fmk, network_suffix, weight_suffix)
65        input_number = info.get("input_number", 1)
66        input_suffix = info.get("input_suffix", "")
67        output_suffix = info.get("output_suffix", "")
68        config.model.init(input_number, input_suffix, output_suffix)
69        config.convert_shapes = info.get("convert_shapes", "")
70        config.convert_output_name = info.get("convert_output_name", "")
71
72        config.need_performance = info.get("need_performance", False)
73        config.benchmark_shapes = info.get("benchmark_shapes", "")
74        config.acc_threshold = info.get("acc_threshold", 0.5)
75        config.warmup_loop = info.get("warmup_loop", 3)
76        config.bench_loop = info.get("bench_loop", 10)
77        config.num_threads = info.get("num_threads", 2)
78        return config
79
80    @staticmethod
81    def exec_cmd(envs: Dict[str, str], cmd: str, args: [str]):
82        cmds = [cmd] + args
83        return_cmd = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8', shell=False,
84                                    env=envs)
85        if return_cmd.returncode == 0:
86            return True
87        print(f"Failed to exec {cmds} \r\n {return_cmd.stderr}")
88        return False
89
90    def convert(self):
91        if TestInfo.exec_cmd(self.cmd_envs, *self.model.convert_cmd(self.convert_shapes, self.convert_output_name)):
92            self.convert_ret = ExecRet.success
93            return True
94        self.convert_ret = ExecRet.failed
95        return False
96
97    def benchmark_accuracy(self):
98        if TestInfo.exec_cmd(self.cmd_envs,
99                             *self.model.benchmark_accuracy_cmd(self.benchmark_shapes, self.acc_threshold)):
100            self.bench_acc_ret = ExecRet.success
101            return True
102        self.bench_acc_ret = ExecRet.failed
103        return False
104
105    def benchmark_performance(self):
106        if TestInfo.exec_cmd(self.cmd_envs,
107                             *self.model.benchmark_performance_cmd(self.benchmark_shapes, self.warmup_loop,
108                                                                   self.bench_loop, self.num_threads)):
109            self.bench_perf_ret = ExecRet.success
110            return True
111        self.bench_perf_ret = ExecRet.failed
112        return False
113