• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/utils/convert_utils.h"
18 #include <algorithm>
19 #include <cfloat>
20 #include <cmath>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "include/common/utils/utils.h"
27 #include "ir/tensor.h"
28 #include "ir/value.h"
29 #include "mindspore/core/ops/sparse_ops.h"
30 #include "utils/anf_utils.h"
31 #include "utils/ms_context.h"
32 #include "utils/hashing.h"
33 
34 namespace mindspore {
ValueToBool(const ValuePtr & v,bool * value)35 bool ValueToBool(const ValuePtr &v, bool *value) {
36   MS_EXCEPTION_IF_NULL(v);
37   if (v->isa<BoolImm>()) {
38     *value = v->cast<BoolImmPtr>()->value();
39   } else if (v->isa<Int32Imm>()) {
40     *value = v->cast<Int32ImmPtr>()->value() != 0;
41   } else if (v->isa<UInt32Imm>()) {
42     *value = v->cast<UInt32ImmPtr>()->value() != 0;
43   } else if (v->isa<FP32Imm>()) {
44     *value = fabs(v->cast<FP32ImmPtr>()->value()) > FLT_EPSILON;
45   } else if (v->isa<FP64Imm>()) {
46     *value = fabs(v->cast<FP64ImmPtr>()->value()) > DBL_EPSILON;
47   } else if (v->isa<StringImm>()) {
48     std::string str = v->cast<StringImmPtr>()->value();
49     *value = str.length() != 0;
50   } else if (v->isa<tensor::Tensor>()) {
51     auto tensor = v->cast<tensor::TensorPtr>();
52     MS_EXCEPTION_IF_NULL(tensor);
53     tensor->data_sync();
54     bool *tensor_data = static_cast<bool *>(tensor->data_c());
55     // maybe need to support if tensor is a bool array
56     auto vb = tensor_data[0];
57     *value = vb;
58   } else {
59     MS_LOG(WARNING) << "value is not supported to cast to be bool";
60     return false;
61   }
62   return true;
63 }
64 
BaseRefToInt(const ValuePtr & v,int64_t * value)65 bool BaseRefToInt(const ValuePtr &v, int64_t *value) {
66   MS_EXCEPTION_IF_NULL(v);
67   if (v->isa<tensor::Tensor>()) {
68     auto tensor = v->cast<tensor::TensorPtr>();
69     tensor->data_sync();
70     if (tensor->Dtype()->ToString() == "Int32") {
71       auto *tensor_data = static_cast<int32_t *>(tensor->data_c());
72       auto vb = tensor_data[0];
73       *value = static_cast<int64_t>(vb);
74     } else if (tensor->Dtype()->ToString() == "Int64") {
75       auto *tensor_data = static_cast<int64_t *>(tensor->data_c());
76       auto vb = tensor_data[0];
77       *value = vb;
78     } else {
79       MS_LOG(ERROR) << "Index must be Int type.";
80     }
81     return true;
82   }
83   MS_LOG(ERROR) << "Index must be tensor type.";
84   return false;
85 }
86 
BaseRefToBool(const BaseRef & v,bool * value)87 bool BaseRefToBool(const BaseRef &v, bool *value) {
88   if (utils::isa<ValuePtr>(v)) {
89     return ValueToBool(utils::cast<ValuePtr>(v), value);
90   } else if (utils::isa<bool>(v)) {
91     auto vb = utils::cast<bool>(v);
92     *value = vb;
93   } else if (utils::isa<int>(v)) {
94     auto vb = utils::cast<int>(v);
95     *value = vb != 0;
96   } else if (utils::isa<unsigned int>(v)) {
97     auto vb = utils::cast<unsigned int>(v);
98     *value = vb != 0;
99   } else if (utils::isa<float>(v)) {
100     auto vb = utils::cast<float>(v);
101     *value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON);
102   } else if (utils::isa<double>(v)) {
103     auto vb = utils::cast<double>(v);
104     *value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON);
105   } else {
106     MS_LOG(DEBUG) << "value is not supported to cast to be bool";
107     return false;
108   }
109   return true;
110 }
111 
112 namespace {
113 // Isomorphism
114 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
115               NodeMapEquiv *const equiv_node);
116 
SameValueNode(const AnfNodePtr & node1,const AnfNodePtr & node2)117 bool SameValueNode(const AnfNodePtr &node1, const AnfNodePtr &node2) {
118   auto a1 = GetValueNode(node1);
119   auto a2 = GetValueNode(node2);
120   if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
121     return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
122   } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
123     return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
124   }
125   return *a1 == *a2;
126 }
127 
SameNodeShallow(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)128 bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
129                      NodeMapEquiv *const equiv_node) {
130   if (equiv_node == nullptr) {
131     MS_LOG(ERROR) << "Invalid equiv_node";
132     return false;
133   }
134   if (equiv_node->count(node1) > 0 && (*equiv_node)[node1] == node2) {
135     return true;
136   }
137   if (IsValueNode<FuncGraph>(node1) && IsValueNode<FuncGraph>(node2)) {
138     return Isomorphic(GetValueNode<FuncGraphPtr>(node1), GetValueNode<FuncGraphPtr>(node2), equiv_func_graph,
139                       equiv_node);
140   }
141   if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) {
142     return SameValueNode(node1, node2);
143   }
144   if (node1->isa<Parameter>() && node2->isa<Parameter>()) {
145     auto para1 = node1->cast<ParameterPtr>();
146     auto para2 = node2->cast<ParameterPtr>();
147     if (para1->name() == para2->name()) {
148       return true;
149     }
150     MS_LOG(DEBUG) << "two parameters are not equal.";
151     return false;
152   }
153   if (AnfUtils::IsCustomActorNode(node1) && AnfUtils::IsCustomActorNode(node2)) {
154     return AnfUtils::IsCutomActorNodeSame(node1, node2);
155   }
156   if (node1->isa<CNode>() && node2->isa<CNode>()) {
157     return SameNode(node1, node2, equiv_func_graph, equiv_node);
158   }
159   MS_LOG(ERROR) << "type error";
160   return false;
161 }
162 
SameNode(const AnfNodePtr & node1,const AnfNodePtr & node2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)163 bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
164               NodeMapEquiv *const equiv_node) {
165   MS_EXCEPTION_IF_NULL(node1);
166   MS_EXCEPTION_IF_NULL(node2);
167   if (node1->isa<CNode>() && node2->isa<CNode>()) {
168     auto &inputs1 = node1->cast<CNodePtr>()->inputs();
169     auto &inputs2 = node2->cast<CNodePtr>()->inputs();
170     for (std::size_t i = 0; i < inputs1.size(); ++i) {
171       if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) {
172         return false;
173       }
174     }
175     return true;
176   }
177   return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
178 }
179 
SameSubgraph(const AnfNodePtr & root1,const AnfNodePtr & root2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)180 bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph,
181                   NodeMapEquiv *const equiv_node) {
182   mindspore::HashSet<AnfNodePtr> done;
183   std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
184 
185   todo.push(std::make_pair(root1, root2));
186   while (!todo.empty()) {
187     AnfNodePtr node1 = todo.top().first;
188     if (done.count(node1) > 0) {
189       todo.pop();
190       continue;
191     }
192     AnfNodePtr node2 = todo.top().second;
193 
194     bool condition = false;
195     const auto &s1 = GetInputs(node1);
196     const auto &s2 = GetInputs(node2);
197 
198     if (s1.size() != s2.size()) {
199       return false;
200     }
201     for (std::size_t i = 0; i < s1.size(); ++i) {
202       if (done.count(s1[i]) == 0) {
203         todo.push(std::make_pair(s1[i], s2[i]));
204         condition = true;
205       }
206     }
207     if (condition) {
208       continue;
209     }
210     (void)done.insert(node1);
211 
212     auto res = SameNode(node1, node2, equiv_func_graph, equiv_node);
213     if (res) {
214       (*equiv_node)[node1] = node2;
215     } else {
216       return false;
217     }
218     todo.pop();
219   }
220   return true;
221 }
222 }  // namespace
223 
Isomorphic(const FuncGraphPtr & fg1,const FuncGraphPtr & fg2,FuncGraphPairMapEquiv * equiv_func_graph,NodeMapEquiv * const equiv_node)224 bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph,
225                 NodeMapEquiv *const equiv_node) {
226   auto fg1_fg2 = std::make_pair(fg1, fg2);
227   if (equiv_func_graph == nullptr) {
228     MS_LOG(ERROR) << "equiv_func_graph not init";
229     return false;
230   }
231   if (equiv_func_graph->find(fg1_fg2) != equiv_func_graph->end()) {
232     return (*equiv_func_graph)[fg1_fg2] != kNotEquiv;
233   }
234   if (fg1 == nullptr || fg2 == nullptr) {
235     MS_LOG(ERROR) << "Invalid function graph";
236     return false;
237   }
238   if (fg1->parameters().size() != fg2->parameters().size()) {
239     MS_LOG(DEBUG) << "parameters size not match";
240     return false;
241   }
242   if (equiv_node != nullptr) {
243     for (std::size_t i = 0; i < fg1->parameters().size(); ++i) {
244       (*equiv_node)[fg1->parameters()[i]] = fg2->parameters()[i];
245     }
246     (*equiv_func_graph)[fg1_fg2] = kPending;
247     auto result = SameSubgraph(fg1->get_return(), fg2->get_return(), equiv_func_graph, equiv_node);
248     (*equiv_func_graph)[fg1_fg2] = EquivState(result);
249     return result;
250   }
251 
252   MS_LOG(ERROR) << "equiv_node not init";
253   return false;
254 }
255 
ScalarToTensor(const ScalarPtr & scalar)256 tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
257   if (scalar == nullptr) {
258     MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
259   }
260   TypePtr data_type = scalar->type();
261   MS_EXCEPTION_IF_NULL(data_type);
262   TypeId type_id = data_type->type_id();
263   switch (type_id) {
264     case kNumberTypeBool:
265       return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type);
266     case kNumberTypeInt8:
267       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
268     case kNumberTypeInt16:
269       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
270     case kNumberTypeInt32:
271       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
272     case kNumberTypeInt64:
273       return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type);
274     case kNumberTypeUInt8:
275       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
276     case kNumberTypeUInt16:
277       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
278     case kNumberTypeUInt32:
279       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
280     case kNumberTypeUInt64:
281       return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type);
282     case kNumberTypeFloat32:
283       return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type);
284     case kNumberTypeFloat64:
285       return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type);
286     default:
287       MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << " is invalid.";
288   }
289 }
290 
291 template <typename T>
ConvertValueListToVector(const ValuePtrList & seq_values)292 std::vector<T> ConvertValueListToVector(const ValuePtrList &seq_values) {
293   size_t element_num = seq_values.size();
294   std::vector<T> array_data(element_num);
295   for (size_t i = 0; i < element_num; i++) {
296     const auto &element = seq_values[i];
297     MS_EXCEPTION_IF_NULL(element);
298     array_data[i] = GetValue<T>(element);
299   }
300   return array_data;
301 }
302 
SequenceToTensor(const ValueSequencePtr & sequence)303 tensor::TensorPtr SequenceToTensor(const ValueSequencePtr &sequence) {
304   MS_EXCEPTION_IF_NULL(sequence);
305   const auto &element_values = sequence->value();
306   if (element_values.empty()) {
307     std::vector<int32_t> array_data;
308     MS_LOG(WARNING) << "The value sequence is empty.";
309     return std::make_shared<tensor::Tensor>(std::move(array_data), TypeIdToType(kNumberTypeInt32));
310   }
311 
312   const auto &first_element = element_values[0];
313   if (!first_element->isa<Scalar>()) {
314     MS_LOG(EXCEPTION) << "For sequence value, only sequence of scalar can convert to TensorValue, but got: "
315                       << sequence->ToString();
316   }
317 
318   TypePtr data_type = first_element->type();
319   MS_EXCEPTION_IF_NULL(data_type);
320   TypeId type_id = data_type->type_id();
321   switch (type_id) {
322     case kNumberTypeInt32:
323       return std::make_shared<tensor::Tensor>(ConvertValueListToVector<int32_t>(element_values), data_type);
324     case kNumberTypeInt64:
325       return std::make_shared<tensor::Tensor>(ConvertValueListToVector<int64_t>(element_values), data_type);
326     case kNumberTypeFloat64:
327       return std::make_shared<tensor::Tensor>(ConvertValueListToVector<double>(element_values), data_type);
328     default:
329       MS_LOG(EXCEPTION) << "When convert sequence to tensor, the sequence type: " << data_type << " is invalid.";
330   }
331 }
332 
333 namespace {
ConvertScalarToKernelTensorValue(const ValuePtr & scalar)334 KernelTensorValuePtr ConvertScalarToKernelTensorValue(const ValuePtr &scalar) {
335   MS_EXCEPTION_IF_NULL(scalar);
336   TypePtr data_type = scalar->type();
337   MS_EXCEPTION_IF_NULL(data_type);
338   TypeId type_id = data_type->type_id();
339   switch (type_id) {
340     case kNumberTypeBool:
341       return std::make_shared<KernelTensorValue>(GetValue<bool>(scalar), data_type);
342     case kNumberTypeInt8:
343       return std::make_shared<KernelTensorValue>(GetValue<int8_t>(scalar), data_type);
344     case kNumberTypeInt16:
345       return std::make_shared<KernelTensorValue>(GetValue<int16_t>(scalar), data_type);
346     case kNumberTypeInt32:
347       return std::make_shared<KernelTensorValue>(GetValue<int32_t>(scalar), data_type);
348     case kNumberTypeInt64:
349       return std::make_shared<KernelTensorValue>(GetValue<int64_t>(scalar), data_type);
350     case kNumberTypeUInt8:
351       return std::make_shared<KernelTensorValue>(GetValue<uint8_t>(scalar), data_type);
352     case kNumberTypeUInt16:
353       return std::make_shared<KernelTensorValue>(GetValue<uint16_t>(scalar), data_type);
354     case kNumberTypeUInt32:
355       return std::make_shared<KernelTensorValue>(GetValue<uint32_t>(scalar), data_type);
356     case kNumberTypeUInt64:
357       return std::make_shared<KernelTensorValue>(GetValue<uint64_t>(scalar), data_type);
358     case kNumberTypeFloat32:
359       return std::make_shared<KernelTensorValue>(GetValue<float>(scalar), data_type);
360     case kNumberTypeFloat64:
361       return std::make_shared<KernelTensorValue>(GetValue<double>(scalar), data_type);
362     default:
363       MS_LOG(EXCEPTION) << "When convert scalar to KernelTensorValue, the scalar type: " << data_type->ToString()
364                         << " is invalid.";
365   }
366 }
367 
368 template <typename T>
ConvertValueListToKernelTensorValue(const ValuePtrList & seq_values,const TypePtr & type)369 KernelTensorValuePtr ConvertValueListToKernelTensorValue(const ValuePtrList &seq_values, const TypePtr &type) {
370   MS_EXCEPTION_IF_NULL(type);
371   size_t element_num = seq_values.size();
372   std::vector<uint8_t> array_data(element_num * sizeof(T));
373   T *array_data_ptr = reinterpret_cast<T *>(array_data.data());
374   MS_EXCEPTION_IF_NULL(array_data_ptr);
375 
376   for (size_t i = 0; i < element_num; i++) {
377     const auto &element = seq_values[i];
378     MS_EXCEPTION_IF_NULL(element);
379     array_data_ptr[i] = GetValue<T>(element);
380   }
381   return std::make_shared<KernelTensorValue>(std::move(array_data), type);
382 }
383 
ConvertSequenceToKernelTensorValue(const ValueSequencePtr & value_seq)384 KernelTensorValuePtr ConvertSequenceToKernelTensorValue(const ValueSequencePtr &value_seq) {
385   MS_EXCEPTION_IF_NULL(value_seq);
386   const auto &element_values = value_seq->value();
387   std::vector<uint8_t> array_data;
388   if (element_values.empty()) {
389     MS_LOG(INFO) << "The value sequence is empty.";
390     return std::make_shared<KernelTensorValue>(std::move(array_data), value_seq->type());
391   }
392 
393   const auto &first_element = element_values[0];
394   if (!first_element->isa<Scalar>()) {
395     MS_LOG(EXCEPTION) << "For sequence value, only sequence of scalar can convert to KernelTensorValue, but got: "
396                       << value_seq->ToString();
397   }
398 
399   TypePtr data_type = first_element->type();
400   MS_EXCEPTION_IF_NULL(data_type);
401   TypeId type_id = data_type->type_id();
402 
403   switch (type_id) {
404     case kNumberTypeBool:
405       return ConvertValueListToKernelTensorValue<bool>(element_values, value_seq->type());
406     case kNumberTypeInt8:
407       return ConvertValueListToKernelTensorValue<int8_t>(element_values, value_seq->type());
408     case kNumberTypeInt16:
409       return ConvertValueListToKernelTensorValue<int16_t>(element_values, value_seq->type());
410     case kNumberTypeInt32:
411       return ConvertValueListToKernelTensorValue<int32_t>(element_values, value_seq->type());
412     case kNumberTypeInt64:
413       return ConvertValueListToKernelTensorValue<int64_t>(element_values, value_seq->type());
414     case kNumberTypeUInt8:
415       return ConvertValueListToKernelTensorValue<uint8_t>(element_values, value_seq->type());
416     case kNumberTypeUInt16:
417       return ConvertValueListToKernelTensorValue<uint16_t>(element_values, value_seq->type());
418     case kNumberTypeUInt32:
419       return ConvertValueListToKernelTensorValue<uint32_t>(element_values, value_seq->type());
420     case kNumberTypeUInt64:
421       return ConvertValueListToKernelTensorValue<uint64_t>(element_values, value_seq->type());
422     case kNumberTypeFloat32:
423       return ConvertValueListToKernelTensorValue<float>(element_values, value_seq->type());
424     case kNumberTypeFloat64:
425       return ConvertValueListToKernelTensorValue<double>(element_values, value_seq->type());
426     default:
427       MS_LOG(EXCEPTION) << "When convert sequence to KernelTensorValue, the element type: " << data_type->ToString()
428                         << " is invalid.";
429   }
430 }
431 }  // namespace
432 
ConvertValueToKernelTensorValue(const ValuePtr & value)433 KernelTensorValuePtr ConvertValueToKernelTensorValue(const ValuePtr &value) {
434   MS_EXCEPTION_IF_NULL(value);
435   if (value->isa<Scalar>()) {
436     return ConvertScalarToKernelTensorValue(value);
437   } else if (value->isa<ValueSequence>()) {
438     auto value_seq = value->cast<ValueSequencePtr>();
439     return ConvertSequenceToKernelTensorValue(value_seq);
440   } else if (value->isa<tensor::BaseTensor>()) {
441     auto tensor_ptr = value->cast<tensor::BaseTensorPtr>();
442     MS_EXCEPTION_IF_NULL(tensor_ptr);
443     return std::make_shared<KernelTensorValue>(tensor_ptr->data_ptr(), tensor_ptr->type());
444   } else if (value->isa<StringImm>()) {
445     auto string_ptr = value->cast<StringImmPtr>();
446     MS_EXCEPTION_IF_NULL(string_ptr);
447     return std::make_shared<KernelTensorValue>(string_ptr, string_ptr->type());
448   } else if (value->isa<Type>()) {
449     return nullptr;
450   } else {
451     MS_LOG(WARNING) << "KernelTensorValue not support the value type: " << value->ToString();
452     return nullptr;
453   }
454 }
455 
456 template <typename T, typename Scalar>
GetTensorValue(const tensor::TensorPtr & tensor)457 ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) {
458   ValuePtr ret;
459   auto tensor_value = TensorValueToVector<T>(tensor);
460   if (tensor_value.size() == 1) {
461     ret = std::make_shared<Scalar>(tensor_value[0]);
462   } else {
463     std::vector<ValuePtr> value_vec;
464     for (const auto &elem : tensor_value) {
465       auto value = std::make_shared<Scalar>(elem);
466       MS_EXCEPTION_IF_NULL(value);
467       value_vec.push_back(value);
468     }
469     ret = std::make_shared<ValueTuple>(value_vec);
470   }
471   return ret;
472 }
473 
CreateValueFromTensor(const tensor::TensorPtr & tensor)474 ValuePtr CreateValueFromTensor(const tensor::TensorPtr &tensor) {
475   ValuePtr ret;
476   if (tensor->has_user_data(kTensorValueIsType)) {
477     ret = tensor->user_data<mindspore::Type>(kTensorValueIsType);
478     return ret;
479   }
480 
481   if (tensor->has_user_data(kTensorValueIsEmpty)) {
482     ret = tensor->user_data<mindspore::Value>(kTensorValueIsEmpty);
483     return ret;
484   }
485 
486   TypePtr data_type = tensor->Dtype();
487   MS_EXCEPTION_IF_NULL(data_type);
488   TypeId type_id = data_type->type_id();
489   switch (type_id) {
490     case kNumberTypeBool: {
491       ret = GetTensorValue<bool, BoolImm>(tensor);
492       break;
493     }
494 
495     case kNumberTypeInt8: {
496       ret = GetTensorValue<int8_t, Int8Imm>(tensor);
497       break;
498     }
499 
500     case kNumberTypeUInt8: {
501       ret = GetTensorValue<uint8_t, UInt8Imm>(tensor);
502       break;
503     }
504 
505     case kNumberTypeInt16: {
506       ret = GetTensorValue<int16_t, Int16Imm>(tensor);
507       break;
508     }
509 
510     case kNumberTypeUInt16: {
511       ret = GetTensorValue<uint16_t, UInt16Imm>(tensor);
512       break;
513     }
514 
515     case kNumberTypeInt32: {
516       ret = GetTensorValue<int32_t, Int32Imm>(tensor);
517       break;
518     }
519 
520     case kNumberTypeUInt32: {
521       ret = GetTensorValue<uint32_t, UInt32Imm>(tensor);
522       break;
523     }
524 
525     case kNumberTypeInt64: {
526       ret = GetTensorValue<int64_t, Int64Imm>(tensor);
527       break;
528     }
529 
530     case kNumberTypeUInt64: {
531       ret = GetTensorValue<uint64_t, UInt64Imm>(tensor);
532       break;
533     }
534 
535     case kNumberTypeFloat32: {
536       ret = GetTensorValue<float, FP32Imm>(tensor);
537       break;
538     }
539 
540     case kNumberTypeFloat64: {
541       ret = GetTensorValue<double, FP64Imm>(tensor);
542       break;
543     }
544 
545     default:
546       MS_LOG(EXCEPTION) << "Can't parse attr value :" << tensor->ToString() << ", Type:" << tensor->type_name();
547   }
548   return ret;
549 }
550 
TensorValueToTensor(const ValuePtr & value,std::vector<tensor::BaseTensorPtr> * tensors)551 void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::BaseTensorPtr> *tensors) {
552   MS_EXCEPTION_IF_NULL(value);
553   MS_EXCEPTION_IF_NULL(tensors);
554   if (value->isa<tensor::BaseTensor>()) {
555     auto tensor = value->cast<tensor::BaseTensorPtr>();
556     MS_EXCEPTION_IF_NULL(tensor);
557     tensors->emplace_back(tensor);
558   } else if (value->isa<Scalar>()) {
559     auto tensor = ScalarToTensor(value->cast<ScalarPtr>());
560     MS_EXCEPTION_IF_NULL(tensor);
561     tensors->emplace_back(tensor);
562   } else if (value->isa<ValueSequence>()) {
563     const auto &value_seq = value->cast<ValueSequencePtr>();
564     MS_EXCEPTION_IF_NULL(value_seq);
565     for (const auto &v : value_seq->value()) {
566       TensorValueToTensor(v, tensors);
567     }
568   }
569 }
570 
CountValueNum(const ValueSequencePtr & value_sequence)571 size_t CountValueNum(const ValueSequencePtr &value_sequence) {
572   MS_EXCEPTION_IF_NULL(value_sequence);
573   size_t cnt = 0;
574   const auto &value_list = value_sequence->value();
575   for (const auto &value : value_list) {
576     if (value->isa<ValueSequence>()) {
577       cnt += CountValueNum(value->cast<ValueSequencePtr>());
578     } else {
579       cnt++;
580     }
581   }
582   return cnt;
583 }
584 
IsAKGSparseOP(const AnfNodePtr & cnode)585 bool IsAKGSparseOP(const AnfNodePtr &cnode) {
586   MS_EXCEPTION_IF_NULL(cnode);
587   const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul,  prim::kPrimCSRMV,  prim::kPrimCSRGather,
588                            prim::kPrimCSR2COO,      prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
589   return IsOneOfPrimitiveCNode(cnode, prims);
590 }
591 
592 namespace {
ConvertTensorListToShapeVector(const tensor::TensorPtrList & tensor_list,size_t index)593 ShapeVector ConvertTensorListToShapeVector(const tensor::TensorPtrList &tensor_list, size_t index) {
594   ShapeVector shape;
595   if (index >= tensor_list.size()) {
596     MS_LOG(EXCEPTION) << "Index " << index << " is out of range of " << tensor_list.size();
597     return shape;
598   }
599 
600   auto converter = [](const tensor::TensorPtr tensorptr) {
601     MS_EXCEPTION_IF_NULL(tensorptr);
602     if (tensorptr->DataDim() != 0) {
603       MS_LOG(EXCEPTION) << "Element must be scalar!";
604     }
605     tensorptr->data_sync(false);
606     return *(static_cast<int64_t *>(tensorptr->data_c()));
607   };
608   std::transform(tensor_list.begin() + index, tensor_list.end(), std::back_inserter(shape), converter);
609   if (shape.empty()) {
610     MS_LOG(ERROR) << "ShapeVector is empty!";
611   }
612   return shape;
613 }
TensorListToCSRTensor(const tensor::TensorPtrList & tensor_list)614 tensor::CSRTensorPtr TensorListToCSRTensor(const tensor::TensorPtrList &tensor_list) {
615   tensor::TensorPtr indptr = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndptrIdx]);
616   tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndicesIdx]);
617   tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kValuesIdx]);
618   ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::CSRTensor::kShapeIdx);
619   auto csr_tensor_ptr = std::make_shared<tensor::CSRTensor>(indptr, indices, values, shape);
620   return csr_tensor_ptr;
621 }
622 
TensorListToCOOTensor(const tensor::TensorPtrList & tensor_list)623 tensor::COOTensorPtr TensorListToCOOTensor(const tensor::TensorPtrList &tensor_list) {
624   tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kIndicesIdx]);
625   tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kValuesIdx]);
626   ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::COOTensor::kShapeIdx);
627   auto coo_tensor_ptr = std::make_shared<tensor::COOTensor>(indices, values, shape);
628   return coo_tensor_ptr;
629 }
630 }  // namespace
631 
TensorListToSparseTensor(const abstract::AbstractBasePtr & abs_sparse,const tensor::TensorPtrList & tensor_list)632 tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
633                                                      const tensor::TensorPtrList &tensor_list) {
634   if (abs_sparse->isa<abstract::AbstractCOOTensor>()) {
635     return TensorListToCOOTensor(tensor_list);
636   }
637   return TensorListToCSRTensor(tensor_list);
638 }
639 
BaseShapeToShapeVector(const abstract::BaseShapePtr & base_shape)640 std::vector<ShapeVector> BaseShapeToShapeVector(const abstract::BaseShapePtr &base_shape) {
641   MS_EXCEPTION_IF_NULL(base_shape);
642   if (base_shape->isa<abstract::Shape>()) {
643     const auto &shape = base_shape->cast<abstract::ShapePtr>();
644     MS_EXCEPTION_IF_NULL(shape);
645     return {shape->shape()};
646   } else if (base_shape->isa<abstract::SequenceShape>()) {
647     const auto &tuple_shape = base_shape->cast<abstract::SequenceShapePtr>();
648     MS_EXCEPTION_IF_NULL(tuple_shape);
649     if (tuple_shape->size() == 0) {
650       return {};
651     }
652     // If the shape is a tuple shape, all shapes need to be consistent.
653     auto element_base_shape = (*tuple_shape)[0];
654     if (element_base_shape->isa<abstract::Shape>()) {
655       const auto &element_shape = element_base_shape->cast<abstract::ShapePtr>();
656       MS_EXCEPTION_IF_NULL(element_shape);
657       return std::vector<ShapeVector>(tuple_shape->size(), element_shape->shape());
658     } else if (element_base_shape->isa<abstract::NoShape>()) {
659       return std::vector<ShapeVector>(tuple_shape->size(), {1});
660     }
661   } else if (base_shape->isa<abstract::NoShape>() || base_shape->isa<abstract::DynamicSequenceShape>()) {
662     return {};
663   }
664   MS_LOG(WARNING) << "Invalid shape:" << base_shape->ToString();
665   return {};
666 }
667 
BaseShapeToShape(const abstract::BaseShapePtr & base_shape)668 ShapeVector BaseShapeToShape(const abstract::BaseShapePtr &base_shape) {
669   MS_EXCEPTION_IF_NULL(base_shape);
670   if (base_shape->isa<abstract::Shape>()) {
671     const auto &shape = base_shape->cast<abstract::ShapePtr>();
672     MS_EXCEPTION_IF_NULL(shape);
673     return shape->shape();
674   } else if (base_shape->isa<abstract::NoShape>()) {
675     return {};
676   }
677   MS_LOG(WARNING) << "Invalid shape:" << base_shape->ToString();
678   return {};
679 }
680 
UpdateValueByAttrDataType(const ValuePtr & value,const std::string & attr_data_type)681 ValuePtr UpdateValueByAttrDataType(const ValuePtr &value, const std::string &attr_data_type) {
682   static std::set<std::string> kListDataType = {"listInt", "listStr", "listBool", "listFloat"};
683   auto iter = kListDataType.find(attr_data_type);
684   ValuePtr ret = value;
685   if (iter != kListDataType.end()) {
686     if (!value->isa<ValueSequence>()) {
687       std::vector<ValuePtr> value_vec;
688       value_vec.push_back(value);
689       ret = std::make_shared<ValueTuple>(value_vec);
690     }
691   }
692   return ret;
693 }
694 
695 namespace {
GetHashId(int a,int b)696 size_t GetHashId(int a, int b) { return a < b ? hash_combine(a, b) : hash_combine(b, a); }
697 
698 static const std::map<size_t, TypeId> tensor_tensor_convert_map = {
699   // Bool
700   {GetHashId(kNumberTypeBool, kNumberTypeBool), kNumberTypeBool},
701   {GetHashId(kNumberTypeBool, kNumberTypeInt8), kNumberTypeInt8},
702   {GetHashId(kNumberTypeBool, kNumberTypeInt16), kNumberTypeInt16},
703   {GetHashId(kNumberTypeBool, kNumberTypeInt32), kNumberTypeInt32},
704   {GetHashId(kNumberTypeBool, kNumberTypeInt64), kNumberTypeInt64},
705   {GetHashId(kNumberTypeBool, kNumberTypeUInt8), kNumberTypeUInt8},
706   {GetHashId(kNumberTypeBool, kNumberTypeUInt16), kNumberTypeUInt16},
707   {GetHashId(kNumberTypeBool, kNumberTypeUInt32), kNumberTypeUInt32},
708   {GetHashId(kNumberTypeBool, kNumberTypeUInt64), kNumberTypeUInt64},
709   {GetHashId(kNumberTypeBool, kNumberTypeFloat16), kNumberTypeFloat16},
710   {GetHashId(kNumberTypeBool, kNumberTypeBFloat16), kNumberTypeBFloat16},
711   {GetHashId(kNumberTypeBool, kNumberTypeFloat32), kNumberTypeFloat32},
712   {GetHashId(kNumberTypeBool, kNumberTypeFloat64), kNumberTypeFloat64},
713   {GetHashId(kNumberTypeBool, kNumberTypeComplex64), kNumberTypeComplex64},
714   {GetHashId(kNumberTypeBool, kNumberTypeComplex128), kNumberTypeComplex128},
715   // Int8
716   {GetHashId(kNumberTypeInt8, kNumberTypeInt8), kNumberTypeInt8},
717   {GetHashId(kNumberTypeInt8, kNumberTypeInt16), kNumberTypeInt16},
718   {GetHashId(kNumberTypeInt8, kNumberTypeInt32), kNumberTypeInt32},
719   {GetHashId(kNumberTypeInt8, kNumberTypeInt64), kNumberTypeInt64},
720   {GetHashId(kNumberTypeInt8, kNumberTypeUInt8), kNumberTypeInt16},
721   {GetHashId(kNumberTypeInt8, kNumberTypeFloat16), kNumberTypeFloat16},
722   {GetHashId(kNumberTypeInt8, kNumberTypeBFloat16), kNumberTypeBFloat16},
723   {GetHashId(kNumberTypeInt8, kNumberTypeFloat32), kNumberTypeFloat32},
724   {GetHashId(kNumberTypeInt8, kNumberTypeFloat64), kNumberTypeFloat64},
725   {GetHashId(kNumberTypeInt8, kNumberTypeComplex64), kNumberTypeComplex64},
726   {GetHashId(kNumberTypeInt8, kNumberTypeComplex128), kNumberTypeComplex128},
727   // Int16
728   {GetHashId(kNumberTypeInt16, kNumberTypeInt16), kNumberTypeInt16},
729   {GetHashId(kNumberTypeInt16, kNumberTypeInt32), kNumberTypeInt32},
730   {GetHashId(kNumberTypeInt16, kNumberTypeInt64), kNumberTypeInt64},
731   {GetHashId(kNumberTypeInt16, kNumberTypeUInt8), kNumberTypeInt16},
732   {GetHashId(kNumberTypeInt16, kNumberTypeFloat16), kNumberTypeFloat16},
733   {GetHashId(kNumberTypeInt16, kNumberTypeBFloat16), kNumberTypeBFloat16},
734   {GetHashId(kNumberTypeInt16, kNumberTypeFloat32), kNumberTypeFloat32},
735   {GetHashId(kNumberTypeInt16, kNumberTypeFloat64), kNumberTypeFloat64},
736   {GetHashId(kNumberTypeInt16, kNumberTypeComplex64), kNumberTypeComplex64},
737   {GetHashId(kNumberTypeInt16, kNumberTypeComplex128), kNumberTypeComplex128},
738   // Int32
739   {GetHashId(kNumberTypeInt32, kNumberTypeInt32), kNumberTypeInt32},
740   {GetHashId(kNumberTypeInt32, kNumberTypeInt64), kNumberTypeInt64},
741   {GetHashId(kNumberTypeInt32, kNumberTypeUInt8), kNumberTypeInt32},
742   {GetHashId(kNumberTypeInt32, kNumberTypeFloat16), kNumberTypeFloat16},
743   {GetHashId(kNumberTypeInt32, kNumberTypeBFloat16), kNumberTypeBFloat16},
744   {GetHashId(kNumberTypeInt32, kNumberTypeFloat32), kNumberTypeFloat32},
745   {GetHashId(kNumberTypeInt32, kNumberTypeFloat64), kNumberTypeFloat64},
746   {GetHashId(kNumberTypeInt32, kNumberTypeComplex64), kNumberTypeComplex64},
747   {GetHashId(kNumberTypeInt32, kNumberTypeComplex128), kNumberTypeComplex128},
748   // Int64
749   {GetHashId(kNumberTypeInt64, kNumberTypeInt64), kNumberTypeInt64},
750   {GetHashId(kNumberTypeInt64, kNumberTypeUInt8), kNumberTypeInt64},
751   {GetHashId(kNumberTypeInt64, kNumberTypeFloat16), kNumberTypeFloat16},
752   {GetHashId(kNumberTypeInt64, kNumberTypeBFloat16), kNumberTypeBFloat16},
753   {GetHashId(kNumberTypeInt64, kNumberTypeFloat32), kNumberTypeFloat32},
754   {GetHashId(kNumberTypeInt64, kNumberTypeFloat64), kNumberTypeFloat64},
755   {GetHashId(kNumberTypeInt64, kNumberTypeComplex64), kNumberTypeComplex64},
756   {GetHashId(kNumberTypeInt64, kNumberTypeComplex128), kNumberTypeComplex128},
757   // UInt8
758   {GetHashId(kNumberTypeUInt8, kNumberTypeUInt8), kNumberTypeUInt8},
759   {GetHashId(kNumberTypeUInt8, kNumberTypeFloat16), kNumberTypeFloat16},
760   {GetHashId(kNumberTypeUInt8, kNumberTypeBFloat16), kNumberTypeBFloat16},
761   {GetHashId(kNumberTypeUInt8, kNumberTypeFloat32), kNumberTypeFloat32},
762   {GetHashId(kNumberTypeUInt8, kNumberTypeFloat64), kNumberTypeFloat64},
763   {GetHashId(kNumberTypeUInt8, kNumberTypeComplex64), kNumberTypeComplex64},
764   {GetHashId(kNumberTypeUInt8, kNumberTypeComplex128), kNumberTypeComplex128},
765   // UInt16
766   {GetHashId(kNumberTypeUInt16, kNumberTypeUInt16), kNumberTypeUInt16},
767   // UInt32
768   {GetHashId(kNumberTypeUInt32, kNumberTypeUInt32), kNumberTypeUInt32},
769   // UInt64
770   {GetHashId(kNumberTypeUInt64, kNumberTypeUInt64), kNumberTypeUInt64},
771   // Float16
772   {GetHashId(kNumberTypeFloat16, kNumberTypeFloat16), kNumberTypeFloat16},
773   {GetHashId(kNumberTypeFloat16, kNumberTypeBFloat16), kNumberTypeFloat32},
774   {GetHashId(kNumberTypeFloat16, kNumberTypeFloat32), kNumberTypeFloat32},
775   {GetHashId(kNumberTypeFloat16, kNumberTypeFloat64), kNumberTypeFloat64},
776   {GetHashId(kNumberTypeFloat16, kNumberTypeComplex64), kNumberTypeComplex64},
777   {GetHashId(kNumberTypeFloat16, kNumberTypeComplex128), kNumberTypeComplex128},
778   // BFloat16
779   {GetHashId(kNumberTypeBFloat16, kNumberTypeBFloat16), kNumberTypeBFloat16},
780   {GetHashId(kNumberTypeBFloat16, kNumberTypeFloat32), kNumberTypeFloat32},
781   {GetHashId(kNumberTypeBFloat16, kNumberTypeFloat64), kNumberTypeFloat64},
782   {GetHashId(kNumberTypeBFloat16, kNumberTypeComplex64), kNumberTypeComplex64},
783   {GetHashId(kNumberTypeBFloat16, kNumberTypeComplex128), kNumberTypeComplex128},
784   // Float32
785   {GetHashId(kNumberTypeFloat32, kNumberTypeFloat32), kNumberTypeFloat32},
786   {GetHashId(kNumberTypeFloat32, kNumberTypeFloat64), kNumberTypeFloat64},
787   {GetHashId(kNumberTypeFloat32, kNumberTypeComplex64), kNumberTypeComplex64},
788   {GetHashId(kNumberTypeFloat32, kNumberTypeComplex128), kNumberTypeComplex128},
789   // Float64
790   {GetHashId(kNumberTypeFloat64, kNumberTypeFloat64), kNumberTypeFloat64},
791   {GetHashId(kNumberTypeFloat64, kNumberTypeComplex64), kNumberTypeComplex128},
792   {GetHashId(kNumberTypeFloat64, kNumberTypeComplex128), kNumberTypeComplex128},
793   // Complex64
794   {GetHashId(kNumberTypeComplex64, kNumberTypeComplex64), kNumberTypeComplex64},
795   {GetHashId(kNumberTypeComplex64, kNumberTypeComplex128), kNumberTypeComplex128},
796   // Complex128
797   {GetHashId(kNumberTypeComplex128, kNumberTypeComplex128), kNumberTypeComplex128},
798 };
799 
800 static const std::map<size_t, TypeId> scalar_tensor_convert_map = {
801   // Scalar is bool.
802   {GetHashId(kNumberTypeBool, kNumberTypeBool), kNumberTypeBool},
803   {GetHashId(kNumberTypeBool, kNumberTypeInt8), kNumberTypeInt8},
804   {GetHashId(kNumberTypeBool, kNumberTypeInt16), kNumberTypeInt16},
805   {GetHashId(kNumberTypeBool, kNumberTypeInt32), kNumberTypeInt32},
806   {GetHashId(kNumberTypeBool, kNumberTypeInt64), kNumberTypeInt64},
807   {GetHashId(kNumberTypeBool, kNumberTypeUInt8), kNumberTypeUInt8},
808   {GetHashId(kNumberTypeBool, kNumberTypeUInt16), kNumberTypeUInt16},
809   {GetHashId(kNumberTypeBool, kNumberTypeUInt32), kNumberTypeUInt32},
810   {GetHashId(kNumberTypeBool, kNumberTypeUInt64), kNumberTypeUInt64},
811   {GetHashId(kNumberTypeBool, kNumberTypeFloat16), kNumberTypeFloat16},
812   {GetHashId(kNumberTypeBool, kNumberTypeBFloat16), kNumberTypeBFloat16},
813   {GetHashId(kNumberTypeBool, kNumberTypeFloat32), kNumberTypeFloat32},
814   {GetHashId(kNumberTypeBool, kNumberTypeFloat64), kNumberTypeFloat64},
815   {GetHashId(kNumberTypeBool, kNumberTypeComplex64), kNumberTypeComplex64},
816   {GetHashId(kNumberTypeBool, kNumberTypeComplex128), kNumberTypeComplex128},
817   // Scalar is int.
818   {GetHashId(kNumberTypeInt64, kNumberTypeBool), kNumberTypeInt64},
819   {GetHashId(kNumberTypeInt64, kNumberTypeInt8), kNumberTypeInt8},
820   {GetHashId(kNumberTypeInt64, kNumberTypeInt16), kNumberTypeInt16},
821   {GetHashId(kNumberTypeInt64, kNumberTypeInt32), kNumberTypeInt32},
822   {GetHashId(kNumberTypeInt64, kNumberTypeInt64), kNumberTypeInt64},
823   {GetHashId(kNumberTypeInt64, kNumberTypeUInt8), kNumberTypeUInt8},
824   {GetHashId(kNumberTypeInt64, kNumberTypeFloat16), kNumberTypeFloat16},
825   {GetHashId(kNumberTypeInt64, kNumberTypeBFloat16), kNumberTypeBFloat16},
826   {GetHashId(kNumberTypeInt64, kNumberTypeFloat32), kNumberTypeFloat32},
827   {GetHashId(kNumberTypeInt64, kNumberTypeFloat64), kNumberTypeFloat64},
828   {GetHashId(kNumberTypeInt64, kNumberTypeComplex64), kNumberTypeComplex64},
829   {GetHashId(kNumberTypeInt64, kNumberTypeComplex128), kNumberTypeComplex128},
830   // Scalar is float.
831   {GetHashId(kNumberTypeFloat32, kNumberTypeBool), kNumberTypeFloat32},
832   {GetHashId(kNumberTypeFloat32, kNumberTypeInt8), kNumberTypeFloat32},
833   {GetHashId(kNumberTypeFloat32, kNumberTypeInt16), kNumberTypeFloat32},
834   {GetHashId(kNumberTypeFloat32, kNumberTypeInt32), kNumberTypeFloat32},
835   {GetHashId(kNumberTypeFloat32, kNumberTypeInt64), kNumberTypeFloat32},
836   {GetHashId(kNumberTypeFloat32, kNumberTypeUInt8), kNumberTypeFloat32},
837   {GetHashId(kNumberTypeFloat32, kNumberTypeFloat16), kNumberTypeFloat16},
838   {GetHashId(kNumberTypeFloat32, kNumberTypeBFloat16), kNumberTypeBFloat16},
839   {GetHashId(kNumberTypeFloat32, kNumberTypeFloat32), kNumberTypeFloat32},
840   {GetHashId(kNumberTypeFloat32, kNumberTypeFloat64), kNumberTypeFloat64},
841   {GetHashId(kNumberTypeFloat32, kNumberTypeComplex64), kNumberTypeComplex64},
842   {GetHashId(kNumberTypeFloat32, kNumberTypeComplex128), kNumberTypeComplex128},
843 };
844 
ConvertTypeForTensorsOrScalars(const TypeId & current,const TypeId & other,const size_t hash_id)845 TypeId ConvertTypeForTensorsOrScalars(const TypeId &current, const TypeId &other, const size_t hash_id) {
846   auto iter = tensor_tensor_convert_map.find(hash_id);
847   if (iter != tensor_tensor_convert_map.end()) {
848     return iter->second;
849   }
850   MS_EXCEPTION(TypeError) << "Type implicit conversion between " << TypeIdToString(current) << " and "
851                           << TypeIdToString(other) << " is not supported.";
852 }
853 
ConvertTypeBetweenTensorAndScalar(const TypeId & tensor_type_id,const TypeId & scalar_type_id,const size_t hash_id)854 TypeId ConvertTypeBetweenTensorAndScalar(const TypeId &tensor_type_id, const TypeId &scalar_type_id,
855                                          const size_t hash_id) {
856   auto iter = scalar_tensor_convert_map.find(hash_id);
857   if (iter != scalar_tensor_convert_map.end()) {
858     return iter->second;
859   }
860   MS_EXCEPTION(TypeError) << "Type implicit conversion between Tensor[" << TypeIdToString(tensor_type_id) << "] and "
861                           << TypeIdToString(scalar_type_id) << " is not supported.";
862 }
863 
GetConversionType(const TypeId & current,bool current_arg_is_tensor,bool is_parameter,const std::pair<TypeId,bool> & sig_type,const TypeId & ref_type_id)864 TypeId GetConversionType(const TypeId &current, bool current_arg_is_tensor, bool is_parameter,
865                          const std::pair<TypeId, bool> &sig_type, const TypeId &ref_type_id) {
866   TypeId saved_type_id = sig_type.first;
867   bool saved_has_tensor = sig_type.second;
868   if (current == saved_type_id) {
869     return current;
870   }
871 
872   if (current != kTypeUnknown && saved_type_id != kTypeUnknown) {
873     auto hash_id = GetHashId(current, saved_type_id);
874     // Tensor + Scalar, Scalar + Tensor
875     if (MS_UNLIKELY(current_arg_is_tensor ^ saved_has_tensor)) {
876       return ConvertTypeBetweenTensorAndScalar(current, saved_type_id, hash_id);
877     }
878     // Tensor + Tensor, Scalar + Scalar
879     if ((is_parameter || saved_type_id == ref_type_id) &&
880         hash_id == GetHashId(kNumberTypeFloat16, kNumberTypeBFloat16)) {
881       // "saved_type_id == ref_type_id": if Parameter exists, its type_id should be equal to the saved_type_id,
882       // otherwise it means that the wrong type cast will be performed on the Parameter.
883       static bool already_printed = false;
884       if (!already_printed) {
885         already_printed = true;
886         MS_LOG(WARNING) << "For operators with side effects, there is an implicit type conversion between "
887                         << TypeIdToString(current) << " and " << TypeIdToString(saved_type_id)
888                         << ", which may result in loss of precision. It is recommended to use Float32.";
889       }
890       return is_parameter ? current : saved_type_id;
891     }
892     return ConvertTypeForTensorsOrScalars(current, saved_type_id, hash_id);
893   }
894   return current != kTypeUnknown ? current : saved_type_id;
895 }
896 }  // namespace
897 
GetSignatureTypeMap(const std::vector<SignatureEnumDType> & dtypes,const std::vector<TypeId> & args_type_id,const std::vector<bool> & args_is_tensor,const std::set<size_t> & write_indices)898 std::map<SignatureEnumDType, std::pair<TypeId, bool>> GetSignatureTypeMap(const std::vector<SignatureEnumDType> &dtypes,
899                                                                           const std::vector<TypeId> &args_type_id,
900                                                                           const std::vector<bool> &args_is_tensor,
901                                                                           const std::set<size_t> &write_indices) {
902   // {T0: (target_type_id=Int32, has_tensor=true), T1: (target_type_id=Float32, has_tensor=false), ...}
903   std::map<SignatureEnumDType, std::pair<TypeId, bool>> sig_type_map;
904   std::map<SignatureEnumDType, TypeId> ref_type_map;
905   size_t args_size = args_type_id.size();
906   for (size_t i = 0; i < args_size; ++i) {
907     bool is_parameter = write_indices.find(i) != write_indices.end();
908     const auto &it = sig_type_map.find(dtypes[i]);
909     if (it == sig_type_map.end()) {
910       (void)sig_type_map.insert(std::make_pair(dtypes[i], std::make_pair(args_type_id[i], args_is_tensor[i])));
911       (void)ref_type_map.insert(std::make_pair(dtypes[i], is_parameter ? args_type_id[i] : kTypeUnknown));
912     } else {
913       it->second.first =
914         GetConversionType(args_type_id[i], args_is_tensor[i], is_parameter, it->second, ref_type_map[dtypes[i]]);
915       it->second.second = args_is_tensor[i] || it->second.second;
916       if (is_parameter && ref_type_map[dtypes[i]] == kTypeUnknown) {
917         ref_type_map[dtypes[i]] = args_type_id[i];
918       }
919     }
920   }
921   return sig_type_map;
922 }
923 
ValueSimpleInfoToString(const ValueSimpleInfo & value_simple_info)924 std::string ValueSimpleInfoToString(const ValueSimpleInfo &value_simple_info) {
925   std::ostringstream buf;
926   buf << "Value simple info element size : " << value_simple_info.size_;
927   for (size_t i = 0; i < value_simple_info.size_; ++i) {
928     buf << ". The " << i << "th shape: " << value_simple_info.shape_vector_[i] << ", dtype "
929         << value_simple_info.dtype_vector_[i];
930     if (!value_simple_info.object_type_vector_.empty()) {
931       buf << ", object type " << value_simple_info.object_type_vector_[i];
932     }
933   }
934   return buf.str();
935 }
936 
TransformValueSimpleInfoToAbstract(const ValueSimpleInfo & value_simple_info)937 abstract::AbstractBasePtr TransformValueSimpleInfoToAbstract(const ValueSimpleInfo &value_simple_info) {
938   if (value_simple_info.size_ < 1) {
939     MS_LOG(EXCEPTION) << "Simple infer info size must greater than 1, but got " << value_simple_info.size_;
940   }
941   abstract::AbstractBasePtr out_abs;
942   if (value_simple_info.size_ == 1 && !value_simple_info.is_tuple_output_) {
943     out_abs = std::make_shared<abstract::AbstractTensor>(value_simple_info.dtype_vector_[kIndex0],
944                                                          value_simple_info.shape_vector_[kIndex0]);
945   } else {
946     AbstractBasePtrList out_abs_list;
947     out_abs_list.resize(value_simple_info.size_);
948     for (size_t i = 0; i < value_simple_info.size_; ++i) {
949       out_abs_list[i] = std::make_shared<abstract::AbstractTensor>(value_simple_info.dtype_vector_[i],
950                                                                    value_simple_info.shape_vector_[i]);
951     }
952     out_abs = std::make_shared<abstract::AbstractTuple>(out_abs_list);
953   }
954   return out_abs;
955 }
956 }  // namespace mindspore
957