• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16for paddle infer session
17"""
18import os
19from abc import ABC
20from typing import Dict
21
22import tensorflow as tf
23import numpy as np
24
25from mslite_bench.infer_base.abs_infer_session import AbcInferSession
26
27
28class TFSession(AbcInferSession, ABC):
29    """TF infer session"""
30    def __init__(self,
31                 model_file,
32                 cfg=None):
33        super().__init__(model_file, cfg)
34        self.graph = None
35        self.model_session = self._create_infer_session()
36
37        self.input_tensor_map = {
38            tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for
39            tensor_name in self.input_tensor_shapes.keys()
40        }
41
42        self.output_tensor_map = {
43            tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for
44            tensor_name in self.output_tensor_names
45        }
46
47    def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
48        """model infer"""
49        results = {
50            key: self.model_session.run(output_tensor,
51                                        feed_dict={
52                                            self.input_tensor_map.get(name): input_data_map.get(name)
53                                            for name in self.input_tensor_shapes.keys()
54                                        })
55            for key, output_tensor in self.output_tensor_map.items()
56        }
57
58        return results
59
60    def _create_infer_session(self):
61        """create infer session"""
62        if not os.path.exists(self.model_file):
63            raise ValueError(f'TF model {self.model_file} does not exist')
64        with tf.io.gfile.GFile(self.model_file, 'rb') as f:
65            graph_def = tf.compat.v1.GraphDef()
66            graph_def.ParseFromString(f.read())
67            input_tensor_map = self._get_tf_input_tensor_map(graph_def)
68            tf.import_graph_def(graph_def, input_map=input_tensor_map, name='')
69            self.logger.debug('Tensor map done')
70            self.graph = tf.compat.v1.get_default_graph()
71        model_session = tf.compat.v1.Session(graph=self.graph)
72        return model_session
73
74    def _get_tf_input_tensor_map(self, graph_def):
75        """get tensorflow input tensor map"""
76        input_tensor_map = {}
77        tf.import_graph_def(graph_def, name='')
78        default_graph = tf.compat.v1.get_default_graph()
79        for key, shape in self.input_tensor_shapes.items():
80            tensor_name = f'{key}:0'
81            input_tensor = default_graph.get_tensor_by_name(tensor_name)
82            input_tensor.set_shape(shape)
83            input_tensor_map[key] = input_tensor
84        return input_tensor_map
85