• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/op_adapter_util.h"
18 
19 #include <string>
20 #include <vector>
21 #include <algorithm>
22 
23 #include "include/common/utils/utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "transform/graph_ir/op_adapter_base.h"
26 #include "transform/graph_ir/io_format_map.h"
27 #include "ir/kernel_tensor_value.h"
28 #include "ops/op_utils.h"
29 
30 namespace mindspore {
GeDataTypeImm()31 GeDataTypeImm::GeDataTypeImm() : IntegerImm(kInt32), v_(::ge::DataType::DT_FLOAT) {}
GeDataTypeImm(::ge::DataType v)32 GeDataTypeImm::GeDataTypeImm(::ge::DataType v) : IntegerImm(kInt32), v_(v) {
33   hash_ = hash_combine({tid(), std::hash<int>{}(v_)});
34 }
operator ==(const Value & other) const35 bool GeDataTypeImm::operator==(const Value &other) const {
36   if (other.isa<GeDataTypeImm>()) {
37     auto &other_ = static_cast<const GeDataTypeImm &>(other);
38     return *this == other_;
39   } else {
40     return false;
41   }
42 }
operator ==(const GeDataTypeImm & other) const43 bool GeDataTypeImm::operator==(const GeDataTypeImm &other) const { return v_ == other.v_; }
DumpText() const44 std::string GeDataTypeImm::DumpText() const {
45   std::ostringstream oss;
46   oss << "GeDataType(" << int(v_) << ")";
47   return oss.str();
48 }
49 
50 namespace transform {
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> &)51 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &) {
52   // To-DO the format may read from ME tensor
53   MS_EXCEPTION_IF_NULL(value);
54   auto me_tensor = value->cast<MeTensorPtr>();
55   auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_ND);
56   return ge_tensor == nullptr ? GeTensor() : *ge_tensor;
57 }
58 
ConvertAnyUtil(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>>)59 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
60                                     const AnyTraits<std::vector<int64_t>>) {
61   MS_EXCEPTION_IF_NULL(value);
62   std::vector<int64_t> list;
63   if (name == "pad") {
64     if (!value->isa<ValueSequence>()) {
65       MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
66     }
67     auto vec = value->cast<ValueSequencePtr>();
68     list.resize(vec->value().size() + 2);
69     list[0] = 1;
70     list[1] = 1;
71     (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2,
72                          [](const ValuePtr &val) { return ops::GetValueWithCheck<int64_t>(val); });
73   } else {
74     int64_t data = ops::GetValueWithCheck<int64_t>(value);
75     int size = 2;  // 2 int in list
76     list = TransformUtil::ConvertIntToList(data, size);
77   }
78 
79   return list;
80 }
81 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::string>)82 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, const AnyTraits<std::string>) {
83   MS_EXCEPTION_IF_NULL(value);
84   auto vec = value->cast<ValueTuplePtr>();
85   if (vec == nullptr) {
86     MS_LOG(EXCEPTION) << "not ValueTuplePtr";
87   }
88   std::ostringstream buffer;
89   int i = 0;
90   for (auto &it : vec->value()) {
91     if (i != 0) {
92       buffer << ",";
93     }
94     buffer << ops::GetValueWithCheck<int64_t>(it);
95     i++;
96   }
97   return buffer.str();
98 }
99 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<float>>,const AnyTraits<float>)100 std::vector<float> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<float>>, const AnyTraits<float>) {
101   MS_EXCEPTION_IF_NULL(value);
102   auto vec = value->cast<ValueTuplePtr>();
103   if (vec == nullptr) {
104     MS_LOG(EXCEPTION) << "not ValueTuplePtr";
105   }
106   std::vector<float> list;
107   list.resize(vec->value().size());
108   (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
109                        [](const ValuePtr &val) { return ops::GetValueWithCheck<float>(val); });
110   return list;
111 }
112 
ConvertAnyUtil(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>>,const AnyTraits<int64_t>)113 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &format,
114                                     const AnyTraits<std::vector<int64_t>>, const AnyTraits<int64_t>) {
115   MS_EXCEPTION_IF_NULL(value);
116   auto vec = value->cast<ValueTuplePtr>();
117   if (vec == nullptr) {
118     MS_LOG(EXCEPTION) << "not ValueTuplePtr";
119   }
120   std::vector<int64_t> list;
121   list.resize(vec->value().size());
122   (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
123                        [](const ValuePtr &val) { return ops::GetValueWithCheck<int64_t>(val); });
124   if (format == kOpFormat_NHWC) {
125     if (list.size() < 4) {
126       MS_LOG(EXCEPTION) << "The size of list is less than 4";
127     } else {
128       int64_t temp = list[1];
129       list[1] = list[2];
130       list[2] = list[3];
131       list[3] = temp;
132     }
133   }
134   return list;
135 }
136 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEType>)137 GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEType>) {
138   MS_EXCEPTION_IF_NULL(value);
139   TypeId me_type;
140   if (value->isa<Type>()) {
141     auto type = value->cast<TypePtr>();
142     MS_EXCEPTION_IF_NULL(type);
143     me_type = type->type_id();
144     if (kObjectTypeTensorType == me_type) {
145       me_type = dyn_cast<TensorType>(type)->element()->type_id();
146     }
147   } else if (value->isa<Int32Imm>()) {
148     // type id
149     me_type = static_cast<TypeId>(GetValue<int32_t>(value));
150   } else if (value->isa<UInt64Imm>()) {
151     // type id
152     me_type = static_cast<TypeId>(GetValue<uint64_t>(value));
153   } else if (value->isa<Int64Imm>()) {
154     // type id
155     me_type = static_cast<TypeId>(GetValue<int64_t>(value));
156   } else if (value->isa<KernelTensorValue>()) {
157     // type id
158     auto value_opt = ops::GetScalarValue<int64_t>(value);
159     me_type = static_cast<TypeId>(value_opt.value());
160   } else {
161     MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString()
162                       << ", type: " << value->type_name() << ", value should be a Typeptr or TypeId";
163   }
164   return TransformUtil::ConvertDataType(me_type);
165 }
166 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<GEType>>)167 std::vector<GeDataType> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<GEType>>) {
168   MS_EXCEPTION_IF_NULL(value);
169   std::vector<GeDataType> data;
170   if (!value->isa<ValueTuple>() && !value->isa<ValueList>()) {
171     MS_LOG(WARNING) << "error convert Value to vector for value: " << value->ToString()
172                     << ", type: " << value->type_name() << ", value should be a tuple or list";
173     data.emplace_back(ConvertAnyUtil(value, AnyTraits<GEType>()));
174     return data;
175   }
176   auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
177   std::transform(vec.begin(), vec.end(), std::back_inserter(data),
178                  [](const ValuePtr &it) { return ConvertAnyUtil(it, AnyTraits<GEType>()); });
179   return data;
180 }
181 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEDataFormat>)182 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEDataFormat>) {
183   MS_EXCEPTION_IF_NULL(value);
184   if (value->isa<StringImm>()) {
185     return GetValue<std::string>(value);
186   }
187   int64_t format_id = GetCastIntegralValue<int64_t>(value);
188   return GEDataFormat::ConvertEnumToString(format_id);
189 }
190 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEPadMod>)191 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEPadMod>) {
192   MS_EXCEPTION_IF_NULL(value);
193   if (value->isa<StringImm>()) {
194     return GetValue<std::string>(value);
195   }
196   int64_t pad_id = GetCastIntegralValue<int64_t>(value);
197   return GEPadMod::ConvertEnumToString(pad_id);
198 }
199 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEReduction>)200 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEReduction>) {
201   MS_EXCEPTION_IF_NULL(value);
202   if (value->isa<StringImm>()) {
203     return GetValue<std::string>(value);
204   }
205   int64_t reduction_id = GetCastIntegralValue<int64_t>(value);
206   return GEReduction::ConvertEnumToString(reduction_id);
207 }
208 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<AscendQuantRoundMode>)209 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AscendQuantRoundMode>) {
210   MS_EXCEPTION_IF_NULL(value);
211   if (value->isa<StringImm>()) {
212     return GetValue<std::string>(value);
213   }
214   int64_t round_mode_id = GetCastIntegralValue<int64_t>(value);
215   return AscendQuantRoundMode::ConvertEnumToString(round_mode_id);
216 }
217 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<FASInputLayoutMode>)218 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FASInputLayoutMode>) {
219   MS_EXCEPTION_IF_NULL(value);
220   if (value->isa<StringImm>()) {
221     return GetValue<std::string>(value);
222   }
223   int64_t input_layout_id = GetCastIntegralValue<int64_t>(value);
224   return FASInputLayoutMode::ConvertEnumToString(input_layout_id);
225 }
226 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<FFNActivationMode>)227 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FFNActivationMode>) {
228   MS_EXCEPTION_IF_NULL(value);
229   if (value->isa<StringImm>()) {
230     return GetValue<std::string>(value);
231   }
232   int64_t activation_id = GetCastIntegralValue<int64_t>(value);
233   return FFNActivationMode::ConvertEnumToString(activation_id);
234 }
235 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<ScatterReduceMode>)236 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ScatterReduceMode>) {
237   MS_EXCEPTION_IF_NULL(value);
238   if (value->isa<StringImm>()) {
239     return GetValue<std::string>(value);
240   }
241   int64_t reduce_id = GetCastIntegralValue<int64_t>(value);
242   return ScatterReduceMode::ConvertEnumToString(reduce_id);
243 }
244 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GECoordinateTransformMode>)245 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GECoordinateTransformMode>) {
246   MS_EXCEPTION_IF_NULL(value);
247   if (value->isa<StringImm>()) {
248     return GetValue<std::string>(value);
249   }
250   int64_t mode_id = GetCastIntegralValue<int64_t>(value);
251   return GECoordinateTransformMode::ConvertEnumToString(mode_id);
252 }
253 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEEnumToStr>,const std::vector<std::string> & enum_string)254 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEEnumToStr>,
255                            const std::vector<std::string> &enum_string) {
256   MS_EXCEPTION_IF_NULL(value);
257 
258   if (value->isa<StringImm>()) {
259     return GetValue<std::string>(value);
260   }
261   int64_t id = GetCastIntegralValue<int64_t>(value);
262   if (id < 0 || id >= static_cast<int64_t>(enum_string.size())) {
263     MS_LOG(EXCEPTION) << "Invalid enum id " << id;
264     return "";
265   }
266   return enum_string[id];
267 }
268 
269 template <typename T1, typename T2>
NestedVectorToTensorImpl(const ValuePtrList & vec,const TypeId & type)270 GeTensor NestedVectorToTensorImpl(const ValuePtrList &vec, const TypeId &type) {
271   const auto &vec_item =
272     vec[0]->isa<ValueTuple>() ? vec[0]->cast<ValueTuplePtr>()->value() : vec[0]->cast<ValueListPtr>()->value();
273   size_t attr_size1 = vec.size();
274   size_t attr_size2 = vec_item.size();
275   std::vector<T1> attr_list;
276   for (const auto &item : vec) {
277     auto value_list = ops::GetValueWithCheck<std::vector<T1>>(item);
278     (void)std::copy(value_list.begin(), value_list.end(), std::back_inserter(attr_list));
279   }
280   auto attr_value = MakeValue(attr_list);
281   auto data = ConvertAnyUtil(attr_value, AnyTraits<T1>(), AnyTraits<std::vector<T2>>());
282   auto desc =
283     TransformUtil::GetGeTensorDesc({static_cast<int>(attr_size1), static_cast<int>(attr_size2)}, type, kOpFormat_NCHW);
284   if (desc == nullptr) {
285     MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
286   }
287   return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(T2));
288 }
289 
NestedVectorToTensor(const ValuePtr & value)290 GeTensor NestedVectorToTensor(const ValuePtr &value) {
291   MS_EXCEPTION_IF_NULL(value);
292   const auto &vec =
293     value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
294   const auto &vec_item =
295     vec[0]->isa<ValueTuple>() ? vec[0]->cast<ValueTuplePtr>()->value() : vec[0]->cast<ValueListPtr>()->value();
296   if (vec_item.empty()) {
297     MS_LOG(WARNING) << "Convert a none nested tuple to an empty ge tensor";
298     return GeTensor(GeTensorDesc(::ge::Shape({0})));
299   }
300   MS_EXCEPTION_IF_NULL(vec_item[0]);
301   TypeId type;
302   if (vec_item[0]->isa<Int32Imm>()) {
303     type = kNumberTypeInt32;
304     return NestedVectorToTensorImpl<int32_t, int32_t>(vec, type);
305   } else if (vec_item[0]->isa<Int64Imm>()) {
306     type = kNumberTypeInt64;
307     return NestedVectorToTensorImpl<int64_t, int64_t>(vec, type);
308   } else if (vec_item[0]->isa<FP32Imm>()) {
309     type = kNumberTypeFloat32;
310     return NestedVectorToTensorImpl<float, float>(vec, type);
311   } else if (vec_item[0]->isa<BoolImm>()) {
312     type = kNumberTypeBool;
313     return NestedVectorToTensorImpl<bool, uint8_t>(vec, type);
314   } else {
315     MS_LOG(EXCEPTION) << "Unsupported data type of nested tuple or list elements: " << vec_item[0]->type_name();
316   }
317 }
318 
319 template <typename T1, typename T2>
VectorToTensorImpl(const ValuePtr & value,const TypeId & type)320 GeTensor VectorToTensorImpl(const ValuePtr &value, const TypeId &type) {
321   const auto &vec =
322     value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
323   auto data = ConvertAnyUtil(value, AnyTraits<T1>(), AnyTraits<std::vector<T2>>());
324   auto format = vec.size() == kDim4 ? kOpFormat_NCHW : kOpFormat_ND;
325   auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, type, format);
326   if (desc == nullptr) {
327     MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
328   }
329   return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(T2));
330 }
331 
VectorToTensorUtil(const ValuePtr & value)332 GeTensor VectorToTensorUtil(const ValuePtr &value) {
333   MS_EXCEPTION_IF_NULL(value);
334   auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
335   if (vec.empty()) {
336     MS_LOG(INFO) << "Convert a none tuple to an empty ge tensor";
337     return GeTensor(GeTensorDesc(::ge::Shape({0}), ::ge::FORMAT_ND, ::ge::DT_INT64));
338   }
339   MS_EXCEPTION_IF_NULL(vec[0]);
340   TypeId type;
341   if (vec[0]->isa<Int32Imm>()) {
342     MS_LOG(INFO) << "convert value to tensor with data type = Int32";
343     type = kNumberTypeInt32;
344     return VectorToTensorImpl<int32_t, int32_t>(value, type);
345   } else if (vec[0]->isa<Int64Imm>()) {
346     MS_LOG(INFO) << "convert value to tensor with data type = Int64";
347     type = kNumberTypeInt64;
348     return VectorToTensorImpl<int64_t, int64_t>(value, type);
349   } else if (vec[0]->isa<FP32Imm>()) {
350     MS_LOG(INFO) << "convert value to tensor with data type = Float32";
351     type = kNumberTypeFloat32;
352     return VectorToTensorImpl<float, float>(value, type);
353   } else if (vec[0]->isa<BoolImm>()) {
354     MS_LOG(INFO) << "convert value to tensor with data type = Bool";
355     type = kNumberTypeBool;
356     return VectorToTensorImpl<bool, uint8_t>(value, type);
357   } else if (vec[0]->isa<ValueTuple>() || vec[0]->isa<ValueList>()) {
358     // convert nested tuple or list to ge tensor, supported two dims
359     MS_LOG(INFO) << "Convert nested tuple or list to ge tensor.";
360     return NestedVectorToTensor(value);
361   } else {
362     MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name();
363   }
364 }
365 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<ValueAny>)366 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ValueAny>) {
367   MS_EXCEPTION_IF_NULL(value);
368   if (value->isa<MeTensor>()) {
369     // convert me tensor to ge tensor
370     return ConvertAnyUtil(value, AnyTraits<MeTensor>());
371   } else if (value->isa<ValueList>() || value->isa<ValueTuple>()) {
372     return VectorToTensorUtil(value);
373   } else if (value->isa<Int32Imm>()) {
374     // convert scalar Int to GeTensor
375     MS_LOG(INFO) << "convert scalar to tensor with data type = Int32";
376     GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_INT32);
377     auto v = GetValue<int32_t>(value);
378     desc.SetRealDimCnt(0);
379     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int32_t));
380   } else if (value->isa<UInt32Imm>()) {
381     // convert scalar UInt to GeTensor
382     MS_LOG(INFO) << "Convert scalar to tensor with data type = UInt32";
383     GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_UINT32);
384     auto v = GetValue<uint32_t>(value);
385     desc.SetRealDimCnt(0);
386     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(uint32_t));
387   } else if (value->isa<Int64Imm>()) {
388     // convert scalar Int64 to GeTensor
389     MS_LOG(INFO) << "convert scalar to tensor with data type = Int64";
390     GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_INT64);
391     auto v = GetValue<int64_t>(value);
392     desc.SetRealDimCnt(0);
393     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int64_t));
394   } else if (value->isa<FP32Imm>()) {
395     // convert scalar FP32 to GeTensor
396     MS_LOG(INFO) << "convert scalar to tensor with data type = FP32";
397     GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_FLOAT);
398     auto v = GetValue<float>(value);
399     desc.SetRealDimCnt(0);
400     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(float));
401   } else if (value->isa<BoolImm>()) {
402     // convert scalar FP32 to GeTensor
403     MS_LOG(INFO) << "convert scalar to tensor with data type = Bool";
404     GeTensorDesc desc(GeShape(), ::ge::FORMAT_ND, ::ge::DT_BOOL);
405     auto v = GetValue<bool>(value);
406     desc.SetRealDimCnt(0);
407     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(bool));
408   } else if (value->isa<StringImm>()) {
409     // convert String to GeTensor
410     MS_LOG(INFO) << "convert string to tensor with data type = String";
411     std::string v = GetValue<std::string>(value);
412     std::vector<int64_t> ge_shape;
413     GeShape shape(ge_shape);
414     GeTensorDesc desc(shape, ::ge::FORMAT_ND, ::ge::DT_STRING);
415     GeTensor str_tensor(desc);
416     (void)str_tensor.SetData(v);
417     return str_tensor;
418   } else {
419     MS_LOG(INFO) << "Unsupported value type: " << value->type_name()
420                  << " to convert to tensor. Value: " << value->ToString();
421   }
422   return GeTensor();
423 }
424 
IsCustomPrim(const PrimitivePtr & prim)425 bool IsCustomPrim(const PrimitivePtr &prim) {
426   if (prim == nullptr) {
427     return false;
428   }
429 
430   if (prim->name() == "Custom") {
431     return true;
432   }
433   return false;
434 }
435 
IsNoNeedConstantFoldCNode(const PrimitivePtr & prim)436 bool IsNoNeedConstantFoldCNode(const PrimitivePtr &prim) {
437   // ON_THE_FLY Quantization node dont need constant folding.
438   return prim->GetAttr("no_need_constant_folding") != nullptr;
439 }
440 
IsCustomCNode(const AnfNodePtr & anf)441 bool IsCustomCNode(const AnfNodePtr &anf) {
442   if (anf == nullptr) {
443     return false;
444   }
445   auto node = anf->cast<CNodePtr>();
446   if (node == nullptr) {
447     return false;
448   }
449   if (node->inputs().empty()) {
450     MS_LOG(EXCEPTION) << "Length of node inputs is empty";
451   }
452   MS_EXCEPTION_IF_NULL(node->inputs()[0]);
453   if (!node->inputs()[0]->isa<ValueNode>()) {
454     return false;
455   }
456   auto cus_prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
457   if (cus_prim == nullptr) {
458     return false;
459   }
460 
461   return IsCustomPrim(cus_prim);
462 }
463 
GetOpIOFormat(const AnfNodePtr & anf)464 std::string GetOpIOFormat(const AnfNodePtr &anf) {
465   std::string ret;
466   if (anf == nullptr) {
467     MS_LOG(ERROR) << "The anf is nullptr";
468     return ret;
469   }
470   auto node = anf->cast<CNodePtr>();
471   if (node == nullptr) {
472     MS_LOG(ERROR) << "The anf is not a cnode.";
473     return ret;
474   }
475   if (node->inputs().empty()) {
476     MS_LOG(EXCEPTION) << "Length of node inputs is empty.";
477   }
478   MS_EXCEPTION_IF_NULL(node->input(0));
479   auto &input = node->input(0);
480   AnfNodePtr prim_node = nullptr;
481   if (input->isa<ValueNode>()) {
482     prim_node = input;
483   } else if (input->isa<CNode>() && input->cast<CNodePtr>()->input(0)->isa<ValueNode>()) {
484     // process cnode1, its input(index 0) is a conde0(partial etc.)
485     prim_node = input->cast<CNodePtr>()->input(0);
486   } else {
487     MS_LOG(ERROR) << "The anf is not a value node or cnode.";
488     return ret;
489   }
490   MS_EXCEPTION_IF_NULL(prim_node);
491   auto prim = GetValueNode<PrimitivePtr>(prim_node);
492   if (prim == nullptr) {
493     MS_LOG(ERROR) << "The anf is not a Primitive.";
494     return ret;
495   }
496   if (prim->HasAttr("io_format")) {
497     return ops::GetValueWithCheck<std::string>(prim->GetAttr("io_format"));
498   }
499   auto io_format_map = IOFormatMap::get();
500   auto iter = io_format_map.find(prim->name());
501   if (iter == io_format_map.end()) {
502     return kOpFormat_DEFAULT;
503   }
504   if (iter->second == "format") {
505     ValuePtr format = prim->GetAttr("format");
506     MS_EXCEPTION_IF_NULL(format);
507     if (format->isa<Int64Imm>()) {
508       bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &format);
509       if (converted) {
510         return ops::GetValueWithCheck<std::string>(format);
511       }
512     } else {
513       return ops::GetValueWithCheck<std::string>(format);
514     }
515   }
516   return iter->second;
517 }
518 }  // namespace transform
519 }  // namespace mindspore
520