1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_UTILS_H_
18 #define MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_UTILS_H_
19
20 #include <map>
21 #include "include/api/data_type.h"
22 #include "triton/core/tritonserver.h"
23
24 namespace triton {
25 namespace backend {
26 namespace mslite {
GetMSDataTypeFromTritonServerDataType(TRITONSERVER_DataType data_type)27 static inline mindspore::DataType GetMSDataTypeFromTritonServerDataType(TRITONSERVER_DataType data_type) {
28 static const std::map<TRITONSERVER_DataType, mindspore::DataType> ms_types = {
29 {TRITONSERVER_TYPE_INVALID, mindspore::DataType::kInvalidType},
30 {TRITONSERVER_TYPE_BOOL, mindspore::DataType::kNumberTypeBool},
31 {TRITONSERVER_TYPE_UINT8, mindspore::DataType::kNumberTypeUInt8},
32 {TRITONSERVER_TYPE_UINT16, mindspore::DataType::kNumberTypeUInt16},
33 {TRITONSERVER_TYPE_UINT32, mindspore::DataType::kNumberTypeUInt32},
34 {TRITONSERVER_TYPE_UINT64, mindspore::DataType::kNumberTypeUInt64},
35 {TRITONSERVER_TYPE_INT8, mindspore::DataType::kNumberTypeInt8},
36 {TRITONSERVER_TYPE_INT16, mindspore::DataType::kNumberTypeInt16},
37 {TRITONSERVER_TYPE_INT32, mindspore::DataType::kNumberTypeInt32},
38 {TRITONSERVER_TYPE_INT64, mindspore::DataType::kNumberTypeInt64},
39 {TRITONSERVER_TYPE_FP16, mindspore::DataType::kNumberTypeFloat16},
40 {TRITONSERVER_TYPE_FP32, mindspore::DataType::kNumberTypeFloat32},
41 {TRITONSERVER_TYPE_FP64, mindspore::DataType::kNumberTypeFloat64},
42 {TRITONSERVER_TYPE_BYTES, mindspore::DataType::kNumberTypeUInt8},
43 {TRITONSERVER_TYPE_BF16, mindspore::DataType::kNumberTypeUInt16}};
44 return ms_types.find(data_type) != ms_types.end() ? ms_types.at(data_type) : mindspore::DataType::kTypeUnknown;
45 }
46
GetTritonServerDataTypeFromMSDataType(mindspore::DataType data_type)47 static inline TRITONSERVER_DataType GetTritonServerDataTypeFromMSDataType(mindspore::DataType data_type) {
48 static const std::map<mindspore::DataType, TRITONSERVER_DataType> triton_types = {
49 {mindspore::DataType::kInvalidType, TRITONSERVER_TYPE_INVALID},
50 {mindspore::DataType::kNumberTypeBool, TRITONSERVER_TYPE_BOOL},
51 {mindspore::DataType::kNumberTypeUInt8, TRITONSERVER_TYPE_UINT8},
52 {mindspore::DataType::kNumberTypeUInt16, TRITONSERVER_TYPE_UINT16},
53 {mindspore::DataType::kNumberTypeUInt32, TRITONSERVER_TYPE_UINT32},
54 {mindspore::DataType::kNumberTypeUInt64, TRITONSERVER_TYPE_UINT64},
55 {mindspore::DataType::kNumberTypeInt8, TRITONSERVER_TYPE_INT8},
56 {mindspore::DataType::kNumberTypeInt16, TRITONSERVER_TYPE_INT16},
57 {mindspore::DataType::kNumberTypeInt32, TRITONSERVER_TYPE_INT32},
58 {mindspore::DataType::kNumberTypeInt64, TRITONSERVER_TYPE_INT64},
59 {mindspore::DataType::kNumberTypeFloat16, TRITONSERVER_TYPE_FP16},
60 {mindspore::DataType::kNumberTypeFloat32, TRITONSERVER_TYPE_FP32},
61 {mindspore::DataType::kNumberTypeFloat64, TRITONSERVER_TYPE_FP64},
62 };
63 return triton_types.find(data_type) != triton_types.end() ? triton_types.at(data_type) : TRITONSERVER_TYPE_INVALID;
64 }
65 } // namespace mslite
66 } // namespace backend
67 } // namespace triton
68 #endif // MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_UTILS_H_
69