• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16configs for mslite bench
17"""
18from dataclasses import dataclass
19from typing import List, Tuple, Dict
20
21
22from mslite_bench.common.model_info_enum import FrameworkType
23
24
25@dataclass
26class Config:
27    """base config"""
28    device: str = 'cpu'
29    device_id: int = 0
30    log_path: str = None
31    batch_size: int = 1
32
33
34class ModelConfig(Config):
35    """model config"""
36    infer_framework: FrameworkType = FrameworkType.MSLITE.value
37    thread_num: int = 1
38    input_tensor_shapes: Dict[str, Tuple] = None
39    input_tensor_dtypes: Dict[str, str] = None
40    output_tensor_names: List[str] = None
41
42
43@dataclass
44class MsliteConfig(ModelConfig):
45    """mslite config"""
46    thread_affinity_mode: int = 2
47
48    ascend_provider: str = ''
49
50
51@dataclass
52class PaddleConfig(ModelConfig):
53    """paddle config"""
54    infer_framework = FrameworkType.PADDLE.value
55    is_fp16: bool = False
56    is_int8: bool = False
57
58    # for paddle infer
59    is_enable_tensorrt: bool = False
60    gpu_memory_size: int = 100
61    tensorrt_optim_input_shape: Dict[str, List[int]] = None
62    tensorrt_min_input_shape: Dict[str, List[int]] = None
63    tensorrt_max_input_shape: Dict[str, List[int]] = None
64
65
66@dataclass
67class OnnxConfig(ModelConfig):
68    """onnx config"""
69    # for onnx export
70    infer_framework = FrameworkType.ONNX.value
71
72
73@dataclass
74class TFConfig(ModelConfig):
75    """tensorflow config"""
76    infer_framework = FrameworkType.TF.value
77
78
79@dataclass
80class BenchConfig(Config):
81    """benchmark config"""
82    eps: float = 1e-5
83    random_input_flag: bool = False
84    cmp_model_file: str = None
85    input_data_file: str = None
86