• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/pytorch/pytorch_node_parser.h"
18 #include <unordered_map>
19 
20 namespace mindspore {
21 namespace lite {
22 namespace {
23 static std::unordered_map<at::ScalarType, TypeId> kTorchDataTypeTransferMap = {
24   {at::ScalarType::Bool, kNumberTypeBool},     {at::ScalarType::Byte, kNumberTypeUInt8},
25   {at::ScalarType::Char, kNumberTypeInt8},     {at::ScalarType::Int, kNumberTypeInt},
26   {at::ScalarType::Long, kNumberTypeInt},      {at::ScalarType::Half, kNumberTypeFloat16},
27   {at::ScalarType::Float, kNumberTypeFloat32}, {at::ScalarType::Double, kNumberTypeFloat32}};
28 }  // namespace
29 
GetTorchNodeType(const torch::jit::Node * torch_node)30 std::string PytorchNodeParser::GetTorchNodeType(const torch::jit::Node *torch_node) {
31   const auto &kind = torch_node->kind();
32   std::string node_type = kind.toUnqualString();
33   if (node_type.empty()) {
34     return node_type;
35   }
36   node_type = node_type.at(0) == '_' ? node_type.substr(1) : node_type;
37   node_type = node_type.at(node_type.size() - 1) == '_' ? node_type.substr(0, node_type.size() - 1) : node_type;
38   return node_type;
39 }
40 
GetDataTypeFromTorch(const at::ScalarType torch_data_type)41 TypeId PytorchNodeParser::GetDataTypeFromTorch(const at::ScalarType torch_data_type) {
42   auto iter = kTorchDataTypeTransferMap.find(torch_data_type);
43   if (iter == kTorchDataTypeTransferMap.end()) {
44     MS_LOG(ERROR) << "Unsupported torch data type: " << torch_data_type;
45     return kTypeUnknown;
46   }
47   return iter->second;
48 }
49 }  // namespace lite
50 }  // namespace mindspore
51