1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16infer session factory for unified api 17""" 18import importlib 19 20from mslite_bench.common.model_info_enum import FrameworkType 21from mslite_bench.utils.infer_log import InferLogger 22from mslite_bench.common.task_common_func import CommonFunc 23 24 25_logger = InferLogger().logger 26 27 28class InferSessionFactory: 29 """ 30 infer session factory 31 """ 32 @classmethod 33 def create_infer_session_by_args(cls, 34 args, 35 logger=None): 36 """ 37 params: 38 args: input arguments 39 logger: logger for mslite bench 40 return: model session 41 """ 42 if logger is None: 43 logger = _logger 44 model_path = args.model_file 45 param_path = args.params_file 46 cfg = CommonFunc.get_framework_config(model_path, 47 args) 48 49 model_session = InferSessionFactory.create_infer_session(model_path, 50 cfg, 51 params_file=param_path) 52 logger.debug('Create model session success') 53 return model_session 54 55 @classmethod 56 def create_infer_session(cls, 57 model_file, 58 cfg, 59 params_file=None): 60 """ 61 params: 62 model_file: path to AI model 63 cfg: framework related config 64 params_file: path to model weight file, for paddle, caffe etc. 65 return: model session 66 """ 67 infer_framework_type = cfg.infer_framework 68 if infer_framework_type == FrameworkType.TF.value: 69 try: 70 infer_module = cls.import_module('mslite_bench.infer_base.tf_infer_session') 71 except ImportError as e: 72 _logger.info('import tf session failed: %s', e) 73 raise 74 infer_session = infer_module.TFSession(model_file, cfg) 75 elif infer_framework_type == FrameworkType.ONNX.value: 76 try: 77 infer_module = cls.import_module('mslite_bench.infer_base.onnx_infer_session') 78 except ImportError as e: 79 _logger.info('import onnx session failed: %s', e) 80 raise 81 infer_session = infer_module.OnnxSession(model_file, cfg) 82 elif infer_framework_type == FrameworkType.PADDLE.value: 83 try: 84 infer_module = cls.import_module('mslite_bench.infer_base.paddle_infer_session') 85 except ImportError as e: 86 _logger.info('import paddle session failed: %s', e) 87 raise 88 infer_session = infer_module.PaddleSession(model_file, cfg, params_file=params_file) 89 elif infer_framework_type == FrameworkType.MSLITE.value: 90 try: 91 infer_module = cls.import_module('mslite_bench.infer_base.mslite_infer_session') 92 except ImportError as e: 93 _logger.info('import paddle session failed: %s', e) 94 raise 95 infer_session = infer_module.MsliteSession(model_file, cfg) 96 else: 97 raise NotImplementedError(f'{infer_framework_type} is not supported yet') 98 return infer_session 99 100 @staticmethod 101 def import_module(module_name, file_path=None): 102 """import module functions""" 103 return importlib.import_module(module_name, package=file_path) 104