1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/convert_utils.h"
18
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <algorithm>
23 #include <utility>
24 #include <cfloat>
25
26 #include "ir/value.h"
27 #include "ir/tensor.h"
28 #include "ir/param_info.h"
29 #include "utils/ms_context.h"
30
31 namespace mindspore {
ValueToBool(const ValuePtr & v,bool * value)32 bool ValueToBool(const ValuePtr &v, bool *value) {
33 MS_EXCEPTION_IF_NULL(v);
34 if (v->isa<BoolImm>()) {
35 *value = v->cast<BoolImmPtr>()->value();
36 } else if (v->isa<Int32Imm>()) {
37 *value = v->cast<Int32ImmPtr>()->value() != 0;
38 } else if (v->isa<UInt32Imm>()) {
39 *value = v->cast<UInt32ImmPtr>()->value() != 0;
40 } else if (v->isa<FP32Imm>()) {
41 *value = v->cast<FP32ImmPtr>()->value() != 0;
42 } else if (v->isa<FP64Imm>()) {
43 *value = v->cast<FP64ImmPtr>()->value() != 0;
44 } else if (v->isa<tensor::Tensor>()) {
45 auto tensor = v->cast<tensor::TensorPtr>();
46 MS_EXCEPTION_IF_NULL(tensor);
47 tensor->data_sync();
48 bool *tensor_data = static_cast<bool *>(tensor->data_c());
49 // maybe need to support if tensor is a bool array
50 auto vb = tensor_data[0];
51 *value = vb;
52 } else {
53 MS_LOG(WARNING) << "value is not supported to cast to be bool";
54 return false;
55 }
56 return true;
57 }
58
BaseRefToInt(const ValuePtr & v,int64_t * value)59 bool BaseRefToInt(const ValuePtr &v, int64_t *value) {
60 MS_EXCEPTION_IF_NULL(v);
61 if (v->isa<tensor::Tensor>()) {
62 auto tensor = v->cast<tensor::TensorPtr>();
63 tensor->data_sync();
64 if (tensor->Dtype()->ToString() == "Int32") {
65 auto *tensor_data = static_cast<int32_t *>(tensor->data_c());
66 auto vb = tensor_data[0];
67 *value = static_cast<int64_t>(vb);
68 } else if (tensor->Dtype()->ToString() == "Int64") {
69 auto *tensor_data = static_cast<int64_t *>(tensor->data_c());
70 auto vb = tensor_data[0];
71 *value = vb;
72 } else {
73 MS_LOG(ERROR) << "Index must be Int type.";
74 }
75 return true;
76 }
77 MS_LOG(ERROR) << "Index must be tensor type.";
78 return false;
79 }
80
BaseRefToBool(const BaseRef & v,bool * value)81 bool BaseRefToBool(const BaseRef &v, bool *value) {
82 if (utils::isa<ValuePtr>(v)) {
83 return ValueToBool(utils::cast<ValuePtr>(v), value);
84 } else if (utils::isa<bool>(v)) {
85 auto vb = utils::cast<bool>(v);
86 *value = vb;
87 } else if (utils::isa<int>(v)) {
88 auto vb = utils::cast<int>(v);
89 *value = vb != 0;
90 } else if (utils::isa<unsigned int>(v)) {
91 auto vb = utils::cast<unsigned int>(v);
92 *value = vb != 0;
93 } else if (utils::isa<float>(v)) {
94 auto vb = utils::cast<float>(v);
95 *value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON);
96 } else if (utils::isa<double>(v)) {
97 auto vb = utils::cast<double>(v);
98 *value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON);
99 } else {
100 MS_LOG(DEBUG) << "value is not supported to cast to be bool";
101 return false;
102 }
103 return true;
104 }
105
106 namespace {
107 // Isomorphism
108 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
109 NodeMapEquiv *const equiv_node);
SameNodeShallow(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)110 bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
111 NodeMapEquiv *const equiv_node) {
112 if (equiv_node == nullptr) {
113 MS_LOG(ERROR) << "Invalid equiv_node";
114 return false;
115 }
116 if (equiv_node->count(node1) > 0 && (*equiv_node)[node1] == node2) {
117 return true;
118 }
119 if (IsValueNode<FuncGraph>(node1) && IsValueNode<FuncGraph>(node2)) {
120 return Isomorphic(GetValueNode<FuncGraphPtr>(node1), GetValueNode<FuncGraphPtr>(node2), equiv_func_graph,
121 equiv_node);
122 }
123 if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) {
124 auto a1 = GetValueNode(node1);
125 auto a2 = GetValueNode(node2);
126 if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
127 return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
128 } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
129 return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
130 } else {
131 return *a1 == *a2;
132 }
133 }
134 if (node1->isa<Parameter>() && node2->isa<Parameter>()) {
135 auto para1 = node1->cast<ParameterPtr>();
136 auto para2 = node2->cast<ParameterPtr>();
137 if (para1->name() == para2->name()) {
138 return true;
139 }
140 MS_LOG(DEBUG) << "two parameters are not equal.";
141 return false;
142 }
143 if (node1->isa<CNode>() && node2->isa<CNode>()) {
144 return SameNode(node1, node2, equiv_func_graph, equiv_node);
145 }
146 MS_LOG(ERROR) << "type error";
147 return false;
148 }
149
SameNode(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)150 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
151 NodeMapEquiv *const equiv_node) {
152 MS_EXCEPTION_IF_NULL(node1);
153 MS_EXCEPTION_IF_NULL(node2);
154 if (node1->isa<CNode>() && node2->isa<CNode>()) {
155 auto &inputs1 = node1->cast<CNodePtr>()->inputs();
156 auto &inputs2 = node2->cast<CNodePtr>()->inputs();
157 for (std::size_t i = 0; i < inputs1.size(); ++i) {
158 if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) {
159 return false;
160 }
161 }
162 return true;
163 }
164 return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
165 }
166
SameSubgraph(const AnfNodePtr & root1,const AnfNodePtr & root2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)167 bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph,
168 NodeMapEquiv *const equiv_node) {
169 std::unordered_set<AnfNodePtr> done;
170 std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
171
172 todo.push(std::make_pair(root1, root2));
173 while (!todo.empty()) {
174 AnfNodePtr node1 = todo.top().first;
175 if (done.count(node1) > 0) {
176 todo.pop();
177 continue;
178 }
179 AnfNodePtr node2 = todo.top().second;
180
181 bool condition = false;
182 const auto &s1 = GetInputs(node1);
183 const auto &s2 = GetInputs(node2);
184
185 if (s1.size() != s2.size()) {
186 return false;
187 }
188 for (std::size_t i = 0; i < s1.size(); ++i) {
189 if (done.count(s1[i]) == 0) {
190 todo.push(std::make_pair(s1[i], s2[i]));
191 condition = true;
192 }
193 }
194 if (condition) {
195 continue;
196 }
197 (void)done.insert(node1);
198
199 auto res = SameNode(node1, node2, equiv_func_graph, equiv_node);
200 if (res) {
201 (*equiv_node)[node1] = node2;
202 } else {
203 return false;
204 }
205 todo.pop();
206 }
207 return true;
208 }
209 } // namespace
210
Isomorphic(const FuncGraphPtr & fg1,const FuncGraphPtr & fg2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)211 bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph,
212 NodeMapEquiv *const equiv_node) {
213 auto fg1_fg2 = std::make_pair(fg1, fg2);
214 if (equiv_func_graph == nullptr) {
215 MS_LOG(ERROR) << "equiv_func_graph not init";
216 return false;
217 }
218 if (equiv_func_graph->find(fg1_fg2) != equiv_func_graph->end()) {
219 return (*equiv_func_graph)[fg1_fg2] != kNotEquiv;
220 }
221 if (fg1 == nullptr || fg2 == nullptr) {
222 MS_LOG(ERROR) << "Invalid function graph";
223 return false;
224 }
225 if (fg1->parameters().size() != fg2->parameters().size()) {
226 MS_LOG(DEBUG) << "parameters size not match";
227 return false;
228 }
229 if (equiv_node != nullptr) {
230 for (std::size_t i = 0; i < fg1->parameters().size(); ++i) {
231 (*equiv_node)[fg1->parameters()[i]] = fg2->parameters()[i];
232 }
233 (*equiv_func_graph)[fg1_fg2] = kPending;
234 auto result = SameSubgraph(fg1->get_return(), fg2->get_return(), equiv_func_graph, equiv_node);
235 (*equiv_func_graph)[fg1_fg2] = EquivState(result);
236 return result;
237 }
238
239 MS_LOG(ERROR) << "equiv_node not init";
240 return false;
241 }
242
ScalarToTensor(const ScalarPtr & scalar)243 tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
244 if (scalar == nullptr) {
245 MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
246 }
247 TypePtr data_type = scalar->type();
248 MS_EXCEPTION_IF_NULL(data_type);
249 TypeId type_id = data_type->type_id();
250 switch (type_id) {
251 case kNumberTypeBool:
252 return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type);
253 case kNumberTypeInt8:
254 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
255 case kNumberTypeInt16:
256 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
257 case kNumberTypeInt32:
258 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
259 case kNumberTypeInt64:
260 return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type);
261 case kNumberTypeUInt8:
262 return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
263 case kNumberTypeUInt16:
264 return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
265 case kNumberTypeUInt32:
266 return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
267 case kNumberTypeUInt64:
268 return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type);
269 case kNumberTypeFloat32:
270 return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type);
271 case kNumberTypeFloat64:
272 return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type);
273 default:
274 MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << "is valid.";
275 }
276 }
277
TensorValueToTensor(const ValuePtr & value,std::vector<tensor::TensorPtr> * tensors)278 void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
279 MS_EXCEPTION_IF_NULL(value);
280 MS_EXCEPTION_IF_NULL(tensors);
281 if (value->isa<ValueTuple>()) {
282 auto value_tuple = value->cast<ValueTuplePtr>();
283 MS_EXCEPTION_IF_NULL(value_tuple);
284 for (size_t i = 0; i < value_tuple->size(); ++i) {
285 ValuePtr element = value_tuple->value()[i];
286 if (element->isa<tensor::Tensor>()) {
287 auto tensor = element->cast<tensor::TensorPtr>();
288 MS_EXCEPTION_IF_NULL(tensor);
289 tensors->emplace_back(tensor);
290 } else if (element->isa<ValueTuple>()) {
291 TensorValueToTensor(element, tensors);
292 }
293 }
294 } else if (value->isa<tensor::Tensor>()) {
295 auto tensor = value->cast<tensor::TensorPtr>();
296 MS_EXCEPTION_IF_NULL(tensor);
297 tensors->emplace_back(tensor);
298 }
299 }
300
CountValueNum(const ValueTuplePtr & value_tuple)301 size_t CountValueNum(const ValueTuplePtr &value_tuple) {
302 MS_EXCEPTION_IF_NULL(value_tuple);
303 size_t cnt = 0;
304 const auto &value_list = value_tuple->value();
305 for (const auto &value : value_list) {
306 if (value->isa<None>()) {
307 continue;
308 } else if (value->isa<ValueTuple>()) {
309 cnt += CountValueNum(value->cast<ValueTuplePtr>());
310 } else {
311 cnt++;
312 }
313 }
314 return cnt;
315 }
316 } // namespace mindspore
317