1import subprocess 2import enum 3 4from typing import Dict 5 6from model_info import ModelInfo, Fmk 7from context import Context 8 9 10class ExecRet(enum.Enum): 11 idle = 0 12 success = 1 13 failed = 2 14 15 def __str__(self): 16 if self.value == ExecRet.idle.value: 17 return "idle" 18 if self.value == ExecRet.success.value: 19 return "pass" 20 if self.value == ExecRet.failed.value: 21 return "failed" 22 return "" 23 24 def __repr__(self): 25 return self.__str__() 26 27 28class TestInfo: 29 def __init__(self): 30 self.model: ModelInfo = None 31 32 self.convert_shapes = "" 33 self.convert_output_name = "" 34 35 self.need_performance = False 36 self.benchmark_shapes = "" 37 self.acc_threshold = 0.5 38 self.warmup_loop = 3 39 self.bench_loop = 10 40 self.num_threads = 2 41 42 self.cmd_envs = Context.instance().export_lib_paths 43 self.cmd_envs["ENABLE_MULTI_BACKEND_RUNTIME"] = "on" 44 45 self.convert_ret: ExecRet = ExecRet.idle 46 self.bench_acc_ret: ExecRet = ExecRet.idle 47 self.bench_perf_ret: ExecRet = ExecRet.idle 48 49 def __str__(self): 50 return f"{{model: {self.model}, convert_shapes: {self.convert_shapes}, " \ 51 f"convert_output_name: {self.convert_output_name}, need_performance: {self.need_performance}, " \ 52 f"benchmark_shapes: {self.benchmark_shapes}, acc_threshold: {self.acc_threshold}, " \ 53 f"warmup_loop: {self.warmup_loop}, bench_loop: {self.bench_loop}, num_threads: {self.num_threads}}}" 54 55 def __repr__(self): 56 return self.__str__() 57 58 @classmethod 59 def create(cls, model_fmk: Fmk, model_name, info: dict): 60 config = cls() 61 fmk = model_fmk 62 network_suffix = info.get("network_suffix", "") 63 weight_suffix = info.get("weight_suffix", "") 64 config.model = ModelInfo(model_name, fmk, network_suffix, weight_suffix) 65 input_number = info.get("input_number", 1) 66 input_suffix = info.get("input_suffix", "") 67 output_suffix = info.get("output_suffix", "") 68 config.model.init(input_number, input_suffix, output_suffix) 69 config.convert_shapes = info.get("convert_shapes", "") 70 config.convert_output_name = info.get("convert_output_name", "") 71 72 config.need_performance = info.get("need_performance", False) 73 config.benchmark_shapes = info.get("benchmark_shapes", "") 74 config.acc_threshold = info.get("acc_threshold", 0.5) 75 config.warmup_loop = info.get("warmup_loop", 3) 76 config.bench_loop = info.get("bench_loop", 10) 77 config.num_threads = info.get("num_threads", 2) 78 return config 79 80 @staticmethod 81 def exec_cmd(envs: Dict[str, str], cmd: str, args: [str]): 82 cmds = [cmd] + args 83 return_cmd = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8', shell=False, 84 env=envs) 85 if return_cmd.returncode == 0: 86 return True 87 print(f"Failed to exec {cmds} \r\n {return_cmd.stderr}") 88 return False 89 90 def convert(self): 91 if TestInfo.exec_cmd(self.cmd_envs, *self.model.convert_cmd(self.convert_shapes, self.convert_output_name)): 92 self.convert_ret = ExecRet.success 93 return True 94 self.convert_ret = ExecRet.failed 95 return False 96 97 def benchmark_accuracy(self): 98 if TestInfo.exec_cmd(self.cmd_envs, 99 *self.model.benchmark_accuracy_cmd(self.benchmark_shapes, self.acc_threshold)): 100 self.bench_acc_ret = ExecRet.success 101 return True 102 self.bench_acc_ret = ExecRet.failed 103 return False 104 105 def benchmark_performance(self): 106 if TestInfo.exec_cmd(self.cmd_envs, 107 *self.model.benchmark_performance_cmd(self.benchmark_shapes, self.warmup_loop, 108 self.bench_loop, self.num_threads)): 109 self.bench_perf_ret = ExecRet.success 110 return True 111 self.bench_perf_ret = ExecRet.failed 112 return False 113