• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16infer session factory for unified api
17"""
18import importlib
19
20from mslite_bench.common.model_info_enum import FrameworkType
21from mslite_bench.utils.infer_log import InferLogger
22from mslite_bench.common.task_common_func import CommonFunc
23
24
25_logger = InferLogger().logger
26
27
28class InferSessionFactory:
29    """
30    infer session factory
31    """
32    @classmethod
33    def create_infer_session_by_args(cls,
34                                     args,
35                                     logger=None):
36        """
37        params:
38        args: input arguments
39        logger: logger for mslite bench
40        return: model session
41        """
42        if logger is None:
43            logger = _logger
44        model_path = args.model_file
45        param_path = args.params_file
46        cfg = CommonFunc.get_framework_config(model_path,
47                                              args)
48
49        model_session = InferSessionFactory.create_infer_session(model_path,
50                                                                 cfg,
51                                                                 params_file=param_path)
52        logger.debug('Create model session success')
53        return model_session
54
55    @classmethod
56    def create_infer_session(cls,
57                             model_file,
58                             cfg,
59                             params_file=None):
60        """
61        params:
62        model_file: path to AI model
63        cfg: framework related config
64        params_file: path to model weight file, for paddle, caffe etc.
65        return: model session
66        """
67        infer_framework_type = cfg.infer_framework
68        if infer_framework_type == FrameworkType.TF.value:
69            try:
70                infer_module = cls.import_module('mslite_bench.infer_base.tf_infer_session')
71            except ImportError as e:
72                _logger.info('import tf session failed: %s', e)
73                raise
74            infer_session = infer_module.TFSession(model_file, cfg)
75        elif infer_framework_type == FrameworkType.ONNX.value:
76            try:
77                infer_module = cls.import_module('mslite_bench.infer_base.onnx_infer_session')
78            except ImportError as e:
79                _logger.info('import onnx session failed: %s', e)
80                raise
81            infer_session = infer_module.OnnxSession(model_file, cfg)
82        elif infer_framework_type == FrameworkType.PADDLE.value:
83            try:
84                infer_module = cls.import_module('mslite_bench.infer_base.paddle_infer_session')
85            except ImportError as e:
86                _logger.info('import paddle session failed: %s', e)
87                raise
88            infer_session = infer_module.PaddleSession(model_file, cfg, params_file=params_file)
89        elif infer_framework_type == FrameworkType.MSLITE.value:
90            try:
91                infer_module = cls.import_module('mslite_bench.infer_base.mslite_infer_session')
92            except ImportError as e:
93                _logger.info('import paddle session failed: %s', e)
94                raise
95            infer_session = infer_module.MsliteSession(model_file, cfg)
96        else:
97            raise NotImplementedError(f'{infer_framework_type} is not supported yet')
98        return infer_session
99
100    @staticmethod
101    def import_module(module_name, file_path=None):
102        """import module functions"""
103        return importlib.import_module(module_name, package=file_path)
104