1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/op_adapter_util.h"
18
19 #include <string>
20 #include <vector>
21 #include <algorithm>
22
23 #include "include/common/utils/utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "transform/graph_ir/op_adapter_base.h"
26 #include "transform/graph_ir/io_format_map.h"
27 #include "ir/kernel_tensor_value.h"
28 #include "ops/op_utils.h"
29
30 namespace mindspore {
GeDataTypeImm()31 GeDataTypeImm::GeDataTypeImm() : IntegerImm(kInt32), v_(::ge::DataType::DT_FLOAT) {}
GeDataTypeImm(::ge::DataType v)32 GeDataTypeImm::GeDataTypeImm(::ge::DataType v) : IntegerImm(kInt32), v_(v) {
33 hash_ = hash_combine({tid(), std::hash<int>{}(v_)});
34 }
operator ==(const Value & other) const35 bool GeDataTypeImm::operator==(const Value &other) const {
36 if (other.isa<GeDataTypeImm>()) {
37 auto &other_ = static_cast<const GeDataTypeImm &>(other);
38 return *this == other_;
39 } else {
40 return false;
41 }
42 }
operator ==(const GeDataTypeImm & other) const43 bool GeDataTypeImm::operator==(const GeDataTypeImm &other) const { return v_ == other.v_; }
DumpText() const44 std::string GeDataTypeImm::DumpText() const {
45 std::ostringstream oss;
46 oss << "GeDataType(" << int(v_) << ")";
47 return oss.str();
48 }
49
50 namespace transform {
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> &)51 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &) {
52 // To-DO the format may read from ME tensor
53 MS_EXCEPTION_IF_NULL(value);
54 auto me_tensor = value->cast<MeTensorPtr>();
55 auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_ND);
56 return ge_tensor == nullptr ? GeTensor() : *ge_tensor;
57 }
58
ConvertAnyUtil(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>>)59 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
60 const AnyTraits<std::vector<int64_t>>) {
61 MS_EXCEPTION_IF_NULL(value);
62 std::vector<int64_t> list;
63 if (name == "pad") {
64 if (!value->isa<ValueSequence>()) {
65 MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
66 }
67 auto vec = value->cast<ValueSequencePtr>();
68 list.resize(vec->value().size() + 2);
69 list[0] = 1;
70 list[1] = 1;
71 (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2,
72 [](const ValuePtr &val) { return ops::GetValueWithCheck<int64_t>(val); });
73 } else {
74 int64_t data = ops::GetValueWithCheck<int64_t>(value);
75 int size = 2; // 2 int in list
76 list = TransformUtil::ConvertIntToList(data, size);
77 }
78
79 return list;
80 }
81
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::string>)82 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, const AnyTraits<std::string>) {
83 MS_EXCEPTION_IF_NULL(value);
84 auto vec = value->cast<ValueTuplePtr>();
85 if (vec == nullptr) {
86 MS_LOG(EXCEPTION) << "not ValueTuplePtr";
87 }
88 std::ostringstream buffer;
89 int i = 0;
90 for (auto &it : vec->value()) {
91 if (i != 0) {
92 buffer << ",";
93 }
94 buffer << ops::GetValueWithCheck<int64_t>(it);
95 i++;
96 }
97 return buffer.str();
98 }
99
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<float>>,const AnyTraits<float>)100 std::vector<float> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<float>>, const AnyTraits<float>) {
101 MS_EXCEPTION_IF_NULL(value);
102 auto vec = value->cast<ValueTuplePtr>();
103 if (vec == nullptr) {
104 MS_LOG(EXCEPTION) << "not ValueTuplePtr";
105 }
106 std::vector<float> list;
107 list.resize(vec->value().size());
108 (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
109 [](const ValuePtr &val) { return ops::GetValueWithCheck<float>(val); });
110 return list;
111 }
112
ConvertAnyUtil(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>>,const AnyTraits<int64_t>)113 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &format,
114 const AnyTraits<std::vector<int64_t>>, const AnyTraits<int64_t>) {
115 MS_EXCEPTION_IF_NULL(value);
116 auto vec = value->cast<ValueTuplePtr>();
117 if (vec == nullptr) {
118 MS_LOG(EXCEPTION) << "not ValueTuplePtr";
119 }
120 std::vector<int64_t> list;
121 list.resize(vec->value().size());
122 (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
123 [](const ValuePtr &val) { return ops::GetValueWithCheck<int64_t>(val); });
124 if (format == kOpFormat_NHWC) {
125 if (list.size() < 4) {
126 MS_LOG(EXCEPTION) << "The size of list is less than 4";
127 } else {
128 int64_t temp = list[1];
129 list[1] = list[2];
130 list[2] = list[3];
131 list[3] = temp;
132 }
133 }
134 return list;
135 }
136
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEType>)137 GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEType>) {
138 MS_EXCEPTION_IF_NULL(value);
139 TypeId me_type;
140 if (value->isa<Type>()) {
141 auto type = value->cast<TypePtr>();
142 MS_EXCEPTION_IF_NULL(type);
143 me_type = type->type_id();
144 if (kObjectTypeTensorType == me_type) {
145 me_type = dyn_cast<TensorType>(type)->element()->type_id();
146 }
147 } else if (value->isa<Int32Imm>()) {
148 // type id
149 me_type = static_cast<TypeId>(GetValue<int32_t>(value));
150 } else if (value->isa<UInt64Imm>()) {
151 // type id
152 me_type = static_cast<TypeId>(GetValue<uint64_t>(value));
153 } else if (value->isa<Int64Imm>()) {
154 // type id
155 me_type = static_cast<TypeId>(GetValue<int64_t>(value));
156 } else if (value->isa<KernelTensorValue>()) {
157 // type id
158 auto value_opt = ops::GetScalarValue<int64_t>(value);
159 me_type = static_cast<TypeId>(value_opt.value());
160 } else {
161 MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString()
162 << ", type: " << value->type_name() << ", value should be a Typeptr or TypeId";
163 }
164 return TransformUtil::ConvertDataType(me_type);
165 }
166
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<GEType>>)167 std::vector<GeDataType> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<GEType>>) {
168 MS_EXCEPTION_IF_NULL(value);
169 std::vector<GeDataType> data;
170 if (!value->isa<ValueTuple>() && !value->isa<ValueList>()) {
171 MS_LOG(WARNING) << "error convert Value to vector for value: " << value->ToString()
172 << ", type: " << value->type_name() << ", value should be a tuple or list";
173 data.emplace_back(ConvertAnyUtil(value, AnyTraits<GEType>()));
174 return data;
175 }
176 auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
177 std::transform(vec.begin(), vec.end(), std::back_inserter(data),
178 [](const ValuePtr &it) { return ConvertAnyUtil(it, AnyTraits<GEType>()); });
179 return data;
180 }
181
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEDataFormat>)182 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEDataFormat>) {
183 MS_EXCEPTION_IF_NULL(value);
184 if (value->isa<StringImm>()) {
185 return GetValue<std::string>(value);
186 }
187 int64_t format_id = GetCastIntegralValue<int64_t>(value);
188 return GEDataFormat::ConvertEnumToString(format_id);
189 }
190
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEPadMod>)191 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEPadMod>) {
192 MS_EXCEPTION_IF_NULL(value);
193 if (value->isa<StringImm>()) {
194 return GetValue<std::string>(value);
195 }
196 int64_t pad_id = GetCastIntegralValue<int64_t>(value);
197 return GEPadMod::ConvertEnumToString(pad_id);
198 }
199
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEReduction>)200 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEReduction>) {
201 MS_EXCEPTION_IF_NULL(value);
202 if (value->isa<StringImm>()) {
203 return GetValue<std::string>(value);
204 }
205 int64_t reduction_id = GetCastIntegralValue<int64_t>(value);
206 return GEReduction::ConvertEnumToString(reduction_id);
207 }
208
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<AscendQuantRoundMode>)209 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AscendQuantRoundMode>) {
210 MS_EXCEPTION_IF_NULL(value);
211 if (value->isa<StringImm>()) {
212 return GetValue<std::string>(value);
213 }
214 int64_t round_mode_id = GetCastIntegralValue<int64_t>(value);
215 return AscendQuantRoundMode::ConvertEnumToString(round_mode_id);
216 }
217
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<FASInputLayoutMode>)218 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FASInputLayoutMode>) {
219 MS_EXCEPTION_IF_NULL(value);
220 if (value->isa<StringImm>()) {
221 return GetValue<std::string>(value);
222 }
223 int64_t input_layout_id = GetCastIntegralValue<int64_t>(value);
224 return FASInputLayoutMode::ConvertEnumToString(input_layout_id);
225 }
226
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<FFNActivationMode>)227 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FFNActivationMode>) {
228 MS_EXCEPTION_IF_NULL(value);
229 if (value->isa<StringImm>()) {
230 return GetValue<std::string>(value);
231 }
232 int64_t activation_id = GetCastIntegralValue<int64_t>(value);
233 return FFNActivationMode::ConvertEnumToString(activation_id);
234 }
235
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<ScatterReduceMode>)236 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ScatterReduceMode>) {
237 MS_EXCEPTION_IF_NULL(value);
238 if (value->isa<StringImm>()) {
239 return GetValue<std::string>(value);
240 }
241 int64_t reduce_id = GetCastIntegralValue<int64_t>(value);
242 return ScatterReduceMode::ConvertEnumToString(reduce_id);
243 }
244
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GECoordinateTransformMode>)245 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GECoordinateTransformMode>) {
246 MS_EXCEPTION_IF_NULL(value);
247 if (value->isa<StringImm>()) {
248 return GetValue<std::string>(value);
249 }
250 int64_t mode_id = GetCastIntegralValue<int64_t>(value);
251 return GECoordinateTransformMode::ConvertEnumToString(mode_id);
252 }
253
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEEnumToStr>,const std::vector<std::string> & enum_string)254 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEEnumToStr>,
255 const std::vector<std::string> &enum_string) {
256 MS_EXCEPTION_IF_NULL(value);
257
258 if (value->isa<StringImm>()) {
259 return GetValue<std::string>(value);
260 }
261 int64_t id = GetCastIntegralValue<int64_t>(value);
262 if (id < 0 || id >= static_cast<int64_t>(enum_string.size())) {
263 MS_LOG(EXCEPTION) << "Invalid enum id " << id;
264 return "";
265 }
266 return enum_string[id];
267 }
268
269 template <typename T1, typename T2>
NestedVectorToTensorImpl(const ValuePtrList & vec,const TypeId & type)270 GeTensor NestedVectorToTensorImpl(const ValuePtrList &vec, const TypeId &type) {
271 const auto &vec_item =
272 vec[0]->isa<ValueTuple>() ? vec[0]->cast<ValueTuplePtr>()->value() : vec[0]->cast<ValueListPtr>()->value();
273 size_t attr_size1 = vec.size();
274 size_t attr_size2 = vec_item.size();
275 std::vector<T1> attr_list;
276 for (const auto &item : vec) {
277 auto value_list = ops::GetValueWithCheck<std::vector<T1>>(item);
278 (void)std::copy(value_list.begin(), value_list.end(), std::back_inserter(attr_list));
279 }
280 auto attr_value = MakeValue(attr_list);
281 auto data = ConvertAnyUtil(attr_value, AnyTraits<T1>(), AnyTraits<std::vector<T2>>());
282 auto desc =
283 TransformUtil::GetGeTensorDesc({static_cast<int>(attr_size1), static_cast<int>(attr_size2)}, type, kOpFormat_NCHW);
284 if (desc == nullptr) {
285 MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
286 }
287 return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(T2));
288 }
289
NestedVectorToTensor(const ValuePtr & value)290 GeTensor NestedVectorToTensor(const ValuePtr &value) {
291 MS_EXCEPTION_IF_NULL(value);
292 const auto &vec =
293 value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
294 const auto &vec_item =
295 vec[0]->isa<ValueTuple>() ? vec[0]->cast<ValueTuplePtr>()->value() : vec[0]->cast<ValueListPtr>()->value();
296 if (vec_item.empty()) {
297 MS_LOG(WARNING) << "Convert a none nested tuple to an empty ge tensor";
298 return GeTensor(GeTensorDesc(::ge::Shape({0})));
299 }
300 MS_EXCEPTION_IF_NULL(vec_item[0]);
301 TypeId type;
302 if (vec_item[0]->isa<Int32Imm>()) {
303 type = kNumberTypeInt32;
304 return NestedVectorToTensorImpl<int32_t, int32_t>(vec, type);
305 } else if (vec_item[0]->isa<Int64Imm>()) {
306 type = kNumberTypeInt64;
307 return NestedVectorToTensorImpl<int64_t, int64_t>(vec, type);
308 } else if (vec_item[0]->isa<FP32Imm>()) {
309 type = kNumberTypeFloat32;
310 return NestedVectorToTensorImpl<float, float>(vec, type);
311 } else if (vec_item[0]->isa<BoolImm>()) {
312 type = kNumberTypeBool;
313 return NestedVectorToTensorImpl<bool, uint8_t>(vec, type);
314 } else {
315 MS_LOG(EXCEPTION) << "Unsupported data type of nested tuple or list elements: " << vec_item[0]->type_name();
316 }
317 }
318
319 template <typename T1, typename T2>
VectorToTensorImpl(const ValuePtr & value,const TypeId & type)320 GeTensor VectorToTensorImpl(const ValuePtr &value, const TypeId &type) {
321 const auto &vec =
322 value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
323 auto data = ConvertAnyUtil(value, AnyTraits<T1>(), AnyTraits<std::vector<T2>>());
324 auto format = vec.size() == kDim4 ? kOpFormat_NCHW : kOpFormat_ND;
325 auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, type, format);
326 if (desc == nullptr) {
327 MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
328 }
329 return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(T2));
330 }
331
VectorToTensorUtil(const ValuePtr & value)332 GeTensor VectorToTensorUtil(const ValuePtr &value) {
333 MS_EXCEPTION_IF_NULL(value);
334 auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
335 if (vec.empty()) {
336 MS_LOG(INFO) << "Convert a none tuple to an empty ge tensor";
337 return GeTensor(GeTensorDesc(::ge::Shape({0}), ::ge::FORMAT_ND, ::ge::DT_INT64));
338 }
339 MS_EXCEPTION_IF_NULL(vec[0]);
340 TypeId type;
341 if (vec[0]->isa<Int32Imm>()) {
342 MS_LOG(INFO) << "convert value to tensor with data type = Int32";
343 type = kNumberTypeInt32;
344 return VectorToTensorImpl<int32_t, int32_t>(value, type);
345 } else if (vec[0]->isa<Int64Imm>()) {
346 MS_LOG(INFO) << "convert value to tensor with data type = Int64";
347 type = kNumberTypeInt64;
348 return VectorToTensorImpl<int64_t, int64_t>(value, type);
349 } else if (vec[0]->isa<FP32Imm>()) {
350 MS_LOG(INFO) << "convert value to tensor with data type = Float32";
351 type = kNumberTypeFloat32;
352 return VectorToTensorImpl<float, float>(value, type);
353 } else if (vec[0]->isa<BoolImm>()) {
354 MS_LOG(INFO) << "convert value to tensor with data type = Bool";
355 type = kNumberTypeBool;
356 return VectorToTensorImpl<bool, uint8_t>(value, type);
357 } else if (vec[0]->isa<ValueTuple>() || vec[0]->isa<ValueList>()) {
358 // convert nested tuple or list to ge tensor, supported two dims
359 MS_LOG(INFO) << "Convert nested tuple or list to ge tensor.";
360 return NestedVectorToTensor(value);
361 } else {
362 MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name();
363 }
364 }
365
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<ValueAny>)366 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ValueAny>) {
367 MS_EXCEPTION_IF_NULL(value);
368 if (value->isa<MeTensor>()) {
369 // convert me tensor to ge tensor
370 return ConvertAnyUtil(value, AnyTraits<MeTensor>());
371 } else if (value->isa<ValueList>() || value->isa<ValueTuple>()) {
372 return VectorToTensorUtil(value);
373 } else if (value->isa<Int32Imm>()) {
374 // convert scalar Int to GeTensor
375 MS_LOG(INFO) << "convert scalar to tensor with data type = Int32";
376 GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_INT32);
377 auto v = GetValue<int32_t>(value);
378 desc.SetRealDimCnt(0);
379 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int32_t));
380 } else if (value->isa<UInt32Imm>()) {
381 // convert scalar UInt to GeTensor
382 MS_LOG(INFO) << "Convert scalar to tensor with data type = UInt32";
383 GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_UINT32);
384 auto v = GetValue<uint32_t>(value);
385 desc.SetRealDimCnt(0);
386 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(uint32_t));
387 } else if (value->isa<Int64Imm>()) {
388 // convert scalar Int64 to GeTensor
389 MS_LOG(INFO) << "convert scalar to tensor with data type = Int64";
390 GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_INT64);
391 auto v = GetValue<int64_t>(value);
392 desc.SetRealDimCnt(0);
393 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int64_t));
394 } else if (value->isa<FP32Imm>()) {
395 // convert scalar FP32 to GeTensor
396 MS_LOG(INFO) << "convert scalar to tensor with data type = FP32";
397 GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_FLOAT);
398 auto v = GetValue<float>(value);
399 desc.SetRealDimCnt(0);
400 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(float));
401 } else if (value->isa<BoolImm>()) {
402 // convert scalar FP32 to GeTensor
403 MS_LOG(INFO) << "convert scalar to tensor with data type = Bool";
404 GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_BOOL);
405 auto v = GetValue<bool>(value);
406 desc.SetRealDimCnt(0);
407 return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(bool));
408 } else if (value->isa<StringImm>()) {
409 // convert String to GeTensor
410 MS_LOG(INFO) << "convert string to tensor with data type = String";
411 std::string v = GetValue<std::string>(value);
412 std::vector<int64_t> ge_shape;
413 GeShape shape(ge_shape);
414 GeTensorDesc desc(shape, ::ge::FORMAT_ND, ::ge::DT_STRING);
415 GeTensor str_tensor(desc);
416 (void)str_tensor.SetData(v);
417 return str_tensor;
418 } else {
419 MS_LOG(INFO) << "Unsupported value type: " << value->type_name()
420 << " to convert to tensor. Value: " << value->ToString();
421 }
422 return GeTensor();
423 }
424
IsCustomPrim(const PrimitivePtr & prim)425 bool IsCustomPrim(const PrimitivePtr &prim) {
426 if (prim == nullptr) {
427 return false;
428 }
429
430 if (prim->name() == "Custom") {
431 return true;
432 }
433 return false;
434 }
435
IsNoNeedConstantFoldCNode(const PrimitivePtr & prim)436 bool IsNoNeedConstantFoldCNode(const PrimitivePtr &prim) {
437 // ON_THE_FLY Quantization node dont need constant folding.
438 return prim->GetAttr("no_need_constant_folding") != nullptr;
439 }
440
IsCustomCNode(const AnfNodePtr & anf)441 bool IsCustomCNode(const AnfNodePtr &anf) {
442 if (anf == nullptr) {
443 return false;
444 }
445 auto node = anf->cast<CNodePtr>();
446 if (node == nullptr) {
447 return false;
448 }
449 if (node->inputs().empty()) {
450 MS_LOG(EXCEPTION) << "Length of node inputs is empty";
451 }
452 MS_EXCEPTION_IF_NULL(node->inputs()[0]);
453 if (!node->inputs()[0]->isa<ValueNode>()) {
454 return false;
455 }
456 auto cus_prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
457 if (cus_prim == nullptr) {
458 return false;
459 }
460
461 return IsCustomPrim(cus_prim);
462 }
463
GetOpIOFormat(const AnfNodePtr & anf)464 std::string GetOpIOFormat(const AnfNodePtr &anf) {
465 std::string ret;
466 if (anf == nullptr) {
467 MS_LOG(ERROR) << "The anf is nullptr";
468 return ret;
469 }
470 auto node = anf->cast<CNodePtr>();
471 if (node == nullptr) {
472 MS_LOG(ERROR) << "The anf is not a cnode.";
473 return ret;
474 }
475 if (node->inputs().empty()) {
476 MS_LOG(EXCEPTION) << "Length of node inputs is empty.";
477 }
478 MS_EXCEPTION_IF_NULL(node->input(0));
479 auto &input = node->input(0);
480 AnfNodePtr prim_node = nullptr;
481 if (input->isa<ValueNode>()) {
482 prim_node = input;
483 } else if (input->isa<CNode>() && input->cast<CNodePtr>()->input(0)->isa<ValueNode>()) {
484 // process cnode1, its input(index 0) is a conde0(partial etc.)
485 prim_node = input->cast<CNodePtr>()->input(0);
486 } else {
487 MS_LOG(ERROR) << "The anf is not a value node or cnode.";
488 return ret;
489 }
490 MS_EXCEPTION_IF_NULL(prim_node);
491 auto prim = GetValueNode<PrimitivePtr>(prim_node);
492 if (prim == nullptr) {
493 MS_LOG(ERROR) << "The anf is not a Primitive.";
494 return ret;
495 }
496 if (prim->HasAttr("io_format")) {
497 return ops::GetValueWithCheck<std::string>(prim->GetAttr("io_format"));
498 }
499 auto io_format_map = IOFormatMap::get();
500 auto iter = io_format_map.find(prim->name());
501 if (iter == io_format_map.end()) {
502 return kOpFormat_DEFAULT;
503 }
504 if (iter->second == "format") {
505 ValuePtr format = prim->GetAttr("format");
506 MS_EXCEPTION_IF_NULL(format);
507 if (format->isa<Int64Imm>()) {
508 bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &format);
509 if (converted) {
510 return ops::GetValueWithCheck<std::string>(format);
511 }
512 } else {
513 return ops::GetValueWithCheck<std::string>(format);
514 }
515 }
516 return iter->second;
517 }
518 } // namespace transform
519 } // namespace mindspore
520