1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/op_adapter_util.h"
18
19 #include <string>
20 #include <vector>
21 #include <algorithm>
22
23 #include "utils/utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "transform/graph_ir/op_adapter_base.h"
26 #include "transform/graph_ir/io_format_map.h"
27
28 namespace mindspore {
29 namespace transform {
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> &)30 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &) {
31 // To-DO the format may read from ME tensor
32 MS_EXCEPTION_IF_NULL(value);
33 auto me_tensor = value->cast<MeTensorPtr>();
34 auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_ND);
35 return ge_tensor == nullptr ? GeTensor() : *ge_tensor;
36 }
37
ConvertAnyUtil(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>>)38 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
39 const AnyTraits<std::vector<int64_t>>) {
40 MS_EXCEPTION_IF_NULL(value);
41 std::vector<int64_t> list;
42 if (name == "pad") {
43 if (!value->isa<ValueSequeue>()) {
44 MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
45 }
46 auto vec = value->cast<ValueSequeuePtr>();
47 list.resize(vec->value().size() + 2);
48 list[0] = 1;
49 list[1] = 1;
50 (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2,
51 [](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int64_t>(val)); });
52 } else {
53 int64_t data = GetValue<int64_t>(value);
54 int size = 2; // 2 int in list
55 list = TransformUtil::ConvertIntToList(data, size);
56 }
57
58 return list;
59 }
60
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::string>)61 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, const AnyTraits<std::string>) {
62 MS_EXCEPTION_IF_NULL(value);
63 auto vec = value->cast<ValueTuplePtr>();
64 if (vec == nullptr) {
65 MS_LOG(EXCEPTION) << "not ValueTuplePtr";
66 }
67 std::ostringstream buffer;
68 int i = 0;
69 for (auto &it : vec->value()) {
70 if (i != 0) {
71 buffer << ",";
72 }
73 buffer << GetValue<int64_t>(it);
74 i++;
75 }
76 return buffer.str();
77 }
78
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<float>>,const AnyTraits<float>)79 std::vector<float> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<float>>, const AnyTraits<float>) {
80 MS_EXCEPTION_IF_NULL(value);
81 auto vec = value->cast<ValueTuplePtr>();
82 if (vec == nullptr) {
83 MS_LOG(EXCEPTION) << "not ValueTuplePtr";
84 }
85 std::vector<float> list;
86 list.resize(vec->value().size());
87 (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
88 [](const ValuePtr &val) { return static_cast<float>(GetValue<float>(val)); });
89 return list;
90 }
91
ConvertAnyUtil(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>>,const AnyTraits<int64_t>)92 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &format,
93 const AnyTraits<std::vector<int64_t>>, const AnyTraits<int64_t>) {
94 MS_EXCEPTION_IF_NULL(value);
95 auto vec = value->cast<ValueTuplePtr>();
96 if (vec == nullptr) {
97 MS_LOG(EXCEPTION) << "not ValueTuplePtr";
98 }
99 std::vector<int64_t> list;
100 list.resize(vec->value().size());
101 (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
102 [](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int64_t>(val)); });
103 if (format == kOpFormat_NHWC) {
104 if (list.size() < 4) {
105 MS_LOG(EXCEPTION) << "The size of list is less than 4";
106 } else {
107 int64_t temp = list[1];
108 list[1] = list[2];
109 list[2] = list[3];
110 list[3] = temp;
111 }
112 }
113 return list;
114 }
115
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEType>)116 GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEType>) {
117 MS_EXCEPTION_IF_NULL(value);
118 if (!value->isa<Type>()) {
119 MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString()
120 << ", type: " << value->type_name() << ", value should be a Typeptr";
121 }
122 auto type = value->cast<TypePtr>();
123 MS_EXCEPTION_IF_NULL(type);
124 TypeId me_type = type->type_id();
125 if (kObjectTypeTensorType == me_type) {
126 me_type = dyn_cast<TensorType>(type)->element()->type_id();
127 }
128 return TransformUtil::ConvertDataType(me_type);
129 }
130
VectorToTensorUtil(const ValuePtr & value)131 GeTensor VectorToTensorUtil(const ValuePtr &value) {
132 // convert tuple or list to ge tensor, only supported one dim for now
133 MS_EXCEPTION_IF_NULL(value);
134 auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
135 if (vec.empty()) {
136 MS_LOG(WARNING) << "Convert a none tuple to an empty ge tensor";
137 return GeTensor(GeTensorDesc(ge::Shape({0})));
138 }
139 MS_EXCEPTION_IF_NULL(vec[0]);
140 if (vec[0]->isa<Int32Imm>()) {
141 MS_LOG(INFO) << "convert value to tensor with data type = Int32";
142 auto data = ConvertAnyUtil(value, AnyTraits<int32_t>(), AnyTraits<std::vector<int32_t>>());
143 auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeInt32, kOpFormat_NCHW);
144 if (desc == nullptr) {
145 MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
146 }
147 return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(int32_t));
148 } else if (vec[0]->isa<Int64Imm>()) {
149 MS_LOG(INFO) << "convert value to tensor with data type = Int64";
150 auto data = ConvertAnyUtil(value, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>());
151 auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeInt64, kOpFormat_NCHW);
152 if (desc == nullptr) {
153 MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
154 }
155 return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(int64_t));
156 } else if (vec[0]->isa<FP32Imm>()) {
157 MS_LOG(INFO) << "convert value to tensor with data type = Float32";
158 auto data = ConvertAnyUtil(value, AnyTraits<float>(), AnyTraits<std::vector<float>>());
159 auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeFloat32, kOpFormat_NCHW);
160 if (desc == nullptr) {
161 MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
162 }
163 return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(float));
164 } else if (vec[0]->isa<BoolImm>()) {
165 MS_LOG(INFO) << "convert value to tensor with data type = Bool";
166 // We use uint8_t to save bool type data
167 auto data = ConvertAnyUtil(value, AnyTraits<bool>(), AnyTraits<std::vector<uint8_t>>());
168 auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeBool, kOpFormat_NCHW);
169 if (desc == nullptr) {
170 MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
171 }
172 return GeTensor(*desc, static_cast<uint8_t *>(data.data()), data.size() * sizeof(uint8_t));
173 } else {
174 MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name();
175 }
176
177 return GeTensor();
178 }
179
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<AnyValue>)180 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>) {
181 MS_EXCEPTION_IF_NULL(value);
182 if (value->isa<MeTensor>()) {
183 // convert me tensor to ge tensor
184 return ConvertAnyUtil(value, AnyTraits<MeTensor>());
185 } else if (value->isa<ValueList>() || value->isa<ValueTuple>()) {
186 return VectorToTensorUtil(value);
187 } else if (value->isa<Int32Imm>()) {
188 // convert scalar Int to GeTensor
189 MS_LOG(INFO) << "convert scalar to tensor with data type = Int32";
190 GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
191 auto v = GetValue<int32_t>(value);
192 desc.SetRealDimCnt(0);
193 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int32_t));
194 } else if (value->isa<Int64Imm>()) {
195 // convert scalar Int64 to GeTensor
196 MS_LOG(INFO) << "convert scalar to tensor with data type = Int64";
197 GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
198 auto v = GetValue<int64_t>(value);
199 desc.SetRealDimCnt(0);
200 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int64_t));
201 } else if (value->isa<FP32Imm>()) {
202 // convert scalar FP32 to GeTensor
203 MS_LOG(INFO) << "convert scalar to tensor with data type = FP32";
204 GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
205 auto v = GetValue<float>(value);
206 desc.SetRealDimCnt(0);
207 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(float));
208 } else if (value->isa<BoolImm>()) {
209 // convert scalar FP32 to GeTensor
210 MS_LOG(INFO) << "convert scalar to tensor with data type = Bool";
211 GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL);
212 auto v = GetValue<bool>(value);
213 desc.SetRealDimCnt(0);
214 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(bool));
215 } else if (value->isa<StringImm>()) {
216 // convert String to GeTensor
217 MS_LOG(INFO) << "convert string to tensor with data type = String";
218 std::string v = GetValue<std::string>(value);
219 std::vector<int64_t> ge_shape;
220 GeShape shape(ge_shape);
221 GeTensorDesc desc(shape, ge::FORMAT_NCHW, ge::DT_STRING);
222 GeTensor str_tensor(desc);
223 str_tensor.SetData(v);
224 return str_tensor;
225 } else {
226 MS_LOG(WARNING) << "Unsupported value type: " << value->type_name()
227 << " to convert to tensor. Value: " << value->ToString();
228 }
229 return GeTensor();
230 }
231
IsCustomPrim(const PrimitivePtr & prim)232 bool IsCustomPrim(const PrimitivePtr &prim) {
233 if (prim == nullptr) {
234 return false;
235 }
236
237 ValuePtr flag = prim->GetAttr("_custom_op_flag");
238 if (flag == nullptr) {
239 return false;
240 }
241
242 bool is_custom_op = GetValue<bool>(flag);
243 if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) {
244 MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op "
245 "can not assign the op information config path.";
246 }
247
248 return is_custom_op;
249 }
250
IsCustomCNode(const AnfNodePtr & anf)251 bool IsCustomCNode(const AnfNodePtr &anf) {
252 if (anf == nullptr) {
253 return false;
254 }
255 auto node = anf->cast<CNodePtr>();
256 if (node == nullptr) {
257 return false;
258 }
259 if (node->inputs().empty()) {
260 MS_LOG(EXCEPTION) << "Length of node inputs is empty";
261 }
262 MS_EXCEPTION_IF_NULL(node->inputs()[0]);
263 if (!node->inputs()[0]->isa<ValueNode>()) {
264 return false;
265 }
266 auto cus_prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
267 if (cus_prim == nullptr) {
268 return false;
269 }
270
271 return IsCustomPrim(cus_prim);
272 }
273
GetOpIOFormat(const AnfNodePtr & anf)274 std::string GetOpIOFormat(const AnfNodePtr &anf) {
275 std::string ret;
276 if (anf == nullptr) {
277 MS_LOG(ERROR) << "The anf is nullptr";
278 return ret;
279 }
280 auto node = anf->cast<CNodePtr>();
281 if (node == nullptr) {
282 MS_LOG(ERROR) << "The anf is not a cnode.";
283 return ret;
284 }
285 if (node->inputs().empty()) {
286 MS_LOG(EXCEPTION) << "Length of node inputs is empty.";
287 }
288 MS_EXCEPTION_IF_NULL(node->inputs()[0]);
289 if (!node->inputs()[0]->isa<ValueNode>()) {
290 MS_LOG(ERROR) << "The anf is not a value node.";
291 return ret;
292 }
293 auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
294 if (prim == nullptr) {
295 MS_LOG(ERROR) << "The anf is not a Primitive.";
296 return ret;
297 }
298 if (prim->HasAttr("io_format")) {
299 return GetValue<std::string>(prim->GetAttr("io_format"));
300 }
301 auto io_format_map = IOFormatMap::get();
302 auto iter = io_format_map.find(prim->name());
303 if (iter == io_format_map.end()) {
304 return "NCHW";
305 }
306 if (iter->second == "format") {
307 ValuePtr format = prim->GetAttr("format");
308 MS_EXCEPTION_IF_NULL(format);
309 if (format->isa<Int64Imm>()) {
310 bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &format);
311 if (converted) {
312 return GetValue<std::string>(format);
313 }
314 } else {
315 return GetValue<std::string>(format);
316 }
317 }
318 return iter->second;
319 }
320 } // namespace transform
321 } // namespace mindspore
322