• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os
2
3from context import Context, Fmk
4
5
6class ModelInfo:
7    def __init__(self, model_name, fmk: Fmk, network_suffix="", weight_suffix=""):
8        self.model_name = model_name
9        self.fmk: Fmk = fmk
10        self.fmk_name = ""
11
12        self.network_suffix = network_suffix
13        self.weight_suffix = weight_suffix
14        self.model_file = ""
15        self.weight_file = ""
16        # input output
17        self.input_suffix = ""
18        self.output_suffix = ""
19        self.input_num = 0
20        self.input_file = ""
21        self.output_file = ""
22
23        self.converted_model_file = ""
24
25    def __str__(self):
26        return f"{self.model_name}: {{fmk: {self.fmk}, fmk_name: {self.fmk_name}, " \
27               f"network_suffix: {self.network_suffix}, weight_suffix: {self.weight_suffix}, " \
28               f"model_file: {self.model_file}, weight_file: {self.weight_file}, input_suffix: {self.input_suffix}, " \
29               f"output_suffix: {self.output_suffix}, input_num: {self.input_num}, input_file: {self.input_file}, " \
30               f"output_file: {self.output_file}}}"
31
32    def __repr__(self):
33        return self.__str__()
34
35    @staticmethod
36    def _input_data_str(input_file_prefix, input_num):
37        if input_num == 1:
38            return input_file_prefix
39        first = True
40        input_bin_str = ""
41        for i in range(input_num):
42            if first:
43                first = False
44                input_bin_str += f"{input_file_prefix}_{i}"
45            else:
46                input_bin_str += f",{input_file_prefix}_{i}"
47        return input_bin_str
48
49    def init(self, input_num=1, input_suffix="", output_suffix=""):
50        self.input_num = input_num
51        self.input_suffix = input_suffix
52        self.output_suffix = output_suffix
53        if self.fmk == Fmk.mindir:
54            self.fmk_name = "mindir"
55            if not self.network_suffix:
56                self.network_suffix = ".mindir"
57            if not self.input_suffix:
58                self.input_suffix = ".mindir.bin"
59            if not self.output_suffix:
60                self.output_suffix = ".mindir.out"
61        elif self.fmk == Fmk.caffe:
62            self.fmk_name = "caffe"
63            if not self.network_suffix:
64                self.network_suffix = ".prototxt"
65            if not self.weight_suffix:
66                self.weight_suffix = ".caffemodel"
67            if not self.input_suffix:
68                self.input_suffix = ".ms.bin"
69            if not self.output_suffix:
70                self.output_suffix = ".ms.out"
71        elif self.fmk == Fmk.onnx:
72            self.fmk_name = "onnx"
73            if not self.network_suffix:
74                self.network_suffix = ".onnx"
75            if not self.input_suffix:
76                self.input_suffix = ".onnx.ms.bin"
77            if not self.output_suffix:
78                self.output_suffix = ".onnx.ms.out"
79        elif self.fmk == Fmk.tf:
80            self.fmk_name = "tf"
81            if not self.network_suffix:
82                self.network_suffix = ".pb"
83            if not self.input_suffix:
84                self.input_suffix = ".pb.ms.bin"
85            if not self.output_suffix:
86                self.output_suffix = ".pb.ms.out"
87        elif self.fmk == Fmk.tflite:
88            self.fmk_name = "tflite"
89            if not self.network_suffix:
90                self.network_suffix = ".tflite"
91            if not self.input_suffix:
92                self.input_suffix = ".tflite.ms.bin"
93            if not self.output_suffix:
94                self.output_suffix = ".tflite.ms.out"
95        else:
96            raise ValueError(f"model({self.model_name}) has unsupported fmk: {self.fmk}")
97        context = Context.instance()
98        self.model_file = os.path.join(context.model_dir_func(self.fmk), self.model_name + self.network_suffix)
99        if self.fmk == Fmk.caffe:
100            self.weight_file = os.path.join(context.model_dir_func(self.fmk), self.model_name + self.weight_suffix)
101        else:
102            self.weight_file = ""
103        input_name = self.model_name + self.input_suffix
104        output_name = self.model_name + self.output_suffix
105        self.input_file = os.path.join(context.input_dir_func(self.fmk), input_name)
106        self.output_file = os.path.join(context.output_dir_func(self.fmk), output_name)
107
108    def convert_cmd(self, input_shapes="", output_name=""):
109        context = Context.instance()
110        if not output_name:
111            self.converted_model_file = os.path.join(context.work_dir, self.model_name)
112        else:
113            self.converted_model_file = os.path.join(context.work_dir, output_name)
114        args = [f"--fmk={self.fmk_name.upper()}", f"--modelFile={self.model_file}",
115                f"--outputFile={self.converted_model_file}"]
116        if self.weight_file:
117            args.append(f"--weightFile={self.weight_file}")
118        if input_shapes:
119            args.append(f"--inputShape={input_shapes}")
120        return context.converter_file, args
121
122    def benchmark_accuracy_cmd(self, input_shapes="", acc_threshold=0.5):
123        context = Context.instance()
124        input_bin_str = ModelInfo._input_data_str(self.input_file, self.input_num)
125        if not self.converted_model_file:
126            self.converted_model_file = os.path.join(context.work_dir, self.model_name)
127        if not os.path.exists(f"{self.converted_model_file}.mindir"):
128            self.converted_model_file += "_graph"  # when FuncGraph split-export
129        args = [f"--enableParallelPredict=false", f"--modelFile={self.converted_model_file}.mindir",
130                f"--inDataFile={input_bin_str}", f"--benchmarkDataFile={self.output_file}",
131                f"--inputShapes={input_shapes}", f"--accuracyThreshold={acc_threshold}", "--device=CPU"]
132        return context.benchmark_file, args
133
134    def benchmark_performance_cmd(self, input_shapes="", warmup_loop=3, loop=10, num_threads=2):
135        context = Context.instance()
136        if not self.converted_model_file:
137            self.converted_model_file = os.path.join(context.work_dir, self.model_name)
138        if not os.path.exists(f"{self.converted_model_file}.mindir"):
139            self.converted_model_file += "_graph"  # when FuncGraph split-export
140        args = [f"--enableParallelPredict=false", f"--modelFile={self.converted_model_file}.mindir",
141                f"--inputShapes={input_shapes}", f"--warmUpLoopCount={warmup_loop}", f"--loopCount={loop}",
142                f"--numThreads={num_threads}", f"--device=CPU"]
143        return context.benchmark_file, args
144