• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/tf/tf_util.h"
18 #include <string>
19 #include <vector>
20 #include <string_view>
21 #include <regex>
22 #include <unordered_map>
23 #include "src/common/log_adapter.h"
24 
25 namespace mindspore {
26 namespace lite {
27 std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = {{tensorflow::DT_INT8, mindspore::kNumberTypeInt8},
28                                                           {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8},
29                                                           {tensorflow::DT_INT16, mindspore::kNumberTypeInt16},
30                                                           {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16},
31                                                           {tensorflow::DT_INT32, mindspore::kNumberTypeInt32},
32                                                           {tensorflow::DT_INT64, mindspore::kNumberTypeInt64},
33                                                           {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16},
34                                                           {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32},
35                                                           {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64},
36                                                           {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64},
37                                                           {tensorflow::DT_BOOL, mindspore::kNumberTypeBool},
38                                                           {tensorflow::DT_STRING, mindspore::kObjectTypeString},
39                                                           {tensorflow::DT_VARIANT, mindspore::kObjectTypeTensorType}};
40 
GetTFDataType(const tensorflow::DataType & tf_data_type)41 TypeId TensorFlowUtils::GetTFDataType(const tensorflow::DataType &tf_data_type) {
42   auto iter = TF_TYPE_MAP.find(tf_data_type);
43   if (iter == TF_TYPE_MAP.end()) {
44     MS_LOG(WARNING) << "unsupported TF data type: " << tf_data_type;
45     return kTypeUnknown;
46   }
47   return iter->second;
48 }
49 
FindAttrValue(const tensorflow::NodeDef & node_def,const std::string & attr_name,tensorflow::AttrValue * attr_value)50 bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &node_def, const std::string &attr_name,
51                                     tensorflow::AttrValue *attr_value) {
52   const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = node_def.attr();
53   const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name);
54   if (it != attr.end()) {
55     *attr_value = it->second;
56     return true;
57   }
58   return false;
59 }
60 
ParseAttrDataType(const tensorflow::NodeDef & node_def,const std::string & attr_name)61 TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name) {
62   tensorflow::AttrValue attr_value;
63   if (!FindAttrValue(node_def, attr_name, &attr_value)) {
64     MS_LOG(ERROR) << "Find attr failed: " << attr_name;
65     return kTypeUnknown;
66   }
67   return GetTFDataType(attr_value.type());
68 }
69 
DecodeInt64(std::string_view * str_view,uint64_t * value)70 bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) {
71   if (value == nullptr) {
72     MS_LOG(ERROR) << "value is nullptr";
73     return false;
74   }
75   if (str_view == nullptr) {
76     *value = 0;
77     MS_LOG(ERROR) << "str_view is nullptr";
78     return false;
79   }
80   auto data = str_view->data();
81   const auto end = data + str_view->size();
82 
83   const char *next = nullptr;
84   uint64_t result = 0;
85   for (uint32_t shift = 0; shift <= 63 && data < end; shift += 7) {
86     uint64_t byte = *(reinterpret_cast<const unsigned char *>(data));
87     data++;
88     if (byte & 128) {
89       result |= ((byte & 127) << shift);
90     } else {
91       result |= (byte << shift);
92       *value = result;
93       next = reinterpret_cast<const char *>(data);
94       break;
95     }
96   }
97 
98   if (next == nullptr) {
99     return false;
100   } else {
101     *str_view = std::string_view(next, end - next);
102     return true;
103   }
104 }
105 
106 // convert input_arg in subgraph to node_name[:index] format
GetFlattenNodeName(const std::string & input_name)107 std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) {
108   std::regex re("\\:+");
109   std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
110                                         std::sregex_token_iterator());
111   std::string ret = input_name;
112   if (input_splits.size() == 3) {
113     if (input_splits[0] == "RaggedRange") {  // Both output index of RaggedRange is 0
114       if (input_splits[1] == "rt_nested_splits") {
115         ret = input_splits[0] + ":0";
116       } else if (input_splits[1] == "rt_dense_values") {
117         ret = input_splits[0] + ":1";
118       } else {
119         MS_LOG(ERROR) << "Failed to flatten RaggedRange node name!";
120       }
121       return ret;
122     } else if (input_splits[0].find("TopKV2") != std::string::npos) {
123       if (input_splits[1] == "values") {
124         return input_splits[0] + ":0";
125       } else if (input_splits[1] == "indices") {
126         return input_splits[0] + ":1";
127       }
128     }
129     if (input_splits[2] == "0") {
130       ret = input_splits[0];
131     } else {
132       ret = input_splits[0] + ":" + input_splits[2];  // multi output node
133     }
134   }
135   return ret;
136 }
137 
138 // get referenced node name from input name
GetNodeName(const std::string & input_name)139 std::string TensorFlowUtils::GetNodeName(const std::string &input_name) {
140   std::regex re("\\:+");
141   std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
142                                         std::sregex_token_iterator());
143   if (input_splits.size() > 1) {
144     return input_splits[0];
145   }
146   return input_name;
147 }
148 
ParseNodeFormat(const tensorflow::NodeDef & node_def)149 mindspore::Format TensorFlowUtils::ParseNodeFormat(const tensorflow::NodeDef &node_def) {
150   tensorflow::AttrValue attr_value;
151   if (!FindAttrValue(node_def, "data_format", &attr_value)) {
152     MS_LOG(ERROR) << "Find attr data_format failed";
153     return mindspore::Format::NCHW;
154   }
155   if (attr_value.s() == "NHWC") {
156     return mindspore::Format::NHWC;
157   }
158   return mindspore::Format::NCHW;
159 }
160 
OutputIsInputOp(const std::string & op_name)161 bool TensorFlowUtils::OutputIsInputOp(const std::string &op_name) {
162   return op_name == "Identity" || op_name == "StopGradient" || op_name == "NoOp" || op_name == "ReadVariableOp";
163 }
164 }  // namespace lite
165 }  // namespace mindspore
166