1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/tf/tf_util.h"
18 #include <string>
19 #include <vector>
20 #include <string_view>
21 #include <regex>
22 #include <unordered_map>
23 #include "src/common/log_adapter.h"
24
25 namespace mindspore {
26 namespace lite {
27 std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = {{tensorflow::DT_INT8, mindspore::kNumberTypeInt8},
28 {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8},
29 {tensorflow::DT_INT16, mindspore::kNumberTypeInt16},
30 {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16},
31 {tensorflow::DT_INT32, mindspore::kNumberTypeInt32},
32 {tensorflow::DT_INT64, mindspore::kNumberTypeInt64},
33 {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16},
34 {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32},
35 {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64},
36 {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64},
37 {tensorflow::DT_BOOL, mindspore::kNumberTypeBool},
38 {tensorflow::DT_STRING, mindspore::kObjectTypeString},
39 {tensorflow::DT_VARIANT, mindspore::kObjectTypeTensorType}};
40
GetTFDataType(const tensorflow::DataType & tf_data_type)41 TypeId TensorFlowUtils::GetTFDataType(const tensorflow::DataType &tf_data_type) {
42 auto iter = TF_TYPE_MAP.find(tf_data_type);
43 if (iter == TF_TYPE_MAP.end()) {
44 MS_LOG(WARNING) << "unsupported TF data type: " << tf_data_type;
45 return kTypeUnknown;
46 }
47 return iter->second;
48 }
49
FindAttrValue(const tensorflow::NodeDef & node_def,const std::string & attr_name,tensorflow::AttrValue * attr_value)50 bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &node_def, const std::string &attr_name,
51 tensorflow::AttrValue *attr_value) {
52 const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = node_def.attr();
53 const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name);
54 if (it != attr.end()) {
55 *attr_value = it->second;
56 return true;
57 }
58 return false;
59 }
60
ParseAttrDataType(const tensorflow::NodeDef & node_def,const std::string & attr_name)61 TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name) {
62 tensorflow::AttrValue attr_value;
63 if (!FindAttrValue(node_def, attr_name, &attr_value)) {
64 MS_LOG(ERROR) << "Find attr failed: " << attr_name;
65 return kTypeUnknown;
66 }
67 return GetTFDataType(attr_value.type());
68 }
69
DecodeInt64(std::string_view * str_view,uint64_t * value)70 bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) {
71 if (value == nullptr) {
72 MS_LOG(ERROR) << "value is nullptr";
73 return false;
74 }
75 if (str_view == nullptr) {
76 *value = 0;
77 MS_LOG(ERROR) << "str_view is nullptr";
78 return false;
79 }
80 auto data = str_view->data();
81 const auto end = data + str_view->size();
82
83 const char *next = nullptr;
84 uint64_t result = 0;
85 for (uint32_t shift = 0; shift <= 63 && data < end; shift += 7) {
86 uint64_t byte = *(reinterpret_cast<const unsigned char *>(data));
87 data++;
88 if (byte & 128) {
89 result |= ((byte & 127) << shift);
90 } else {
91 result |= (byte << shift);
92 *value = result;
93 next = reinterpret_cast<const char *>(data);
94 break;
95 }
96 }
97
98 if (next == nullptr) {
99 return false;
100 } else {
101 *str_view = std::string_view(next, end - next);
102 return true;
103 }
104 }
105
106 // convert input_arg in subgraph to node_name[:index] format
GetFlattenNodeName(const std::string & input_name)107 std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) {
108 std::regex re("\\:+");
109 std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
110 std::sregex_token_iterator());
111 std::string ret = input_name;
112 if (input_splits.size() == 3) {
113 if (input_splits[0] == "RaggedRange") { // Both output index of RaggedRange is 0
114 if (input_splits[1] == "rt_nested_splits") {
115 ret = input_splits[0] + ":0";
116 } else if (input_splits[1] == "rt_dense_values") {
117 ret = input_splits[0] + ":1";
118 } else {
119 MS_LOG(ERROR) << "Failed to flatten RaggedRange node name!";
120 }
121 return ret;
122 } else if (input_splits[0].find("TopKV2") != std::string::npos) {
123 if (input_splits[1] == "values") {
124 return input_splits[0] + ":0";
125 } else if (input_splits[1] == "indices") {
126 return input_splits[0] + ":1";
127 }
128 }
129 if (input_splits[2] == "0") {
130 ret = input_splits[0];
131 } else {
132 ret = input_splits[0] + ":" + input_splits[2]; // multi output node
133 }
134 }
135 return ret;
136 }
137
138 // get referenced node name from input name
GetNodeName(const std::string & input_name)139 std::string TensorFlowUtils::GetNodeName(const std::string &input_name) {
140 std::regex re("\\:+");
141 std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
142 std::sregex_token_iterator());
143 if (input_splits.size() > 1) {
144 return input_splits[0];
145 }
146 return input_name;
147 }
148
ParseNodeFormat(const tensorflow::NodeDef & node_def)149 mindspore::Format TensorFlowUtils::ParseNodeFormat(const tensorflow::NodeDef &node_def) {
150 tensorflow::AttrValue attr_value;
151 if (!FindAttrValue(node_def, "data_format", &attr_value)) {
152 MS_LOG(ERROR) << "Find attr data_format failed";
153 return mindspore::Format::NCHW;
154 }
155 if (attr_value.s() == "NHWC") {
156 return mindspore::Format::NHWC;
157 }
158 return mindspore::Format::NCHW;
159 }
160
OutputIsInputOp(const std::string & op_name)161 bool TensorFlowUtils::OutputIsInputOp(const std::string &op_name) {
162 return op_name == "Identity" || op_name == "StopGradient" || op_name == "NoOp" || op_name == "ReadVariableOp";
163 }
164 } // namespace lite
165 } // namespace mindspore
166