1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16for paddle infer session 17""" 18import os 19from abc import ABC 20from typing import Dict 21 22import tensorflow as tf 23import numpy as np 24 25from mslite_bench.infer_base.abs_infer_session import AbcInferSession 26 27 28class TFSession(AbcInferSession, ABC): 29 """TF infer session""" 30 def __init__(self, 31 model_file, 32 cfg=None): 33 super().__init__(model_file, cfg) 34 self.graph = None 35 self.model_session = self._create_infer_session() 36 37 self.input_tensor_map = { 38 tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for 39 tensor_name in self.input_tensor_shapes.keys() 40 } 41 42 self.output_tensor_map = { 43 tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for 44 tensor_name in self.output_tensor_names 45 } 46 47 def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 48 """model infer""" 49 results = { 50 key: self.model_session.run(output_tensor, 51 feed_dict={ 52 self.input_tensor_map.get(name): input_data_map.get(name) 53 for name in self.input_tensor_shapes.keys() 54 }) 55 for key, output_tensor in self.output_tensor_map.items() 56 } 57 58 return results 59 60 def _create_infer_session(self): 61 """create infer session""" 62 if not os.path.exists(self.model_file): 63 raise ValueError(f'TF model {self.model_file} does not exist') 64 with tf.io.gfile.GFile(self.model_file, 'rb') as f: 65 graph_def = tf.compat.v1.GraphDef() 66 graph_def.ParseFromString(f.read()) 67 input_tensor_map = self._get_tf_input_tensor_map(graph_def) 68 tf.import_graph_def(graph_def, input_map=input_tensor_map, name='') 69 self.logger.debug('Tensor map done') 70 self.graph = tf.compat.v1.get_default_graph() 71 model_session = tf.compat.v1.Session(graph=self.graph) 72 return model_session 73 74 def _get_tf_input_tensor_map(self, graph_def): 75 """get tensorflow input tensor map""" 76 input_tensor_map = {} 77 tf.import_graph_def(graph_def, name='') 78 default_graph = tf.compat.v1.get_default_graph() 79 for key, shape in self.input_tensor_shapes.items(): 80 tensor_name = f'{key}:0' 81 input_tensor = default_graph.get_tensor_by_name(tensor_name) 82 input_tensor.set_shape(shape) 83 input_tensor_map[key] = input_tensor 84 return input_tensor_map 85