• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_UTILS_H_
18 #define MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_UTILS_H_
19 
20 #include <map>
21 #include "include/api/data_type.h"
22 #include "triton/core/tritonserver.h"
23 
24 namespace triton {
25 namespace backend {
26 namespace mslite {
GetMSDataTypeFromTritonServerDataType(TRITONSERVER_DataType data_type)27 static inline mindspore::DataType GetMSDataTypeFromTritonServerDataType(TRITONSERVER_DataType data_type) {
28   static const std::map<TRITONSERVER_DataType, mindspore::DataType> ms_types = {
29     {TRITONSERVER_TYPE_INVALID, mindspore::DataType::kInvalidType},
30     {TRITONSERVER_TYPE_BOOL, mindspore::DataType::kNumberTypeBool},
31     {TRITONSERVER_TYPE_UINT8, mindspore::DataType::kNumberTypeUInt8},
32     {TRITONSERVER_TYPE_UINT16, mindspore::DataType::kNumberTypeUInt16},
33     {TRITONSERVER_TYPE_UINT32, mindspore::DataType::kNumberTypeUInt32},
34     {TRITONSERVER_TYPE_UINT64, mindspore::DataType::kNumberTypeUInt64},
35     {TRITONSERVER_TYPE_INT8, mindspore::DataType::kNumberTypeInt8},
36     {TRITONSERVER_TYPE_INT16, mindspore::DataType::kNumberTypeInt16},
37     {TRITONSERVER_TYPE_INT32, mindspore::DataType::kNumberTypeInt32},
38     {TRITONSERVER_TYPE_INT64, mindspore::DataType::kNumberTypeInt64},
39     {TRITONSERVER_TYPE_FP16, mindspore::DataType::kNumberTypeFloat16},
40     {TRITONSERVER_TYPE_FP32, mindspore::DataType::kNumberTypeFloat32},
41     {TRITONSERVER_TYPE_FP64, mindspore::DataType::kNumberTypeFloat64},
42     {TRITONSERVER_TYPE_BYTES, mindspore::DataType::kNumberTypeUInt8},
43     {TRITONSERVER_TYPE_BF16, mindspore::DataType::kNumberTypeUInt16}};
44   return ms_types.find(data_type) != ms_types.end() ? ms_types.at(data_type) : mindspore::DataType::kTypeUnknown;
45 }
46 
GetTritonServerDataTypeFromMSDataType(mindspore::DataType data_type)47 static inline TRITONSERVER_DataType GetTritonServerDataTypeFromMSDataType(mindspore::DataType data_type) {
48   static const std::map<mindspore::DataType, TRITONSERVER_DataType> triton_types = {
49     {mindspore::DataType::kInvalidType, TRITONSERVER_TYPE_INVALID},
50     {mindspore::DataType::kNumberTypeBool, TRITONSERVER_TYPE_BOOL},
51     {mindspore::DataType::kNumberTypeUInt8, TRITONSERVER_TYPE_UINT8},
52     {mindspore::DataType::kNumberTypeUInt16, TRITONSERVER_TYPE_UINT16},
53     {mindspore::DataType::kNumberTypeUInt32, TRITONSERVER_TYPE_UINT32},
54     {mindspore::DataType::kNumberTypeUInt64, TRITONSERVER_TYPE_UINT64},
55     {mindspore::DataType::kNumberTypeInt8, TRITONSERVER_TYPE_INT8},
56     {mindspore::DataType::kNumberTypeInt16, TRITONSERVER_TYPE_INT16},
57     {mindspore::DataType::kNumberTypeInt32, TRITONSERVER_TYPE_INT32},
58     {mindspore::DataType::kNumberTypeInt64, TRITONSERVER_TYPE_INT64},
59     {mindspore::DataType::kNumberTypeFloat16, TRITONSERVER_TYPE_FP16},
60     {mindspore::DataType::kNumberTypeFloat32, TRITONSERVER_TYPE_FP32},
61     {mindspore::DataType::kNumberTypeFloat64, TRITONSERVER_TYPE_FP64},
62   };
63   return triton_types.find(data_type) != triton_types.end() ? triton_types.at(data_type) : TRITONSERVER_TYPE_INVALID;
64 }
65 }  // namespace mslite
66 }  // namespace backend
67 }  // namespace triton
68 #endif  // MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_UTILS_H_
69