• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16for onnx infer session
17"""
18from abc import ABC
19from typing import Dict
20
21import onnx
22import onnxruntime
23import numpy as np
24
25from mslite_bench.infer_base.abs_infer_session import AbcInferSession
26
27
28class OnnxSession(AbcInferSession, ABC):
29    """onnx infer session"""
30    def __init__(self,
31                 model_file,
32                 cfg=None):
33        super().__init__(model_file, cfg)
34        self.model = onnx.load(model_file)
35        self.output_nodes = self._get_all_output_nodes()
36        self.output_tensor_names = self._get_output_tensor_names()
37        self.model_session = self._create_infer_session()
38
39    def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
40        """onnx infer"""
41        outputs = self.model_session.run(self.output_tensor_names,
42                                         input_data_map)
43        result = {}
44        for key, value in zip(self.output_tensor_names, outputs):
45            result[key] = value
46        return result
47
48    def _create_infer_session(self):
49        """create infer session"""
50        model_session = onnxruntime.InferenceSession(self.model_file,
51                                                     providers=['CPUExecutionProvider'])
52        self.logger.debug('onnx Session create successfully')
53        return model_session
54
55    def _get_all_input_nodes(self):
56        """get all input nodes"""
57        all_input_nodes = self.model.graph.input
58        input_initializer_nodes = self.model.graph.initializer
59
60        return list(set(all_input_nodes) - set(input_initializer_nodes))
61
62    def _get_all_output_nodes(self):
63        """get all output nodes"""
64        return self.model.graph.output
65
66    def _get_output_tensor_names(self):
67        """get output tensor names"""
68        if self.output_tensor_names is None:
69            self.output_tensor_names = [node.name for node in self.output_nodes]
70        return self.output_tensor_names
71