• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/abstract_value.h"
20 
21 #include <regex>
22 #include <algorithm>
23 #include <utility>
24 
25 #include "ir/value.h"
26 #include "utils/hash_map.h"
27 #include "utils/hashing.h"
28 #include "utils/log_adapter.h"
29 #include "utils/ms_utils.h"
30 #include "abstract/utils.h"
31 #include "utils/ms_context.h"
32 #include "utils/trace_base.h"
33 #include "utils/compile_config.h"
34 
35 namespace mindspore {
36 namespace abstract {
37 using mindspore::common::IsEqual;
38 
AbstractBase(const ValuePtr & value,const TypePtr & type,const BaseShapePtr & shape)39 AbstractBase::AbstractBase(const ValuePtr &value, const TypePtr &type, const BaseShapePtr &shape)
40     : value_(value), type_(type), shape_(shape) {}
41 
AbstractBase(const AbstractBase & other)42 AbstractBase::AbstractBase(const AbstractBase &other)
43     : Base(other),
44       value_(other.value_),
45       type_(other.type_),
46       shape_(other.shape_),
47       value_desc_(other.value_desc_),
48       name_(other.name_),
49       inplace_abstract_(other.inplace_abstract_) {
50   user_data_ = other.user_data_;
51 }
52 
operator =(const AbstractBase & other)53 AbstractBase &AbstractBase::operator=(const AbstractBase &other) {
54   if (&other != this) {
55     this->value_ = other.value_;
56     this->type_ = other.type_;
57     this->shape_ = other.shape_;
58     this->user_data_ = other.user_data_;
59     this->value_desc_ = other.value_desc_;
60     this->name_ = other.name_;
61     this->inplace_abstract_ = other.inplace_abstract_;
62   }
63   return *this;
64 }
65 
66 AbstractBase::TraceNodeProvider AbstractBase::trace_node_provider_ = nullptr;
67 
JoinSupplementaryInfo(const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)68 std::string JoinSupplementaryInfo(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) {
69   std::ostringstream oss;
70   oss << "#dmsg#Framework Error Message:#dmsg#This: " << abstract1->ToString() << ", other: " << abstract2->ToString();
71   // Get trace info of node.
72   AnfNodePtr node = nullptr;
73   if (AbstractBase::trace_node_provider_ != nullptr) {
74     AbstractBase::trace_node_provider_(&node);
75   }
76   if (node != nullptr) {
77     oss << ". Please check the node: " << node->DebugString() << trace::DumpSourceLines(node);
78   }
79   return oss.str();
80 }
81 
AbstractTypeJoinLogging(const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)82 inline void AbstractTypeJoinLogging(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) {
83   std::ostringstream oss;
84   oss << "Type Join Failed: Abstract type " << abstract1->type_name() << " cannot join with " << abstract2->type_name()
85       << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed\n";
86   oss << JoinSupplementaryInfo(abstract1, abstract2);
87   MS_EXCEPTION(TypeError) << oss.str();
88 }
89 
TypeJoinLogging(const TypePtr & type1,const TypePtr & type2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)90 inline void TypeJoinLogging(const TypePtr &type1, const TypePtr &type2, const AbstractBasePtr &abstract1,
91                             const AbstractBasePtr &abstract2) {
92   std::ostringstream oss;
93   oss << "Type Join Failed: dtype1 = " << type1->ToString() << ", dtype2 = " << type2->ToString()
94       << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed\n";
95   oss << JoinSupplementaryInfo(abstract1, abstract2);
96   MS_EXCEPTION(TypeError) << oss.str();
97 }
98 
ShapeJoinLogging(const BaseShapePtr & shape1,const BaseShapePtr & shape2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)99 inline void ShapeJoinLogging(const BaseShapePtr &shape1, const BaseShapePtr &shape2, const AbstractBasePtr &abstract1,
100                              const AbstractBasePtr &abstract2) {
101   std::ostringstream oss;
102   oss << "Shape Join Failed: shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString()
103       << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Shape%20Join%20Failed\n";
104   oss << JoinSupplementaryInfo(abstract1, abstract2);
105   MS_EXCEPTION(ValueError) << oss.str();
106 }
107 
ExtractLoggingInfo(const std::string & info)108 std::string ExtractLoggingInfo(const std::string &info) {
109   // Extract log information based on the keyword "Type Join Failed" or "Shape Join Failed"
110   std::regex e("(Type Join Failed|Shape Join Failed).*?\n.*?(Type%20Join%20Failed|Shape%20Join%20Failed)");
111   std::smatch result;
112   bool found = std::regex_search(info, result, e);
113   if (found) {
114     return result.str();
115   }
116   return "";
117 }
118 
IsUndeterminedType(const TypePtr & type)119 static inline bool IsUndeterminedType(const TypePtr &type) {
120   return (type != nullptr) && (type->type_id() == kObjectTypeUndeterminedType);
121 }
122 
set_value(const ValuePtr & value)123 void AbstractBase::set_value(const ValuePtr &value) {
124   MS_EXCEPTION_IF_NULL(value);
125   value_ = value;
126 }
127 
set_type(const TypePtr & type)128 void AbstractBase::set_type(const TypePtr &type) {
129   MS_EXCEPTION_IF_NULL(type);
130   type_ = type;
131 }
132 
set_shape(const BaseShapePtr & shape)133 void AbstractBase::set_shape(const BaseShapePtr &shape) {
134   MS_EXCEPTION_IF_NULL(shape);
135   shape_ = shape;
136 }
137 
value_desc() const138 const std::string &AbstractBase::value_desc() const { return value_desc_; }
139 
hash() const140 std::size_t AbstractBase::hash() const { return tid(); }
141 
set_value_desc(const std::string & desc)142 void AbstractBase::set_value_desc(const std::string &desc) { value_desc_ = desc; }
143 
GetTypeTrack() const144 const TypePtr &AbstractBase::GetTypeTrack() const { return type_; }
145 
GetValueTrack() const146 const ValuePtr &AbstractBase::GetValueTrack() const { return value_; }
147 
GetShapeTrack() const148 const BaseShapePtr &AbstractBase::GetShapeTrack() const { return shape_; }
149 
BuildShape() const150 BaseShapePtr AbstractBase::BuildShape() const { return kNoShape; }
151 
GetShape() const152 BaseShapePtr AbstractBase::GetShape() const { return BuildShape(); }
153 
GetType() const154 TypePtr AbstractBase::GetType() const { return BuildType(); }
155 
GetValue() const156 ValuePtr AbstractBase::GetValue() const { return BuildValue(); }
157 
set_trace_node_provider(const TraceNodeProvider & trace_node_provider)158 void AbstractBase::set_trace_node_provider(const TraceNodeProvider &trace_node_provider) {
159   trace_node_provider_ = trace_node_provider;
160 }
161 
Join(const AbstractBasePtr & other)162 AbstractBasePtr AbstractBase::Join(const AbstractBasePtr &other) {
163   MS_EXCEPTION_IF_NULL(other);
164   return shared_from_base<AbstractBase>();
165 }
166 
IsBroaden() const167 bool AbstractBase::IsBroaden() const { return value_->ContainsValueAny(); }
168 
operator ==(const AbstractBase & other) const169 bool AbstractBase::operator==(const AbstractBase &other) const {
170   if (this == &other) {
171     // Same object.
172     return true;
173   }
174   // Check C++ type.
175   if (tid() != other.tid()) {
176     return false;
177   }
178   // If both are "undetermined" type, they are considered equal.
179   if (IsUndeterminedType(BuildType()) && IsUndeterminedType(other.BuildType())) {
180     return true;
181   }
182   // Check data type, shape and value.
183   return IsEqual(type_, other.type_) && IsEqual(shape_, other.shape_) && IsEqual(value_, other.value_);
184 }
185 
BuildValue() const186 ValuePtr AbstractBase::BuildValue() const {
187   if (value_ == nullptr) {
188     return RealBuildValue();
189   }
190   return value_;
191 }
192 
Broaden() const193 AbstractBasePtr AbstractBase::Broaden() const {
194   AbstractBasePtr clone = Clone();
195   MS_EXCEPTION_IF_NULL(clone);
196   clone->set_value(kValueAny);
197   return clone;
198 }
199 
PartialBroaden() const200 AbstractBasePtr AbstractBase::PartialBroaden() const { return Clone(); }
201 
ToString() const202 std::string AbstractBase::ToString() const {
203   std::ostringstream buffer;
204   std::string value = std::string("value is null");
205   if (value_ != nullptr) {
206     value = value_->ToString();
207   }
208   MS_EXCEPTION_IF_NULL(type_);
209   MS_EXCEPTION_IF_NULL(shape_);
210   buffer << type_name() << "("
211          << "Type: " << type_->ToString() << ", Value: " << value << ", Shape: " << shape_->ToString() << ")";
212   return buffer.str();
213 }
214 
ToString(bool verbose) const215 std::string AbstractBase::ToString(bool verbose) const {
216   if (verbose) {
217     return ToString();
218   }
219   std::ostringstream buffer;
220   auto tensor_value = BuildValue();
221   auto shape = GetShape();
222   auto type = BuildType();
223   if (shape != nullptr && type != nullptr) {
224     buffer << type << ", " << shape->ToString();
225     if (tensor_value != nullptr && !tensor_value->ContainsValueAny()) {
226       buffer << ", value=...";
227     }
228   } else if (type != nullptr) {
229     buffer << type;
230     if (tensor_value != nullptr && !tensor_value->ContainsValueAny()) {
231       buffer << ", value=...";
232     }
233   }
234   return buffer.str();
235 }
236 
set_interpret_bool_checker(InterpretBoolChecker checker)237 void AbstractBase::set_interpret_bool_checker(InterpretBoolChecker checker) { interpret_bool_checker_ = checker; }
238 
interpret_bool_checker()239 AbstractBase::InterpretBoolChecker AbstractBase::interpret_bool_checker() { return interpret_bool_checker_; }
240 
set_pyexecute_user_data_catcher(PyExecuteUserDataCatcher catcher)241 void AbstractBase::set_pyexecute_user_data_catcher(PyExecuteUserDataCatcher catcher) {
242   pyexecute_user_data_catcher_ = catcher;
243 }
244 
pyexecute_user_data_catcher()245 AbstractBase::PyExecuteUserDataCatcher AbstractBase::pyexecute_user_data_catcher() {
246   return pyexecute_user_data_catcher_;
247 }
248 
name() const249 std::string AbstractBase::name() const { return name_; }
250 
set_name(const std::string & name)251 void AbstractBase::set_name(const std::string &name) { name_ = name; }
252 
inplace_abstract() const253 AbstractBasePtr AbstractBase::inplace_abstract() const { return inplace_abstract_; }
254 
set_inplace_abstract(const AbstractBasePtr & inplace_abstract)255 void AbstractBase::set_inplace_abstract(const AbstractBasePtr &inplace_abstract) {
256   inplace_abstract_ = inplace_abstract;
257 }
258 
RealBuildValue() const259 ValuePtr AbstractBase::RealBuildValue() const { return kValueAny; }
260 
AbstractScalar()261 AbstractScalar::AbstractScalar() : AbstractBase(kValueAny, kTypeAny) {}
262 
AbstractScalar(const ValuePtr & value,const TypePtr & type)263 AbstractScalar::AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
264 
AbstractScalar(const ValuePtr & value)265 AbstractScalar::AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {}
266 
AbstractScalar(int value)267 AbstractScalar::AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {}
268 
AbstractScalar(int64_t value)269 AbstractScalar::AbstractScalar(int64_t value) : AbstractBase(MakeValue(value), kInt64) {}
270 
AbstractScalar(float value)271 AbstractScalar::AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {}
272 
AbstractScalar(double value)273 AbstractScalar::AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {}
274 
AbstractScalar(bool value)275 AbstractScalar::AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {}
276 
AbstractScalar(const std::string & value)277 AbstractScalar::AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {}
278 
AbstractScalar(const TypePtr & type)279 AbstractScalar::AbstractScalar(const TypePtr &type) : AbstractBase(kValueAny, type) {}
280 
Broaden() const281 AbstractBasePtr AbstractScalar::Broaden() const {
282   if (is_variable_) {
283     return AbstractBase::Broaden();
284   }
285   auto context = MsContext::GetInstance();
286   MS_EXCEPTION_IF_NULL(context);
287   if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
288     return AbstractBase::Broaden();
289   }
290   auto type_id = GetTypeTrack()->type_id();
291   if (type_id == kObjectTypeEnvType) {
292     return AbstractBase::Broaden();
293   }
294   return Clone();
295 }
296 
Join(const AbstractBasePtr & other)297 AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
298   MS_EXCEPTION_IF_NULL(other);
299   if (*this == *other) {
300     return shared_from_base<AbstractBase>();
301   }
302   const auto &this_type = GetTypeTrack();
303   const auto &other_type = other->GetTypeTrack();
304   TypePtr res_type = TypeJoin(this_type, other_type);
305   if (res_type == kTypeAny) {
306     TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
307   }
308   const auto &this_value = GetValueTrack();
309   const auto &value_other = other->GetValueTrack();
310   if (other->isa<AbstractNegligible>() && !this_value->isa<ValueAny>()) {
311     return std::make_shared<AbstractAny>();
312   }
313 
314   ValuePtr res_value = ValueJoin(this_value, value_other);
315   if (res_value == this_value) {
316     return shared_from_base<AbstractBase>();
317   }
318   return std::make_shared<AbstractScalar>(res_value, res_type);
319 }
320 
set_is_variable(bool is_variable)321 void AbstractScalar::set_is_variable(bool is_variable) { is_variable_ = is_variable; }
322 
hash() const323 std::size_t AbstractScalar::hash() const {
324   return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()});
325 }
326 
BuildType() const327 TypePtr AbstractScalar::BuildType() const { return GetTypeTrack(); }
328 
Clone() const329 AbstractBasePtr AbstractScalar::Clone() const {
330   auto abs = std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
331   abs->set_is_variable(is_variable_);
332   abs->SetSymbolicShape(this->GetSymbolicShape());
333   abs->SetSymbolicValue(this->GetSymbolicValue());
334   return abs;
335 }
336 
Clone() const337 AbstractBasePtr AbstractType::Clone() const {
338   ValuePtr this_value = GetValueTrack();
339   if (this_value == nullptr || !this_value->isa<Type>()) {
340     return nullptr;
341   }
342   auto this_type = this_value->cast_ptr<Type>();
343   return std::make_shared<AbstractType>(this_type->Clone());
344 }
345 
operator ==(const AbstractBase & other) const346 bool AbstractType::operator==(const AbstractBase &other) const {
347   if (this == &other) {
348     return true;
349   }
350   return tid() == other.tid() &&
351          IsEqual(dyn_cast_ptr<Type>(GetValueTrack()), dyn_cast_ptr<Type>(other.GetValueTrack()));
352 }
353 
ToString() const354 std::string AbstractType::ToString() const {
355   std::ostringstream buffer;
356   ValuePtr this_value = GetValueTrack();
357   if (this_value == nullptr) {
358     buffer << "AbstractType value: nullptr";
359     return buffer.str();
360   }
361   if (!this_value->isa<Type>()) {
362     buffer << type_name() << "(Value: nullptr)";
363     return buffer.str();
364   }
365   auto this_type = this_value->cast_ptr<Type>();
366   buffer << type_name() << "("
367          << "Value: " << this_type->ToString() << ")";
368   return buffer.str();
369 }
370 
Join(const AbstractBasePtr & other)371 AbstractBasePtr AbstractClass::Join(const AbstractBasePtr &other) {
372   MS_EXCEPTION_IF_NULL(other);
373   bool success = (*this == *other);
374   if (!success) {
375     const auto &this_type = GetTypeTrack();
376     const auto &other_type = other->GetTypeTrack();
377     TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
378   }
379   return shared_from_base<AbstractBase>();
380 }
381 
Clone() const382 AbstractBasePtr AbstractClass::Clone() const { return std::make_shared<AbstractClass>(GetValueTrack()); }
383 
operator ==(const AbstractBase & other) const384 bool AbstractClass::operator==(const AbstractBase &other) const {
385   if (this == &other) {
386     return true;
387   }
388   return tid() == other.tid() && IsEqual(GetValueTrack(), other.GetValueTrack());
389 }
390 
ToString() const391 std::string AbstractClass::ToString() const {
392   std::ostringstream buffer;
393   ValuePtr val = GetValueTrack();
394   MS_EXCEPTION_IF_NULL(val);
395   buffer << type_name() << "(" << val->ToString() << ")";
396   return buffer.str();
397 }
398 
BuildType() const399 TypePtr AbstractType::BuildType() const { return std::make_shared<TypeType>(); }
400 
Broaden() const401 AbstractBasePtr AbstractType::Broaden() const { return Clone(); }
402 
AbstractProblem(const ValueProblemPtr & err,const AnfNodePtr & node)403 AbstractProblem::AbstractProblem(const ValueProblemPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
404   if (err == nullptr || node == nullptr) {
405     MS_LOG(EXCEPTION) << "err or node is nullptr";
406   }
407 }
408 
ToString() const409 std::string AbstractProblem::ToString() const {
410   std::ostringstream buffer;
411   auto value_track = GetValueTrack();
412   MS_EXCEPTION_IF_NULL(value_track);
413   MS_EXCEPTION_IF_NULL(node_);
414   buffer << type_name() << "("
415          << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")";
416   return buffer.str();
417 }
418 
BuildType() const419 TypePtr AbstractProblem::BuildType() const { return std::make_shared<Problem>(); }
420 
Broaden() const421 AbstractBasePtr AbstractProblem::Broaden() const { return Clone(); }
422 
Clone() const423 AbstractBasePtr AbstractProblem::Clone() const {
424   return std::make_shared<AbstractProblem>(GetValueTrack()->cast<ValueProblemPtr>(), node_);
425 }
426 
AbstractScript()427 AbstractScript::AbstractScript() : AbstractBase(kValueAny, kTypeAny) {}
428 
AbstractScript(const ValuePtr & value,const TypePtr & type)429 AbstractScript::AbstractScript(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
430 
AbstractScript(const ValuePtr & value)431 AbstractScript::AbstractScript(const ValuePtr &value) : AbstractBase(value, kString) {}
432 
hash() const433 std::size_t AbstractScript::hash() const {
434   return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()});
435 }
436 
BuildType() const437 TypePtr AbstractScript::BuildType() const { return GetTypeTrack(); }
438 
Clone() const439 AbstractBasePtr AbstractScript::Clone() const {
440   return std::make_shared<AbstractScript>(GetValueTrack(), GetTypeTrack()->Clone());
441 }
442 
Broaden() const443 AbstractBasePtr AbstractScript::Broaden() const { return Clone(); }
444 
Join(const AbstractBasePtr & other)445 AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) {
446   MS_EXCEPTION_IF_NULL(other);
447   if (other->isa<AbstractNegligible>()) {
448     return shared_from_base<AbstractBase>();
449   }
450   auto other_func = dyn_cast<AbstractFunction>(other);
451   if (other_func == nullptr) {
452     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
453   }
454   return Join(other_func);
455 }
456 
operator ==(const AbstractBase & other) const457 bool AbstractFunction::operator==(const AbstractBase &other) const {
458   if (this == &other) {
459     return true;
460   }
461   if (!other.isa<AbstractFunction>()) {
462     return false;
463   }
464   return *this == static_cast<const AbstractFunction &>(other);
465 }
466 
BuildType() const467 TypePtr AbstractFunction::BuildType() const { return std::make_shared<Function>(); }
468 
Clone() const469 AbstractBasePtr AbstractFunction::Clone() const { return Copy(); }
470 
Broaden() const471 AbstractBasePtr AbstractFunction::Broaden() const {
472   return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
473 }
474 
tracking_id() const475 std::uintptr_t AbstractFunction::tracking_id() const { return 0; }
476 
CopyWithoutTrackingId() const477 AbstractFunctionPtr AbstractFunction::CopyWithoutTrackingId() const { return Copy(); }
478 
context() const479 AnalysisContextPtr AbstractFunction::context() const { return nullptr; }
480 
ToTrackingId(const AnfNodePtr & node)481 std::uintptr_t AbstractFunction::ToTrackingId(const AnfNodePtr &node) {
482   return reinterpret_cast<std::uintptr_t>(node.get());
483 }
484 
AbstractKeywordArg(const std::string & key,const AbstractBasePtr & argument)485 AbstractKeywordArg::AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument)
486     : arg_name_(key), arg_value_(argument) {}
487 
get_key() const488 std::string AbstractKeywordArg::get_key() const { return arg_name_; }
489 
get_arg() const490 AbstractBasePtr AbstractKeywordArg::get_arg() const { return arg_value_; }
491 
AbstractUndetermined()492 AbstractUndetermined::AbstractUndetermined() : AbstractBase(kValueAny) {}
493 
AbstractUndetermined(const AbstractBasePtr & element,const BaseShapePtr & shape)494 AbstractUndetermined::AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape)
495     : AbstractBase(kValueAny), element_(element) {
496   if (element == nullptr) {
497     MS_LOG(EXCEPTION) << "element is nullptr";
498   }
499   if (element->isa<AbstractUndetermined>()) {
500     MS_LOG(EXCEPTION) << "element type error";
501   }
502   MS_EXCEPTION_IF_NULL(shape);
503   if (shape->isa<NoShape>()) {
504     MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
505   }
506   AbstractBase::set_shape(shape);
507 }
508 
AbstractUndetermined(const TypePtr & element_type,const ShapeVector & shape)509 AbstractUndetermined::AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape)
510     : AbstractBase(kValueAny), element_(std::make_shared<AbstractScalar>(kValueAny, element_type)) {
511   if (element_type == nullptr) {
512     MS_LOG(EXCEPTION) << "element_type is nullptr";
513   }
514   AbstractBase::set_shape(std::make_shared<Shape>(shape));
515 }
516 
AbstractUndetermined(const TypePtr & element_type,const BaseShapePtr & shape)517 AbstractUndetermined::AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape)
518     : AbstractBase(kValueAny), element_(std::make_shared<AbstractScalar>(kValueAny, element_type)) {
519   if (element_type == nullptr) {
520     MS_LOG(EXCEPTION) << "element_type is nullptr";
521   }
522   MS_EXCEPTION_IF_NULL(shape);
523   if (shape->isa<NoShape>()) {
524     MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
525   }
526   AbstractBase::set_shape(shape);
527 }
528 
BuildType() const529 TypePtr AbstractUndetermined::BuildType() const { return std::make_shared<UndeterminedType>(); }
530 
Clone() const531 AbstractBasePtr AbstractUndetermined::Clone() const { return std::make_shared<AbstractUndetermined>(); }
532 
element() const533 AbstractBasePtr AbstractUndetermined::element() const { return element_; }
534 
AbstractTensor(const AbstractBasePtr & element,const BaseShapePtr & shape)535 AbstractTensor::AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape)
536     : AbstractUndetermined(element, shape) {}
537 
AbstractTensor(const TypePtr & element_type,const ShapeVector & shape)538 AbstractTensor::AbstractTensor(const TypePtr &element_type, const ShapeVector &shape)
539     : AbstractUndetermined(element_type, shape) {}
540 
AbstractTensor(const tensor::TensorPtr & tensor)541 AbstractTensor::AbstractTensor(const tensor::TensorPtr &tensor)
542     : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {}
543 
AbstractTensor(const TypePtr & element_type,const BaseShapePtr & shape)544 AbstractTensor::AbstractTensor(const TypePtr &element_type, const BaseShapePtr &shape)
545     : AbstractUndetermined(element_type, shape) {}
546 
hash() const547 std::size_t AbstractTensor::hash() const {
548   // We have to exclude value pointer from hash, because CSE (Common Subexpression Elimination)
549   // will use this hash to find duplicate ValueNodes that Tensor values are equal.
550   auto hash_sum = hash_combine(tid(), element_->hash());
551   const auto &shape = GetShapeTrack();
552   if (shape != nullptr) {
553     hash_sum = hash_combine(hash_sum, shape->hash());
554   }
555   return hash_sum;
556 }
557 
is_adapter() const558 bool AbstractTensor::is_adapter() const { return is_adapter_; }
559 
set_is_adapter(bool is_adapter)560 void AbstractTensor::set_is_adapter(bool is_adapter) { is_adapter_ = is_adapter; }
561 
AbstractAny()562 AbstractAny::AbstractAny()
563     : AbstractTensor(DefaultDtype(), std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny}))) {}
564 
Join(const AbstractBasePtr & other)565 AbstractBasePtr AbstractAny::Join(const AbstractBasePtr &other) {
566   MS_EXCEPTION_IF_NULL(other);
567   return std::make_shared<AbstractAny>();
568 }
569 
Broaden() const570 AbstractBasePtr AbstractAny::Broaden() const { return Clone(); }
571 
Clone() const572 AbstractBasePtr AbstractAny::Clone() const {
573   auto any_abstract = std::make_shared<AbstractAny>();
574   if (supposed_tensor_dtype()) {
575     MS_EXCEPTION_IF_NULL(element());
576     const auto &dtype = element()->BuildType();
577     MS_EXCEPTION_IF_NULL(any_abstract->element());
578     any_abstract->element()->set_type(dtype);
579     any_abstract->set_supposed_tensor_dtype(true);
580   }
581   return any_abstract;
582 }
583 
ToString() const584 std::string AbstractAny::ToString() const { return type_name(); }
585 
supposed_tensor_dtype() const586 bool AbstractAny::supposed_tensor_dtype() const { return supposed_tensor_dtype_; }
587 
set_supposed_tensor_dtype(bool flag)588 void AbstractAny::set_supposed_tensor_dtype(bool flag) { supposed_tensor_dtype_ = flag; }
589 
DefaultDtype()590 TypePtr AbstractAny::DefaultDtype() { return kFloat64; }
591 
message() const592 const std::string &AbstractJoinedAny::message() const { return message_; }
593 
set_message(const std::string & message)594 void AbstractJoinedAny::set_message(const std::string &message) { message_ = message; }
595 
exception() const596 AbstractJoinedAny::ExceptionType AbstractJoinedAny::exception() const { return exception_; }
597 
set_exception(ExceptionType exception)598 void AbstractJoinedAny::set_exception(ExceptionType exception) { exception_ = exception; }
599 
ThrowException() const600 void AbstractJoinedAny::ThrowException() const {
601   if (exception_ == kTypeError) {
602     MS_EXCEPTION(TypeError) << message_;
603   } else if (exception_ == kValueError) {
604     MS_EXCEPTION(ValueError) << message_;
605   } else {
606     MS_LOG(EXCEPTION) << message_;
607   }
608 }
609 
610 namespace {
CollectSequenceNodes(const AnfNodeWeakPtrList & source_sequence_nodes,AnfNodeWeakPtrList * sequence_nodes_ptr)611 void CollectSequenceNodes(const AnfNodeWeakPtrList &source_sequence_nodes, AnfNodeWeakPtrList *sequence_nodes_ptr) {
612   AnfNodeWeakPtrList &sequence_nodes = *sequence_nodes_ptr;
613   auto sequence_nodes_size = source_sequence_nodes.size();
614   for (size_t i = 0; i < sequence_nodes_size; ++i) {
615     // Lock sequence nodes of this.
616     auto &source_weak_node = source_sequence_nodes[i];
617     auto source_sequence_node = source_weak_node.lock();
618     if (source_sequence_node == nullptr) {
619       continue;
620     }
621     // Check and emplace sequence node for this.
622     auto this_iter = std::find_if(
623       sequence_nodes.begin(), sequence_nodes.end(),
624       [&source_sequence_node](const AnfNodeWeakPtr &weak_node) { return source_sequence_node == weak_node.lock(); });
625     if (this_iter == sequence_nodes.end()) {
626       (void)sequence_nodes.emplace_back(AnfNodeWeakPtr(source_sequence_node));
627     }
628   }
629 }
630 
SynchronizeSequenceNodesElementsUseFlagsInner(const AnfNodeWeakPtrList & sequence_nodes)631 void SynchronizeSequenceNodesElementsUseFlagsInner(const AnfNodeWeakPtrList &sequence_nodes) {
632   // Choose the candidate sequence node, that we use its flags as unique one.
633   AnfNodePtr candidate_sequence_node = sequence_nodes[0].lock();
634   MS_EXCEPTION_IF_NULL(candidate_sequence_node);
635   size_t candidate_index = 0;
636   for (size_t i = 1; i < sequence_nodes.size(); ++i) {
637     auto current_sequence_node = sequence_nodes[i].lock();
638     MS_EXCEPTION_IF_NULL(current_sequence_node);
639     if (candidate_sequence_node == current_sequence_node) {
640       continue;
641     }
642     auto candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
643     MS_EXCEPTION_IF_NULL(candidate_flags);
644     auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node);
645     MS_EXCEPTION_IF_NULL(current_flags);
646     if (candidate_flags == current_flags) {
647       continue;
648     }
649 
650     // Find the sequence node whose flags are most used.
651     auto candidate_count = candidate_flags.use_count();
652     auto current_count = current_flags.use_count();
653     if (candidate_count < current_count) {
654       candidate_sequence_node = current_sequence_node;
655       candidate_index = i;
656     }
657   }
658 
659   // Synchronize the elements use flags for all sequence nodes with candidate sequence node.
660   // We set the same 'elements_use_flags' for them after here.
661   auto candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
662   MS_LOG(DEBUG) << "Sequence nodes size: " << sequence_nodes.size() << ", candidate node index: " << candidate_index
663                 << ", candidate node: " << candidate_sequence_node->DebugString() << ", flags: " << candidate_flags;
664   for (size_t i = 0; i < sequence_nodes.size(); ++i) {
665     auto current_sequence_node = sequence_nodes[i].lock();
666     MS_EXCEPTION_IF_NULL(current_sequence_node);
667     if (candidate_sequence_node == current_sequence_node) {
668       continue;
669     }
670     auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node);
671     if (candidate_flags == current_flags) {
672       continue;
673     }
674 
675     // Merge the use flags, set true if either is true.
676     for (size_t j = 0; j < candidate_flags->size(); ++j) {
677       MS_LOG(DEBUG) << "Check elements_use_flags[" << j << "], this_flag: " << (*candidate_flags)[j]
678                     << ", other_flag: " << (*current_flags)[j];
679       (*candidate_flags)[j] = ((*candidate_flags)[j] || (*current_flags)[j]);
680     }
681     // Use the candidate sequence node flags.
682     SetSequenceNodeElementsUseFlags(current_sequence_node, candidate_flags);
683     MS_LOG(DEBUG) << "Reset flags for sequence node[" << i << "]: " << current_sequence_node->DebugString()
684                   << ", flags: " << candidate_flags;
685   }
686 }
687 
CheckSequenceNodesValid(const AnfNodeWeakPtrList & sequence_nodes)688 void CheckSequenceNodesValid(const AnfNodeWeakPtrList &sequence_nodes) {
689   if (!IS_OUTPUT_ON(MsLogLevel::kDebug)) {
690     return;
691   }
692   if (sequence_nodes.size() <= 1) {
693     return;
694   }
695   AnfNodePtr candidate_sequence_node = sequence_nodes[0].lock();
696   if (candidate_sequence_node == nullptr) {
697     MS_LOG(DEBUG) << "candidate_sequence_node is null.";
698     return;
699   }
700   auto candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
701   if (candidate_flags == nullptr) {
702     MS_LOG(DEBUG) << "The candidate_flags is null, sequence_nodes[0]: " << candidate_sequence_node->DebugString();
703     return;
704   }
705   for (size_t i = 0; i < sequence_nodes.size(); ++i) {
706     auto current_sequence_node = sequence_nodes[i].lock();
707     if (current_sequence_node == nullptr) {
708       MS_LOG(DEBUG) << "current_sequence_node is null.";
709       return;
710     }
711     MS_LOG(DEBUG) << "sequence_nodes[" << i << "]: " << current_sequence_node << "/"
712                   << current_sequence_node->DebugString()
713                   << ", flags: " << GetSequenceNodeElementsUseFlags(current_sequence_node);
714   }
715   for (size_t i = 1; i < sequence_nodes.size(); ++i) {
716     auto current_sequence_node = sequence_nodes[i].lock();
717     if (current_sequence_node == nullptr) {
718       MS_LOG(DEBUG) << "current_sequence_node is null.";
719       return;
720     }
721     if (candidate_sequence_node == current_sequence_node) {
722       continue;
723     }
724     candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
725     MS_EXCEPTION_IF_NULL(candidate_flags);
726     auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node);
727     MS_EXCEPTION_IF_NULL(current_flags);
728     if (candidate_flags == current_flags) {
729       continue;
730     }
731     MS_LOG(DEBUG) << "Should use same flags pointer, candidate_node: " << candidate_sequence_node->DebugString()
732                   << ", current_node: " << current_sequence_node->DebugString();
733 
734     if (candidate_flags->size() != current_flags->size()) {
735       MS_LOG(INTERNAL_EXCEPTION) << "Flag not same size";
736     }
737     for (size_t j = 0; j < candidate_flags->size(); ++j) {
738       if ((*candidate_flags)[j] != (*current_flags)[j]) {
739         MS_LOG(INTERNAL_EXCEPTION) << "Not equal elements_use_flags[" << j << "], this_flag: " << (*candidate_flags)[j]
740                                    << ", other_flag: " << (*current_flags)[j];
741       }
742     }
743   }
744 }
745 
SynchronizeSequenceNodesElementsUseFlags(const AnfNodeWeakPtrList & lhs_sequence_nodes,const AnfNodeWeakPtrList & rhs_sequence_nodes)746 AnfNodeWeakPtrList SynchronizeSequenceNodesElementsUseFlags(const AnfNodeWeakPtrList &lhs_sequence_nodes,
747                                                             const AnfNodeWeakPtrList &rhs_sequence_nodes) {
748   // Collect this and other sequence nodes.
749   AnfNodeWeakPtrList sequence_nodes;
750   CollectSequenceNodes(lhs_sequence_nodes, &sequence_nodes);
751   CollectSequenceNodes(rhs_sequence_nodes, &sequence_nodes);
752   if (sequence_nodes.size() <= 1) {
753     MS_LOG(DEBUG) << "Sequence nodes size should exceed 1.";
754     return sequence_nodes;
755   }
756   // Synchronize the elements use flags for all sequence nodes.
757   SynchronizeSequenceNodesElementsUseFlagsInner(sequence_nodes);
758   CheckSequenceNodesValid(sequence_nodes);
759   return sequence_nodes;
760 }
761 
AbstractCanJoin(const AbstractBasePtr & abs1,const AbstractBasePtr & abs2)762 bool AbstractCanJoin(const AbstractBasePtr &abs1, const AbstractBasePtr &abs2) {
763   try {
764     MS_LOG_TRY_CATCH_SCOPE;
765     (void)abs1->Join(abs2);
766   } catch (std::exception &) {
767     return false;
768   }
769   return true;
770 }
771 
CheckElementAbstractSame(const AbstractBasePtr & first_element,const AbstractBasePtr & cur_element,const size_t i,bool raise_exception)772 bool CheckElementAbstractSame(const AbstractBasePtr &first_element, const AbstractBasePtr &cur_element, const size_t i,
773                               bool raise_exception) {
774   MS_EXCEPTION_IF_NULL(first_element);
775   MS_EXCEPTION_IF_NULL(cur_element);
776   if (first_element->isa<abstract::AbstractAny>() || cur_element->isa<abstract::AbstractAny>()) {
777     return true;
778   }
779   auto first_element_type_id = first_element->BuildType()->generic_type_id();
780   auto cur_element_type_id = cur_element->BuildType()->generic_type_id();
781   if (first_element_type_id != cur_element_type_id) {
782     if (!raise_exception) {
783       return false;
784     }
785     MS_EXCEPTION(ValueError) << "In graph mode, the element type of dynamic length array must be the same."
786                              << "The element type do not match, can not convert to dynamic length sequence. "
787                              << "The 0th element type is: " << TypeIdToString(first_element_type_id) << ". The " << i
788                              << "th element type is: " << TypeIdToString(cur_element_type_id);
789   }
790   auto first_element_shape = first_element->GetShape();
791   MS_EXCEPTION_IF_NULL(first_element_shape);
792   auto cur_element_shape = cur_element->GetShape();
793   MS_EXCEPTION_IF_NULL(cur_element_shape);
794   if (*first_element_shape != *cur_element_shape) {
795     return false;
796   }
797   if (!AbstractCanJoin(first_element, cur_element)) {
798     if (!raise_exception) {
799       return false;
800     }
801     MS_EXCEPTION(TypeError) << "In graph mode, the element shape of dynamic length array must be the same."
802                             << "The element do not match, can not convert to dynamic length sequence. "
803                             << "The 0th element is: " << first_element->ToString() << ". The " << i
804                             << "th element shape is: " << cur_element->ToString();
805   }
806   return true;
807 }
808 }  // namespace
809 
AbstractSequence(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & sequence_nodes)810 AbstractSequence::AbstractSequence(AbstractBasePtrList &&elements,
811                                    const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes)
812     : elements_(std::move(elements)), sequence_nodes_(sequence_nodes) {
813   if (sequence_nodes != nullptr) {
814     CheckSequenceNodesValid(*sequence_nodes);
815   }
816 }
817 
AbstractSequence(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & sequence_nodes)818 AbstractSequence::AbstractSequence(const AbstractBasePtrList &elements,
819                                    const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes)
820     : elements_(elements), sequence_nodes_(sequence_nodes) {
821   if (sequence_nodes != nullptr) {
822     CheckSequenceNodesValid(*sequence_nodes);
823   }
824 }
825 
operator [](const std::size_t & dim) const826 const AbstractBasePtr AbstractSequence::operator[](const std::size_t &dim) const {
827   if (dynamic_len_) {
828     MS_LOG(EXCEPTION) << "Can not get element from dynamic length sequence " << ToString();
829   }
830   if (dim >= size()) {
831     MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list.";
832   }
833   return elements_[dim];
834 }
835 
ToStringInternal() const836 std::string AbstractSequence::ToStringInternal() const {
837   std::ostringstream buffer;
838   if (dynamic_len_) {
839     buffer << "dynamic_len_element_abs: "
840            << (dynamic_len_element_abs_ == nullptr ? "nullptr" : dynamic_len_element_abs_->ToString());
841     return buffer.str();
842   }
843   size_t i = 0;
844   size_t size = elements_.size();
845   auto prefix_space = std::string(space_num_, ' ');
846   for (const auto &element : elements_) {
847     MS_EXCEPTION_IF_NULL(element);
848     if (element->isa<AbstractSequence>()) {
849       constexpr auto kIndent = 4;
850       element->cast<AbstractSequencePtr>()->space_num_ = space_num_ + kIndent;
851     }
852     buffer << "\n" << prefix_space << "element[" << i << "]: " << element->ToString();
853     if (i < size - 1) {
854       buffer << ", ";
855     }
856     i++;
857   }
858   return buffer.str();
859 }
860 
ToString() const861 std::string AbstractSequence::ToString() const {
862   std::stringstream ss;
863   ss << "\n";
864   ss << type_name();
865   ss << "{";
866   ss << ToStringInternal();
867   if (!dynamic_len_ && sequence_nodes() != nullptr && !sequence_nodes()->empty()) {
868     ss << ", " << std::string(space_num_, ' ') << "sequence_nodes: {";
869     for (size_t i = 0; i < sequence_nodes()->size(); ++i) {
870       auto sequence_node = (*sequence_nodes())[i].lock();
871       if (sequence_node == nullptr) {
872         ss << "<freed node>";
873         continue;
874       } else {
875         ss << sequence_node->DebugString();
876       }
877       auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
878       if (flags != nullptr) {
879         ss << ", elements_use_flags: {ptr: " << flags << ", value: " << (*flags) << "}";
880       }
881       if (i != sequence_nodes()->size() - 1) {
882         ss << ", ";
883       }
884     }
885     ss << "}";
886   }
887   ss << ", dynamic_len:" << dynamic_len_ << ", is dyn arg:" << dyn_len_arg_;
888   ss << "}";
889   ss << "\n";
890   return ss.str();
891 }
892 
ToString(bool verbose) const893 std::string AbstractSequence::ToString(bool verbose) const {
894   if (verbose) {
895     return ToString();
896   }
897   std::ostringstream buffer;
898   size_t i = 0;
899   size_t size = elements_.size();
900   buffer << type_name() << " {";
901   for (const auto &element : elements_) {
902     MS_EXCEPTION_IF_NULL(element);
903     buffer << element->ToString(false);
904     if (i < size - 1) {
905       buffer << ", ";
906     }
907     i++;
908   }
909   buffer << "}";
910   return buffer.str();
911 }
912 
SequenceNodesJoin(const AbstractBasePtr & other)913 AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &other) {
914   AnfNodeWeakPtrList sequence_nodes;
915   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
916   if (!enable_eliminate_unused_element || this->sequence_nodes() == nullptr) {
917     return sequence_nodes;
918   }
919   auto other_sequence = dyn_cast_ptr<AbstractSequence>(other);
920   if (other_sequence == nullptr) {
921     return sequence_nodes;
922   }
923   auto this_sequence_nodes_size = (this->sequence_nodes() == nullptr ? 0 : this->sequence_nodes()->size());
924   auto other_sequence_nodes_size =
925     (other_sequence->sequence_nodes() == nullptr ? 0 : other_sequence->sequence_nodes()->size());
926   // The tuple or list output which has sequence_nodes may be joined with a tuple output node like top_k,
927   // we should return the branch which has sequence_nodes.
928   if (this_sequence_nodes_size == 0 && other_sequence_nodes_size == 0) {
929     return sequence_nodes;
930   } else if (this_sequence_nodes_size == 0) {
931     return *(other_sequence->sequence_nodes());
932   } else if (other_sequence_nodes_size == 0) {
933     return *(this->sequence_nodes());
934   }
935   // Collect this and other sequence nodes.
936   if (this->sequence_nodes() != nullptr) {
937     CollectSequenceNodes(*this->sequence_nodes(), &sequence_nodes);
938   }
939   if (other_sequence->sequence_nodes() != nullptr) {
940     CollectSequenceNodes(*other_sequence->sequence_nodes(), &sequence_nodes);
941   }
942   if (sequence_nodes.empty()) {
943     MS_LOG(INFO) << "Sequence nodes size should not be empty.";
944     return sequence_nodes;
945   }
946   // Synchronize the elements use flags for all sequence nodes.
947   SynchronizeSequenceNodesElementsUseFlagsInner(sequence_nodes);
948 
949   CheckSequenceNodesValid(sequence_nodes);
950   this->InsertSequenceNodes(sequence_nodes);
951   other_sequence->InsertSequenceNodes(sequence_nodes);
952   return sequence_nodes;
953 }
954 
SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr & lhs_sequence,const AbstractSequencePtr & rhs_sequence)955 void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence,
956                                                     const AbstractSequencePtr &rhs_sequence) {
957   if (lhs_sequence->sequence_nodes() == nullptr || rhs_sequence->sequence_nodes() == nullptr) {
958     return;
959   }
960   auto sequence_nodes =
961     SynchronizeSequenceNodesElementsUseFlags(*lhs_sequence->sequence_nodes(), *rhs_sequence->sequence_nodes());
962   lhs_sequence->InsertSequenceNodes(sequence_nodes);
963   rhs_sequence->InsertSequenceNodes(sequence_nodes);
964   if (lhs_sequence->elements().size() != rhs_sequence->elements().size()) {
965     MS_LOG(INTERNAL_EXCEPTION) << "The elements size should be equal, " << lhs_sequence->ToString() << ", "
966                                << rhs_sequence->ToString();
967   }
968   for (size_t i = 0; i < lhs_sequence->elements().size(); ++i) {
969     auto lhs_inner_sequence = dyn_cast<AbstractSequence>(lhs_sequence->elements()[i]);
970     if (lhs_inner_sequence == nullptr) {
971       continue;
972     }
973     auto rhs_inner_sequence = dyn_cast<AbstractSequence>(rhs_sequence->elements()[i]);
974     if (rhs_inner_sequence == nullptr) {
975       continue;
976     }
977     SynchronizeSequenceElementsUseFlagsRecursively(lhs_inner_sequence, rhs_inner_sequence);
978   }
979 }
980 
InsertSequenceNodes(const AnfNodeWeakPtrList & sequence_nodes)981 void AbstractSequence::InsertSequenceNodes(const AnfNodeWeakPtrList &sequence_nodes) {
982   if (dynamic_len_) {
983     MS_LOG(INTERNAL_EXCEPTION) << "Can not insert sequence nodes for dynamic length sequence " << ToString();
984   }
985   if (sequence_nodes_ == nullptr) {
986     MS_LOG(DEBUG) << "The sequence_nodes is null.";
987     sequence_nodes_ = std::make_shared<AnfNodeWeakPtrList>();
988   }
989   for (auto &weak_node : sequence_nodes) {
990     auto sequence_node = weak_node.lock();
991     InsertSequenceNode(sequence_node);
992   }
993 }
994 
InsertSequenceNode(const AnfNodePtr & sequence_node)995 void AbstractSequence::InsertSequenceNode(const AnfNodePtr &sequence_node) {
996   if (dynamic_len_) {
997     MS_LOG(INTERNAL_EXCEPTION) << "Can not insert sequence node for dynamic length sequence " << ToString();
998   }
999   if (sequence_nodes_ == nullptr) {
1000     MS_LOG(DEBUG) << "The sequence_nodes is null.";
1001     sequence_nodes_ = std::make_shared<AnfNodeWeakPtrList>();
1002   }
1003   auto iter =
1004     std::find_if(sequence_nodes_->begin(), sequence_nodes_->end(),
1005                  [&sequence_node](const AnfNodeWeakPtr &weak_node) { return sequence_node == weak_node.lock(); });
1006   if (iter == sequence_nodes_->end()) {
1007     (void)sequence_nodes_->emplace_back(sequence_node);
1008     CheckSequenceNodesValid(*sequence_nodes_);
1009   } else {
1010     MS_LOG(DEBUG) << "Fail to insert node \'" << sequence_node->DebugString() << "\' into sequence nodes.";
1011   }
1012 }
1013 
UpdateSequenceNode(const AnfNodePtr & old_sequence_node,const AnfNodePtr & new_sequence_node)1014 void AbstractSequence::UpdateSequenceNode(const AnfNodePtr &old_sequence_node, const AnfNodePtr &new_sequence_node) {
1015   if (dynamic_len_) {
1016     MS_LOG(INTERNAL_EXCEPTION) << "Can not update sequence node for dynamic length sequence " << ToString();
1017   }
1018   if (sequence_nodes_ == nullptr) {
1019     MS_LOG(DEBUG) << "The sequence_nodes is null.";
1020     sequence_nodes_ = std::make_shared<AnfNodeWeakPtrList>();
1021   }
1022   auto iter = std::find_if(
1023     sequence_nodes_->begin(), sequence_nodes_->end(),
1024     [&old_sequence_node](const AnfNodeWeakPtr &weak_node) { return old_sequence_node == weak_node.lock(); });
1025   if (iter != sequence_nodes_->end()) {
1026     *iter = new_sequence_node;
1027     CheckSequenceNodesValid(*sequence_nodes_);
1028     return;
1029   }
1030   MS_LOG(INTERNAL_EXCEPTION) << "Not found old node \'" << old_sequence_node->DebugString() << "\' in sequence nodes.";
1031 }
1032 
PurifyElements()1033 bool AbstractSequence::PurifyElements() {
1034   if (dynamic_len_ || sequence_nodes_ == nullptr || sequence_nodes_->empty()) {
1035     return true;
1036   }
1037   // Just use any sequence node's elements_use_flags.
1038   AnfNodePtr not_free_node = nullptr;
1039   std::shared_ptr<std::vector<bool>> elements_use_flags_ptr = nullptr;
1040   for (auto &weak_node : *sequence_nodes_) {
1041     auto sequence_node = weak_node.lock();
1042     if (sequence_node == nullptr) {
1043       MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
1044       continue;
1045     }
1046     not_free_node = sequence_node;
1047     auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
1048     if (flags != nullptr) {
1049       elements_use_flags_ptr = flags;
1050       break;
1051     }
1052   }
1053   if (elements_use_flags_ptr == nullptr) {
1054     if (not_free_node == nullptr) {
1055       MS_LOG(INFO) << "Check if all sequence nodes are released, or none elements use flags in them. nodes size: "
1056                    << sequence_nodes_->size();
1057     } else {
1058       MS_LOG(INFO) << "Check if none elements use flags in sequence ndoes. one of node: "
1059                    << not_free_node->DebugString();
1060     }
1061     return false;
1062   }
1063 
1064   // Purify the elements.
1065   auto &elements_use_flags = *elements_use_flags_ptr;
1066   if (elements_use_flags.size() < elements_.size()) {
1067     MS_LOG(INTERNAL_EXCEPTION) << "Elements size should not be greater to elements use flags size. " << ToString();
1068   }
1069   for (size_t i = 0; i < elements_.size(); ++i) {
1070     MS_EXCEPTION_IF_NULL(elements_[i]);
1071     if (!elements_use_flags[i]) {
1072       const auto unuse_node_none = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(0));
1073       if (elements_[i]->isa<AbstractProblem>()) {
1074         unuse_node_none->set_type(std::make_shared<Problem>());
1075       }
1076       elements_[i] = unuse_node_none;
1077       MS_LOG(DEBUG) << "Erase elements[" << i << "] abstract as Zero for " << ToString();
1078     } else {
1079       MS_LOG(DEBUG) << "Keep elements[" << i << "] abstract as " << elements_[i]->ToString();
1080     }
1081   }
1082   return true;
1083 }
1084 
1085 // Convert self from a fixed length sequence to dynamic length sequence.
CheckAndConvertToDynamicLenSequence(bool raise_exception)1086 void AbstractSequence::CheckAndConvertToDynamicLenSequence(bool raise_exception) {
1087   // Can not use size() since it will raise error when sequence is already dynamic length.
1088   const size_t input_len = elements_.size();
1089   if (input_len > 1) {
1090     auto first_element = elements()[0];
1091     MS_EXCEPTION_IF_NULL(first_element);
1092     for (size_t i = 1; i < input_len; ++i) {
1093       auto cur_element = elements()[i];
1094       MS_EXCEPTION_IF_NULL(cur_element);
1095       bool ret = CheckElementAbstractSame(first_element, cur_element, i, raise_exception);
1096       if (!ret) {
1097         return;
1098       }
1099     }
1100     set_dynamic_len_element_abs(first_element);
1101   } else if (input_len == 1) {
1102     set_dynamic_len_element_abs(elements()[0]);
1103   }
1104   set_value(kValueAny);
1105   // Set sequence nodes to nulltpr to disable DDE.
1106   sequence_nodes_ = nullptr;
1107   set_dynamic_len(true);
1108 }
1109 
1110 // Convert self's cloned abstract from a fixed length sequence to dynamic length sequence, just like tensor broaden.
BroadenToDynamicLenSequence()1111 AbstractSequencePtr AbstractSequence::BroadenToDynamicLenSequence() {
1112   if (dynamic_len()) {
1113     return shared_from_base<AbstractSequence>();
1114   }
1115   if (isa<AbstractSparseTensor>()) {
1116     return shared_from_base<AbstractSequence>();
1117   }
1118   auto clone_sequence = Clone()->cast<AbstractSequencePtr>();
1119   clone_sequence->CheckAndConvertToDynamicLenSequence(false);
1120   if (clone_sequence->dynamic_len()) {
1121     set_dyn_len_arg();
1122     // Set all sequence inputs as used if trans from fixed length to dynamic length.
1123     SetSequenceElementsUseFlagsRecursively(shared_from_base<AbstractSequence>(), true);
1124     return clone_sequence;
1125   }
1126   // Convert to dynamic len failed, return original abstract.
1127   return shared_from_base<AbstractSequence>();
1128 }
1129 
ElementsType() const1130 TypePtrList AbstractSequence::ElementsType() const {
1131   TypePtrList element_type_list;
1132   for (const auto &element : elements_) {
1133     MS_EXCEPTION_IF_NULL(element);
1134     TypePtr element_type = element->BuildType();
1135     element_type_list.push_back(element_type);
1136   }
1137   return element_type_list;
1138 }
1139 
ElementsShape() const1140 BaseShapePtrList AbstractSequence::ElementsShape() const {
1141   BaseShapePtrList element_shape_list;
1142   for (const auto &element : elements_) {
1143     MS_EXCEPTION_IF_NULL(element);
1144     BaseShapePtr element_shape = element->GetShape();
1145     element_shape_list.push_back(element_shape);
1146   }
1147   return element_shape_list;
1148 }
1149 
ElementsClone() const1150 AbstractBasePtrList AbstractSequence::ElementsClone() const {
1151   AbstractBasePtrList element_list;
1152   for (const auto &element : elements_) {
1153     MS_EXCEPTION_IF_NULL(element);
1154     AbstractBasePtr clone = element->Clone();
1155     element_list.push_back(clone);
1156   }
1157   return element_list;
1158 }
1159 
ElementsBroaden() const1160 AbstractBasePtrList AbstractSequence::ElementsBroaden() const {
1161   AbstractBasePtrList element_list;
1162   for (const auto &element : elements_) {
1163     MS_EXCEPTION_IF_NULL(element);
1164     AbstractBasePtr broadend = element->Broaden();
1165     element_list.push_back(broadend);
1166   }
1167   return element_list;
1168 }
1169 
ElementsPartialBroaden() const1170 AbstractBasePtrList AbstractSequence::ElementsPartialBroaden() const {
1171   AbstractBasePtrList element_list;
1172   for (const auto &element : elements_) {
1173     MS_EXCEPTION_IF_NULL(element);
1174     AbstractBasePtr broadend = element->PartialBroaden();
1175     element_list.push_back(broadend);
1176   }
1177   return element_list;
1178 }
1179 
GetValueFromUserData(const AbstractBasePtr & element_abs)1180 std::pair<bool, ValuePtr> GetValueFromUserData(const AbstractBasePtr &element_abs) {
1181   MS_EXCEPTION_IF_NULL(element_abs);
1182   if (abstract::AbstractBase::pyexecute_user_data_catcher()) {
1183     return abstract::AbstractBase::pyexecute_user_data_catcher()(element_abs);
1184   }
1185   return {false, nullptr};
1186 }
1187 
1188 template <typename T>
ElementsBuildValue() const1189 ValuePtr AbstractSequence::ElementsBuildValue() const {
1190   std::vector<ValuePtr> element_value_list;
1191   for (const auto &element : elements_) {
1192     MS_EXCEPTION_IF_NULL(element);
1193     auto [has_user_data, element_value] = GetValueFromUserData(element);
1194     if (has_user_data && element_value != nullptr) {
1195       element_value_list.push_back(element_value);
1196       continue;
1197     }
1198     element_value = element->BuildValue();
1199     MS_EXCEPTION_IF_NULL(element_value);
1200     if (element_value->isa<ValueAny>()) {
1201       return kValueAny;
1202     }
1203     element_value_list.push_back(element_value);
1204   }
1205   return std::make_shared<T>(element_value_list);
1206 }
1207 template MS_CORE_API ValuePtr AbstractSequence::ElementsBuildValue<ValueTuple>() const;
1208 template MS_CORE_API ValuePtr AbstractSequence::ElementsBuildValue<ValueList>() const;
1209 
1210 template <typename T>
ElementsJoin(const AbstractSequencePtr & other)1211 AbstractBasePtr AbstractSequence::ElementsJoin(const AbstractSequencePtr &other) {
1212   MS_EXCEPTION_IF_NULL(other);
1213   auto joined_list = AbstractJoin(elements_, other->elements_);
1214   bool changes = false;
1215   for (std::size_t i = 0; i < elements_.size(); i++) {
1216     if (elements_[i] != joined_list[i]) {
1217       changes = true;
1218       break;
1219     }
1220   }
1221   if (!changes) {
1222     return shared_from_base<AbstractBase>();
1223   }
1224   return std::make_shared<T>(joined_list);
1225 }
1226 template AbstractBasePtr AbstractSequence::ElementsJoin<AbstractList>(const AbstractSequencePtr &);
1227 template AbstractBasePtr AbstractSequence::ElementsJoin<AbstractTuple>(const AbstractSequencePtr &);
1228 
hash() const1229 std::size_t AbstractSequence::hash() const {
1230   if (dynamic_len_) {
1231     size_t hash_val = hash_combine(tid(), static_cast<size_t>(dynamic_len_));
1232     if (dynamic_len_element_abs_ != nullptr) {
1233       return hash_combine(hash_val, static_cast<size_t>(dynamic_len_element_abs_->hash()));
1234     }
1235     return hash_val;
1236   }
1237   return hash_combine(tid(), AbstractBasePtrListHash(elements_));
1238 }
1239 
size() const1240 std::size_t AbstractSequence::size() const {
1241   if (dynamic_len_) {
1242     if (dynamic_len_element_abs_ == nullptr) {
1243       return 0;
1244     }
1245     MS_LOG(INTERNAL_EXCEPTION) << "Can not get size for dynamic length sequence " << ToString();
1246   }
1247   return elements_.size();
1248 }
1249 
empty() const1250 bool AbstractSequence::empty() const {
1251   if (dynamic_len_) {
1252     if (dynamic_len_element_abs_ == nullptr) {
1253       return true;
1254     }
1255     MS_LOG(INTERNAL_EXCEPTION) << "Can not call function empty() for dynamic length sequence " << ToString();
1256   }
1257   return elements_.empty();
1258 }
1259 
set_dynamic_len(bool dynamic_len)1260 void AbstractSequence::set_dynamic_len(bool dynamic_len) {
1261   if (dynamic_len) {
1262     sequence_nodes_ = nullptr;
1263   }
1264   dynamic_len_ = dynamic_len;
1265 }
1266 
set_dynamic_len_element_abs(const AbstractBasePtr & dynamic_len_element_abs)1267 void AbstractSequence::set_dynamic_len_element_abs(const AbstractBasePtr &dynamic_len_element_abs) {
1268   if (dynamic_len_element_abs == nullptr) {
1269     return;
1270   }
1271   if (dynamic_len_element_abs->isa<abstract::AbstractDictionary>()) {
1272     MS_EXCEPTION(TypeError) << "DynamicSequence does not support dictionary type as element type now.";
1273   }
1274   // dynamic_len_element_abs should ignore value.
1275   dynamic_len_element_abs_ = AbstractBroaden(dynamic_len_element_abs);
1276 }
1277 
operator ==(const AbstractBase & other) const1278 bool AbstractSequence::operator==(const AbstractBase &other) const {
1279   if (this == &other) {
1280     return true;
1281   }
1282   if (tid() != other.tid()) {
1283     return false;
1284   }
1285   const auto &other_sequence = dynamic_cast<const AbstractSequence &>(other);
1286   if (dynamic_len_ != other_sequence.dynamic_len()) {
1287     // Variable length sequence and constant length sequence can not be the same.
1288     return false;
1289   }
1290 
1291   if (dynamic_len_) {
1292     // If the abstract of element for two variable sequence is the same, these two sequence is the same.
1293     return IsEqual(dynamic_len_element_abs_, other_sequence.dynamic_len_element_abs());
1294   }
1295 
1296   if (elements_.size() != other_sequence.elements_.size()) {
1297     return false;
1298   }
1299   for (size_t i = 0; i < elements_.size(); ++i) {
1300     if (!IsEqual(elements_[i], other_sequence.elements_[i])) {
1301       return false;
1302     }
1303   }
1304   return true;
1305 }
1306 
elements() const1307 const AbstractBasePtrList &AbstractSequence::elements() const { return elements_; }
1308 
sequence_nodes() const1309 const std::shared_ptr<AnfNodeWeakPtrList> &AbstractSequence::sequence_nodes() const { return sequence_nodes_; }
1310 
set_sequence_nodes(const std::shared_ptr<AnfNodeWeakPtrList> & sequence_nodes)1311 void AbstractSequence::set_sequence_nodes(const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes) {
1312   sequence_nodes_ = sequence_nodes;
1313 }
1314 
dynamic_len() const1315 bool AbstractSequence::dynamic_len() const { return dynamic_len_; }
1316 
dynamic_len_element_abs() const1317 AbstractBasePtr AbstractSequence::dynamic_len_element_abs() const { return dynamic_len_element_abs_; }
1318 
set_dyn_len_arg()1319 void AbstractSequence::set_dyn_len_arg() { dyn_len_arg_ = true; }
1320 
dyn_len_arg() const1321 bool AbstractSequence::dyn_len_arg() const { return dyn_len_arg_; }
1322 
ContainsAllBroadenTensors() const1323 bool AbstractSequence::ContainsAllBroadenTensors() const {
1324   if (dynamic_len_) {
1325     if (dynamic_len_element_abs_ != nullptr && dynamic_len_element_abs_->isa<AbstractTensor>()) {
1326       return true;
1327     }
1328     return false;
1329   }
1330   if (elements_.empty()) {
1331     return false;
1332   }
1333   auto exist_not_broadened_tensor = [](const AbstractBasePtr &abs) {
1334     bool is_broaden_tensor = abs->isa<abstract::AbstractUndetermined>() && abs->IsBroaden();
1335     bool is_broaden_sequence = abs->isa<abstract::AbstractSequence>() &&
1336                                abs->cast_ptr<abstract::AbstractSequence>()->ContainsAllBroadenTensors();
1337     return !is_broaden_tensor && !is_broaden_sequence;
1338   };
1339   return !std::any_of(elements_.cbegin(), elements_.cend(), exist_not_broadened_tensor);
1340 }
1341 
AbstractTuple(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)1342 AbstractTuple::AbstractTuple(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
1343     : AbstractSequence(std::move(elements), tuple_nodes) {}
1344 
AbstractTuple(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)1345 AbstractTuple::AbstractTuple(const AbstractBasePtrList &elements,
1346                              const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
1347     : AbstractSequence(elements, tuple_nodes) {}
1348 
BuildType() const1349 TypePtr AbstractTuple::BuildType() const {
1350   auto ret = std::make_shared<Tuple>(ElementsType());
1351   if (dynamic_len_) {
1352     ret->set_dynamic_len(dynamic_len_);
1353     if (dynamic_len_element_abs_ != nullptr) {
1354       ret->set_dynamic_element_type(dynamic_len_element_abs_->BuildType());
1355     }
1356   }
1357   return ret;
1358 }
1359 
BuildShape() const1360 BaseShapePtr AbstractTuple::BuildShape() const {
1361   if (dynamic_len_) {
1362     if (dynamic_len_element_abs_ == nullptr) {
1363       return std::make_shared<DynamicSequenceShape>(nullptr);
1364     }
1365     return std::make_shared<DynamicSequenceShape>(dynamic_len_element_abs_->GetShape());
1366   }
1367   return std::make_shared<TupleShape>(ElementsShape());
1368 }
1369 
Clone() const1370 AbstractBasePtr AbstractTuple::Clone() const {
1371   auto ret = std::make_shared<AbstractTuple>(ElementsClone(), sequence_nodes());
1372   ret->dyn_len_arg_ = dyn_len_arg_;
1373   ret->set_dynamic_len(dynamic_len_);
1374   ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1375   ret->SetSymbolicShape(this->GetSymbolicShape());
1376   ret->SetSymbolicValue(this->GetSymbolicValue());
1377   return ret;
1378 }
1379 
Broaden() const1380 AbstractBasePtr AbstractTuple::Broaden() const {
1381   auto ret = std::make_shared<AbstractTuple>(ElementsBroaden(), sequence_nodes());
1382   ret->set_dynamic_len(dynamic_len_);
1383   ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1384   return ret;
1385 }
1386 
PartialBroaden() const1387 AbstractBasePtr AbstractTuple::PartialBroaden() const {
1388   auto ret = std::make_shared<AbstractTuple>(ElementsPartialBroaden(), sequence_nodes());
1389   ret->set_dynamic_len(dynamic_len_);
1390   ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1391   return ret;
1392 }
1393 
RealBuildValue() const1394 ValuePtr AbstractTuple::RealBuildValue() const {
1395   if (dynamic_len_) {
1396     return kValueAny;
1397   }
1398   return ElementsBuildValue<ValueTuple>();
1399 }
1400 
set_shape(const BaseShapePtr & shape)1401 void AbstractTuple::set_shape(const BaseShapePtr &shape) {
1402   auto tuple_shape = dyn_cast_ptr<TupleShape>(shape);
1403   MS_EXCEPTION_IF_NULL(tuple_shape);
1404   if (tuple_shape->shape().size() != elements_.size()) {
1405     MS_LOG(INTERNAL_EXCEPTION) << "Size mismatch: " << tuple_shape->shape().size() << " vs " << elements_.size();
1406   }
1407 
1408   for (size_t i = 0; i < elements_.size(); ++i) {
1409     MS_EXCEPTION_IF_NULL(elements_[i]);
1410     elements_[i]->set_shape(tuple_shape->shape()[i]);
1411   }
1412 }
1413 
operator ==(const AbstractBase & other) const1414 bool AbstractTuple::operator==(const AbstractBase &other) const {
1415   if (this == &other) {
1416     return true;
1417   }
1418   if (!other.isa<AbstractTuple>()) {
1419     return false;
1420   }
1421   return AbstractSequence::operator==(static_cast<const AbstractSequence &>(other));
1422 }
1423 
AbstractList(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & list_nodes)1424 AbstractList::AbstractList(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes)
1425     : AbstractSequence(std::move(elements), list_nodes) {}
1426 
AbstractList(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & list_nodes)1427 AbstractList::AbstractList(const AbstractBasePtrList &elements, const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes)
1428     : AbstractSequence(elements, list_nodes) {}
1429 
operator ==(const AbstractBase & other) const1430 bool AbstractList::operator==(const AbstractBase &other) const {
1431   if (this == &other) {
1432     return true;
1433   }
1434   if (!other.isa<AbstractList>()) {
1435     return false;
1436   }
1437   auto other_extra_info = static_cast<const AbstractList &>(other).extra_info();
1438   if (extra_info_->size() != 0 && other_extra_info->size() != 0 && extra_info_ != other_extra_info) {
1439     return false;
1440   }
1441   return AbstractSequence::operator==(static_cast<const AbstractSequence &>(other));
1442 }
1443 
BuildType() const1444 TypePtr AbstractList::BuildType() const {
1445   auto ret = std::make_shared<List>(ElementsType());
1446   if (dynamic_len_) {
1447     ret->set_dynamic_len(dynamic_len_);
1448     if (dynamic_len_element_abs_ != nullptr) {
1449       ret->set_dynamic_element_type(dynamic_len_element_abs_->BuildType());
1450     }
1451   }
1452   return ret;
1453 }
1454 
BuildShape() const1455 BaseShapePtr AbstractList::BuildShape() const {
1456   if (dynamic_len_) {
1457     if (dynamic_len_element_abs_ == nullptr) {
1458       return std::make_shared<DynamicSequenceShape>(nullptr);
1459     }
1460     return std::make_shared<DynamicSequenceShape>(dynamic_len_element_abs_->GetShape());
1461   }
1462   return std::make_shared<ListShape>(ElementsShape());
1463 }
1464 
Clone() const1465 AbstractBasePtr AbstractList::Clone() const {
1466   auto ret = std::make_shared<AbstractList>(ElementsClone(), sequence_nodes());
1467   ret->dyn_len_arg_ = dyn_len_arg_;
1468   ret->set_dynamic_len(dynamic_len_);
1469   ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1470   ret->set_extra_info(extra_info_);
1471   ret->SetSymbolicShape(this->GetSymbolicShape());
1472   ret->SetSymbolicValue(this->GetSymbolicValue());
1473   return ret;
1474 }
1475 
Broaden() const1476 AbstractBasePtr AbstractList::Broaden() const {
1477   auto ret = std::make_shared<AbstractList>(ElementsBroaden(), sequence_nodes());
1478   ret->set_dynamic_len(dynamic_len_);
1479   ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1480   ret->set_extra_info(extra_info_);
1481   return ret;
1482 }
1483 
PartialBroaden() const1484 AbstractBasePtr AbstractList::PartialBroaden() const {
1485   auto ret = std::make_shared<AbstractList>(ElementsPartialBroaden(), sequence_nodes());
1486   ret->set_dynamic_len(dynamic_len_);
1487   ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1488   ret->set_extra_info(extra_info_);
1489   return ret;
1490 }
1491 
RealBuildValue() const1492 ValuePtr AbstractList::RealBuildValue() const {
1493   if (dynamic_len_) {
1494     return kValueAny;
1495   }
1496   return ElementsBuildValue<ValueList>();
1497 }
1498 
CheckAndConvertToDynamicLenSequence(bool raise_exception)1499 void AbstractList::CheckAndConvertToDynamicLenSequence(bool raise_exception) {
1500   AbstractSequence::CheckAndConvertToDynamicLenSequence(raise_exception);
1501   ClearExtraInfo();
1502 }
1503 
DynamicLenSequenceJoin(const AbstractSequencePtr & other)1504 std::shared_ptr<AbstractSequence> AbstractSequence::DynamicLenSequenceJoin(const AbstractSequencePtr &other) {
1505   auto other_dyn_sequence_abs = other;
1506   if (!dynamic_len() || !other->dynamic_len()) {
1507     MS_LOG(EXCEPTION) << "Can't join fixed length sequence. \nthis:" << ToString() << ", \nother:" << other->ToString();
1508   }
1509   auto element_abs1 = dynamic_len_element_abs_;
1510   auto element_abs2 = other_dyn_sequence_abs->dynamic_len_element_abs();
1511   AbstractBasePtr join_element_abs = nullptr;
1512   // When two element abstracts are not nullptr, join them to get the new element abstract.
1513   // When one or none of the element abstract is nullptr, the result element abstract is another.
1514   if (element_abs1 == nullptr) {
1515     join_element_abs = element_abs2;
1516   } else if (element_abs2 == nullptr) {
1517     join_element_abs = element_abs1;
1518   } else {
1519     join_element_abs = element_abs1->Join(element_abs2);
1520   }
1521   auto ret = Clone()->cast<AbstractSequencePtr>();
1522   ret->set_dynamic_len_element_abs(join_element_abs);
1523   return ret;
1524 }
1525 
Join(const AbstractBasePtr & other)1526 AbstractBasePtr AbstractTuple::Join(const AbstractBasePtr &other) {
1527   if (other->isa<AbstractNegligible>()) {
1528     return shared_from_base<AbstractBase>();
1529   }
1530   auto other_sequence = other->cast<AbstractTuplePtr>();
1531   if (other_sequence == nullptr) {
1532     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1533   }
1534   try {
1535     if (dynamic_len_) {
1536       return DynamicLenSequenceJoin(other_sequence->BroadenToDynamicLenSequence());
1537     }
1538     if (other_sequence->dynamic_len()) {
1539       return other_sequence->DynamicLenSequenceJoin(BroadenToDynamicLenSequence());
1540     }
1541     if (other_sequence->size() != size()) {
1542       auto dyn_len_sequence = BroadenToDynamicLenSequence();
1543       return dyn_len_sequence->Join(other_sequence->BroadenToDynamicLenSequence());
1544     }
1545   } catch (std::exception &e) {
1546     MS_LOG(ERROR) << e.what();
1547     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1548   }
1549   auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractTuple>(other_sequence));
1550   MS_EXCEPTION_IF_NULL(res);
1551   res->InsertSequenceNodes(SequenceNodesJoin(other));
1552   return res;
1553 }
1554 
Join(const AbstractBasePtr & other)1555 AbstractBasePtr AbstractList::Join(const AbstractBasePtr &other) {
1556   if (other->isa<AbstractNegligible>()) {
1557     return shared_from_base<AbstractBase>();
1558   }
1559   auto other_sequence = other->cast<AbstractListPtr>();
1560   if (other_sequence == nullptr) {
1561     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1562   }
1563   try {
1564     if (dynamic_len_) {
1565       return DynamicLenSequenceJoin(other_sequence->BroadenToDynamicLenSequence());
1566     }
1567     if (other_sequence->dynamic_len()) {
1568       return other_sequence->DynamicLenSequenceJoin(BroadenToDynamicLenSequence());
1569     }
1570     if (other_sequence->size() != size()) {
1571       auto dyn_len_sequence = BroadenToDynamicLenSequence();
1572       return dyn_len_sequence->DynamicLenSequenceJoin(other_sequence->BroadenToDynamicLenSequence());
1573     }
1574   } catch (std::exception &e) {
1575     MS_LOG(ERROR) << e.what();
1576     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1577   }
1578 
1579   auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractList>(other_sequence));
1580   MS_EXCEPTION_IF_NULL(res);
1581   res->InsertSequenceNodes(SequenceNodesJoin(other));
1582   return res;
1583 }
1584 
RealBuildValue() const1585 ValuePtr AbstractNamedTuple::RealBuildValue() const {
1586   std::vector<ValuePtr> element_value_list;
1587   for (const auto &element : elements_) {
1588     MS_EXCEPTION_IF_NULL(element);
1589     auto element_value = element->BuildValue();
1590     MS_EXCEPTION_IF_NULL(element_value);
1591     element_value_list.push_back(element_value);
1592   }
1593   std::vector<ValuePtr> key_value_list;
1594   for (const auto &key : keys_) {
1595     MS_EXCEPTION_IF_NULL(key);
1596     auto key_value = key->BuildValue();
1597     MS_EXCEPTION_IF_NULL(key_value);
1598     key_value_list.push_back(key_value);
1599   }
1600   return std::make_shared<ValueNamedTuple>(sub_class_name_, key_value_list, element_value_list);
1601 }
1602 
BuildType() const1603 TypePtr AbstractSlice::BuildType() const {
1604   MS_EXCEPTION_IF_NULL(start_);
1605   MS_EXCEPTION_IF_NULL(stop_);
1606   MS_EXCEPTION_IF_NULL(step_);
1607   TypePtr start = start_->BuildType();
1608   TypePtr stop = stop_->BuildType();
1609   TypePtr step = step_->BuildType();
1610   return std::make_shared<Slice>(start, stop, step);
1611 }
1612 
AbstractDictionary(const std::vector<AbstractElementPair> & key_values)1613 AbstractDictionary::AbstractDictionary(const std::vector<AbstractElementPair> &key_values) : key_values_(key_values) {}
1614 
size() const1615 std::size_t AbstractDictionary::size() const { return key_values_.size(); }
1616 
elements() const1617 const std::vector<AbstractElementPair> &AbstractDictionary::elements() const { return key_values_; }
1618 
AbstractSlice(const AbstractBasePtr & start,const AbstractBasePtr & stop,const AbstractBasePtr & step)1619 AbstractSlice::AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step)
1620     : start_(start), stop_(stop), step_(step) {}
1621 
operator ==(const AbstractBase & other) const1622 bool AbstractSlice::operator==(const AbstractBase &other) const {
1623   if (this == &other) {
1624     return true;
1625   }
1626   if (!other.isa<AbstractSlice>()) {
1627     return false;
1628   }
1629   const auto &other_slice = dynamic_cast<const AbstractSlice &>(other);
1630   return IsEqual(start_, other_slice.start_) && IsEqual(stop_, other_slice.stop_) && IsEqual(step_, other_slice.step_);
1631 }
1632 
Clone() const1633 AbstractBasePtr AbstractSlice::Clone() const {
1634   MS_EXCEPTION_IF_NULL(start_);
1635   MS_EXCEPTION_IF_NULL(stop_);
1636   MS_EXCEPTION_IF_NULL(step_);
1637   AbstractBasePtr start = start_->Clone();
1638   AbstractBasePtr stop = stop_->Clone();
1639   AbstractBasePtr step = step_->Clone();
1640   return std::make_shared<AbstractSlice>(start, stop, step);
1641 }
1642 
Broaden() const1643 AbstractBasePtr AbstractSlice::Broaden() const {
1644   MS_EXCEPTION_IF_NULL(start_);
1645   MS_EXCEPTION_IF_NULL(stop_);
1646   MS_EXCEPTION_IF_NULL(step_);
1647   AbstractBasePtr start = start_->Broaden();
1648   AbstractBasePtr stop = stop_->Broaden();
1649   AbstractBasePtr step = step_->Broaden();
1650   return std::make_shared<AbstractSlice>(start, stop, step);
1651 }
1652 
ToString() const1653 std::string AbstractSlice::ToString() const {
1654   std::ostringstream buffer;
1655   buffer << type_name() << "[";
1656   MS_EXCEPTION_IF_NULL(start_);
1657   buffer << start_->ToString() << " : ";
1658   MS_EXCEPTION_IF_NULL(stop_);
1659   buffer << stop_->ToString() << " : ";
1660   MS_EXCEPTION_IF_NULL(step_);
1661   buffer << step_->ToString();
1662   buffer << "]";
1663   return buffer.str();
1664 }
1665 
RealBuildValue() const1666 ValuePtr AbstractSlice::RealBuildValue() const {
1667   MS_EXCEPTION_IF_NULL(start_);
1668   MS_EXCEPTION_IF_NULL(stop_);
1669   MS_EXCEPTION_IF_NULL(step_);
1670   ValuePtr start = start_->BuildValue();
1671   ValuePtr stop = stop_->BuildValue();
1672   ValuePtr step = step_->BuildValue();
1673   if (start->isa<ValueAny>() || stop->isa<ValueAny>() || step->isa<ValueAny>()) {
1674     return kValueAny;
1675   }
1676   return std::make_shared<ValueSlice>(start, stop, step);
1677 }
1678 
hash() const1679 std::size_t AbstractSlice::hash() const {
1680   MS_EXCEPTION_IF_NULL(start_);
1681   MS_EXCEPTION_IF_NULL(stop_);
1682   MS_EXCEPTION_IF_NULL(step_);
1683   return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
1684 }
1685 
start() const1686 AbstractBasePtr AbstractSlice::start() const { return start_; }
1687 
stop() const1688 AbstractBasePtr AbstractSlice::stop() const { return stop_; }
1689 
step() const1690 AbstractBasePtr AbstractSlice::step() const { return step_; }
1691 
shape() const1692 ShapePtr AbstractUndetermined::shape() const {
1693   auto shp = dyn_cast<Shape>(GetShapeTrack());
1694   if (shp == nullptr) {
1695     MS_LOG(INTERNAL_EXCEPTION) << "Tensor should have a shape.";
1696   }
1697   return shp;
1698 }
1699 
set_shape(const BaseShapePtr & shape)1700 void AbstractUndetermined::set_shape(const BaseShapePtr &shape) {
1701   MS_EXCEPTION_IF_NULL(shape);
1702   if (shape->isa<NoShape>()) {
1703     MS_LOG(INTERNAL_EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
1704   }
1705   AbstractBase::set_shape(shape);
1706 }
1707 
BuildType() const1708 TypePtr AbstractTensor::BuildType() const {
1709   MS_EXCEPTION_IF_NULL(element_);
1710   TypePtr element_type = element_->BuildType();
1711   return std::make_shared<TensorType>(element_type);
1712 }
1713 
BuildShape() const1714 BaseShapePtr AbstractTensor::BuildShape() const {
1715   auto shape = GetShapeTrack();
1716   // Guard from using set_shape(nullptr)
1717   if (shape == nullptr) {
1718     return kNoShape;
1719   }
1720   return shape;
1721 }
1722 
Join(const AbstractBasePtr & other)1723 AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
1724   MS_EXCEPTION_IF_NULL(other);
1725   auto other_type = other->BuildType();
1726   MS_EXCEPTION_IF_NULL(other_type);
1727   MS_EXCEPTION_IF_NULL(element_);
1728   if (other->isa<AbstractNegligible>()) {
1729     return shared_from_base<AbstractBase>();
1730   }
1731   // AbstractTensor join with AbstractUndetermined
1732   if (other_type->type_id() == kObjectTypeUndeterminedType) {
1733     auto other_undetermined_tensor = dyn_cast_ptr<AbstractUndetermined>(other);
1734     MS_EXCEPTION_IF_NULL(other_undetermined_tensor);
1735     // Check shape
1736     auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape());
1737     if (res_shape == nullptr) {
1738       ShapeJoinLogging(shape(), other_undetermined_tensor->shape(), shared_from_base<AbstractBase>(), other);
1739     }
1740     // Check element
1741     auto element = element_->Join(other_undetermined_tensor->element());
1742     MS_EXCEPTION_IF_NULL(element);
1743     return std::make_shared<AbstractUndetermined>(element, res_shape);
1744   }
1745 
1746   // AbstractTensor join with AbstractTensor
1747   auto other_tensor = dyn_cast_ptr<AbstractTensor>(other);
1748   if (other_tensor == nullptr) {
1749     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1750   }
1751   if (*this == *other) {
1752     return shared_from_base<AbstractBase>();
1753   }
1754   // Check shape
1755   auto res_shape = ShapeJoin(this->shape(), other_tensor->shape());
1756   if (res_shape == nullptr) {
1757     ShapeJoinLogging(shape(), other_tensor->shape(), shared_from_base<AbstractBase>(), other);
1758   }
1759   // Check element
1760   auto element = element_->Join(other_tensor->element_);
1761   MS_EXCEPTION_IF_NULL(element);
1762   auto ret = std::make_shared<AbstractTensor>(element, res_shape);
1763   ret->set_is_adapter(is_adapter_);
1764   return ret;
1765 }
1766 
equal_to(const AbstractTensor & other) const1767 bool AbstractTensor::equal_to(const AbstractTensor &other) const {
1768   if (this == &other) {
1769     return true;
1770   }
1771   // Check if both Tensor or both AdapterTensor.
1772   if (is_adapter() != other.is_adapter()) {
1773     return false;
1774   }
1775   const auto &v1 = GetValueTrack();
1776   const auto &v2 = other.GetValueTrack();
1777   MS_EXCEPTION_IF_NULL(v1);
1778   MS_EXCEPTION_IF_NULL(v2);
1779   // Check if both point to same specific value.
1780   if (!v1->isa<ValueAny>()) {
1781     return v1 == v2;
1782   }
1783   // Check if both are ValueAny.
1784   if (!v2->isa<ValueAny>()) {
1785     return false;
1786   }
1787   // Check element type.
1788   if (!IsEqual(element_, other.element_)) {
1789     return false;
1790   }
1791   // Check shape.
1792   return IsEqual(shape(), other.shape());
1793 }
1794 
operator ==(const AbstractTensor & other) const1795 bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
1796 
operator ==(const AbstractBase & other) const1797 bool AbstractTensor::operator==(const AbstractBase &other) const {
1798   if (this == &other) {
1799     return true;
1800   }
1801   if (tid() != other.tid()) {
1802     return false;
1803   }
1804   return equal_to(static_cast<const AbstractTensor &>(other));
1805 }
1806 
Clone() const1807 AbstractBasePtr AbstractTensor::Clone() const {
1808   MS_EXCEPTION_IF_NULL(element_);
1809   auto clone = std::make_shared<AbstractTensor>(element_->Clone());
1810   ShapePtr shp = shape();
1811   clone->set_shape(shp->Clone());
1812   clone->set_value(GetValueTrack());
1813   clone->set_is_adapter(is_adapter());
1814   clone->SetSymbolicShape(this->GetSymbolicShape());
1815   clone->SetSymbolicValue(this->GetSymbolicValue());
1816   return clone;
1817 }
1818 
Broaden() const1819 AbstractBasePtr AbstractTensor::Broaden() const {
1820   MS_EXCEPTION_IF_NULL(element_);
1821   auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
1822   auto shp = shape();
1823   MS_EXCEPTION_IF_NULL(shp);
1824   broaden->set_shape(shp->Clone());
1825   broaden->set_value(kValueAny);
1826   broaden->set_is_adapter(is_adapter());
1827   return broaden;
1828 }
1829 
BroadenWithShape() const1830 AbstractBasePtr AbstractTensor::BroadenWithShape() const {
1831   MS_EXCEPTION_IF_NULL(element_);
1832   auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
1833   auto shp = shape()->Clone();
1834   MS_EXCEPTION_IF_NULL(shp);
1835   shp->Broaden();
1836   broaden->set_shape(shp);
1837   broaden->set_value(kValueAny);
1838   broaden->set_is_adapter(is_adapter());
1839   return broaden;
1840 }
1841 
PartialBroaden() const1842 AbstractBasePtr AbstractTensor::PartialBroaden() const { return Broaden(); }
1843 
ToString() const1844 std::string AbstractTensor::ToString() const {
1845   std::ostringstream buffer;
1846   BaseShapePtr shape_track = GetShapeTrack();
1847   MS_EXCEPTION_IF_NULL(shape_track);
1848   MS_EXCEPTION_IF_NULL(element_);
1849   auto value_track = GetValueTrack();
1850   MS_EXCEPTION_IF_NULL(value_track);
1851   std::string is_adapter = this->is_adapter() ? "True" : "False";
1852   buffer << type_name() << "("
1853          << "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
1854          << ", is adapter: " << is_adapter << ", value_ptr: " << value_track << ", value: " << value_track->ToString()
1855          << ")";
1856   return buffer.str();
1857 }
1858 
BuildType() const1859 TypePtr AbstractAny::BuildType() const {
1860   MS_EXCEPTION_IF_NULL(element_);
1861   TypePtr element_type = element_->BuildType();
1862   return std::make_shared<AnyType>(element_type);
1863 }
1864 
Join(const AbstractBasePtr & other)1865 AbstractBasePtr AbstractNegligible::Join(const AbstractBasePtr &other) {
1866   MS_EXCEPTION_IF_NULL(other);
1867   if (other->isa<AbstractScalar>()) {
1868     const auto &value_other = other->GetValueTrack();
1869     if (!value_other->isa<ValueAny>()) {
1870       return std::make_shared<AbstractAny>();
1871     }
1872   }
1873   return other;
1874 }
1875 
Clone() const1876 AbstractBasePtr AbstractNegligible::Clone() const { return std::make_shared<AbstractNegligible>(); }
1877 
Broaden() const1878 AbstractBasePtr AbstractNegligible::Broaden() const { return Clone(); }
1879 
ToString() const1880 std::string AbstractNegligible::ToString() const { return type_name(); }
1881 
BuildType() const1882 TypePtr AbstractNegligible::BuildType() const {
1883   MS_EXCEPTION_IF_NULL(element_);
1884   TypePtr element_type = element_->BuildType();
1885   return std::make_shared<NegligibleType>(element_type);
1886 }
1887 
BuildType() const1888 TypePtr AbstractDictionary::BuildType() const {
1889   std::vector<std::pair<ValuePtr, TypePtr>> key_values;
1890   for (const auto &item : key_values_) {
1891     MS_EXCEPTION_IF_NULL(item.first);
1892     MS_EXCEPTION_IF_NULL(item.second);
1893     ValuePtr key_type = item.first->BuildValue();
1894     TypePtr value_type = item.second->BuildType();
1895     (void)key_values.emplace_back(key_type, value_type);
1896   }
1897   return std::make_shared<Dictionary>(key_values);
1898 }
1899 
operator ==(const AbstractBase & other) const1900 bool AbstractDictionary::operator==(const AbstractBase &other) const {
1901   if (this == &other) {
1902     return true;
1903   }
1904   if (!other.isa<AbstractDictionary>()) {
1905     return false;
1906   }
1907   const auto &other_dict = dynamic_cast<const AbstractDictionary &>(other);
1908   if (key_values_.size() != other_dict.key_values_.size()) {
1909     return false;
1910   }
1911   for (size_t index = 0; index < key_values_.size(); ++index) {
1912     auto &kv1 = key_values_[index];
1913     auto &kv2 = other_dict.key_values_[index];
1914     if (!IsEqual(kv1.first, kv2.first) || !IsEqual(kv1.second, kv2.second)) {
1915       return false;
1916     }
1917   }
1918   return true;
1919 }
1920 
Clone() const1921 AbstractBasePtr AbstractDictionary::Clone() const {
1922   std::vector<AbstractElementPair> kv;
1923   (void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
1924                        [](const AbstractElementPair &item) {
1925                          MS_EXCEPTION_IF_NULL(item.first);
1926                          MS_EXCEPTION_IF_NULL(item.second);
1927                          return std::make_pair(item.first->Clone(), item.second->Clone());
1928                        });
1929   auto ret = std::make_shared<AbstractDictionary>(kv);
1930   ret->set_extra_info(extra_info_);
1931   return ret;
1932 }
1933 
Broaden() const1934 AbstractBasePtr AbstractDictionary::Broaden() const {
1935   std::vector<AbstractElementPair> kv;
1936   (void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
1937                        [](const AbstractElementPair &item) {
1938                          MS_EXCEPTION_IF_NULL(item.second);
1939                          return std::make_pair(item.first, item.second->Broaden());
1940                        });
1941   auto ret = std::make_shared<AbstractDictionary>(kv);
1942   ret->set_extra_info(extra_info_);
1943   return ret;
1944 }
1945 
ToString() const1946 std::string AbstractDictionary::ToString() const {
1947   std::ostringstream buffer;
1948   buffer << type_name() << "{ ";
1949   for (const auto &kv : key_values_) {
1950     MS_EXCEPTION_IF_NULL(kv.first);
1951     MS_EXCEPTION_IF_NULL(kv.second);
1952     buffer << "(" << kv.first->ToString() << ": " << kv.second->ToString() << ") ";
1953   }
1954   buffer << "}";
1955   return buffer.str();
1956 }
1957 
hash() const1958 std::size_t AbstractDictionary::hash() const {
1959   std::size_t hash_sum = std::accumulate(key_values_.cbegin(), key_values_.cend(), tid(),
1960                                          [](std::size_t hash_sum, const AbstractElementPair &item) {
1961                                            MS_EXCEPTION_IF_NULL(item.first);
1962                                            MS_EXCEPTION_IF_NULL(item.second);
1963                                            hash_sum = hash_combine(hash_sum, item.first->hash());
1964                                            hash_sum = hash_combine(hash_sum, item.second->hash());
1965                                            return hash_sum;
1966                                          });
1967   return hash_sum;
1968 }
1969 
RealBuildValue() const1970 ValuePtr AbstractDictionary::RealBuildValue() const {
1971   std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
1972   for (const auto &item : key_values_) {
1973     MS_EXCEPTION_IF_NULL(item.first);
1974     MS_EXCEPTION_IF_NULL(item.second);
1975     auto key_element_value = item.first->BuildValue();
1976     auto value_element_value = item.second->BuildValue();
1977     MS_EXCEPTION_IF_NULL(key_element_value);
1978     MS_EXCEPTION_IF_NULL(value_element_value);
1979     if (value_element_value->isa<ValueAny>()) {
1980       return kValueAny;
1981     }
1982     (void)key_values.emplace_back(key_element_value, value_element_value);
1983   }
1984   return std::make_shared<ValueDictionary>(key_values);
1985 }
1986 
Join(const AbstractBasePtr & other)1987 AbstractBasePtr AbstractDictionary::Join(const AbstractBasePtr &other) {
1988   if (other->isa<AbstractNegligible>()) {
1989     return shared_from_base<AbstractBase>();
1990   }
1991   auto other_dict = other->cast<AbstractDictionaryPtr>();
1992   if (other_dict == nullptr) {
1993     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1994   }
1995   if (elements().size() != other_dict->elements().size()) {
1996     MS_LOG(EXCEPTION)
1997       << "Join failed as dict don't have the same size. spec1: " << ::mindspore::ToString(elements())
1998       << ", spec2: " << ::mindspore::ToString(other_dict->elements())
1999       << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed\n";
2000   }
2001 
2002   auto JoinElement = [](const AbstractBasePtr &abs1, const AbstractBasePtr &abs2) -> AbstractBasePtr {
2003     MS_EXCEPTION_IF_NULL(abs1);
2004     auto joined_res = abs1->Join(abs2);
2005     MS_EXCEPTION_IF_NULL(joined_res);
2006     return joined_res;
2007   };
2008 
2009   std::vector<AbstractElementPair> joined_key_values;
2010   bool changes = false;
2011   for (std::size_t i = 0; i < elements().size(); i++) {
2012     auto key_value = elements()[i];
2013     auto other_key_value = other_dict->elements()[i];
2014     auto joined_key = JoinElement(elements()[i].first, other_dict->elements()[i].first);
2015     if (joined_key != elements()[i].first) {
2016       changes = true;
2017     }
2018     auto joined_value = JoinElement(elements()[i].second, other_dict->elements()[i].second);
2019     if (joined_value != elements()[i].second) {
2020       changes = true;
2021     }
2022     (void)joined_key_values.emplace_back(joined_key, joined_value);
2023   }
2024   if (!changes) {
2025     return shared_from_base<AbstractBase>();
2026   }
2027   return std::make_shared<AbstractDictionary>(joined_key_values);
2028 }
2029 
AbstractJTagged(const AbstractBasePtr & element)2030 AbstractJTagged::AbstractJTagged(const AbstractBasePtr &element) : element_(element) {}
2031 
BuildType() const2032 TypePtr AbstractJTagged::BuildType() const {
2033   MS_EXCEPTION_IF_NULL(element_);
2034   TypePtr subtype = element_->BuildType();
2035   return std::make_shared<JTagged>(subtype);
2036 }
2037 
Join(const AbstractBasePtr & other)2038 AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
2039   MS_EXCEPTION_IF_NULL(other);
2040   if (other->isa<AbstractNegligible>()) {
2041     return shared_from_base<AbstractBase>();
2042   }
2043   auto other_jtagged = dyn_cast_ptr<AbstractJTagged>(other);
2044   if (other_jtagged == nullptr) {
2045     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
2046   }
2047   MS_EXCEPTION_IF_NULL(element_);
2048   auto joined_elem = element_->Join(other_jtagged->element_);
2049   return std::make_shared<AbstractJTagged>(joined_elem);
2050 }
2051 
operator ==(const AbstractBase & other) const2052 bool AbstractJTagged::operator==(const AbstractBase &other) const {
2053   if (this == &other) {
2054     return true;
2055   }
2056   if (!other.isa<AbstractJTagged>()) {
2057     return false;
2058   }
2059   const auto &other_jtagged = dynamic_cast<const AbstractJTagged &>(other);
2060   return IsEqual(element_, other_jtagged.element_);
2061 }
2062 
ToString() const2063 std::string AbstractJTagged::ToString() const {
2064   std::ostringstream buffer;
2065   MS_EXCEPTION_IF_NULL(element_);
2066   buffer << type_name() << "("
2067          << "element: " << element_->ToString() << ")";
2068   return buffer.str();
2069 }
2070 
Clone() const2071 AbstractBasePtr AbstractJTagged::Clone() const { return std::make_shared<AbstractJTagged>(element_->Clone()); }
2072 
Broaden() const2073 AbstractBasePtr AbstractJTagged::Broaden() const { return std::make_shared<AbstractJTagged>(element_->Broaden()); }
2074 
element()2075 AbstractBasePtr AbstractJTagged::element() { return element_; }
2076 
hash() const2077 std::size_t AbstractJTagged::hash() const { return hash_combine(tid(), element_->hash()); }
2078 
AbstractRefTensor(const AbstractTensorPtr & ref_value,const ValuePtr & ref_key_value)2079 AbstractRefTensor::AbstractRefTensor(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value)
2080     : AbstractTensor(*ref_value), ref_key_value_(ref_key_value) {
2081   set_type(std::make_shared<RefType>());
2082   set_is_adapter(ref_value->is_adapter());
2083   MS_EXCEPTION_IF_NULL(ref_key_value);
2084   if (ref_key_value != kValueAny && !ref_key_value->isa<RefKey>()) {
2085     MS_LOG(INTERNAL_EXCEPTION) << "ref_key_value must be kValueAny or RefKey, but got:" << ref_key_value->ToString();
2086   }
2087 }
2088 
BuildType() const2089 TypePtr AbstractRefTensor::BuildType() const {
2090   auto type = AbstractTensor::BuildType();
2091   auto subtype = dyn_cast_ptr<TensorType>(type);
2092   MS_EXCEPTION_IF_NULL(subtype);
2093   return std::make_shared<RefType>(subtype);
2094 }
2095 
operator ==(const AbstractBase & other) const2096 bool AbstractRefTensor::operator==(const AbstractBase &other) const {
2097   if (this == &other) {
2098     return true;
2099   }
2100   if (!other.isa<AbstractRefTensor>()) {
2101     return false;
2102   }
2103   return AbstractTensor::equal_to(dynamic_cast<const AbstractTensor &>(other));
2104 }
2105 
Join(const std::shared_ptr<AbstractRefTensor> & other)2106 AbstractBasePtr AbstractRefTensor::Join(const std::shared_ptr<AbstractRefTensor> &other) {
2107   if (*this == *other) {
2108     return shared_from_base<AbstractRefTensor>();
2109   }
2110   if (other->isa<AbstractNegligible>()) {
2111     return shared_from_base<AbstractBase>();
2112   }
2113   // Firstly, join the ref_key_value.
2114   auto joined_ref_key = ValueJoin(ref_key_value_, other->ref_key_value_);
2115   // Secondly , join the tensor value.
2116   auto joined_tensor = AbstractTensor::Join(other)->cast<AbstractTensorPtr>();
2117   MS_EXCEPTION_IF_NULL(joined_tensor);
2118   return std::make_shared<AbstractRefTensor>(joined_tensor, joined_ref_key);
2119 }
2120 
Join(const AbstractBasePtr & other)2121 AbstractBasePtr AbstractRefTensor::Join(const AbstractBasePtr &other) {
2122   MS_EXCEPTION_IF_NULL(other);
2123   // Abstract ref join abstract ref
2124   if (other->isa<AbstractRefTensor>()) {
2125     return AbstractRefTensor::Join(other->cast<AbstractRefPtr>());
2126   }
2127   if (other->isa<AbstractNegligible>()) {
2128     return shared_from_base<AbstractBase>();
2129   }
2130   // Abstract ref join other abstract are same to AbstractTensor::Join.
2131   auto joined_tensor = AbstractTensor::Join(other);
2132   if (!joined_tensor->isa<AbstractTensor>()) {
2133     MS_LOG(INTERNAL_EXCEPTION) << "Expect an AbstractTensor, but got:" << joined_tensor->ToString()
2134                                << ", other:" << other->ToString();
2135   }
2136   return joined_tensor;
2137 }
2138 
Clone() const2139 AbstractBasePtr AbstractRefTensor::Clone() const {
2140   auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
2141   return std::make_shared<AbstractRefTensor>(abs_tensor, ref_key_value_);
2142 }
2143 
Broaden() const2144 AbstractBasePtr AbstractRefTensor::Broaden() const {
2145   // Always broaden for ref
2146   auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
2147   // Broaden the tensor value and keep the ref_key_value.
2148   auto ret = std::make_shared<AbstractRefTensor>(abs_tensor, ref_key_value_);
2149   return ret;
2150 }
2151 
ToString() const2152 std::string AbstractRefTensor::ToString() const {
2153   std::ostringstream buffer;
2154   MS_EXCEPTION_IF_NULL(ref_key_value_);
2155   buffer << type_name() << "("
2156          << "key: " << ref_key_value_->ToString() << " ref_value: " << AbstractTensor::ToString();
2157   auto value = GetValueTrack();
2158   if (value != nullptr) {
2159     buffer << ", value: " << value->ToString();
2160   }
2161   buffer << ")";
2162   return buffer.str();
2163 }
2164 
PartialBroaden() const2165 AbstractBasePtr AbstractRefTensor::PartialBroaden() const { return Clone(); }
2166 
AbstractNone()2167 AbstractNone::AbstractNone() : AbstractBase() { set_type(std::make_shared<TypeNone>()); }
2168 
BuildType() const2169 TypePtr AbstractNone::BuildType() const { return std::make_shared<TypeNone>(); }
2170 
Clone() const2171 AbstractBasePtr AbstractNone::Clone() const { return std::make_shared<AbstractNone>(); }
2172 
operator ==(const AbstractBase & other) const2173 bool AbstractNone::operator==(const AbstractBase &other) const { return other.isa<AbstractNone>(); }
2174 
ToString() const2175 std::string AbstractNone::ToString() const {
2176   std::ostringstream buffer;
2177   buffer << type_name() << "(Value: None)";
2178   return buffer.str();
2179 }
2180 
Join(const AbstractBasePtr & other)2181 AbstractBasePtr AbstractNone::Join(const AbstractBasePtr &other) {
2182   MS_EXCEPTION_IF_NULL(other);
2183   if (other->isa<AbstractNegligible>()) {
2184     return shared_from_base<AbstractBase>();
2185   }
2186   if (!other->isa<AbstractNone>()) {
2187     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
2188   }
2189   return shared_from_base<AbstractNone>();
2190 }
2191 
AbstractNull()2192 AbstractNull::AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
2193 
BuildType() const2194 TypePtr AbstractNull::BuildType() const { return std::make_shared<TypeNull>(); }
2195 
Clone() const2196 AbstractBasePtr AbstractNull::Clone() const { return std::make_shared<AbstractNull>(); }
2197 
RealBuildValue() const2198 ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
2199 
operator ==(const AbstractBase & other) const2200 bool AbstractNull::operator==(const AbstractBase &other) const { return other.isa<AbstractNull>(); }
2201 
ToString() const2202 std::string AbstractNull::ToString() const {
2203   std::ostringstream buffer;
2204   buffer << type_name() << "(Value: Null)";
2205   return buffer.str();
2206 }
2207 
AbstractTimeOut()2208 AbstractTimeOut::AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
2209 
BuildType() const2210 TypePtr AbstractTimeOut::BuildType() const { return std::make_shared<TypeNull>(); }
2211 
Clone() const2212 AbstractBasePtr AbstractTimeOut::Clone() const { return std::make_shared<AbstractTimeOut>(); }
2213 
operator ==(const AbstractBase & other) const2214 bool AbstractTimeOut::operator==(const AbstractBase &other) const { return other.isa<AbstractTimeOut>(); }
2215 
ToString() const2216 std::string AbstractTimeOut::ToString() const {
2217   std::ostringstream buffer;
2218   buffer << "AbstractTimeOut "
2219          << "(Value: Null)";
2220   return buffer.str();
2221 }
2222 
AbstractEllipsis()2223 AbstractEllipsis::AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); }
2224 
BuildType() const2225 TypePtr AbstractEllipsis::BuildType() const { return std::make_shared<TypeEllipsis>(); }
2226 
Clone() const2227 AbstractBasePtr AbstractEllipsis::Clone() const { return std::make_shared<AbstractEllipsis>(); }
2228 
operator ==(const AbstractBase & other) const2229 bool AbstractEllipsis::operator==(const AbstractBase &other) const { return other.isa<AbstractEllipsis>(); }
2230 
ToString() const2231 std::string AbstractEllipsis::ToString() const {
2232   std::ostringstream buffer;
2233   buffer << type_name() << "(Value: Ellipsis)";
2234   return buffer.str();
2235 }
2236 
CloneAsTensor() const2237 AbstractBasePtr AbstractRefTensor::CloneAsTensor() const { return AbstractTensor::Clone(); }
2238 
ref()2239 AbstractTensorPtr AbstractRefTensor::ref() { return shared_from_base<AbstractTensor>(); }
2240 
ref_key_value() const2241 ValuePtr AbstractRefTensor::ref_key_value() const { return ref_key_value_; }
2242 
AbstractSparseTensor(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2243 AbstractSparseTensor::AbstractSparseTensor(AbstractBasePtrList &&elements,
2244                                            const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2245     : AbstractTuple(std::move(elements), tuple_nodes) {}
2246 
AbstractSparseTensor(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2247 AbstractSparseTensor::AbstractSparseTensor(const AbstractBasePtrList &elements,
2248                                            const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2249     : AbstractTuple(elements, tuple_nodes) {}
2250 
BuildShape() const2251 BaseShapePtr AbstractSparseTensor::BuildShape() const {
2252   return std::make_shared<TupleShape>(ElementsShapeTupleRecursive());
2253 }
2254 
AbstractRowTensor(const AbstractBasePtr & element,const BaseShapePtr & shape)2255 AbstractRowTensor::AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape)
2256     : AbstractUndetermined(element, shape) {}
2257 
AbstractRowTensor(const TypePtr & element_type,const ShapeVector & shape)2258 AbstractRowTensor::AbstractRowTensor(const TypePtr &element_type, const ShapeVector &shape)
2259     : AbstractUndetermined(element_type, shape) {}
2260 
indices() const2261 const AbstractTensorPtr AbstractRowTensor::indices() const { return indices_; }
2262 
set_indices(const AbstractTensorPtr & indices)2263 void AbstractRowTensor::set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
2264 
values() const2265 const AbstractTensorPtr AbstractRowTensor::values() const { return values_; }
2266 
set_values(const AbstractTensorPtr & values)2267 void AbstractRowTensor::set_values(const AbstractTensorPtr &values) { values_ = values; }
2268 
dense_shape() const2269 const AbstractTuplePtr AbstractRowTensor::dense_shape() const { return dense_shape_; }
2270 
set_dense_shape(const AbstractTuplePtr & dense_shape)2271 void AbstractRowTensor::set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
2272 
AbstractCOOTensor(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2273 AbstractCOOTensor::AbstractCOOTensor(AbstractBasePtrList &&elements,
2274                                      const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2275     : AbstractSparseTensor(std::move(elements), tuple_nodes) {}
2276 
AbstractCOOTensor(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2277 AbstractCOOTensor::AbstractCOOTensor(const AbstractBasePtrList &elements,
2278                                      const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2279     : AbstractSparseTensor(elements, tuple_nodes) {}
2280 
AbstractCSRTensor(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2281 AbstractCSRTensor::AbstractCSRTensor(AbstractBasePtrList &&elements,
2282                                      const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2283     : AbstractSparseTensor(std::move(elements), tuple_nodes) {}
2284 
AbstractCSRTensor(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2285 AbstractCSRTensor::AbstractCSRTensor(const AbstractBasePtrList &elements,
2286                                      const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2287     : AbstractSparseTensor(elements, tuple_nodes) {}
2288 
hash() const2289 std::size_t AbstractMonad::hash() const { return hash_combine({tid()}); }
2290 
BuildType() const2291 TypePtr AbstractMonad::BuildType() const { return GetTypeTrack(); }
2292 
Broaden() const2293 AbstractBasePtr AbstractMonad::Broaden() const { return AbstractBase::Broaden(); }
2294 
ToString() const2295 std::string AbstractMonad::ToString() const {
2296   std::ostringstream buffer;
2297   buffer << type_name() << "(" << GetValueTrack()->ToString() << ")";
2298   return buffer.str();
2299 }
2300 
AbstractMonad(const ValuePtr & value,const TypePtr & type)2301 AbstractMonad::AbstractMonad(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
2302 
AbstractUMonad(const ValuePtr & value)2303 AbstractUMonad::AbstractUMonad(const ValuePtr &value) : AbstractMonad(value, kUMonadType) {}
2304 
Clone() const2305 AbstractBasePtr AbstractUMonad::Clone() const { return std::make_shared<AbstractUMonad>(GetValueTrack()); }
2306 
AbstractIOMonad(const ValuePtr & value)2307 AbstractIOMonad::AbstractIOMonad(const ValuePtr &value) : AbstractMonad(value, kIOMonadType) {}
2308 
Clone() const2309 AbstractBasePtr AbstractIOMonad::Clone() const { return std::make_shared<AbstractIOMonad>(GetValueTrack()); }
2310 
map_tensor_type() const2311 MapTensorTypePtr AbstractMapTensor::map_tensor_type() const { return dyn_cast<MapTensorType>(GetTypeTrack()); }
2312 
shape() const2313 ShapePtr AbstractMapTensor::shape() const { return dyn_cast<Shape>(GetShapeTrack()); }
2314 
value_shape() const2315 const ShapePtr &AbstractMapTensor::value_shape() const { return value_shape_; }
2316 
ref_key_value() const2317 const ValuePtr &AbstractMapTensor::ref_key_value() const { return ref_key_value_; }
2318 
default_value() const2319 const ValuePtr &AbstractMapTensor::default_value() const { return default_value_; }
2320 
permit_filter_value() const2321 const ValuePtr &AbstractMapTensor::permit_filter_value() const { return permit_filter_value_; }
2322 
evict_filter_value() const2323 const ValuePtr &AbstractMapTensor::evict_filter_value() const { return evict_filter_value_; }
2324 
BuildType() const2325 TypePtr AbstractMapTensor::BuildType() const { return GetTypeTrack(); }
2326 
BuildShape() const2327 BaseShapePtr AbstractMapTensor::BuildShape() const { return GetShapeTrack(); }
2328 
BuildType() const2329 TypePtr AbstractKeywordArg::BuildType() const {
2330   MS_EXCEPTION_IF_NULL(arg_value_);
2331   TypePtr type = arg_value_->BuildType();
2332   return std::make_shared<Keyword>(arg_name_, type);
2333 }
2334 
Clone() const2335 AbstractBasePtr AbstractKeywordArg::Clone() const {
2336   MS_EXCEPTION_IF_NULL(arg_value_);
2337   return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
2338 }
2339 
Broaden() const2340 AbstractBasePtr AbstractKeywordArg::Broaden() const {
2341   MS_EXCEPTION_IF_NULL(arg_value_);
2342   return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
2343 }
2344 
hash() const2345 std::size_t AbstractKeywordArg::hash() const {
2346   MS_EXCEPTION_IF_NULL(arg_value_);
2347   return hash_combine({tid(), std::hash<std::string>{}(arg_name_), arg_value_->hash()});
2348 }
2349 
ToString() const2350 std::string AbstractKeywordArg::ToString() const {
2351   std::ostringstream buffer;
2352   MS_EXCEPTION_IF_NULL(arg_value_);
2353   buffer << type_name() << "(";
2354   buffer << "key: " << arg_name_;
2355   buffer << ", value: " << arg_value_->ToString();
2356   buffer << ")";
2357   return buffer.str();
2358 }
2359 
operator ==(const AbstractBase & other) const2360 bool AbstractKeywordArg::operator==(const AbstractBase &other) const {
2361   if (this == &other) {
2362     return true;
2363   }
2364   if (!other.isa<AbstractKeywordArg>()) {
2365     return false;
2366   }
2367   return *this == static_cast<const AbstractKeywordArg &>(other);
2368 }
2369 
operator ==(const AbstractKeywordArg & other) const2370 bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const {
2371   if (this == &other) {
2372     return true;
2373   }
2374   return other.arg_name_ == arg_name_ && IsEqual(other.arg_value_, arg_value_);
2375 }
2376 
RealBuildValue() const2377 ValuePtr AbstractKeywordArg::RealBuildValue() const {
2378   MS_EXCEPTION_IF_NULL(arg_value_);
2379   ValuePtr value = arg_value_->BuildValue();
2380   MS_EXCEPTION_IF_NULL(value);
2381   if (value->isa<ValueAny>()) {
2382     return kValueAny;
2383   }
2384   return std::make_shared<KeywordArg>(arg_name_, value);
2385 }
2386 
AbstractBasePtrListHash(const AbstractBasePtrList & args_abs_list)2387 std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_abs_list) {
2388   // Hash for empty list is zero.
2389   if (args_abs_list.empty()) {
2390     return 0;
2391   }
2392   // Hashing all elements is costly, we only calculate hash from
2393   // the first element and last few elements base on some experiments.
2394   // In some scenarios, this may lead high hash conflicts. Therefore,
2395   // we should use this hash function in hash tables that can tolerate
2396   // high hash conflicts, such as std::unordered_map.
2397   constexpr size_t kMaxLastElements = 4;
2398   const size_t n_args = args_abs_list.size();
2399   // Hash from list size and the first element.
2400   std::size_t hash_value = hash_combine(n_args, args_abs_list[0]->hash());
2401   // Hash from last few elements.
2402   const size_t start = ((n_args > kMaxLastElements) ? (n_args - kMaxLastElements) : 1);
2403   for (size_t i = start; i < n_args; ++i) {
2404     hash_value = hash_combine(hash_value, args_abs_list[i]->hash());
2405   }
2406   return hash_value;
2407 }
2408 
AbstractBasePtrListDeepEqual(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs)2409 bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
2410   const std::size_t size = lhs.size();
2411   if (size != rhs.size()) {
2412     return false;
2413   }
2414   for (std::size_t i = 0; i < size; ++i) {
2415     if (!IsEqual(lhs[i], rhs[i])) {
2416       return false;
2417     }
2418   }
2419   return true;
2420 }
2421 
2422 // SparseTensor
2423 template <typename T>
GetAbsPtrAt(size_t index) const2424 const T AbstractSparseTensor::GetAbsPtrAt(size_t index) const {
2425   if (index >= size()) {
2426     MS_LOG(EXCEPTION) << "Index should be in range of [0, " << size() << "), but got " << index
2427                       << " for abstract: " << type_name();
2428   }
2429   AbstractBasePtr base = elements()[index];
2430   MS_EXCEPTION_IF_NULL(base);
2431   return base->cast<T>();
2432 }
2433 
ElementsShapeTupleRecursive() const2434 BaseShapePtrList AbstractSparseTensor::ElementsShapeTupleRecursive() const {
2435   BaseShapePtrList element_shape_list;
2436   for (const auto &element : elements()) {
2437     MS_EXCEPTION_IF_NULL(element);
2438     auto abs_tuple = element->cast_ptr<AbstractTuple>();
2439     if (abs_tuple == nullptr) {
2440       element_shape_list.push_back(element->GetShape());
2441     } else {
2442       for (const auto &scalar : abs_tuple->elements()) {
2443         MS_EXCEPTION_IF_NULL(scalar);
2444         element_shape_list.push_back(scalar->GetShape());
2445       }
2446     }
2447   }
2448   return element_shape_list;
2449 }
2450 
GetTensorTypeIdAt(size_t index) const2451 const TypeId AbstractSparseTensor::GetTensorTypeIdAt(size_t index) const {
2452   size_t shape_idx = size() - 1;
2453   if (index >= shape_idx) {
2454     MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape_idx << "), but got " << index << " for "
2455                       << ToString();
2456   }
2457   auto abs_tensor = GetAbsPtrAt<abstract::AbstractTensorPtr>(index);
2458   MS_EXCEPTION_IF_NULL(abs_tensor);
2459   return abs_tensor->element()->BuildType()->type_id();
2460 }
2461 
GetShapeTypeIdAt(size_t index) const2462 const TypeId AbstractSparseTensor::GetShapeTypeIdAt(size_t index) const {
2463   if (index >= shape()->size()) {
2464     MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape()->size() << "), but got " << index << " for "
2465                       << ToString();
2466   }
2467   return shape()->elements()[index]->BuildType()->type_id();
2468 }
2469 
BuildType() const2470 TypePtr AbstractSparseTensor::BuildType() const { return std::make_shared<SparseTensorType>(); }
2471 
shape() const2472 const AbstractTuplePtr AbstractSparseTensor::shape() const {
2473   auto res = GetAbsPtrAt<abstract::AbstractTuplePtr>(size() - 1);
2474   if (res == nullptr) {
2475     MS_LOG(INTERNAL_EXCEPTION) << "Get shape nullptr in AbstractSparseTensor: " << ToString();
2476   }
2477   return res;
2478 }
2479 
2480 // RowTensor
BuildType() const2481 TypePtr AbstractRowTensor::BuildType() const {
2482   MS_EXCEPTION_IF_NULL(element());
2483   TypePtr element_type = element()->BuildType();
2484   return std::make_shared<RowTensorType>(element_type);
2485 }
2486 
Clone() const2487 AbstractBasePtr AbstractRowTensor::Clone() const {
2488   MS_EXCEPTION_IF_NULL(element());
2489   auto clone = std::make_shared<AbstractRowTensor>(element()->Clone());
2490   ShapePtr shp = shape();
2491   MS_EXCEPTION_IF_NULL(shp);
2492   clone->set_shape(shp->Clone());
2493   clone->set_value(GetValueTrack());
2494   MS_EXCEPTION_IF_NULL(indices_);
2495   MS_EXCEPTION_IF_NULL(values_);
2496   MS_EXCEPTION_IF_NULL(dense_shape_);
2497   auto indices_clone = indices_->Clone();
2498   auto value_clone = values_->Clone();
2499   auto dense_clone = dense_shape_->Clone();
2500   MS_EXCEPTION_IF_NULL(indices_clone);
2501   MS_EXCEPTION_IF_NULL(value_clone);
2502   MS_EXCEPTION_IF_NULL(dense_clone);
2503   clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
2504   clone->set_values(value_clone->cast<AbstractTensorPtr>());
2505   clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
2506   return clone;
2507 }
2508 
MakeAbstract(const BaseShapePtr & shp) const2509 AbstractRowTensorPtr AbstractRowTensor::MakeAbstract(const BaseShapePtr &shp) const {
2510   MS_EXCEPTION_IF_NULL(element());
2511   auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
2512   broaden->set_shape(shp);
2513   broaden->set_value(kValueAny);
2514   MS_EXCEPTION_IF_NULL(indices_);
2515   MS_EXCEPTION_IF_NULL(values_);
2516   MS_EXCEPTION_IF_NULL(dense_shape_);
2517   auto indices_clone = indices_->Clone();
2518   auto value_clone = values_->Clone();
2519   auto dense_clone = dense_shape_->Clone();
2520   MS_EXCEPTION_IF_NULL(indices_clone);
2521   MS_EXCEPTION_IF_NULL(value_clone);
2522   MS_EXCEPTION_IF_NULL(dense_clone);
2523   broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
2524   broaden->set_values(value_clone->cast<AbstractTensorPtr>());
2525   broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
2526   return broaden;
2527 }
2528 
Broaden() const2529 AbstractBasePtr AbstractRowTensor::Broaden() const {
2530   auto shp = shape()->Clone();
2531   MS_EXCEPTION_IF_NULL(shp);
2532   return MakeAbstract(shp);
2533 }
2534 
BroadenWithShape() const2535 AbstractBasePtr AbstractRowTensor::BroadenWithShape() const {
2536   auto shp = shape()->Clone();
2537   MS_EXCEPTION_IF_NULL(shp);
2538   shp->Broaden();
2539   return MakeAbstract(shp);
2540 }
2541 
ToString() const2542 std::string AbstractRowTensor::ToString() const {
2543   std::ostringstream buffer;
2544   BaseShapePtr shape_track = GetShapeTrack();
2545   MS_EXCEPTION_IF_NULL(shape_track);
2546   MS_EXCEPTION_IF_NULL(element());
2547   auto value_track = GetValueTrack();
2548   MS_EXCEPTION_IF_NULL(value_track);
2549   MS_EXCEPTION_IF_NULL(indices_);
2550   MS_EXCEPTION_IF_NULL(values_);
2551   MS_EXCEPTION_IF_NULL(dense_shape_);
2552   buffer << type_name() << "("
2553          << "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
2554          << ", value_ptr: " << value_track << ", value: " << value_track->ToString()
2555          << ", indices: " << indices_->ToString() << ", values: " << values_->ToString()
2556          << ", dense_shape: " << dense_shape_->ToString() << ")";
2557   return buffer.str();
2558 }
2559 
2560 // COOTensor
BuildType() const2561 TypePtr AbstractCOOTensor::BuildType() const {
2562   MS_EXCEPTION_IF_NULL(indices());
2563   MS_EXCEPTION_IF_NULL(values());
2564   MS_EXCEPTION_IF_NULL(shape());
2565   TypePtrList elements{indices()->element()->BuildType(), values()->element()->BuildType()};
2566   (void)std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
2567                        [](const AbstractBasePtr &p) { return p->BuildType(); });
2568   return std::make_shared<COOTensorType>(elements);
2569 }
2570 
Clone() const2571 AbstractBasePtr AbstractCOOTensor::Clone() const {
2572   AbstractBasePtrList element_list;
2573   for (const auto &element : elements()) {
2574     MS_EXCEPTION_IF_NULL(element);
2575     AbstractBasePtr clone = element->Clone();
2576     element_list.push_back(clone);
2577   }
2578   return std::make_shared<abstract::AbstractCOOTensor>(element_list);
2579 }
2580 
Broaden() const2581 AbstractBasePtr AbstractCOOTensor::Broaden() const {
2582   return std::make_shared<abstract::AbstractCOOTensor>(ElementsBroaden());
2583 }
2584 
PartialBroaden() const2585 AbstractBasePtr AbstractCOOTensor::PartialBroaden() const { return Broaden(); }
2586 
ToString() const2587 std::string AbstractCOOTensor::ToString() const {
2588   std::ostringstream buffer;
2589   auto indices_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2590   auto values_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2591   auto shape_ = GetAbsPtrAt<abstract::AbstractTuplePtr>(size() - 1);
2592   MS_EXCEPTION_IF_NULL(indices_);
2593   MS_EXCEPTION_IF_NULL(values_);
2594   MS_EXCEPTION_IF_NULL(shape_);
2595   buffer << type_name() << "("
2596          << "indices: " << indices_->ToString() << ", values" << values_->ToString()
2597          << ", dense_shape: " << shape_->ToString();
2598   return buffer.str();
2599 }
2600 
indices() const2601 const AbstractTensorPtr AbstractCOOTensor::indices() const {
2602   auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2603   if (res == nullptr) {
2604     MS_LOG(INTERNAL_EXCEPTION) << "Get indices nullptr in AbstractCOOTensor: " << ToString();
2605   }
2606   return res;
2607 }
2608 
values() const2609 const AbstractTensorPtr AbstractCOOTensor::values() const {
2610   auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2611   if (res == nullptr) {
2612     MS_LOG(INTERNAL_EXCEPTION) << "Get values nullptr in AbstractCOOTensor: " << ToString();
2613   }
2614   return res;
2615 }
2616 
2617 // CSRTensor
BuildType() const2618 TypePtr AbstractCSRTensor::BuildType() const {
2619   MS_EXCEPTION_IF_NULL(indptr());
2620   MS_EXCEPTION_IF_NULL(indices());
2621   MS_EXCEPTION_IF_NULL(values());
2622   MS_EXCEPTION_IF_NULL(shape());
2623   TypePtrList elements{indptr()->element()->BuildType(), indices()->element()->BuildType(),
2624                        values()->element()->BuildType()};
2625   (void)std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
2626                        [](const AbstractBasePtr &p) { return p->BuildType(); });
2627   return std::make_shared<CSRTensorType>(elements);
2628 }
2629 
Clone() const2630 AbstractBasePtr AbstractCSRTensor::Clone() const {
2631   AbstractBasePtrList element_list;
2632   for (const auto &element : elements()) {
2633     MS_EXCEPTION_IF_NULL(element);
2634     AbstractBasePtr clone = element->Clone();
2635     element_list.push_back(clone);
2636   }
2637   return std::make_shared<abstract::AbstractCSRTensor>(element_list);
2638 }
2639 
Broaden() const2640 AbstractBasePtr AbstractCSRTensor::Broaden() const {
2641   return std::make_shared<abstract::AbstractCSRTensor>(ElementsBroaden());
2642 }
2643 
PartialBroaden() const2644 AbstractBasePtr AbstractCSRTensor::PartialBroaden() const { return Broaden(); }
2645 
ToString() const2646 std::string AbstractCSRTensor::ToString() const {
2647   std::ostringstream buffer;
2648   auto indptr_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndptrIdx);
2649   auto indices_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2650   auto values_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2651   auto shape_ = GetAbsPtrAt<abstract::AbstractTuplePtr>(size() - 1);
2652   MS_EXCEPTION_IF_NULL(indptr_);
2653   MS_EXCEPTION_IF_NULL(indices_);
2654   MS_EXCEPTION_IF_NULL(values_);
2655   MS_EXCEPTION_IF_NULL(shape_);
2656   buffer << type_name() << "("
2657          << "indptr: " << indptr_->ToString() << ", indices: " << indices_->ToString() << ", values"
2658          << values_->ToString() << ", dense_shape: " << shape_->ToString();
2659   return buffer.str();
2660 }
2661 
indptr() const2662 const AbstractTensorPtr AbstractCSRTensor::indptr() const {
2663   auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndptrIdx);
2664   if (res == nullptr) {
2665     MS_LOG(INTERNAL_EXCEPTION) << "Get indptr nullptr in AbstractCSRTensor: " << ToString();
2666   }
2667   return res;
2668 }
2669 
indices() const2670 const AbstractTensorPtr AbstractCSRTensor::indices() const {
2671   auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2672   if (res == nullptr) {
2673     MS_LOG(INTERNAL_EXCEPTION) << "Get indices nullptr in AbstractCSRTensor: " << ToString();
2674   }
2675   return res;
2676 }
2677 
values() const2678 const AbstractTensorPtr AbstractCSRTensor::values() const {
2679   auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2680   if (res == nullptr) {
2681     MS_LOG(INTERNAL_EXCEPTION) << "Get values nullptr in AbstractCSRTensor: " << ToString();
2682   }
2683   return res;
2684 }
2685 
AbstractMapTensor(const MapTensorPtr & map_tensor)2686 AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor)
2687     : AbstractBase(map_tensor, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
2688                    std::make_shared<Shape>(map_tensor->shape())),
2689       ref_key_value_(kValueAny),
2690       default_value_(map_tensor->default_value()),
2691       permit_filter_value_(map_tensor->permit_filter_value()),
2692       evict_filter_value_(map_tensor->evict_filter_value()),
2693       value_shape_(std::make_shared<Shape>(map_tensor->value_shape())) {}
2694 
AbstractMapTensor(const MapTensorPtr & map_tensor,const ValuePtr & ref_key_value)2695 AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor, const ValuePtr &ref_key_value)
2696     : AbstractBase(kValueAny, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
2697                    std::make_shared<Shape>(map_tensor->shape())),
2698       ref_key_value_(ref_key_value),
2699       default_value_(map_tensor->default_value()),
2700       permit_filter_value_(map_tensor->permit_filter_value()),
2701       evict_filter_value_(map_tensor->evict_filter_value()),
2702       value_shape_(std::make_shared<Shape>(map_tensor->value_shape())) {}
2703 
AbstractMapTensor(const AbstractMapTensor & other)2704 AbstractMapTensor::AbstractMapTensor(const AbstractMapTensor &other)
2705     : AbstractBase(other),
2706       ref_key_value_(other.ref_key_value_),
2707       default_value_(other.default_value_),
2708       permit_filter_value_(other.permit_filter_value()),
2709       evict_filter_value_(other.evict_filter_value()),
2710       value_shape_(other.value_shape_) {
2711   set_shape(other.shape());
2712 }
2713 
AbstractMapTensor(const TypePtr & type,const ShapePtr & value_shape,const ValuePtr & value,const ValuePtr & ref_key_value,const ValuePtr & default_value,const ValuePtr & permit_filter_value,const ValuePtr & evict_filter_value)2714 AbstractMapTensor::AbstractMapTensor(const TypePtr &type, const ShapePtr &value_shape, const ValuePtr &value,
2715                                      const ValuePtr &ref_key_value, const ValuePtr &default_value,
2716                                      const ValuePtr &permit_filter_value, const ValuePtr &evict_filter_value) {
2717   set_value(value);
2718   set_type(type);
2719   ref_key_value_ = ref_key_value;
2720   default_value_ = default_value;
2721   permit_filter_value_ = permit_filter_value;
2722   evict_filter_value_ = evict_filter_value;
2723   ShapeVector shape = {abstract::Shape::kShapeDimAny};
2724   (void)shape.insert(shape.end(), value_shape->shape().begin(), value_shape->shape().end());
2725   set_shape(std::make_shared<mindspore::abstract::Shape>(shape));
2726 }
2727 
operator =(const AbstractMapTensor & other)2728 AbstractMapTensor &AbstractMapTensor::operator=(const AbstractMapTensor &other) {
2729   if (this == &other) {
2730     return *this;
2731   }
2732   this->ref_key_value_ = other.ref_key_value();
2733   this->default_value_ = other.default_value();
2734   this->permit_filter_value_ = other.permit_filter_value();
2735   this->evict_filter_value_ = other.evict_filter_value();
2736   this->value_shape_ = other.value_shape_;
2737   this->set_shape(other.shape());
2738   return *this;
2739 }
2740 
Clone() const2741 AbstractBasePtr AbstractMapTensor::Clone() const { return std::make_shared<AbstractMapTensor>(*this); }
2742 
Join(const AbstractBasePtr & other)2743 AbstractBasePtr AbstractMapTensor::Join(const AbstractBasePtr &other) {
2744   MS_EXCEPTION_IF_NULL(other);
2745   if (other->isa<AbstractNegligible>()) {
2746     return shared_from_base<AbstractBase>();
2747   }
2748   // Same pointer.
2749   if (this == other.get()) {
2750     return shared_from_base<AbstractMapTensor>();
2751   }
2752 
2753   // Check class.
2754   auto other_abs = dyn_cast<AbstractMapTensor>(other);
2755   if (other_abs == nullptr) {
2756     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
2757   }
2758 
2759   // Join type.
2760   auto joined_type = TypeJoin(GetTypeTrack(), other_abs->GetTypeTrack());
2761   if (joined_type == kTypeAny) {
2762     TypeJoinLogging(GetTypeTrack(), other_abs->GetTypeTrack(), shared_from_base<AbstractBase>(), other);
2763   }
2764 
2765   // Join shape
2766   auto joined_shape = ShapeJoin(value_shape(), other_abs->value_shape());
2767   if (joined_shape == nullptr) {
2768     ShapeJoinLogging(value_shape(), other_abs->value_shape(), shared_from_base<AbstractBase>(), other);
2769   }
2770 
2771   // Join value.
2772   auto joined_value = (GetValueTrack() == other_abs->GetValueTrack() ? GetValueTrack() : kValueAny);
2773 
2774   // Join the ref_key_value.
2775   auto joined_ref_key = ValueJoin(ref_key_value_, other_abs->ref_key_value_);
2776 
2777   // Join the default_value.
2778   auto joined_default_value = ValueJoin(default_value_, other_abs->default_value_);
2779   if (joined_default_value->ContainsValueAny()) {
2780     MS_EXCEPTION(ValueError) << "Join default value failed for MapTensor. " << default_value_->ToString()
2781                              << " != " << other_abs->default_value_->ToString();
2782   }
2783 
2784   // Join the permit_filter_value.
2785   auto joined_permit_filter_value = ValueJoin(permit_filter_value_, other_abs->permit_filter_value_);
2786   if (joined_permit_filter_value->ContainsValueAny()) {
2787     MS_EXCEPTION(ValueError) << "Join default value failed for MapTensor. " << permit_filter_value_->ToString()
2788                              << " != " << other_abs->permit_filter_value_->ToString();
2789   }
2790 
2791   // Join the evict_filter_value.
2792   auto joined_evict_filter_value = ValueJoin(evict_filter_value_, other_abs->evict_filter_value_);
2793   if (joined_evict_filter_value->ContainsValueAny()) {
2794     MS_EXCEPTION(ValueError) << "Join evict_filter_value failed for MapTensor. " << evict_filter_value_->ToString()
2795                              << " != " << other_abs->evict_filter_value_->ToString();
2796   }
2797 
2798   return std::make_shared<AbstractMapTensor>(joined_type, joined_shape, joined_value, joined_ref_key,
2799                                              joined_default_value, joined_permit_filter_value,
2800                                              joined_evict_filter_value);
2801 }
2802 
operator ==(const AbstractBase & other) const2803 bool AbstractMapTensor::operator==(const AbstractBase &other) const {
2804   if (this == &other) {
2805     return true;
2806   }
2807   if (!other.isa<AbstractMapTensor>()) {
2808     return false;
2809   }
2810   const auto &v1 = GetValueTrack();
2811   const auto &v2 = other.GetValueTrack();
2812   MS_EXCEPTION_IF_NULL(v1);
2813   MS_EXCEPTION_IF_NULL(v2);
2814   // Check if both point to same specific value.
2815   if (!v1->isa<ValueAny>()) {
2816     return v1 == v2;
2817   }
2818   // Check if both are ValueAny.
2819   if (!v2->isa<ValueAny>()) {
2820     return false;
2821   }
2822   const auto &other_map_tensor = dynamic_cast<const AbstractMapTensor &>(other);
2823   return common::IsEqual(GetTypeTrack(), other_map_tensor.GetTypeTrack()) &&
2824          common::IsEqual(GetShapeTrack(), other_map_tensor.GetShapeTrack()) &&
2825          common::IsEqual(default_value(), other_map_tensor.default_value());
2826 }
2827 
hash() const2828 std::size_t AbstractMapTensor::hash() const {
2829   const auto &type = GetTypeTrack();
2830   const auto &value_shape = GetShapeTrack();
2831   MS_EXCEPTION_IF_NULL(type);
2832   MS_EXCEPTION_IF_NULL(value_shape);
2833   MS_EXCEPTION_IF_NULL(default_value_);
2834   std::size_t hash_value = hash_combine(tid(), type->hash());
2835   hash_value = hash_combine(hash_value, value_shape->hash());
2836   return hash_combine(hash_value, default_value_->hash());
2837 }
2838 
ToString() const2839 std::string AbstractMapTensor::ToString() const {
2840   const auto &type = GetTypeTrack();
2841   const auto &value = GetValueTrack();
2842   const auto &value_shape = GetShapeTrack();
2843   MS_EXCEPTION_IF_NULL(type);
2844   MS_EXCEPTION_IF_NULL(value);
2845   MS_EXCEPTION_IF_NULL(value_shape);
2846   return type_name() + "(" + type->ToString() + " " + value_shape->ToString() +
2847          " key: " + (ref_key_value_ == nullptr ? "<null>" : ref_key_value_->ToString()) +
2848          " value: " + value->ToString() + ")";
2849 }
2850 
Join(const AbstractBasePtr & other)2851 AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
2852   MS_EXCEPTION_IF_NULL(other);
2853   if (!other->isa<AbstractUMonad>()) {
2854     auto this_type = GetTypeTrack();
2855     auto other_type = other->GetTypeTrack();
2856     MS_EXCEPTION_IF_NULL(this_type);
2857     MS_EXCEPTION_IF_NULL(other);
2858     TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
2859   }
2860   return shared_from_base<AbstractBase>();
2861 }
2862 
operator ==(const AbstractBase & other) const2863 bool AbstractUMonad::operator==(const AbstractBase &other) const { return other.isa<AbstractUMonad>(); }
2864 
Join(const AbstractBasePtr & other)2865 AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) {
2866   MS_EXCEPTION_IF_NULL(other);
2867   if (!other->isa<AbstractIOMonad>()) {
2868     auto this_type = GetTypeTrack();
2869     auto other_type = other->GetTypeTrack();
2870     MS_EXCEPTION_IF_NULL(this_type);
2871     MS_EXCEPTION_IF_NULL(other);
2872     TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
2873   }
2874   return shared_from_base<AbstractBase>();
2875 }
2876 
operator ==(const AbstractBase & other) const2877 bool AbstractIOMonad::operator==(const AbstractBase &other) const { return other.isa<AbstractIOMonad>(); }
2878 
GetRefKeyValue(const AbstractBasePtr & abs)2879 ValuePtr GetRefKeyValue(const AbstractBasePtr &abs) {
2880   auto abs_ref = abs->cast_ptr<AbstractRefTensor>();
2881   if (abs_ref != nullptr) {
2882     return abs_ref->ref_key_value();
2883   }
2884   auto abs_map_tensor = abs->cast_ptr<AbstractMapTensor>();
2885   if (abs_map_tensor != nullptr) {
2886     return abs_map_tensor->ref_key_value();
2887   }
2888   return nullptr;
2889 }
2890 }  // namespace abstract
2891 }  // namespace mindspore
2892