1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/common/utils/convert_utils.h"
18 #include <algorithm>
19 #include <cfloat>
20 #include <cmath>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "include/common/utils/utils.h"
27 #include "ir/tensor.h"
28 #include "ir/value.h"
29 #include "mindspore/core/ops/sparse_ops.h"
30 #include "utils/anf_utils.h"
31 #include "utils/ms_context.h"
32 #include "utils/hashing.h"
33
34 namespace mindspore {
ValueToBool(const ValuePtr & v,bool * value)35 bool ValueToBool(const ValuePtr &v, bool *value) {
36 MS_EXCEPTION_IF_NULL(v);
37 if (v->isa<BoolImm>()) {
38 *value = v->cast<BoolImmPtr>()->value();
39 } else if (v->isa<Int32Imm>()) {
40 *value = v->cast<Int32ImmPtr>()->value() != 0;
41 } else if (v->isa<UInt32Imm>()) {
42 *value = v->cast<UInt32ImmPtr>()->value() != 0;
43 } else if (v->isa<FP32Imm>()) {
44 *value = fabs(v->cast<FP32ImmPtr>()->value()) > FLT_EPSILON;
45 } else if (v->isa<FP64Imm>()) {
46 *value = fabs(v->cast<FP64ImmPtr>()->value()) > DBL_EPSILON;
47 } else if (v->isa<StringImm>()) {
48 std::string str = v->cast<StringImmPtr>()->value();
49 *value = str.length() != 0;
50 } else if (v->isa<tensor::Tensor>()) {
51 auto tensor = v->cast<tensor::TensorPtr>();
52 MS_EXCEPTION_IF_NULL(tensor);
53 tensor->data_sync();
54 bool *tensor_data = static_cast<bool *>(tensor->data_c());
55 // maybe need to support if tensor is a bool array
56 auto vb = tensor_data[0];
57 *value = vb;
58 } else {
59 MS_LOG(WARNING) << "value is not supported to cast to be bool";
60 return false;
61 }
62 return true;
63 }
64
BaseRefToInt(const ValuePtr & v,int64_t * value)65 bool BaseRefToInt(const ValuePtr &v, int64_t *value) {
66 MS_EXCEPTION_IF_NULL(v);
67 if (v->isa<tensor::Tensor>()) {
68 auto tensor = v->cast<tensor::TensorPtr>();
69 tensor->data_sync();
70 if (tensor->Dtype()->ToString() == "Int32") {
71 auto *tensor_data = static_cast<int32_t *>(tensor->data_c());
72 auto vb = tensor_data[0];
73 *value = static_cast<int64_t>(vb);
74 } else if (tensor->Dtype()->ToString() == "Int64") {
75 auto *tensor_data = static_cast<int64_t *>(tensor->data_c());
76 auto vb = tensor_data[0];
77 *value = vb;
78 } else {
79 MS_LOG(ERROR) << "Index must be Int type.";
80 }
81 return true;
82 }
83 MS_LOG(ERROR) << "Index must be tensor type.";
84 return false;
85 }
86
BaseRefToBool(const BaseRef & v,bool * value)87 bool BaseRefToBool(const BaseRef &v, bool *value) {
88 if (utils::isa<ValuePtr>(v)) {
89 return ValueToBool(utils::cast<ValuePtr>(v), value);
90 } else if (utils::isa<bool>(v)) {
91 auto vb = utils::cast<bool>(v);
92 *value = vb;
93 } else if (utils::isa<int>(v)) {
94 auto vb = utils::cast<int>(v);
95 *value = vb != 0;
96 } else if (utils::isa<unsigned int>(v)) {
97 auto vb = utils::cast<unsigned int>(v);
98 *value = vb != 0;
99 } else if (utils::isa<float>(v)) {
100 auto vb = utils::cast<float>(v);
101 *value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON);
102 } else if (utils::isa<double>(v)) {
103 auto vb = utils::cast<double>(v);
104 *value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON);
105 } else {
106 MS_LOG(DEBUG) << "value is not supported to cast to be bool";
107 return false;
108 }
109 return true;
110 }
111
112 namespace {
113 // Isomorphism
114 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
115 NodeMapEquiv *const equiv_node);
116
SameValueNode(const AnfNodePtr & node1,const AnfNodePtr & node2)117 bool SameValueNode(const AnfNodePtr &node1, const AnfNodePtr &node2) {
118 auto a1 = GetValueNode(node1);
119 auto a2 = GetValueNode(node2);
120 if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
121 return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
122 } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
123 return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
124 }
125 return *a1 == *a2;
126 }
127
SameNodeShallow(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)128 bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
129 NodeMapEquiv *const equiv_node) {
130 if (equiv_node == nullptr) {
131 MS_LOG(ERROR) << "Invalid equiv_node";
132 return false;
133 }
134 if (equiv_node->count(node1) > 0 && (*equiv_node)[node1] == node2) {
135 return true;
136 }
137 if (IsValueNode<FuncGraph>(node1) && IsValueNode<FuncGraph>(node2)) {
138 return Isomorphic(GetValueNode<FuncGraphPtr>(node1), GetValueNode<FuncGraphPtr>(node2), equiv_func_graph,
139 equiv_node);
140 }
141 if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) {
142 return SameValueNode(node1, node2);
143 }
144 if (node1->isa<Parameter>() && node2->isa<Parameter>()) {
145 auto para1 = node1->cast<ParameterPtr>();
146 auto para2 = node2->cast<ParameterPtr>();
147 if (para1->name() == para2->name()) {
148 return true;
149 }
150 MS_LOG(DEBUG) << "two parameters are not equal.";
151 return false;
152 }
153 if (AnfUtils::IsCustomActorNode(node1) && AnfUtils::IsCustomActorNode(node2)) {
154 return AnfUtils::IsCutomActorNodeSame(node1, node2);
155 }
156 if (node1->isa<CNode>() && node2->isa<CNode>()) {
157 return SameNode(node1, node2, equiv_func_graph, equiv_node);
158 }
159 MS_LOG(ERROR) << "type error";
160 return false;
161 }
162
SameNode(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)163 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
164 NodeMapEquiv *const equiv_node) {
165 MS_EXCEPTION_IF_NULL(node1);
166 MS_EXCEPTION_IF_NULL(node2);
167 if (node1->isa<CNode>() && node2->isa<CNode>()) {
168 auto &inputs1 = node1->cast<CNodePtr>()->inputs();
169 auto &inputs2 = node2->cast<CNodePtr>()->inputs();
170 for (std::size_t i = 0; i < inputs1.size(); ++i) {
171 if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) {
172 return false;
173 }
174 }
175 return true;
176 }
177 return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
178 }
179
SameSubgraph(const AnfNodePtr & root1,const AnfNodePtr & root2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)180 bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph,
181 NodeMapEquiv *const equiv_node) {
182 mindspore::HashSet<AnfNodePtr> done;
183 std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
184
185 todo.push(std::make_pair(root1, root2));
186 while (!todo.empty()) {
187 AnfNodePtr node1 = todo.top().first;
188 if (done.count(node1) > 0) {
189 todo.pop();
190 continue;
191 }
192 AnfNodePtr node2 = todo.top().second;
193
194 bool condition = false;
195 const auto &s1 = GetInputs(node1);
196 const auto &s2 = GetInputs(node2);
197
198 if (s1.size() != s2.size()) {
199 return false;
200 }
201 for (std::size_t i = 0; i < s1.size(); ++i) {
202 if (done.count(s1[i]) == 0) {
203 todo.push(std::make_pair(s1[i], s2[i]));
204 condition = true;
205 }
206 }
207 if (condition) {
208 continue;
209 }
210 (void)done.insert(node1);
211
212 auto res = SameNode(node1, node2, equiv_func_graph, equiv_node);
213 if (res) {
214 (*equiv_node)[node1] = node2;
215 } else {
216 return false;
217 }
218 todo.pop();
219 }
220 return true;
221 }
222 } // namespace
223
Isomorphic(const FuncGraphPtr & fg1,const FuncGraphPtr & fg2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)224 bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph,
225 NodeMapEquiv *const equiv_node) {
226 auto fg1_fg2 = std::make_pair(fg1, fg2);
227 if (equiv_func_graph == nullptr) {
228 MS_LOG(ERROR) << "equiv_func_graph not init";
229 return false;
230 }
231 if (equiv_func_graph->find(fg1_fg2) != equiv_func_graph->end()) {
232 return (*equiv_func_graph)[fg1_fg2] != kNotEquiv;
233 }
234 if (fg1 == nullptr || fg2 == nullptr) {
235 MS_LOG(ERROR) << "Invalid function graph";
236 return false;
237 }
238 if (fg1->parameters().size() != fg2->parameters().size()) {
239 MS_LOG(DEBUG) << "parameters size not match";
240 return false;
241 }
242 if (equiv_node != nullptr) {
243 for (std::size_t i = 0; i < fg1->parameters().size(); ++i) {
244 (*equiv_node)[fg1->parameters()[i]] = fg2->parameters()[i];
245 }
246 (*equiv_func_graph)[fg1_fg2] = kPending;
247 auto result = SameSubgraph(fg1->get_return(), fg2->get_return(), equiv_func_graph, equiv_node);
248 (*equiv_func_graph)[fg1_fg2] = EquivState(result);
249 return result;
250 }
251
252 MS_LOG(ERROR) << "equiv_node not init";
253 return false;
254 }
255
ScalarToTensor(const ScalarPtr & scalar)256 tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
257 if (scalar == nullptr) {
258 MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
259 }
260 TypePtr data_type = scalar->type();
261 MS_EXCEPTION_IF_NULL(data_type);
262 TypeId type_id = data_type->type_id();
263 switch (type_id) {
264 case kNumberTypeBool:
265 return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type);
266 case kNumberTypeInt8:
267 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
268 case kNumberTypeInt16:
269 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
270 case kNumberTypeInt32:
271 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
272 case kNumberTypeInt64:
273 return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type);
274 case kNumberTypeUInt8:
275 return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
276 case kNumberTypeUInt16:
277 return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
278 case kNumberTypeUInt32:
279 return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
280 case kNumberTypeUInt64:
281 return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type);
282 case kNumberTypeFloat32:
283 return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type);
284 case kNumberTypeFloat64:
285 return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type);
286 default:
287 MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << " is invalid.";
288 }
289 }
290
291 template <typename T>
ConvertValueListToVector(const ValuePtrList & seq_values)292 std::vector<T> ConvertValueListToVector(const ValuePtrList &seq_values) {
293 size_t element_num = seq_values.size();
294 std::vector<T> array_data(element_num);
295 for (size_t i = 0; i < element_num; i++) {
296 const auto &element = seq_values[i];
297 MS_EXCEPTION_IF_NULL(element);
298 array_data[i] = GetValue<T>(element);
299 }
300 return array_data;
301 }
302
SequenceToTensor(const ValueSequencePtr & sequence)303 tensor::TensorPtr SequenceToTensor(const ValueSequencePtr &sequence) {
304 MS_EXCEPTION_IF_NULL(sequence);
305 const auto &element_values = sequence->value();
306 if (element_values.empty()) {
307 std::vector<int32_t> array_data;
308 MS_LOG(WARNING) << "The value sequence is empty.";
309 return std::make_shared<tensor::Tensor>(std::move(array_data), TypeIdToType(kNumberTypeInt32));
310 }
311
312 const auto &first_element = element_values[0];
313 if (!first_element->isa<Scalar>()) {
314 MS_LOG(EXCEPTION) << "For sequence value, only sequence of scalar can convert to TensorValue, but got: "
315 << sequence->ToString();
316 }
317
318 TypePtr data_type = first_element->type();
319 MS_EXCEPTION_IF_NULL(data_type);
320 TypeId type_id = data_type->type_id();
321 switch (type_id) {
322 case kNumberTypeInt32:
323 return std::make_shared<tensor::Tensor>(ConvertValueListToVector<int32_t>(element_values), data_type);
324 case kNumberTypeInt64:
325 return std::make_shared<tensor::Tensor>(ConvertValueListToVector<int64_t>(element_values), data_type);
326 case kNumberTypeFloat64:
327 return std::make_shared<tensor::Tensor>(ConvertValueListToVector<double>(element_values), data_type);
328 default:
329 MS_LOG(EXCEPTION) << "When convert sequence to tensor, the sequence type: " << data_type << " is invalid.";
330 }
331 }
332
333 namespace {
ConvertScalarToKernelTensorValue(const ValuePtr & scalar)334 KernelTensorValuePtr ConvertScalarToKernelTensorValue(const ValuePtr &scalar) {
335 MS_EXCEPTION_IF_NULL(scalar);
336 TypePtr data_type = scalar->type();
337 MS_EXCEPTION_IF_NULL(data_type);
338 TypeId type_id = data_type->type_id();
339 switch (type_id) {
340 case kNumberTypeBool:
341 return std::make_shared<KernelTensorValue>(GetValue<bool>(scalar), data_type);
342 case kNumberTypeInt8:
343 return std::make_shared<KernelTensorValue>(GetValue<int8_t>(scalar), data_type);
344 case kNumberTypeInt16:
345 return std::make_shared<KernelTensorValue>(GetValue<int16_t>(scalar), data_type);
346 case kNumberTypeInt32:
347 return std::make_shared<KernelTensorValue>(GetValue<int32_t>(scalar), data_type);
348 case kNumberTypeInt64:
349 return std::make_shared<KernelTensorValue>(GetValue<int64_t>(scalar), data_type);
350 case kNumberTypeUInt8:
351 return std::make_shared<KernelTensorValue>(GetValue<uint8_t>(scalar), data_type);
352 case kNumberTypeUInt16:
353 return std::make_shared<KernelTensorValue>(GetValue<uint16_t>(scalar), data_type);
354 case kNumberTypeUInt32:
355 return std::make_shared<KernelTensorValue>(GetValue<uint32_t>(scalar), data_type);
356 case kNumberTypeUInt64:
357 return std::make_shared<KernelTensorValue>(GetValue<uint64_t>(scalar), data_type);
358 case kNumberTypeFloat32:
359 return std::make_shared<KernelTensorValue>(GetValue<float>(scalar), data_type);
360 case kNumberTypeFloat64:
361 return std::make_shared<KernelTensorValue>(GetValue<double>(scalar), data_type);
362 default:
363 MS_LOG(EXCEPTION) << "When convert scalar to KernelTensorValue, the scalar type: " << data_type->ToString()
364 << " is invalid.";
365 }
366 }
367
368 template <typename T>
ConvertValueListToKernelTensorValue(const ValuePtrList & seq_values,const TypePtr & type)369 KernelTensorValuePtr ConvertValueListToKernelTensorValue(const ValuePtrList &seq_values, const TypePtr &type) {
370 MS_EXCEPTION_IF_NULL(type);
371 size_t element_num = seq_values.size();
372 std::vector<uint8_t> array_data(element_num * sizeof(T));
373 T *array_data_ptr = reinterpret_cast<T *>(array_data.data());
374 MS_EXCEPTION_IF_NULL(array_data_ptr);
375
376 for (size_t i = 0; i < element_num; i++) {
377 const auto &element = seq_values[i];
378 MS_EXCEPTION_IF_NULL(element);
379 array_data_ptr[i] = GetValue<T>(element);
380 }
381 return std::make_shared<KernelTensorValue>(std::move(array_data), type);
382 }
383
ConvertSequenceToKernelTensorValue(const ValueSequencePtr & value_seq)384 KernelTensorValuePtr ConvertSequenceToKernelTensorValue(const ValueSequencePtr &value_seq) {
385 MS_EXCEPTION_IF_NULL(value_seq);
386 const auto &element_values = value_seq->value();
387 std::vector<uint8_t> array_data;
388 if (element_values.empty()) {
389 MS_LOG(INFO) << "The value sequence is empty.";
390 return std::make_shared<KernelTensorValue>(std::move(array_data), value_seq->type());
391 }
392
393 const auto &first_element = element_values[0];
394 if (!first_element->isa<Scalar>()) {
395 MS_LOG(EXCEPTION) << "For sequence value, only sequence of scalar can convert to KernelTensorValue, but got: "
396 << value_seq->ToString();
397 }
398
399 TypePtr data_type = first_element->type();
400 MS_EXCEPTION_IF_NULL(data_type);
401 TypeId type_id = data_type->type_id();
402
403 switch (type_id) {
404 case kNumberTypeBool:
405 return ConvertValueListToKernelTensorValue<bool>(element_values, value_seq->type());
406 case kNumberTypeInt8:
407 return ConvertValueListToKernelTensorValue<int8_t>(element_values, value_seq->type());
408 case kNumberTypeInt16:
409 return ConvertValueListToKernelTensorValue<int16_t>(element_values, value_seq->type());
410 case kNumberTypeInt32:
411 return ConvertValueListToKernelTensorValue<int32_t>(element_values, value_seq->type());
412 case kNumberTypeInt64:
413 return ConvertValueListToKernelTensorValue<int64_t>(element_values, value_seq->type());
414 case kNumberTypeUInt8:
415 return ConvertValueListToKernelTensorValue<uint8_t>(element_values, value_seq->type());
416 case kNumberTypeUInt16:
417 return ConvertValueListToKernelTensorValue<uint16_t>(element_values, value_seq->type());
418 case kNumberTypeUInt32:
419 return ConvertValueListToKernelTensorValue<uint32_t>(element_values, value_seq->type());
420 case kNumberTypeUInt64:
421 return ConvertValueListToKernelTensorValue<uint64_t>(element_values, value_seq->type());
422 case kNumberTypeFloat32:
423 return ConvertValueListToKernelTensorValue<float>(element_values, value_seq->type());
424 case kNumberTypeFloat64:
425 return ConvertValueListToKernelTensorValue<double>(element_values, value_seq->type());
426 default:
427 MS_LOG(EXCEPTION) << "When convert sequence to KernelTensorValue, the element type: " << data_type->ToString()
428 << " is invalid.";
429 }
430 }
431 } // namespace
432
ConvertValueToKernelTensorValue(const ValuePtr & value)433 KernelTensorValuePtr ConvertValueToKernelTensorValue(const ValuePtr &value) {
434 MS_EXCEPTION_IF_NULL(value);
435 if (value->isa<Scalar>()) {
436 return ConvertScalarToKernelTensorValue(value);
437 } else if (value->isa<ValueSequence>()) {
438 auto value_seq = value->cast<ValueSequencePtr>();
439 return ConvertSequenceToKernelTensorValue(value_seq);
440 } else if (value->isa<tensor::BaseTensor>()) {
441 auto tensor_ptr = value->cast<tensor::BaseTensorPtr>();
442 MS_EXCEPTION_IF_NULL(tensor_ptr);
443 return std::make_shared<KernelTensorValue>(tensor_ptr->data_ptr(), tensor_ptr->type());
444 } else if (value->isa<StringImm>()) {
445 auto string_ptr = value->cast<StringImmPtr>();
446 MS_EXCEPTION_IF_NULL(string_ptr);
447 return std::make_shared<KernelTensorValue>(string_ptr, string_ptr->type());
448 } else if (value->isa<Type>()) {
449 return nullptr;
450 } else {
451 MS_LOG(WARNING) << "KernelTensorValue not support the value type: " << value->ToString();
452 return nullptr;
453 }
454 }
455
456 template <typename T, typename Scalar>
GetTensorValue(const tensor::TensorPtr & tensor)457 ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) {
458 ValuePtr ret;
459 auto tensor_value = TensorValueToVector<T>(tensor);
460 if (tensor_value.size() == 1) {
461 ret = std::make_shared<Scalar>(tensor_value[0]);
462 } else {
463 std::vector<ValuePtr> value_vec;
464 for (const auto &elem : tensor_value) {
465 auto value = std::make_shared<Scalar>(elem);
466 MS_EXCEPTION_IF_NULL(value);
467 value_vec.push_back(value);
468 }
469 ret = std::make_shared<ValueTuple>(value_vec);
470 }
471 return ret;
472 }
473
CreateValueFromTensor(const tensor::TensorPtr & tensor)474 ValuePtr CreateValueFromTensor(const tensor::TensorPtr &tensor) {
475 ValuePtr ret;
476 if (tensor->has_user_data(kTensorValueIsType)) {
477 ret = tensor->user_data<mindspore::Type>(kTensorValueIsType);
478 return ret;
479 }
480
481 if (tensor->has_user_data(kTensorValueIsEmpty)) {
482 ret = tensor->user_data<mindspore::Value>(kTensorValueIsEmpty);
483 return ret;
484 }
485
486 TypePtr data_type = tensor->Dtype();
487 MS_EXCEPTION_IF_NULL(data_type);
488 TypeId type_id = data_type->type_id();
489 switch (type_id) {
490 case kNumberTypeBool: {
491 ret = GetTensorValue<bool, BoolImm>(tensor);
492 break;
493 }
494
495 case kNumberTypeInt8: {
496 ret = GetTensorValue<int8_t, Int8Imm>(tensor);
497 break;
498 }
499
500 case kNumberTypeUInt8: {
501 ret = GetTensorValue<uint8_t, UInt8Imm>(tensor);
502 break;
503 }
504
505 case kNumberTypeInt16: {
506 ret = GetTensorValue<int16_t, Int16Imm>(tensor);
507 break;
508 }
509
510 case kNumberTypeUInt16: {
511 ret = GetTensorValue<uint16_t, UInt16Imm>(tensor);
512 break;
513 }
514
515 case kNumberTypeInt32: {
516 ret = GetTensorValue<int32_t, Int32Imm>(tensor);
517 break;
518 }
519
520 case kNumberTypeUInt32: {
521 ret = GetTensorValue<uint32_t, UInt32Imm>(tensor);
522 break;
523 }
524
525 case kNumberTypeInt64: {
526 ret = GetTensorValue<int64_t, Int64Imm>(tensor);
527 break;
528 }
529
530 case kNumberTypeUInt64: {
531 ret = GetTensorValue<uint64_t, UInt64Imm>(tensor);
532 break;
533 }
534
535 case kNumberTypeFloat32: {
536 ret = GetTensorValue<float, FP32Imm>(tensor);
537 break;
538 }
539
540 case kNumberTypeFloat64: {
541 ret = GetTensorValue<double, FP64Imm>(tensor);
542 break;
543 }
544
545 default:
546 MS_LOG(EXCEPTION) << "Can't parse attr value :" << tensor->ToString() << ", Type:" << tensor->type_name();
547 }
548 return ret;
549 }
550
TensorValueToTensor(const ValuePtr & value,std::vector<tensor::BaseTensorPtr> * tensors)551 void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::BaseTensorPtr> *tensors) {
552 MS_EXCEPTION_IF_NULL(value);
553 MS_EXCEPTION_IF_NULL(tensors);
554 if (value->isa<tensor::BaseTensor>()) {
555 auto tensor = value->cast<tensor::BaseTensorPtr>();
556 MS_EXCEPTION_IF_NULL(tensor);
557 tensors->emplace_back(tensor);
558 } else if (value->isa<Scalar>()) {
559 auto tensor = ScalarToTensor(value->cast<ScalarPtr>());
560 MS_EXCEPTION_IF_NULL(tensor);
561 tensors->emplace_back(tensor);
562 } else if (value->isa<ValueSequence>()) {
563 const auto &value_seq = value->cast<ValueSequencePtr>();
564 MS_EXCEPTION_IF_NULL(value_seq);
565 for (const auto &v : value_seq->value()) {
566 TensorValueToTensor(v, tensors);
567 }
568 }
569 }
570
CountValueNum(const ValueSequencePtr & value_sequence)571 size_t CountValueNum(const ValueSequencePtr &value_sequence) {
572 MS_EXCEPTION_IF_NULL(value_sequence);
573 size_t cnt = 0;
574 const auto &value_list = value_sequence->value();
575 for (const auto &value : value_list) {
576 if (value->isa<ValueSequence>()) {
577 cnt += CountValueNum(value->cast<ValueSequencePtr>());
578 } else {
579 cnt++;
580 }
581 }
582 return cnt;
583 }
584
IsAKGSparseOP(const AnfNodePtr & cnode)585 bool IsAKGSparseOP(const AnfNodePtr &cnode) {
586 MS_EXCEPTION_IF_NULL(cnode);
587 const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV, prim::kPrimCSRGather,
588 prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
589 return IsOneOfPrimitiveCNode(cnode, prims);
590 }
591
592 namespace {
ConvertTensorListToShapeVector(const tensor::TensorPtrList & tensor_list,size_t index)593 ShapeVector ConvertTensorListToShapeVector(const tensor::TensorPtrList &tensor_list, size_t index) {
594 ShapeVector shape;
595 if (index >= tensor_list.size()) {
596 MS_LOG(EXCEPTION) << "Index " << index << " is out of range of " << tensor_list.size();
597 return shape;
598 }
599
600 auto converter = [](const tensor::TensorPtr tensorptr) {
601 MS_EXCEPTION_IF_NULL(tensorptr);
602 if (tensorptr->DataDim() != 0) {
603 MS_LOG(EXCEPTION) << "Element must be scalar!";
604 }
605 tensorptr->data_sync(false);
606 return *(static_cast<int64_t *>(tensorptr->data_c()));
607 };
608 std::transform(tensor_list.begin() + index, tensor_list.end(), std::back_inserter(shape), converter);
609 if (shape.empty()) {
610 MS_LOG(ERROR) << "ShapeVector is empty!";
611 }
612 return shape;
613 }
TensorListToCSRTensor(const tensor::TensorPtrList & tensor_list)614 tensor::CSRTensorPtr TensorListToCSRTensor(const tensor::TensorPtrList &tensor_list) {
615 tensor::TensorPtr indptr = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndptrIdx]);
616 tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndicesIdx]);
617 tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kValuesIdx]);
618 ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::CSRTensor::kShapeIdx);
619 auto csr_tensor_ptr = std::make_shared<tensor::CSRTensor>(indptr, indices, values, shape);
620 return csr_tensor_ptr;
621 }
622
TensorListToCOOTensor(const tensor::TensorPtrList & tensor_list)623 tensor::COOTensorPtr TensorListToCOOTensor(const tensor::TensorPtrList &tensor_list) {
624 tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kIndicesIdx]);
625 tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kValuesIdx]);
626 ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::COOTensor::kShapeIdx);
627 auto coo_tensor_ptr = std::make_shared<tensor::COOTensor>(indices, values, shape);
628 return coo_tensor_ptr;
629 }
630 } // namespace
631
TensorListToSparseTensor(const abstract::AbstractBasePtr & abs_sparse,const tensor::TensorPtrList & tensor_list)632 tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
633 const tensor::TensorPtrList &tensor_list) {
634 if (abs_sparse->isa<abstract::AbstractCOOTensor>()) {
635 return TensorListToCOOTensor(tensor_list);
636 }
637 return TensorListToCSRTensor(tensor_list);
638 }
639
BaseShapeToShapeVector(const abstract::BaseShapePtr & base_shape)640 std::vector<ShapeVector> BaseShapeToShapeVector(const abstract::BaseShapePtr &base_shape) {
641 MS_EXCEPTION_IF_NULL(base_shape);
642 if (base_shape->isa<abstract::Shape>()) {
643 const auto &shape = base_shape->cast<abstract::ShapePtr>();
644 MS_EXCEPTION_IF_NULL(shape);
645 return {shape->shape()};
646 } else if (base_shape->isa<abstract::SequenceShape>()) {
647 const auto &tuple_shape = base_shape->cast<abstract::SequenceShapePtr>();
648 MS_EXCEPTION_IF_NULL(tuple_shape);
649 if (tuple_shape->size() == 0) {
650 return {};
651 }
652 // If the shape is a tuple shape, all shapes need to be consistent.
653 auto element_base_shape = (*tuple_shape)[0];
654 if (element_base_shape->isa<abstract::Shape>()) {
655 const auto &element_shape = element_base_shape->cast<abstract::ShapePtr>();
656 MS_EXCEPTION_IF_NULL(element_shape);
657 return std::vector<ShapeVector>(tuple_shape->size(), element_shape->shape());
658 } else if (element_base_shape->isa<abstract::NoShape>()) {
659 return std::vector<ShapeVector>(tuple_shape->size(), {1});
660 }
661 } else if (base_shape->isa<abstract::NoShape>() || base_shape->isa<abstract::DynamicSequenceShape>()) {
662 return {};
663 }
664 MS_LOG(WARNING) << "Invalid shape:" << base_shape->ToString();
665 return {};
666 }
667
BaseShapeToShape(const abstract::BaseShapePtr & base_shape)668 ShapeVector BaseShapeToShape(const abstract::BaseShapePtr &base_shape) {
669 MS_EXCEPTION_IF_NULL(base_shape);
670 if (base_shape->isa<abstract::Shape>()) {
671 const auto &shape = base_shape->cast<abstract::ShapePtr>();
672 MS_EXCEPTION_IF_NULL(shape);
673 return shape->shape();
674 } else if (base_shape->isa<abstract::NoShape>()) {
675 return {};
676 }
677 MS_LOG(WARNING) << "Invalid shape:" << base_shape->ToString();
678 return {};
679 }
680
UpdateValueByAttrDataType(const ValuePtr & value,const std::string & attr_data_type)681 ValuePtr UpdateValueByAttrDataType(const ValuePtr &value, const std::string &attr_data_type) {
682 static std::set<std::string> kListDataType = {"listInt", "listStr", "listBool", "listFloat"};
683 auto iter = kListDataType.find(attr_data_type);
684 ValuePtr ret = value;
685 if (iter != kListDataType.end()) {
686 if (!value->isa<ValueSequence>()) {
687 std::vector<ValuePtr> value_vec;
688 value_vec.push_back(value);
689 ret = std::make_shared<ValueTuple>(value_vec);
690 }
691 }
692 return ret;
693 }
694
695 namespace {
GetHashId(int a,int b)696 size_t GetHashId(int a, int b) { return a < b ? hash_combine(a, b) : hash_combine(b, a); }
697
698 static const std::map<size_t, TypeId> tensor_tensor_convert_map = {
699 // Bool
700 {GetHashId(kNumberTypeBool, kNumberTypeBool), kNumberTypeBool},
701 {GetHashId(kNumberTypeBool, kNumberTypeInt8), kNumberTypeInt8},
702 {GetHashId(kNumberTypeBool, kNumberTypeInt16), kNumberTypeInt16},
703 {GetHashId(kNumberTypeBool, kNumberTypeInt32), kNumberTypeInt32},
704 {GetHashId(kNumberTypeBool, kNumberTypeInt64), kNumberTypeInt64},
705 {GetHashId(kNumberTypeBool, kNumberTypeUInt8), kNumberTypeUInt8},
706 {GetHashId(kNumberTypeBool, kNumberTypeUInt16), kNumberTypeUInt16},
707 {GetHashId(kNumberTypeBool, kNumberTypeUInt32), kNumberTypeUInt32},
708 {GetHashId(kNumberTypeBool, kNumberTypeUInt64), kNumberTypeUInt64},
709 {GetHashId(kNumberTypeBool, kNumberTypeFloat16), kNumberTypeFloat16},
710 {GetHashId(kNumberTypeBool, kNumberTypeBFloat16), kNumberTypeBFloat16},
711 {GetHashId(kNumberTypeBool, kNumberTypeFloat32), kNumberTypeFloat32},
712 {GetHashId(kNumberTypeBool, kNumberTypeFloat64), kNumberTypeFloat64},
713 {GetHashId(kNumberTypeBool, kNumberTypeComplex64), kNumberTypeComplex64},
714 {GetHashId(kNumberTypeBool, kNumberTypeComplex128), kNumberTypeComplex128},
715 // Int8
716 {GetHashId(kNumberTypeInt8, kNumberTypeInt8), kNumberTypeInt8},
717 {GetHashId(kNumberTypeInt8, kNumberTypeInt16), kNumberTypeInt16},
718 {GetHashId(kNumberTypeInt8, kNumberTypeInt32), kNumberTypeInt32},
719 {GetHashId(kNumberTypeInt8, kNumberTypeInt64), kNumberTypeInt64},
720 {GetHashId(kNumberTypeInt8, kNumberTypeUInt8), kNumberTypeInt16},
721 {GetHashId(kNumberTypeInt8, kNumberTypeFloat16), kNumberTypeFloat16},
722 {GetHashId(kNumberTypeInt8, kNumberTypeBFloat16), kNumberTypeBFloat16},
723 {GetHashId(kNumberTypeInt8, kNumberTypeFloat32), kNumberTypeFloat32},
724 {GetHashId(kNumberTypeInt8, kNumberTypeFloat64), kNumberTypeFloat64},
725 {GetHashId(kNumberTypeInt8, kNumberTypeComplex64), kNumberTypeComplex64},
726 {GetHashId(kNumberTypeInt8, kNumberTypeComplex128), kNumberTypeComplex128},
727 // Int16
728 {GetHashId(kNumberTypeInt16, kNumberTypeInt16), kNumberTypeInt16},
729 {GetHashId(kNumberTypeInt16, kNumberTypeInt32), kNumberTypeInt32},
730 {GetHashId(kNumberTypeInt16, kNumberTypeInt64), kNumberTypeInt64},
731 {GetHashId(kNumberTypeInt16, kNumberTypeUInt8), kNumberTypeInt16},
732 {GetHashId(kNumberTypeInt16, kNumberTypeFloat16), kNumberTypeFloat16},
733 {GetHashId(kNumberTypeInt16, kNumberTypeBFloat16), kNumberTypeBFloat16},
734 {GetHashId(kNumberTypeInt16, kNumberTypeFloat32), kNumberTypeFloat32},
735 {GetHashId(kNumberTypeInt16, kNumberTypeFloat64), kNumberTypeFloat64},
736 {GetHashId(kNumberTypeInt16, kNumberTypeComplex64), kNumberTypeComplex64},
737 {GetHashId(kNumberTypeInt16, kNumberTypeComplex128), kNumberTypeComplex128},
738 // Int32
739 {GetHashId(kNumberTypeInt32, kNumberTypeInt32), kNumberTypeInt32},
740 {GetHashId(kNumberTypeInt32, kNumberTypeInt64), kNumberTypeInt64},
741 {GetHashId(kNumberTypeInt32, kNumberTypeUInt8), kNumberTypeInt32},
742 {GetHashId(kNumberTypeInt32, kNumberTypeFloat16), kNumberTypeFloat16},
743 {GetHashId(kNumberTypeInt32, kNumberTypeBFloat16), kNumberTypeBFloat16},
744 {GetHashId(kNumberTypeInt32, kNumberTypeFloat32), kNumberTypeFloat32},
745 {GetHashId(kNumberTypeInt32, kNumberTypeFloat64), kNumberTypeFloat64},
746 {GetHashId(kNumberTypeInt32, kNumberTypeComplex64), kNumberTypeComplex64},
747 {GetHashId(kNumberTypeInt32, kNumberTypeComplex128), kNumberTypeComplex128},
748 // Int64
749 {GetHashId(kNumberTypeInt64, kNumberTypeInt64), kNumberTypeInt64},
750 {GetHashId(kNumberTypeInt64, kNumberTypeUInt8), kNumberTypeInt64},
751 {GetHashId(kNumberTypeInt64, kNumberTypeFloat16), kNumberTypeFloat16},
752 {GetHashId(kNumberTypeInt64, kNumberTypeBFloat16), kNumberTypeBFloat16},
753 {GetHashId(kNumberTypeInt64, kNumberTypeFloat32), kNumberTypeFloat32},
754 {GetHashId(kNumberTypeInt64, kNumberTypeFloat64), kNumberTypeFloat64},
755 {GetHashId(kNumberTypeInt64, kNumberTypeComplex64), kNumberTypeComplex64},
756 {GetHashId(kNumberTypeInt64, kNumberTypeComplex128), kNumberTypeComplex128},
757 // UInt8
758 {GetHashId(kNumberTypeUInt8, kNumberTypeUInt8), kNumberTypeUInt8},
759 {GetHashId(kNumberTypeUInt8, kNumberTypeFloat16), kNumberTypeFloat16},
760 {GetHashId(kNumberTypeUInt8, kNumberTypeBFloat16), kNumberTypeBFloat16},
761 {GetHashId(kNumberTypeUInt8, kNumberTypeFloat32), kNumberTypeFloat32},
762 {GetHashId(kNumberTypeUInt8, kNumberTypeFloat64), kNumberTypeFloat64},
763 {GetHashId(kNumberTypeUInt8, kNumberTypeComplex64), kNumberTypeComplex64},
764 {GetHashId(kNumberTypeUInt8, kNumberTypeComplex128), kNumberTypeComplex128},
765 // UInt16
766 {GetHashId(kNumberTypeUInt16, kNumberTypeUInt16), kNumberTypeUInt16},
767 // UInt32
768 {GetHashId(kNumberTypeUInt32, kNumberTypeUInt32), kNumberTypeUInt32},
769 // UInt64
770 {GetHashId(kNumberTypeUInt64, kNumberTypeUInt64), kNumberTypeUInt64},
771 // Float16
772 {GetHashId(kNumberTypeFloat16, kNumberTypeFloat16), kNumberTypeFloat16},
773 {GetHashId(kNumberTypeFloat16, kNumberTypeBFloat16), kNumberTypeFloat32},
774 {GetHashId(kNumberTypeFloat16, kNumberTypeFloat32), kNumberTypeFloat32},
775 {GetHashId(kNumberTypeFloat16, kNumberTypeFloat64), kNumberTypeFloat64},
776 {GetHashId(kNumberTypeFloat16, kNumberTypeComplex64), kNumberTypeComplex64},
777 {GetHashId(kNumberTypeFloat16, kNumberTypeComplex128), kNumberTypeComplex128},
778 // BFloat16
779 {GetHashId(kNumberTypeBFloat16, kNumberTypeBFloat16), kNumberTypeBFloat16},
780 {GetHashId(kNumberTypeBFloat16, kNumberTypeFloat32), kNumberTypeFloat32},
781 {GetHashId(kNumberTypeBFloat16, kNumberTypeFloat64), kNumberTypeFloat64},
782 {GetHashId(kNumberTypeBFloat16, kNumberTypeComplex64), kNumberTypeComplex64},
783 {GetHashId(kNumberTypeBFloat16, kNumberTypeComplex128), kNumberTypeComplex128},
784 // Float32
785 {GetHashId(kNumberTypeFloat32, kNumberTypeFloat32), kNumberTypeFloat32},
786 {GetHashId(kNumberTypeFloat32, kNumberTypeFloat64), kNumberTypeFloat64},
787 {GetHashId(kNumberTypeFloat32, kNumberTypeComplex64), kNumberTypeComplex64},
788 {GetHashId(kNumberTypeFloat32, kNumberTypeComplex128), kNumberTypeComplex128},
789 // Float64
790 {GetHashId(kNumberTypeFloat64, kNumberTypeFloat64), kNumberTypeFloat64},
791 {GetHashId(kNumberTypeFloat64, kNumberTypeComplex64), kNumberTypeComplex128},
792 {GetHashId(kNumberTypeFloat64, kNumberTypeComplex128), kNumberTypeComplex128},
793 // Complex64
794 {GetHashId(kNumberTypeComplex64, kNumberTypeComplex64), kNumberTypeComplex64},
795 {GetHashId(kNumberTypeComplex64, kNumberTypeComplex128), kNumberTypeComplex128},
796 // Complex128
797 {GetHashId(kNumberTypeComplex128, kNumberTypeComplex128), kNumberTypeComplex128},
798 };
799
800 static const std::map<size_t, TypeId> scalar_tensor_convert_map = {
801 // Scalar is bool.
802 {GetHashId(kNumberTypeBool, kNumberTypeBool), kNumberTypeBool},
803 {GetHashId(kNumberTypeBool, kNumberTypeInt8), kNumberTypeInt8},
804 {GetHashId(kNumberTypeBool, kNumberTypeInt16), kNumberTypeInt16},
805 {GetHashId(kNumberTypeBool, kNumberTypeInt32), kNumberTypeInt32},
806 {GetHashId(kNumberTypeBool, kNumberTypeInt64), kNumberTypeInt64},
807 {GetHashId(kNumberTypeBool, kNumberTypeUInt8), kNumberTypeUInt8},
808 {GetHashId(kNumberTypeBool, kNumberTypeUInt16), kNumberTypeUInt16},
809 {GetHashId(kNumberTypeBool, kNumberTypeUInt32), kNumberTypeUInt32},
810 {GetHashId(kNumberTypeBool, kNumberTypeUInt64), kNumberTypeUInt64},
811 {GetHashId(kNumberTypeBool, kNumberTypeFloat16), kNumberTypeFloat16},
812 {GetHashId(kNumberTypeBool, kNumberTypeBFloat16), kNumberTypeBFloat16},
813 {GetHashId(kNumberTypeBool, kNumberTypeFloat32), kNumberTypeFloat32},
814 {GetHashId(kNumberTypeBool, kNumberTypeFloat64), kNumberTypeFloat64},
815 {GetHashId(kNumberTypeBool, kNumberTypeComplex64), kNumberTypeComplex64},
816 {GetHashId(kNumberTypeBool, kNumberTypeComplex128), kNumberTypeComplex128},
817 // Scalar is int.
818 {GetHashId(kNumberTypeInt64, kNumberTypeBool), kNumberTypeInt64},
819 {GetHashId(kNumberTypeInt64, kNumberTypeInt8), kNumberTypeInt8},
820 {GetHashId(kNumberTypeInt64, kNumberTypeInt16), kNumberTypeInt16},
821 {GetHashId(kNumberTypeInt64, kNumberTypeInt32), kNumberTypeInt32},
822 {GetHashId(kNumberTypeInt64, kNumberTypeInt64), kNumberTypeInt64},
823 {GetHashId(kNumberTypeInt64, kNumberTypeUInt8), kNumberTypeUInt8},
824 {GetHashId(kNumberTypeInt64, kNumberTypeFloat16), kNumberTypeFloat16},
825 {GetHashId(kNumberTypeInt64, kNumberTypeBFloat16), kNumberTypeBFloat16},
826 {GetHashId(kNumberTypeInt64, kNumberTypeFloat32), kNumberTypeFloat32},
827 {GetHashId(kNumberTypeInt64, kNumberTypeFloat64), kNumberTypeFloat64},
828 {GetHashId(kNumberTypeInt64, kNumberTypeComplex64), kNumberTypeComplex64},
829 {GetHashId(kNumberTypeInt64, kNumberTypeComplex128), kNumberTypeComplex128},
830 // Scalar is float.
831 {GetHashId(kNumberTypeFloat32, kNumberTypeBool), kNumberTypeFloat32},
832 {GetHashId(kNumberTypeFloat32, kNumberTypeInt8), kNumberTypeFloat32},
833 {GetHashId(kNumberTypeFloat32, kNumberTypeInt16), kNumberTypeFloat32},
834 {GetHashId(kNumberTypeFloat32, kNumberTypeInt32), kNumberTypeFloat32},
835 {GetHashId(kNumberTypeFloat32, kNumberTypeInt64), kNumberTypeFloat32},
836 {GetHashId(kNumberTypeFloat32, kNumberTypeUInt8), kNumberTypeFloat32},
837 {GetHashId(kNumberTypeFloat32, kNumberTypeFloat16), kNumberTypeFloat16},
838 {GetHashId(kNumberTypeFloat32, kNumberTypeBFloat16), kNumberTypeBFloat16},
839 {GetHashId(kNumberTypeFloat32, kNumberTypeFloat32), kNumberTypeFloat32},
840 {GetHashId(kNumberTypeFloat32, kNumberTypeFloat64), kNumberTypeFloat64},
841 {GetHashId(kNumberTypeFloat32, kNumberTypeComplex64), kNumberTypeComplex64},
842 {GetHashId(kNumberTypeFloat32, kNumberTypeComplex128), kNumberTypeComplex128},
843 };
844
ConvertTypeForTensorsOrScalars(const TypeId & current,const TypeId & other,const size_t hash_id)845 TypeId ConvertTypeForTensorsOrScalars(const TypeId ¤t, const TypeId &other, const size_t hash_id) {
846 auto iter = tensor_tensor_convert_map.find(hash_id);
847 if (iter != tensor_tensor_convert_map.end()) {
848 return iter->second;
849 }
850 MS_EXCEPTION(TypeError) << "Type implicit conversion between " << TypeIdToString(current) << " and "
851 << TypeIdToString(other) << " is not supported.";
852 }
853
ConvertTypeBetweenTensorAndScalar(const TypeId & tensor_type_id,const TypeId & scalar_type_id,const size_t hash_id)854 TypeId ConvertTypeBetweenTensorAndScalar(const TypeId &tensor_type_id, const TypeId &scalar_type_id,
855 const size_t hash_id) {
856 auto iter = scalar_tensor_convert_map.find(hash_id);
857 if (iter != scalar_tensor_convert_map.end()) {
858 return iter->second;
859 }
860 MS_EXCEPTION(TypeError) << "Type implicit conversion between Tensor[" << TypeIdToString(tensor_type_id) << "] and "
861 << TypeIdToString(scalar_type_id) << " is not supported.";
862 }
863
GetConversionType(const TypeId & current,bool current_arg_is_tensor,bool is_parameter,const std::pair<TypeId,bool> & sig_type,const TypeId & ref_type_id)864 TypeId GetConversionType(const TypeId ¤t, bool current_arg_is_tensor, bool is_parameter,
865 const std::pair<TypeId, bool> &sig_type, const TypeId &ref_type_id) {
866 TypeId saved_type_id = sig_type.first;
867 bool saved_has_tensor = sig_type.second;
868 if (current == saved_type_id) {
869 return current;
870 }
871
872 if (current != kTypeUnknown && saved_type_id != kTypeUnknown) {
873 auto hash_id = GetHashId(current, saved_type_id);
874 // Tensor + Scalar, Scalar + Tensor
875 if (MS_UNLIKELY(current_arg_is_tensor ^ saved_has_tensor)) {
876 return ConvertTypeBetweenTensorAndScalar(current, saved_type_id, hash_id);
877 }
878 // Tensor + Tensor, Scalar + Scalar
879 if ((is_parameter || saved_type_id == ref_type_id) &&
880 hash_id == GetHashId(kNumberTypeFloat16, kNumberTypeBFloat16)) {
881 // "saved_type_id == ref_type_id": if Parameter exists, its type_id should be equal to the saved_type_id,
882 // otherwise it means that the wrong type cast will be performed on the Parameter.
883 static bool already_printed = false;
884 if (!already_printed) {
885 already_printed = true;
886 MS_LOG(WARNING) << "For operators with side effects, there is an implicit type conversion between "
887 << TypeIdToString(current) << " and " << TypeIdToString(saved_type_id)
888 << ", which may result in loss of precision. It is recommended to use Float32.";
889 }
890 return is_parameter ? current : saved_type_id;
891 }
892 return ConvertTypeForTensorsOrScalars(current, saved_type_id, hash_id);
893 }
894 return current != kTypeUnknown ? current : saved_type_id;
895 }
896 } // namespace
897
GetSignatureTypeMap(const std::vector<SignatureEnumDType> & dtypes,const std::vector<TypeId> & args_type_id,const std::vector<bool> & args_is_tensor,const std::set<size_t> & write_indices)898 std::map<SignatureEnumDType, std::pair<TypeId, bool>> GetSignatureTypeMap(const std::vector<SignatureEnumDType> &dtypes,
899 const std::vector<TypeId> &args_type_id,
900 const std::vector<bool> &args_is_tensor,
901 const std::set<size_t> &write_indices) {
902 // {T0: (target_type_id=Int32, has_tensor=true), T1: (target_type_id=Float32, has_tensor=false), ...}
903 std::map<SignatureEnumDType, std::pair<TypeId, bool>> sig_type_map;
904 std::map<SignatureEnumDType, TypeId> ref_type_map;
905 size_t args_size = args_type_id.size();
906 for (size_t i = 0; i < args_size; ++i) {
907 bool is_parameter = write_indices.find(i) != write_indices.end();
908 const auto &it = sig_type_map.find(dtypes[i]);
909 if (it == sig_type_map.end()) {
910 (void)sig_type_map.insert(std::make_pair(dtypes[i], std::make_pair(args_type_id[i], args_is_tensor[i])));
911 (void)ref_type_map.insert(std::make_pair(dtypes[i], is_parameter ? args_type_id[i] : kTypeUnknown));
912 } else {
913 it->second.first =
914 GetConversionType(args_type_id[i], args_is_tensor[i], is_parameter, it->second, ref_type_map[dtypes[i]]);
915 it->second.second = args_is_tensor[i] || it->second.second;
916 if (is_parameter && ref_type_map[dtypes[i]] == kTypeUnknown) {
917 ref_type_map[dtypes[i]] = args_type_id[i];
918 }
919 }
920 }
921 return sig_type_map;
922 }
923
ValueSimpleInfoToString(const ValueSimpleInfo & value_simple_info)924 std::string ValueSimpleInfoToString(const ValueSimpleInfo &value_simple_info) {
925 std::ostringstream buf;
926 buf << "Value simple info element size : " << value_simple_info.size_;
927 for (size_t i = 0; i < value_simple_info.size_; ++i) {
928 buf << ". The " << i << "th shape: " << value_simple_info.shape_vector_[i] << ", dtype "
929 << value_simple_info.dtype_vector_[i];
930 if (!value_simple_info.object_type_vector_.empty()) {
931 buf << ", object type " << value_simple_info.object_type_vector_[i];
932 }
933 }
934 return buf.str();
935 }
936
TransformValueSimpleInfoToAbstract(const ValueSimpleInfo & value_simple_info)937 abstract::AbstractBasePtr TransformValueSimpleInfoToAbstract(const ValueSimpleInfo &value_simple_info) {
938 if (value_simple_info.size_ < 1) {
939 MS_LOG(EXCEPTION) << "Simple infer info size must greater than 1, but got " << value_simple_info.size_;
940 }
941 abstract::AbstractBasePtr out_abs;
942 if (value_simple_info.size_ == 1 && !value_simple_info.is_tuple_output_) {
943 out_abs = std::make_shared<abstract::AbstractTensor>(value_simple_info.dtype_vector_[kIndex0],
944 value_simple_info.shape_vector_[kIndex0]);
945 } else {
946 AbstractBasePtrList out_abs_list;
947 out_abs_list.resize(value_simple_info.size_);
948 for (size_t i = 0; i < value_simple_info.size_; ++i) {
949 out_abs_list[i] = std::make_shared<abstract::AbstractTensor>(value_simple_info.dtype_vector_[i],
950 value_simple_info.shape_vector_[i]);
951 }
952 out_abs = std::make_shared<abstract::AbstractTuple>(out_abs_list);
953 }
954 return out_abs;
955 }
956 } // namespace mindspore
957