1import os 2 3from context import Context, Fmk 4 5 6class ModelInfo: 7 def __init__(self, model_name, fmk: Fmk, network_suffix="", weight_suffix=""): 8 self.model_name = model_name 9 self.fmk: Fmk = fmk 10 self.fmk_name = "" 11 12 self.network_suffix = network_suffix 13 self.weight_suffix = weight_suffix 14 self.model_file = "" 15 self.weight_file = "" 16 # input output 17 self.input_suffix = "" 18 self.output_suffix = "" 19 self.input_num = 0 20 self.input_file = "" 21 self.output_file = "" 22 23 self.converted_model_file = "" 24 25 def __str__(self): 26 return f"{self.model_name}: {{fmk: {self.fmk}, fmk_name: {self.fmk_name}, " \ 27 f"network_suffix: {self.network_suffix}, weight_suffix: {self.weight_suffix}, " \ 28 f"model_file: {self.model_file}, weight_file: {self.weight_file}, input_suffix: {self.input_suffix}, " \ 29 f"output_suffix: {self.output_suffix}, input_num: {self.input_num}, input_file: {self.input_file}, " \ 30 f"output_file: {self.output_file}}}" 31 32 def __repr__(self): 33 return self.__str__() 34 35 @staticmethod 36 def _input_data_str(input_file_prefix, input_num): 37 if input_num == 1: 38 return input_file_prefix 39 first = True 40 input_bin_str = "" 41 for i in range(input_num): 42 if first: 43 first = False 44 input_bin_str += f"{input_file_prefix}_{i}" 45 else: 46 input_bin_str += f",{input_file_prefix}_{i}" 47 return input_bin_str 48 49 def init(self, input_num=1, input_suffix="", output_suffix=""): 50 self.input_num = input_num 51 self.input_suffix = input_suffix 52 self.output_suffix = output_suffix 53 if self.fmk == Fmk.mindir: 54 self.fmk_name = "mindir" 55 if not self.network_suffix: 56 self.network_suffix = ".mindir" 57 if not self.input_suffix: 58 self.input_suffix = ".mindir.bin" 59 if not self.output_suffix: 60 self.output_suffix = ".mindir.out" 61 elif self.fmk == Fmk.caffe: 62 self.fmk_name = "caffe" 63 if not self.network_suffix: 64 self.network_suffix = ".prototxt" 65 if not self.weight_suffix: 66 self.weight_suffix = ".caffemodel" 67 if not self.input_suffix: 68 self.input_suffix = ".ms.bin" 69 if not self.output_suffix: 70 self.output_suffix = ".ms.out" 71 elif self.fmk == Fmk.onnx: 72 self.fmk_name = "onnx" 73 if not self.network_suffix: 74 self.network_suffix = ".onnx" 75 if not self.input_suffix: 76 self.input_suffix = ".onnx.ms.bin" 77 if not self.output_suffix: 78 self.output_suffix = ".onnx.ms.out" 79 elif self.fmk == Fmk.tf: 80 self.fmk_name = "tf" 81 if not self.network_suffix: 82 self.network_suffix = ".pb" 83 if not self.input_suffix: 84 self.input_suffix = ".pb.ms.bin" 85 if not self.output_suffix: 86 self.output_suffix = ".pb.ms.out" 87 elif self.fmk == Fmk.tflite: 88 self.fmk_name = "tflite" 89 if not self.network_suffix: 90 self.network_suffix = ".tflite" 91 if not self.input_suffix: 92 self.input_suffix = ".tflite.ms.bin" 93 if not self.output_suffix: 94 self.output_suffix = ".tflite.ms.out" 95 else: 96 raise ValueError(f"model({self.model_name}) has unsupported fmk: {self.fmk}") 97 context = Context.instance() 98 self.model_file = os.path.join(context.model_dir_func(self.fmk), self.model_name + self.network_suffix) 99 if self.fmk == Fmk.caffe: 100 self.weight_file = os.path.join(context.model_dir_func(self.fmk), self.model_name + self.weight_suffix) 101 else: 102 self.weight_file = "" 103 input_name = self.model_name + self.input_suffix 104 output_name = self.model_name + self.output_suffix 105 self.input_file = os.path.join(context.input_dir_func(self.fmk), input_name) 106 self.output_file = os.path.join(context.output_dir_func(self.fmk), output_name) 107 108 def convert_cmd(self, input_shapes="", output_name=""): 109 context = Context.instance() 110 if not output_name: 111 self.converted_model_file = os.path.join(context.work_dir, self.model_name) 112 else: 113 self.converted_model_file = os.path.join(context.work_dir, output_name) 114 args = [f"--fmk={self.fmk_name.upper()}", f"--modelFile={self.model_file}", 115 f"--outputFile={self.converted_model_file}"] 116 if self.weight_file: 117 args.append(f"--weightFile={self.weight_file}") 118 if input_shapes: 119 args.append(f"--inputShape={input_shapes}") 120 return context.converter_file, args 121 122 def benchmark_accuracy_cmd(self, input_shapes="", acc_threshold=0.5): 123 context = Context.instance() 124 input_bin_str = ModelInfo._input_data_str(self.input_file, self.input_num) 125 if not self.converted_model_file: 126 self.converted_model_file = os.path.join(context.work_dir, self.model_name) 127 if not os.path.exists(f"{self.converted_model_file}.mindir"): 128 self.converted_model_file += "_graph" # when FuncGraph split-export 129 args = [f"--enableParallelPredict=false", f"--modelFile={self.converted_model_file}.mindir", 130 f"--inDataFile={input_bin_str}", f"--benchmarkDataFile={self.output_file}", 131 f"--inputShapes={input_shapes}", f"--accuracyThreshold={acc_threshold}", "--device=CPU"] 132 return context.benchmark_file, args 133 134 def benchmark_performance_cmd(self, input_shapes="", warmup_loop=3, loop=10, num_threads=2): 135 context = Context.instance() 136 if not self.converted_model_file: 137 self.converted_model_file = os.path.join(context.work_dir, self.model_name) 138 if not os.path.exists(f"{self.converted_model_file}.mindir"): 139 self.converted_model_file += "_graph" # when FuncGraph split-export 140 args = [f"--enableParallelPredict=false", f"--modelFile={self.converted_model_file}.mindir", 141 f"--inputShapes={input_shapes}", f"--warmUpLoopCount={warmup_loop}", f"--loopCount={loop}", 142 f"--numThreads={num_threads}", f"--device=CPU"] 143 return context.benchmark_file, args 144