• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16common functions
17"""
18import logging
19import os
20import stat
21from typing import Dict, Tuple
22import importlib
23
24import numpy as np
25
26from mslite_bench.common.model_info_enum import FrameworkType
27from mslite_bench.common.enum_class import NumpyDtype
28from mslite_bench.common.config import (
29    MsliteConfig, TFConfig, PaddleConfig, OnnxConfig
30)
31
32
33class CommonFunc:
34    """common functions"""
35    @classmethod
36    def get_framework_config(cls,
37                             model_path,
38                             args):
39        """
40        get framework config by model type and args
41        params:
42        model_path: path to model file
43        args: input arguments
44        return: model config
45        """
46        if not os.path.exists(model_path):
47            raise ValueError(f'Create model session failed: {model_path} does not exist')
48
49        if model_path.endswith('pb'):
50            cfg = cls.init_tf_cfg()
51        elif model_path.endswith('onnx'):
52            cfg = cls.init_onnx_cfg()
53        elif model_path.endswith('ms') or model_path.endswith('mindir'):
54            cfg = cls.init_mslite_cfg(args, model_path)
55        elif model_path.endswith('pdmodel'):
56            cfg = cls.init_paddle_cfg(args)
57        else:
58            raise ValueError(f'model {model_path} is not supported yet')
59
60        cfg.input_tensor_shapes = cls.get_tensor_shapes(args.input_tensor_shapes)
61        cfg.device = args.device
62        cfg.device_id = args.device_id
63        cfg.batch_size = args.batch_size
64        cfg.output_tensor_names = args.output_tensor_names
65        cfg.thread_num = args.thread_num
66
67        if cfg.input_tensor_shapes is None and args.input_data_file is not None:
68            input_data_map = cls.get_input_data_map_from_file(args.input_data_file)
69            cfg.input_tensor_shapes = {
70                key: val.shape for key, val in input_data_map.items()
71            }
72
73        return cfg
74
75    @classmethod
76    def create_numpy_data_map(cls,
77                              args):
78        """
79        create input tensor map, with key input tensor name,
80        value its numpy value
81        """
82        if args.input_data_file is not None:
83            input_data_map = np.load(args.input_data_file, allow_pickle=True).item()
84            return input_data_map
85
86        input_tensor_dtypes = CommonFunc.parse_dtype_infos(args.input_tensor_dtypes)
87        input_tensor_shapes = CommonFunc.get_tensor_shapes(args.input_tensor_shapes)
88        input_tensor_infos = {
89            key: (shape, input_tensor_dtypes.get(key))
90            for key, shape in input_tensor_shapes.items()
91        }
92        try:
93            input_tensor_map = cls.create_numpy_data_map_out(input_tensor_infos)
94        except ValueError as e:
95            raise e
96
97        return input_tensor_map
98
99    @classmethod
100    def init_onnx_cfg(cls):
101        """init onnx config"""
102        cfg = OnnxConfig()
103        return cfg
104
105    @classmethod
106    def init_mslite_cfg(cls, args, model_path):
107        """init mslite config"""
108        cfg = MsliteConfig()
109        cfg.infer_framework = FrameworkType.MSLITE.value
110        cfg.mslite_model_type = 4 if model_path.endswith('ms') else 0
111        cfg.thread_affinity_mode = args.thread_affinity_mode
112        cfg.ascend_provider = args.ascend_provider
113        return cfg
114
115    @classmethod
116    def init_paddle_cfg(cls, args):
117        """init paddle config"""
118        cfg = PaddleConfig()
119        cfg.infer_framework = FrameworkType.PADDLE.value
120        cfg.is_fp16 = args.is_fp16
121        cfg.is_int8 = args.is_int8
122        cfg.is_enable_tensorrt = args.is_enable_tensorrt
123        def tmp_func(x):
124            if x is None:
125                return None
126            return cls.get_tensor_shapes(x)
127        cfg.tensorrt_optim_input_shape = tmp_func(args.tensorrt_optim_input_shape)
128        cfg.tensorrt_min_input_shape = tmp_func(args.tensorrt_min_input_shape)
129        cfg.tensorrt_max_input_shape = tmp_func(args.tensorrt_max_input_shape)
130        if cfg.tensorrt_min_input_shape is None:
131            cfg.tensorrt_min_input_shape = cfg.tensorrt_optim_input_shape
132        if cfg.tensorrt_max_input_shape is None:
133            cfg.tensorrt_max_input_shape = cfg.tensorrt_optim_input_shape
134        return cfg
135
136    @staticmethod
137    def get_tensor_shapes(tensor_shapes: str) -> Dict[str, Tuple[int]]:
138        """parse tensor shapes string into dict"""
139        if tensor_shapes is None:
140            return {}
141
142        input_tensor_shape = {}
143        shape_list = tensor_shapes.split(';')
144
145        for shapes in shape_list:
146            name, shape = shapes.split(':')
147            shape = [int(i) for i in shape.split(',')]
148            input_tensor_shape[name] = shape
149
150        return input_tensor_shape
151
152    @staticmethod
153    def import_module(module_name, file_path=None):
154        """import module functions"""
155        return importlib.import_module(module_name, package=file_path)
156
157    @staticmethod
158    def get_input_data_map_from_file(input_data_file):
159        """get input data map from file"""
160        return np.load(input_data_file, allow_pickle=True).item()
161
162    @staticmethod
163    def create_numpy_data_map_out(tensor_infos):
164        """create numpy data dict"""
165        np_data_map = {}
166        for tensor_name, infos in tensor_infos.items():
167            if not isinstance(infos, tuple):
168                raise ValueError('input info shall contain tensor shape and tensor dtype')
169            shape, dtype = infos
170            np_dtype = getattr(NumpyDtype, dtype.upper()).value
171            tensor_data = np.random.rand(*shape).astype(np_dtype)
172            np_data_map[tensor_name] = tensor_data
173
174        return np_data_map
175
176    @staticmethod
177    def save_output_as_benchmark_txt(save_dir,
178                                     output_tensor):
179        """save output tensor as benchmark type text"""
180        for key, value in output_tensor.items():
181            save_path = f'{save_dir}_{"".join(key.split("/"))}.txt'
182            shape = value.shape
183            shape_str = ''
184            for val in shape:
185                shape_str = shape_str.join(f'{val} ')
186            dim = len(shape)
187            flags = os.O_WRONLY
188            mode = stat.S_IWUSR | stat.S_IRUSR
189            with os.fdopen(os.open(save_path, flags, mode), 'w') as fi:
190                fi.write(f'{key} {dim} {shape_str}\n')
191                np.savetxt(fi, value.flatten(), newline=' ')
192
193    @staticmethod
194    def init_tf_cfg():
195        """init tensorflow config"""
196        cfg = TFConfig()
197        return cfg
198
199    @staticmethod
200    def logging_level(level):
201        if level == 0:
202            return logging.DEBUG
203        if level == 1:
204            return logging.INFO
205        if level == 2:
206            return logging.WARNING
207        return logging.ERROR
208
209    @staticmethod
210    def parse_dtype_infos(dtype_infos):
211        """
212        parse input dtype infos string to dict, key is input tensor,
213        value is tensor dtype
214        params:
215        model_path: path to model file
216        args: input arguments
217        return: model config
218        """
219        infos = dtype_infos.split(';')
220        ret = {}
221        for info in infos:
222            key, dtype = info.split(':')
223            ret[key] = dtype.strip()
224
225        return ret
226