• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/op_adapter_util.h"
18 
19 #include <string>
20 #include <vector>
21 #include <algorithm>
22 
23 #include "utils/utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "transform/graph_ir/op_adapter_base.h"
26 #include "transform/graph_ir/io_format_map.h"
27 
28 namespace mindspore {
29 namespace transform {
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> &)30 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &) {
31   // To-DO the format may read from ME tensor
32   MS_EXCEPTION_IF_NULL(value);
33   auto me_tensor = value->cast<MeTensorPtr>();
34   auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_ND);
35   return ge_tensor == nullptr ? GeTensor() : *ge_tensor;
36 }
37 
ConvertAnyUtil(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>>)38 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
39                                     const AnyTraits<std::vector<int64_t>>) {
40   MS_EXCEPTION_IF_NULL(value);
41   std::vector<int64_t> list;
42   if (name == "pad") {
43     if (!value->isa<ValueSequeue>()) {
44       MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
45     }
46     auto vec = value->cast<ValueSequeuePtr>();
47     list.resize(vec->value().size() + 2);
48     list[0] = 1;
49     list[1] = 1;
50     (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2,
51                          [](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int64_t>(val)); });
52   } else {
53     int64_t data = GetValue<int64_t>(value);
54     int size = 2;  // 2 int in list
55     list = TransformUtil::ConvertIntToList(data, size);
56   }
57 
58   return list;
59 }
60 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::string>)61 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, const AnyTraits<std::string>) {
62   MS_EXCEPTION_IF_NULL(value);
63   auto vec = value->cast<ValueTuplePtr>();
64   if (vec == nullptr) {
65     MS_LOG(EXCEPTION) << "not ValueTuplePtr";
66   }
67   std::ostringstream buffer;
68   int i = 0;
69   for (auto &it : vec->value()) {
70     if (i != 0) {
71       buffer << ",";
72     }
73     buffer << GetValue<int64_t>(it);
74     i++;
75   }
76   return buffer.str();
77 }
78 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<std::vector<float>>,const AnyTraits<float>)79 std::vector<float> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<float>>, const AnyTraits<float>) {
80   MS_EXCEPTION_IF_NULL(value);
81   auto vec = value->cast<ValueTuplePtr>();
82   if (vec == nullptr) {
83     MS_LOG(EXCEPTION) << "not ValueTuplePtr";
84   }
85   std::vector<float> list;
86   list.resize(vec->value().size());
87   (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
88                        [](const ValuePtr &val) { return static_cast<float>(GetValue<float>(val)); });
89   return list;
90 }
91 
ConvertAnyUtil(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>>,const AnyTraits<int64_t>)92 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &format,
93                                     const AnyTraits<std::vector<int64_t>>, const AnyTraits<int64_t>) {
94   MS_EXCEPTION_IF_NULL(value);
95   auto vec = value->cast<ValueTuplePtr>();
96   if (vec == nullptr) {
97     MS_LOG(EXCEPTION) << "not ValueTuplePtr";
98   }
99   std::vector<int64_t> list;
100   list.resize(vec->value().size());
101   (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
102                        [](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int64_t>(val)); });
103   if (format == kOpFormat_NHWC) {
104     if (list.size() < 4) {
105       MS_LOG(EXCEPTION) << "The size of list is less than 4";
106     } else {
107       int64_t temp = list[1];
108       list[1] = list[2];
109       list[2] = list[3];
110       list[3] = temp;
111     }
112   }
113   return list;
114 }
115 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<GEType>)116 GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEType>) {
117   MS_EXCEPTION_IF_NULL(value);
118   if (!value->isa<Type>()) {
119     MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString()
120                       << ", type: " << value->type_name() << ", value should be a Typeptr";
121   }
122   auto type = value->cast<TypePtr>();
123   MS_EXCEPTION_IF_NULL(type);
124   TypeId me_type = type->type_id();
125   if (kObjectTypeTensorType == me_type) {
126     me_type = dyn_cast<TensorType>(type)->element()->type_id();
127   }
128   return TransformUtil::ConvertDataType(me_type);
129 }
130 
VectorToTensorUtil(const ValuePtr & value)131 GeTensor VectorToTensorUtil(const ValuePtr &value) {
132   // convert tuple or list to ge tensor, only supported one dim for now
133   MS_EXCEPTION_IF_NULL(value);
134   auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
135   if (vec.empty()) {
136     MS_LOG(WARNING) << "Convert a none tuple to an empty ge tensor";
137     return GeTensor(GeTensorDesc(ge::Shape({0})));
138   }
139   MS_EXCEPTION_IF_NULL(vec[0]);
140   if (vec[0]->isa<Int32Imm>()) {
141     MS_LOG(INFO) << "convert value to tensor with data type = Int32";
142     auto data = ConvertAnyUtil(value, AnyTraits<int32_t>(), AnyTraits<std::vector<int32_t>>());
143     auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeInt32, kOpFormat_NCHW);
144     if (desc == nullptr) {
145       MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
146     }
147     return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(int32_t));
148   } else if (vec[0]->isa<Int64Imm>()) {
149     MS_LOG(INFO) << "convert value to tensor with data type = Int64";
150     auto data = ConvertAnyUtil(value, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>());
151     auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeInt64, kOpFormat_NCHW);
152     if (desc == nullptr) {
153       MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
154     }
155     return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(int64_t));
156   } else if (vec[0]->isa<FP32Imm>()) {
157     MS_LOG(INFO) << "convert value to tensor with data type = Float32";
158     auto data = ConvertAnyUtil(value, AnyTraits<float>(), AnyTraits<std::vector<float>>());
159     auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeFloat32, kOpFormat_NCHW);
160     if (desc == nullptr) {
161       MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
162     }
163     return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(float));
164   } else if (vec[0]->isa<BoolImm>()) {
165     MS_LOG(INFO) << "convert value to tensor with data type = Bool";
166     // We use uint8_t to save bool type data
167     auto data = ConvertAnyUtil(value, AnyTraits<bool>(), AnyTraits<std::vector<uint8_t>>());
168     auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeBool, kOpFormat_NCHW);
169     if (desc == nullptr) {
170       MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
171     }
172     return GeTensor(*desc, static_cast<uint8_t *>(data.data()), data.size() * sizeof(uint8_t));
173   } else {
174     MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name();
175   }
176 
177   return GeTensor();
178 }
179 
ConvertAnyUtil(const ValuePtr & value,const AnyTraits<AnyValue>)180 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>) {
181   MS_EXCEPTION_IF_NULL(value);
182   if (value->isa<MeTensor>()) {
183     // convert me tensor to ge tensor
184     return ConvertAnyUtil(value, AnyTraits<MeTensor>());
185   } else if (value->isa<ValueList>() || value->isa<ValueTuple>()) {
186     return VectorToTensorUtil(value);
187   } else if (value->isa<Int32Imm>()) {
188     // convert scalar Int to GeTensor
189     MS_LOG(INFO) << "convert scalar to tensor with data type = Int32";
190     GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
191     auto v = GetValue<int32_t>(value);
192     desc.SetRealDimCnt(0);
193     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int32_t));
194   } else if (value->isa<Int64Imm>()) {
195     // convert scalar Int64 to GeTensor
196     MS_LOG(INFO) << "convert scalar to tensor with data type = Int64";
197     GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
198     auto v = GetValue<int64_t>(value);
199     desc.SetRealDimCnt(0);
200     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int64_t));
201   } else if (value->isa<FP32Imm>()) {
202     // convert scalar FP32 to GeTensor
203     MS_LOG(INFO) << "convert scalar to tensor with data type = FP32";
204     GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
205     auto v = GetValue<float>(value);
206     desc.SetRealDimCnt(0);
207     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(float));
208   } else if (value->isa<BoolImm>()) {
209     // convert scalar FP32 to GeTensor
210     MS_LOG(INFO) << "convert scalar to tensor with data type = Bool";
211     GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL);
212     auto v = GetValue<bool>(value);
213     desc.SetRealDimCnt(0);
214     return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(bool));
215   } else if (value->isa<StringImm>()) {
216     // convert String to GeTensor
217     MS_LOG(INFO) << "convert string to tensor with data type = String";
218     std::string v = GetValue<std::string>(value);
219     std::vector<int64_t> ge_shape;
220     GeShape shape(ge_shape);
221     GeTensorDesc desc(shape, ge::FORMAT_NCHW, ge::DT_STRING);
222     GeTensor str_tensor(desc);
223     str_tensor.SetData(v);
224     return str_tensor;
225   } else {
226     MS_LOG(WARNING) << "Unsupported value type: " << value->type_name()
227                     << " to convert to tensor. Value: " << value->ToString();
228   }
229   return GeTensor();
230 }
231 
IsCustomPrim(const PrimitivePtr & prim)232 bool IsCustomPrim(const PrimitivePtr &prim) {
233   if (prim == nullptr) {
234     return false;
235   }
236 
237   ValuePtr flag = prim->GetAttr("_custom_op_flag");
238   if (flag == nullptr) {
239     return false;
240   }
241 
242   bool is_custom_op = GetValue<bool>(flag);
243   if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) {
244     MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op "
245                          "can not assign the op information config path.";
246   }
247 
248   return is_custom_op;
249 }
250 
IsCustomCNode(const AnfNodePtr & anf)251 bool IsCustomCNode(const AnfNodePtr &anf) {
252   if (anf == nullptr) {
253     return false;
254   }
255   auto node = anf->cast<CNodePtr>();
256   if (node == nullptr) {
257     return false;
258   }
259   if (node->inputs().empty()) {
260     MS_LOG(EXCEPTION) << "Length of node inputs is empty";
261   }
262   MS_EXCEPTION_IF_NULL(node->inputs()[0]);
263   if (!node->inputs()[0]->isa<ValueNode>()) {
264     return false;
265   }
266   auto cus_prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
267   if (cus_prim == nullptr) {
268     return false;
269   }
270 
271   return IsCustomPrim(cus_prim);
272 }
273 
GetOpIOFormat(const AnfNodePtr & anf)274 std::string GetOpIOFormat(const AnfNodePtr &anf) {
275   std::string ret;
276   if (anf == nullptr) {
277     MS_LOG(ERROR) << "The anf is nullptr";
278     return ret;
279   }
280   auto node = anf->cast<CNodePtr>();
281   if (node == nullptr) {
282     MS_LOG(ERROR) << "The anf is not a cnode.";
283     return ret;
284   }
285   if (node->inputs().empty()) {
286     MS_LOG(EXCEPTION) << "Length of node inputs is empty.";
287   }
288   MS_EXCEPTION_IF_NULL(node->inputs()[0]);
289   if (!node->inputs()[0]->isa<ValueNode>()) {
290     MS_LOG(ERROR) << "The anf is not a value node.";
291     return ret;
292   }
293   auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
294   if (prim == nullptr) {
295     MS_LOG(ERROR) << "The anf is not a Primitive.";
296     return ret;
297   }
298   if (prim->HasAttr("io_format")) {
299     return GetValue<std::string>(prim->GetAttr("io_format"));
300   }
301   auto io_format_map = IOFormatMap::get();
302   auto iter = io_format_map.find(prim->name());
303   if (iter == io_format_map.end()) {
304     return "NCHW";
305   }
306   if (iter->second == "format") {
307     ValuePtr format = prim->GetAttr("format");
308     MS_EXCEPTION_IF_NULL(format);
309     if (format->isa<Int64Imm>()) {
310       bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &format);
311       if (converted) {
312         return GetValue<std::string>(format);
313       }
314     } else {
315       return GetValue<std::string>(format);
316     }
317   }
318   return iter->second;
319 }
320 }  // namespace transform
321 }  // namespace mindspore
322