1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16for mslite infer session 17""" 18from abc import ABC 19from typing import Dict 20 21import mindspore_lite as mslite 22from mindspore_lite import DataType 23import numpy as np 24 25from mslite_bench.infer_base.abs_infer_session import AbcInferSession 26 27 28class MsliteSession(AbcInferSession, ABC): 29 """ 30 mindspore lite infer session 31 """ 32 def __init__(self, 33 model_file, 34 cfg=None): 35 super().__init__(model_file, cfg) 36 self.thread_num = cfg.thread_num 37 mslite_model_type = self._set_ms_model_type() 38 self.model_type = mslite.ModelType(mslite_model_type) 39 self.device = cfg.device 40 self.thread_affinity_mode = cfg.thread_affinity_mode 41 self.context = self._init_context() 42 self.model_session = self._create_infer_session() 43 self.model_inputs = self.model_session.get_inputs() 44 self.dtype_map = { 45 DataType.BOOL: np.bool_, 46 DataType.INT8: np.int8, 47 DataType.INT16: np.int16, 48 DataType.INT32: np.int32, 49 DataType.INT64: np.int64, 50 DataType.UINT8: np.uint8, 51 DataType.UINT16: np.uint16, 52 DataType.UINT32: np.uint32, 53 DataType.UINT64: np.uint64, 54 DataType.FLOAT16: np.float16, 55 DataType.FLOAT32: np.float32, 56 DataType.FLOAT64: np.float64, 57 } 58 59 def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 60 """model infer""" 61 self._check_and_resize_input_tensor(input_data_map) 62 for model_input in self.model_inputs: 63 tensor_name = model_input.name.rstrip() 64 input_data = input_data_map.get(tensor_name, None) 65 if self.dtype_map[model_input.dtype] != input_data.dtype: 66 self.logger.warning('Input data type %s is different ' 67 'from input tensor dtype %s, would convert' 68 'input data type to %s ', 69 input_data.dtype, 70 model_input.dtype, 71 model_input.dtype) 72 input_data = input_data.astype(self.dtype_map[model_input.dtype]) 73 model_input.set_data_from_numpy(input_data) 74 outputs = self.model_session.predict(self.model_inputs) 75 predict_results = { 76 tensor.name.rstrip(): tensor.get_data_to_numpy() 77 for tensor in outputs 78 } 79 return predict_results 80 81 def _check_and_resize_input_tensor(self, input_data_map): 82 """check and resize input tensor""" 83 is_need_reshape = False 84 input_shape_list = [] 85 86 for model_input in self.model_inputs: 87 tensor_name = model_input.name.rstrip() 88 input_data = input_data_map.get(tensor_name, None) 89 if input_data is None: 90 raise ValueError(f'{tensor_name} is not in model inputs') 91 if model_input.shape != list(input_data.shape): 92 self.logger.warning('model input shape: %s is not equal' 93 'with input data shape: %s, model input shape' 94 'would be reshaped', model_input.shape, input_data.shape) 95 is_need_reshape = True 96 input_shape_list.append(list(input_data.shape)) 97 98 if is_need_reshape: 99 self.model_session.resize(self.model_inputs, input_shape_list) 100 self.model_inputs = self.model_session.get_inputs() 101 102 def _create_infer_session(self): 103 """create mslite infer session""" 104 model_session = mslite.Model() 105 model_session.build_from_file(self.model_file, 106 self.model_type, 107 self.context) 108 return model_session 109 110 def _get_input_tensor_infos(self): 111 """get infos about input tensors""" 112 input_tensor_infos = {} 113 tensor_shape_list = [] 114 resize_tensor_list = [] 115 for input_tensor in self.model_inputs: 116 tensor_name = input_tensor.name.rstrip() 117 dtype = input_tensor.dtype 118 shape = input_tensor.shape 119 if -1 in shape or not shape: 120 resize_tensor_list.append(input_tensor) 121 shape = self.input_tensor_shapes.get(tensor_name, None) 122 tensor_shape_list.append(list(shape)) 123 input_tensor_infos[tensor_name] = (shape, dtype) 124 125 if not resize_tensor_list: 126 self.model_session.resize(resize_tensor_list, tensor_shape_list) 127 self.model_inputs = self.model_session.get_inputs() 128 129 return input_tensor_infos 130 131 def _init_context(self): 132 """init mslite context""" 133 context = mslite.Context() 134 context.target = [self.device] 135 if self.device == 'ascend': 136 context.ascend.device_id = 0 137 context.provider = self.cfg.ascend_provider 138 context.cpu.thread_num = self.thread_num 139 context.cpu.thread_affinity_mode = self.thread_affinity_mode 140 return context 141 142 def _set_ms_model_type(self): 143 """set mslite model type""" 144 if self.model_file.endswith('ms'): 145 mslite_model_type = 4 146 else: 147 mslite_model_type = 0 148 return mslite_model_type 149