1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/pytorch/pytorch_node_parser.h"
18 #include <unordered_map>
19
20 namespace mindspore {
21 namespace lite {
22 namespace {
23 static std::unordered_map<at::ScalarType, TypeId> kTorchDataTypeTransferMap = {
24 {at::ScalarType::Bool, kNumberTypeBool}, {at::ScalarType::Byte, kNumberTypeUInt8},
25 {at::ScalarType::Char, kNumberTypeInt8}, {at::ScalarType::Int, kNumberTypeInt},
26 {at::ScalarType::Long, kNumberTypeInt}, {at::ScalarType::Half, kNumberTypeFloat16},
27 {at::ScalarType::Float, kNumberTypeFloat32}, {at::ScalarType::Double, kNumberTypeFloat32}};
28 } // namespace
29
GetTorchNodeType(const torch::jit::Node * torch_node)30 std::string PytorchNodeParser::GetTorchNodeType(const torch::jit::Node *torch_node) {
31 const auto &kind = torch_node->kind();
32 std::string node_type = kind.toUnqualString();
33 if (node_type.empty()) {
34 return node_type;
35 }
36 node_type = node_type.at(0) == '_' ? node_type.substr(1) : node_type;
37 node_type = node_type.at(node_type.size() - 1) == '_' ? node_type.substr(0, node_type.size() - 1) : node_type;
38 return node_type;
39 }
40
GetDataTypeFromTorch(const at::ScalarType torch_data_type)41 TypeId PytorchNodeParser::GetDataTypeFromTorch(const at::ScalarType torch_data_type) {
42 auto iter = kTorchDataTypeTransferMap.find(torch_data_type);
43 if (iter == kTorchDataTypeTransferMap.end()) {
44 MS_LOG(ERROR) << "Unsupported torch data type: " << torch_data_type;
45 return kTypeUnknown;
46 }
47 return iter->second;
48 }
49 } // namespace lite
50 } // namespace mindspore
51