• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/convert_utils.h"
18 
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <algorithm>
23 #include <utility>
24 #include <cfloat>
25 
26 #include "ir/value.h"
27 #include "ir/tensor.h"
28 #include "ir/param_info.h"
29 #include "utils/ms_context.h"
30 
31 namespace mindspore {
ValueToBool(const ValuePtr & v,bool * value)32 bool ValueToBool(const ValuePtr &v, bool *value) {
33   MS_EXCEPTION_IF_NULL(v);
34   if (v->isa<BoolImm>()) {
35     *value = v->cast<BoolImmPtr>()->value();
36   } else if (v->isa<Int32Imm>()) {
37     *value = v->cast<Int32ImmPtr>()->value() != 0;
38   } else if (v->isa<UInt32Imm>()) {
39     *value = v->cast<UInt32ImmPtr>()->value() != 0;
40   } else if (v->isa<FP32Imm>()) {
41     *value = v->cast<FP32ImmPtr>()->value() != 0;
42   } else if (v->isa<FP64Imm>()) {
43     *value = v->cast<FP64ImmPtr>()->value() != 0;
44   } else if (v->isa<tensor::Tensor>()) {
45     auto tensor = v->cast<tensor::TensorPtr>();
46     MS_EXCEPTION_IF_NULL(tensor);
47     tensor->data_sync();
48     bool *tensor_data = static_cast<bool *>(tensor->data_c());
49     // maybe need to support if tensor is a bool array
50     auto vb = tensor_data[0];
51     *value = vb;
52   } else {
53     MS_LOG(WARNING) << "value is not supported to cast to be bool";
54     return false;
55   }
56   return true;
57 }
58 
BaseRefToInt(const ValuePtr & v,int64_t * value)59 bool BaseRefToInt(const ValuePtr &v, int64_t *value) {
60   MS_EXCEPTION_IF_NULL(v);
61   if (v->isa<tensor::Tensor>()) {
62     auto tensor = v->cast<tensor::TensorPtr>();
63     tensor->data_sync();
64     if (tensor->Dtype()->ToString() == "Int32") {
65       auto *tensor_data = static_cast<int32_t *>(tensor->data_c());
66       auto vb = tensor_data[0];
67       *value = static_cast<int64_t>(vb);
68     } else if (tensor->Dtype()->ToString() == "Int64") {
69       auto *tensor_data = static_cast<int64_t *>(tensor->data_c());
70       auto vb = tensor_data[0];
71       *value = vb;
72     } else {
73       MS_LOG(ERROR) << "Index must be Int type.";
74     }
75     return true;
76   }
77   MS_LOG(ERROR) << "Index must be tensor type.";
78   return false;
79 }
80 
BaseRefToBool(const BaseRef & v,bool * value)81 bool BaseRefToBool(const BaseRef &v, bool *value) {
82   if (utils::isa<ValuePtr>(v)) {
83     return ValueToBool(utils::cast<ValuePtr>(v), value);
84   } else if (utils::isa<bool>(v)) {
85     auto vb = utils::cast<bool>(v);
86     *value = vb;
87   } else if (utils::isa<int>(v)) {
88     auto vb = utils::cast<int>(v);
89     *value = vb != 0;
90   } else if (utils::isa<unsigned int>(v)) {
91     auto vb = utils::cast<unsigned int>(v);
92     *value = vb != 0;
93   } else if (utils::isa<float>(v)) {
94     auto vb = utils::cast<float>(v);
95     *value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON);
96   } else if (utils::isa<double>(v)) {
97     auto vb = utils::cast<double>(v);
98     *value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON);
99   } else {
100     MS_LOG(DEBUG) << "value is not supported to cast to be bool";
101     return false;
102   }
103   return true;
104 }
105 
106 namespace {
107 // Isomorphism
108 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
109               NodeMapEquiv *const equiv_node);
SameNodeShallow(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)110 bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
111                      NodeMapEquiv *const equiv_node) {
112   if (equiv_node == nullptr) {
113     MS_LOG(ERROR) << "Invalid equiv_node";
114     return false;
115   }
116   if (equiv_node->count(node1) > 0 && (*equiv_node)[node1] == node2) {
117     return true;
118   }
119   if (IsValueNode<FuncGraph>(node1) && IsValueNode<FuncGraph>(node2)) {
120     return Isomorphic(GetValueNode<FuncGraphPtr>(node1), GetValueNode<FuncGraphPtr>(node2), equiv_func_graph,
121                       equiv_node);
122   }
123   if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) {
124     auto a1 = GetValueNode(node1);
125     auto a2 = GetValueNode(node2);
126     if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
127       return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
128     } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
129       return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
130     } else {
131       return *a1 == *a2;
132     }
133   }
134   if (node1->isa<Parameter>() && node2->isa<Parameter>()) {
135     auto para1 = node1->cast<ParameterPtr>();
136     auto para2 = node2->cast<ParameterPtr>();
137     if (para1->name() == para2->name()) {
138       return true;
139     }
140     MS_LOG(DEBUG) << "two parameters are not equal.";
141     return false;
142   }
143   if (node1->isa<CNode>() && node2->isa<CNode>()) {
144     return SameNode(node1, node2, equiv_func_graph, equiv_node);
145   }
146   MS_LOG(ERROR) << "type error";
147   return false;
148 }
149 
SameNode(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)150 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
151               NodeMapEquiv *const equiv_node) {
152   MS_EXCEPTION_IF_NULL(node1);
153   MS_EXCEPTION_IF_NULL(node2);
154   if (node1->isa<CNode>() && node2->isa<CNode>()) {
155     auto &inputs1 = node1->cast<CNodePtr>()->inputs();
156     auto &inputs2 = node2->cast<CNodePtr>()->inputs();
157     for (std::size_t i = 0; i < inputs1.size(); ++i) {
158       if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) {
159         return false;
160       }
161     }
162     return true;
163   }
164   return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
165 }
166 
SameSubgraph(const AnfNodePtr & root1,const AnfNodePtr & root2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)167 bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph,
168                   NodeMapEquiv *const equiv_node) {
169   std::unordered_set<AnfNodePtr> done;
170   std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
171 
172   todo.push(std::make_pair(root1, root2));
173   while (!todo.empty()) {
174     AnfNodePtr node1 = todo.top().first;
175     if (done.count(node1) > 0) {
176       todo.pop();
177       continue;
178     }
179     AnfNodePtr node2 = todo.top().second;
180 
181     bool condition = false;
182     const auto &s1 = GetInputs(node1);
183     const auto &s2 = GetInputs(node2);
184 
185     if (s1.size() != s2.size()) {
186       return false;
187     }
188     for (std::size_t i = 0; i < s1.size(); ++i) {
189       if (done.count(s1[i]) == 0) {
190         todo.push(std::make_pair(s1[i], s2[i]));
191         condition = true;
192       }
193     }
194     if (condition) {
195       continue;
196     }
197     (void)done.insert(node1);
198 
199     auto res = SameNode(node1, node2, equiv_func_graph, equiv_node);
200     if (res) {
201       (*equiv_node)[node1] = node2;
202     } else {
203       return false;
204     }
205     todo.pop();
206   }
207   return true;
208 }
209 }  // namespace
210 
Isomorphic(const FuncGraphPtr & fg1,const FuncGraphPtr & fg2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)211 bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph,
212                 NodeMapEquiv *const equiv_node) {
213   auto fg1_fg2 = std::make_pair(fg1, fg2);
214   if (equiv_func_graph == nullptr) {
215     MS_LOG(ERROR) << "equiv_func_graph not init";
216     return false;
217   }
218   if (equiv_func_graph->find(fg1_fg2) != equiv_func_graph->end()) {
219     return (*equiv_func_graph)[fg1_fg2] != kNotEquiv;
220   }
221   if (fg1 == nullptr || fg2 == nullptr) {
222     MS_LOG(ERROR) << "Invalid function graph";
223     return false;
224   }
225   if (fg1->parameters().size() != fg2->parameters().size()) {
226     MS_LOG(DEBUG) << "parameters size not match";
227     return false;
228   }
229   if (equiv_node != nullptr) {
230     for (std::size_t i = 0; i < fg1->parameters().size(); ++i) {
231       (*equiv_node)[fg1->parameters()[i]] = fg2->parameters()[i];
232     }
233     (*equiv_func_graph)[fg1_fg2] = kPending;
234     auto result = SameSubgraph(fg1->get_return(), fg2->get_return(), equiv_func_graph, equiv_node);
235     (*equiv_func_graph)[fg1_fg2] = EquivState(result);
236     return result;
237   }
238 
239   MS_LOG(ERROR) << "equiv_node not init";
240   return false;
241 }
242 
ScalarToTensor(const ScalarPtr & scalar)243 tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
244   if (scalar == nullptr) {
245     MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
246   }
247   TypePtr data_type = scalar->type();
248   MS_EXCEPTION_IF_NULL(data_type);
249   TypeId type_id = data_type->type_id();
250   switch (type_id) {
251     case kNumberTypeBool:
252       return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type);
253     case kNumberTypeInt8:
254       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
255     case kNumberTypeInt16:
256       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
257     case kNumberTypeInt32:
258       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
259     case kNumberTypeInt64:
260       return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type);
261     case kNumberTypeUInt8:
262       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
263     case kNumberTypeUInt16:
264       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
265     case kNumberTypeUInt32:
266       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
267     case kNumberTypeUInt64:
268       return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type);
269     case kNumberTypeFloat32:
270       return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type);
271     case kNumberTypeFloat64:
272       return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type);
273     default:
274       MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << "is valid.";
275   }
276 }
277 
TensorValueToTensor(const ValuePtr & value,std::vector<tensor::TensorPtr> * tensors)278 void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
279   MS_EXCEPTION_IF_NULL(value);
280   MS_EXCEPTION_IF_NULL(tensors);
281   if (value->isa<ValueTuple>()) {
282     auto value_tuple = value->cast<ValueTuplePtr>();
283     MS_EXCEPTION_IF_NULL(value_tuple);
284     for (size_t i = 0; i < value_tuple->size(); ++i) {
285       ValuePtr element = value_tuple->value()[i];
286       if (element->isa<tensor::Tensor>()) {
287         auto tensor = element->cast<tensor::TensorPtr>();
288         MS_EXCEPTION_IF_NULL(tensor);
289         tensors->emplace_back(tensor);
290       } else if (element->isa<ValueTuple>()) {
291         TensorValueToTensor(element, tensors);
292       }
293     }
294   } else if (value->isa<tensor::Tensor>()) {
295     auto tensor = value->cast<tensor::TensorPtr>();
296     MS_EXCEPTION_IF_NULL(tensor);
297     tensors->emplace_back(tensor);
298   }
299 }
300 
CountValueNum(const ValueTuplePtr & value_tuple)301 size_t CountValueNum(const ValueTuplePtr &value_tuple) {
302   MS_EXCEPTION_IF_NULL(value_tuple);
303   size_t cnt = 0;
304   const auto &value_list = value_tuple->value();
305   for (const auto &value : value_list) {
306     if (value->isa<None>()) {
307       continue;
308     } else if (value->isa<ValueTuple>()) {
309       cnt += CountValueNum(value->cast<ValueTuplePtr>());
310     } else {
311       cnt++;
312     }
313   }
314   return cnt;
315 }
316 }  // namespace mindspore
317