• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16for mslite infer session
17"""
18from abc import ABC
19from typing import Dict
20
21import mindspore_lite as mslite
22from mindspore_lite import DataType
23import numpy as np
24
25from mslite_bench.infer_base.abs_infer_session import AbcInferSession
26
27
28class MsliteSession(AbcInferSession, ABC):
29    """
30    mindspore lite infer session
31    """
32    def __init__(self,
33                 model_file,
34                 cfg=None):
35        super().__init__(model_file, cfg)
36        self.thread_num = cfg.thread_num
37        mslite_model_type = self._set_ms_model_type()
38        self.model_type = mslite.ModelType(mslite_model_type)
39        self.device = cfg.device
40        self.thread_affinity_mode = cfg.thread_affinity_mode
41        self.context = self._init_context()
42        self.model_session = self._create_infer_session()
43        self.model_inputs = self.model_session.get_inputs()
44        self.dtype_map = {
45            DataType.BOOL: np.bool_,
46            DataType.INT8: np.int8,
47            DataType.INT16: np.int16,
48            DataType.INT32: np.int32,
49            DataType.INT64: np.int64,
50            DataType.UINT8: np.uint8,
51            DataType.UINT16: np.uint16,
52            DataType.UINT32: np.uint32,
53            DataType.UINT64: np.uint64,
54            DataType.FLOAT16: np.float16,
55            DataType.FLOAT32: np.float32,
56            DataType.FLOAT64: np.float64,
57        }
58
59    def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
60        """model infer"""
61        self._check_and_resize_input_tensor(input_data_map)
62        for model_input in self.model_inputs:
63            tensor_name = model_input.name.rstrip()
64            input_data = input_data_map.get(tensor_name, None)
65            if self.dtype_map[model_input.dtype] != input_data.dtype:
66                self.logger.warning('Input data type %s  is different '
67                                    'from input tensor dtype %s, would convert'
68                                    'input data type to %s ',
69                                    input_data.dtype,
70                                    model_input.dtype,
71                                    model_input.dtype)
72                input_data = input_data.astype(self.dtype_map[model_input.dtype])
73            model_input.set_data_from_numpy(input_data)
74        outputs = self.model_session.predict(self.model_inputs)
75        predict_results = {
76            tensor.name.rstrip(): tensor.get_data_to_numpy()
77            for tensor in outputs
78        }
79        return predict_results
80
81    def _check_and_resize_input_tensor(self, input_data_map):
82        """check and resize input tensor"""
83        is_need_reshape = False
84        input_shape_list = []
85
86        for model_input in self.model_inputs:
87            tensor_name = model_input.name.rstrip()
88            input_data = input_data_map.get(tensor_name, None)
89            if input_data is None:
90                raise ValueError(f'{tensor_name} is not in model inputs')
91            if model_input.shape != list(input_data.shape):
92                self.logger.warning('model input shape: %s is not equal'
93                                    'with input data shape: %s, model input shape'
94                                    'would be reshaped', model_input.shape, input_data.shape)
95                is_need_reshape = True
96            input_shape_list.append(list(input_data.shape))
97
98        if is_need_reshape:
99            self.model_session.resize(self.model_inputs, input_shape_list)
100            self.model_inputs = self.model_session.get_inputs()
101
102    def _create_infer_session(self):
103        """create mslite infer session"""
104        model_session = mslite.Model()
105        model_session.build_from_file(self.model_file,
106                                      self.model_type,
107                                      self.context)
108        return model_session
109
110    def _get_input_tensor_infos(self):
111        """get infos about input tensors"""
112        input_tensor_infos = {}
113        tensor_shape_list = []
114        resize_tensor_list = []
115        for input_tensor in self.model_inputs:
116            tensor_name = input_tensor.name.rstrip()
117            dtype = input_tensor.dtype
118            shape = input_tensor.shape
119            if -1 in shape or not shape:
120                resize_tensor_list.append(input_tensor)
121                shape = self.input_tensor_shapes.get(tensor_name, None)
122                tensor_shape_list.append(list(shape))
123            input_tensor_infos[tensor_name] = (shape, dtype)
124
125        if not resize_tensor_list:
126            self.model_session.resize(resize_tensor_list, tensor_shape_list)
127            self.model_inputs = self.model_session.get_inputs()
128
129        return input_tensor_infos
130
131    def _init_context(self):
132        """init mslite context"""
133        context = mslite.Context()
134        context.target = [self.device]
135        if self.device == 'ascend':
136            context.ascend.device_id = 0
137            context.provider = self.cfg.ascend_provider
138        context.cpu.thread_num = self.thread_num
139        context.cpu.thread_affinity_mode = self.thread_affinity_mode
140        return context
141
142    def _set_ms_model_type(self):
143        """set mslite model type"""
144        if self.model_file.endswith('ms'):
145            mslite_model_type = 4
146        else:
147            mslite_model_type = 0
148        return mslite_model_type
149