1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/abstract_value.h"
20
21 #include <regex>
22 #include <algorithm>
23 #include <utility>
24
25 #include "ir/value.h"
26 #include "utils/hash_map.h"
27 #include "utils/hashing.h"
28 #include "utils/log_adapter.h"
29 #include "utils/ms_utils.h"
30 #include "abstract/utils.h"
31 #include "utils/ms_context.h"
32 #include "utils/trace_base.h"
33 #include "utils/compile_config.h"
34
35 namespace mindspore {
36 namespace abstract {
37 using mindspore::common::IsEqual;
38
AbstractBase(const ValuePtr & value,const TypePtr & type,const BaseShapePtr & shape)39 AbstractBase::AbstractBase(const ValuePtr &value, const TypePtr &type, const BaseShapePtr &shape)
40 : value_(value), type_(type), shape_(shape) {}
41
AbstractBase(const AbstractBase & other)42 AbstractBase::AbstractBase(const AbstractBase &other)
43 : Base(other),
44 value_(other.value_),
45 type_(other.type_),
46 shape_(other.shape_),
47 value_desc_(other.value_desc_),
48 name_(other.name_),
49 inplace_abstract_(other.inplace_abstract_) {
50 user_data_ = other.user_data_;
51 }
52
operator =(const AbstractBase & other)53 AbstractBase &AbstractBase::operator=(const AbstractBase &other) {
54 if (&other != this) {
55 this->value_ = other.value_;
56 this->type_ = other.type_;
57 this->shape_ = other.shape_;
58 this->user_data_ = other.user_data_;
59 this->value_desc_ = other.value_desc_;
60 this->name_ = other.name_;
61 this->inplace_abstract_ = other.inplace_abstract_;
62 }
63 return *this;
64 }
65
66 AbstractBase::TraceNodeProvider AbstractBase::trace_node_provider_ = nullptr;
67
JoinSupplementaryInfo(const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)68 std::string JoinSupplementaryInfo(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) {
69 std::ostringstream oss;
70 oss << "#dmsg#Framework Error Message:#dmsg#This: " << abstract1->ToString() << ", other: " << abstract2->ToString();
71 // Get trace info of node.
72 AnfNodePtr node = nullptr;
73 if (AbstractBase::trace_node_provider_ != nullptr) {
74 AbstractBase::trace_node_provider_(&node);
75 }
76 if (node != nullptr) {
77 oss << ". Please check the node: " << node->DebugString() << trace::DumpSourceLines(node);
78 }
79 return oss.str();
80 }
81
AbstractTypeJoinLogging(const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)82 inline void AbstractTypeJoinLogging(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) {
83 std::ostringstream oss;
84 oss << "Type Join Failed: Abstract type " << abstract1->type_name() << " cannot join with " << abstract2->type_name()
85 << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed\n";
86 oss << JoinSupplementaryInfo(abstract1, abstract2);
87 MS_EXCEPTION(TypeError) << oss.str();
88 }
89
TypeJoinLogging(const TypePtr & type1,const TypePtr & type2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)90 inline void TypeJoinLogging(const TypePtr &type1, const TypePtr &type2, const AbstractBasePtr &abstract1,
91 const AbstractBasePtr &abstract2) {
92 std::ostringstream oss;
93 oss << "Type Join Failed: dtype1 = " << type1->ToString() << ", dtype2 = " << type2->ToString()
94 << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed\n";
95 oss << JoinSupplementaryInfo(abstract1, abstract2);
96 MS_EXCEPTION(TypeError) << oss.str();
97 }
98
ShapeJoinLogging(const BaseShapePtr & shape1,const BaseShapePtr & shape2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)99 inline void ShapeJoinLogging(const BaseShapePtr &shape1, const BaseShapePtr &shape2, const AbstractBasePtr &abstract1,
100 const AbstractBasePtr &abstract2) {
101 std::ostringstream oss;
102 oss << "Shape Join Failed: shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString()
103 << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Shape%20Join%20Failed\n";
104 oss << JoinSupplementaryInfo(abstract1, abstract2);
105 MS_EXCEPTION(ValueError) << oss.str();
106 }
107
ExtractLoggingInfo(const std::string & info)108 std::string ExtractLoggingInfo(const std::string &info) {
109 // Extract log information based on the keyword "Type Join Failed" or "Shape Join Failed"
110 std::regex e("(Type Join Failed|Shape Join Failed).*?\n.*?(Type%20Join%20Failed|Shape%20Join%20Failed)");
111 std::smatch result;
112 bool found = std::regex_search(info, result, e);
113 if (found) {
114 return result.str();
115 }
116 return "";
117 }
118
IsUndeterminedType(const TypePtr & type)119 static inline bool IsUndeterminedType(const TypePtr &type) {
120 return (type != nullptr) && (type->type_id() == kObjectTypeUndeterminedType);
121 }
122
set_value(const ValuePtr & value)123 void AbstractBase::set_value(const ValuePtr &value) {
124 MS_EXCEPTION_IF_NULL(value);
125 value_ = value;
126 }
127
set_type(const TypePtr & type)128 void AbstractBase::set_type(const TypePtr &type) {
129 MS_EXCEPTION_IF_NULL(type);
130 type_ = type;
131 }
132
set_shape(const BaseShapePtr & shape)133 void AbstractBase::set_shape(const BaseShapePtr &shape) {
134 MS_EXCEPTION_IF_NULL(shape);
135 shape_ = shape;
136 }
137
value_desc() const138 const std::string &AbstractBase::value_desc() const { return value_desc_; }
139
hash() const140 std::size_t AbstractBase::hash() const { return tid(); }
141
set_value_desc(const std::string & desc)142 void AbstractBase::set_value_desc(const std::string &desc) { value_desc_ = desc; }
143
GetTypeTrack() const144 const TypePtr &AbstractBase::GetTypeTrack() const { return type_; }
145
GetValueTrack() const146 const ValuePtr &AbstractBase::GetValueTrack() const { return value_; }
147
GetShapeTrack() const148 const BaseShapePtr &AbstractBase::GetShapeTrack() const { return shape_; }
149
BuildShape() const150 BaseShapePtr AbstractBase::BuildShape() const { return kNoShape; }
151
GetShape() const152 BaseShapePtr AbstractBase::GetShape() const { return BuildShape(); }
153
GetType() const154 TypePtr AbstractBase::GetType() const { return BuildType(); }
155
GetValue() const156 ValuePtr AbstractBase::GetValue() const { return BuildValue(); }
157
set_trace_node_provider(const TraceNodeProvider & trace_node_provider)158 void AbstractBase::set_trace_node_provider(const TraceNodeProvider &trace_node_provider) {
159 trace_node_provider_ = trace_node_provider;
160 }
161
Join(const AbstractBasePtr & other)162 AbstractBasePtr AbstractBase::Join(const AbstractBasePtr &other) {
163 MS_EXCEPTION_IF_NULL(other);
164 return shared_from_base<AbstractBase>();
165 }
166
IsBroaden() const167 bool AbstractBase::IsBroaden() const { return value_->ContainsValueAny(); }
168
operator ==(const AbstractBase & other) const169 bool AbstractBase::operator==(const AbstractBase &other) const {
170 if (this == &other) {
171 // Same object.
172 return true;
173 }
174 // Check C++ type.
175 if (tid() != other.tid()) {
176 return false;
177 }
178 // If both are "undetermined" type, they are considered equal.
179 if (IsUndeterminedType(BuildType()) && IsUndeterminedType(other.BuildType())) {
180 return true;
181 }
182 // Check data type, shape and value.
183 return IsEqual(type_, other.type_) && IsEqual(shape_, other.shape_) && IsEqual(value_, other.value_);
184 }
185
BuildValue() const186 ValuePtr AbstractBase::BuildValue() const {
187 if (value_ == nullptr) {
188 return RealBuildValue();
189 }
190 return value_;
191 }
192
Broaden() const193 AbstractBasePtr AbstractBase::Broaden() const {
194 AbstractBasePtr clone = Clone();
195 MS_EXCEPTION_IF_NULL(clone);
196 clone->set_value(kValueAny);
197 return clone;
198 }
199
PartialBroaden() const200 AbstractBasePtr AbstractBase::PartialBroaden() const { return Clone(); }
201
ToString() const202 std::string AbstractBase::ToString() const {
203 std::ostringstream buffer;
204 std::string value = std::string("value is null");
205 if (value_ != nullptr) {
206 value = value_->ToString();
207 }
208 MS_EXCEPTION_IF_NULL(type_);
209 MS_EXCEPTION_IF_NULL(shape_);
210 buffer << type_name() << "("
211 << "Type: " << type_->ToString() << ", Value: " << value << ", Shape: " << shape_->ToString() << ")";
212 return buffer.str();
213 }
214
ToString(bool verbose) const215 std::string AbstractBase::ToString(bool verbose) const {
216 if (verbose) {
217 return ToString();
218 }
219 std::ostringstream buffer;
220 auto tensor_value = BuildValue();
221 auto shape = GetShape();
222 auto type = BuildType();
223 if (shape != nullptr && type != nullptr) {
224 buffer << type << ", " << shape->ToString();
225 if (tensor_value != nullptr && !tensor_value->ContainsValueAny()) {
226 buffer << ", value=...";
227 }
228 } else if (type != nullptr) {
229 buffer << type;
230 if (tensor_value != nullptr && !tensor_value->ContainsValueAny()) {
231 buffer << ", value=...";
232 }
233 }
234 return buffer.str();
235 }
236
set_interpret_bool_checker(InterpretBoolChecker checker)237 void AbstractBase::set_interpret_bool_checker(InterpretBoolChecker checker) { interpret_bool_checker_ = checker; }
238
interpret_bool_checker()239 AbstractBase::InterpretBoolChecker AbstractBase::interpret_bool_checker() { return interpret_bool_checker_; }
240
set_pyexecute_user_data_catcher(PyExecuteUserDataCatcher catcher)241 void AbstractBase::set_pyexecute_user_data_catcher(PyExecuteUserDataCatcher catcher) {
242 pyexecute_user_data_catcher_ = catcher;
243 }
244
pyexecute_user_data_catcher()245 AbstractBase::PyExecuteUserDataCatcher AbstractBase::pyexecute_user_data_catcher() {
246 return pyexecute_user_data_catcher_;
247 }
248
name() const249 std::string AbstractBase::name() const { return name_; }
250
set_name(const std::string & name)251 void AbstractBase::set_name(const std::string &name) { name_ = name; }
252
inplace_abstract() const253 AbstractBasePtr AbstractBase::inplace_abstract() const { return inplace_abstract_; }
254
set_inplace_abstract(const AbstractBasePtr & inplace_abstract)255 void AbstractBase::set_inplace_abstract(const AbstractBasePtr &inplace_abstract) {
256 inplace_abstract_ = inplace_abstract;
257 }
258
RealBuildValue() const259 ValuePtr AbstractBase::RealBuildValue() const { return kValueAny; }
260
AbstractScalar()261 AbstractScalar::AbstractScalar() : AbstractBase(kValueAny, kTypeAny) {}
262
AbstractScalar(const ValuePtr & value,const TypePtr & type)263 AbstractScalar::AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
264
AbstractScalar(const ValuePtr & value)265 AbstractScalar::AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {}
266
AbstractScalar(int value)267 AbstractScalar::AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {}
268
AbstractScalar(int64_t value)269 AbstractScalar::AbstractScalar(int64_t value) : AbstractBase(MakeValue(value), kInt64) {}
270
AbstractScalar(float value)271 AbstractScalar::AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {}
272
AbstractScalar(double value)273 AbstractScalar::AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {}
274
AbstractScalar(bool value)275 AbstractScalar::AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {}
276
AbstractScalar(const std::string & value)277 AbstractScalar::AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {}
278
AbstractScalar(const TypePtr & type)279 AbstractScalar::AbstractScalar(const TypePtr &type) : AbstractBase(kValueAny, type) {}
280
Broaden() const281 AbstractBasePtr AbstractScalar::Broaden() const {
282 if (is_variable_) {
283 return AbstractBase::Broaden();
284 }
285 auto context = MsContext::GetInstance();
286 MS_EXCEPTION_IF_NULL(context);
287 if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
288 return AbstractBase::Broaden();
289 }
290 auto type_id = GetTypeTrack()->type_id();
291 if (type_id == kObjectTypeEnvType) {
292 return AbstractBase::Broaden();
293 }
294 return Clone();
295 }
296
Join(const AbstractBasePtr & other)297 AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
298 MS_EXCEPTION_IF_NULL(other);
299 if (*this == *other) {
300 return shared_from_base<AbstractBase>();
301 }
302 const auto &this_type = GetTypeTrack();
303 const auto &other_type = other->GetTypeTrack();
304 TypePtr res_type = TypeJoin(this_type, other_type);
305 if (res_type == kTypeAny) {
306 TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
307 }
308 const auto &this_value = GetValueTrack();
309 const auto &value_other = other->GetValueTrack();
310 if (other->isa<AbstractNegligible>() && !this_value->isa<ValueAny>()) {
311 return std::make_shared<AbstractAny>();
312 }
313
314 ValuePtr res_value = ValueJoin(this_value, value_other);
315 if (res_value == this_value) {
316 return shared_from_base<AbstractBase>();
317 }
318 return std::make_shared<AbstractScalar>(res_value, res_type);
319 }
320
set_is_variable(bool is_variable)321 void AbstractScalar::set_is_variable(bool is_variable) { is_variable_ = is_variable; }
322
hash() const323 std::size_t AbstractScalar::hash() const {
324 return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()});
325 }
326
BuildType() const327 TypePtr AbstractScalar::BuildType() const { return GetTypeTrack(); }
328
Clone() const329 AbstractBasePtr AbstractScalar::Clone() const {
330 auto abs = std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
331 abs->set_is_variable(is_variable_);
332 abs->SetSymbolicShape(this->GetSymbolicShape());
333 abs->SetSymbolicValue(this->GetSymbolicValue());
334 return abs;
335 }
336
Clone() const337 AbstractBasePtr AbstractType::Clone() const {
338 ValuePtr this_value = GetValueTrack();
339 if (this_value == nullptr || !this_value->isa<Type>()) {
340 return nullptr;
341 }
342 auto this_type = this_value->cast_ptr<Type>();
343 return std::make_shared<AbstractType>(this_type->Clone());
344 }
345
operator ==(const AbstractBase & other) const346 bool AbstractType::operator==(const AbstractBase &other) const {
347 if (this == &other) {
348 return true;
349 }
350 return tid() == other.tid() &&
351 IsEqual(dyn_cast_ptr<Type>(GetValueTrack()), dyn_cast_ptr<Type>(other.GetValueTrack()));
352 }
353
ToString() const354 std::string AbstractType::ToString() const {
355 std::ostringstream buffer;
356 ValuePtr this_value = GetValueTrack();
357 if (this_value == nullptr) {
358 buffer << "AbstractType value: nullptr";
359 return buffer.str();
360 }
361 if (!this_value->isa<Type>()) {
362 buffer << type_name() << "(Value: nullptr)";
363 return buffer.str();
364 }
365 auto this_type = this_value->cast_ptr<Type>();
366 buffer << type_name() << "("
367 << "Value: " << this_type->ToString() << ")";
368 return buffer.str();
369 }
370
Join(const AbstractBasePtr & other)371 AbstractBasePtr AbstractClass::Join(const AbstractBasePtr &other) {
372 MS_EXCEPTION_IF_NULL(other);
373 bool success = (*this == *other);
374 if (!success) {
375 const auto &this_type = GetTypeTrack();
376 const auto &other_type = other->GetTypeTrack();
377 TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
378 }
379 return shared_from_base<AbstractBase>();
380 }
381
Clone() const382 AbstractBasePtr AbstractClass::Clone() const { return std::make_shared<AbstractClass>(GetValueTrack()); }
383
operator ==(const AbstractBase & other) const384 bool AbstractClass::operator==(const AbstractBase &other) const {
385 if (this == &other) {
386 return true;
387 }
388 return tid() == other.tid() && IsEqual(GetValueTrack(), other.GetValueTrack());
389 }
390
ToString() const391 std::string AbstractClass::ToString() const {
392 std::ostringstream buffer;
393 ValuePtr val = GetValueTrack();
394 MS_EXCEPTION_IF_NULL(val);
395 buffer << type_name() << "(" << val->ToString() << ")";
396 return buffer.str();
397 }
398
BuildType() const399 TypePtr AbstractType::BuildType() const { return std::make_shared<TypeType>(); }
400
Broaden() const401 AbstractBasePtr AbstractType::Broaden() const { return Clone(); }
402
AbstractProblem(const ValueProblemPtr & err,const AnfNodePtr & node)403 AbstractProblem::AbstractProblem(const ValueProblemPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
404 if (err == nullptr || node == nullptr) {
405 MS_LOG(EXCEPTION) << "err or node is nullptr";
406 }
407 }
408
ToString() const409 std::string AbstractProblem::ToString() const {
410 std::ostringstream buffer;
411 auto value_track = GetValueTrack();
412 MS_EXCEPTION_IF_NULL(value_track);
413 MS_EXCEPTION_IF_NULL(node_);
414 buffer << type_name() << "("
415 << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")";
416 return buffer.str();
417 }
418
BuildType() const419 TypePtr AbstractProblem::BuildType() const { return std::make_shared<Problem>(); }
420
Broaden() const421 AbstractBasePtr AbstractProblem::Broaden() const { return Clone(); }
422
Clone() const423 AbstractBasePtr AbstractProblem::Clone() const {
424 return std::make_shared<AbstractProblem>(GetValueTrack()->cast<ValueProblemPtr>(), node_);
425 }
426
AbstractScript()427 AbstractScript::AbstractScript() : AbstractBase(kValueAny, kTypeAny) {}
428
AbstractScript(const ValuePtr & value,const TypePtr & type)429 AbstractScript::AbstractScript(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
430
AbstractScript(const ValuePtr & value)431 AbstractScript::AbstractScript(const ValuePtr &value) : AbstractBase(value, kString) {}
432
hash() const433 std::size_t AbstractScript::hash() const {
434 return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()});
435 }
436
BuildType() const437 TypePtr AbstractScript::BuildType() const { return GetTypeTrack(); }
438
Clone() const439 AbstractBasePtr AbstractScript::Clone() const {
440 return std::make_shared<AbstractScript>(GetValueTrack(), GetTypeTrack()->Clone());
441 }
442
Broaden() const443 AbstractBasePtr AbstractScript::Broaden() const { return Clone(); }
444
Join(const AbstractBasePtr & other)445 AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) {
446 MS_EXCEPTION_IF_NULL(other);
447 if (other->isa<AbstractNegligible>()) {
448 return shared_from_base<AbstractBase>();
449 }
450 auto other_func = dyn_cast<AbstractFunction>(other);
451 if (other_func == nullptr) {
452 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
453 }
454 return Join(other_func);
455 }
456
operator ==(const AbstractBase & other) const457 bool AbstractFunction::operator==(const AbstractBase &other) const {
458 if (this == &other) {
459 return true;
460 }
461 if (!other.isa<AbstractFunction>()) {
462 return false;
463 }
464 return *this == static_cast<const AbstractFunction &>(other);
465 }
466
BuildType() const467 TypePtr AbstractFunction::BuildType() const { return std::make_shared<Function>(); }
468
Clone() const469 AbstractBasePtr AbstractFunction::Clone() const { return Copy(); }
470
Broaden() const471 AbstractBasePtr AbstractFunction::Broaden() const {
472 return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
473 }
474
tracking_id() const475 std::uintptr_t AbstractFunction::tracking_id() const { return 0; }
476
CopyWithoutTrackingId() const477 AbstractFunctionPtr AbstractFunction::CopyWithoutTrackingId() const { return Copy(); }
478
context() const479 AnalysisContextPtr AbstractFunction::context() const { return nullptr; }
480
ToTrackingId(const AnfNodePtr & node)481 std::uintptr_t AbstractFunction::ToTrackingId(const AnfNodePtr &node) {
482 return reinterpret_cast<std::uintptr_t>(node.get());
483 }
484
AbstractKeywordArg(const std::string & key,const AbstractBasePtr & argument)485 AbstractKeywordArg::AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument)
486 : arg_name_(key), arg_value_(argument) {}
487
get_key() const488 std::string AbstractKeywordArg::get_key() const { return arg_name_; }
489
get_arg() const490 AbstractBasePtr AbstractKeywordArg::get_arg() const { return arg_value_; }
491
AbstractUndetermined()492 AbstractUndetermined::AbstractUndetermined() : AbstractBase(kValueAny) {}
493
AbstractUndetermined(const AbstractBasePtr & element,const BaseShapePtr & shape)494 AbstractUndetermined::AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape)
495 : AbstractBase(kValueAny), element_(element) {
496 if (element == nullptr) {
497 MS_LOG(EXCEPTION) << "element is nullptr";
498 }
499 if (element->isa<AbstractUndetermined>()) {
500 MS_LOG(EXCEPTION) << "element type error";
501 }
502 MS_EXCEPTION_IF_NULL(shape);
503 if (shape->isa<NoShape>()) {
504 MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
505 }
506 AbstractBase::set_shape(shape);
507 }
508
AbstractUndetermined(const TypePtr & element_type,const ShapeVector & shape)509 AbstractUndetermined::AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape)
510 : AbstractBase(kValueAny), element_(std::make_shared<AbstractScalar>(kValueAny, element_type)) {
511 if (element_type == nullptr) {
512 MS_LOG(EXCEPTION) << "element_type is nullptr";
513 }
514 AbstractBase::set_shape(std::make_shared<Shape>(shape));
515 }
516
AbstractUndetermined(const TypePtr & element_type,const BaseShapePtr & shape)517 AbstractUndetermined::AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape)
518 : AbstractBase(kValueAny), element_(std::make_shared<AbstractScalar>(kValueAny, element_type)) {
519 if (element_type == nullptr) {
520 MS_LOG(EXCEPTION) << "element_type is nullptr";
521 }
522 MS_EXCEPTION_IF_NULL(shape);
523 if (shape->isa<NoShape>()) {
524 MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
525 }
526 AbstractBase::set_shape(shape);
527 }
528
BuildType() const529 TypePtr AbstractUndetermined::BuildType() const { return std::make_shared<UndeterminedType>(); }
530
Clone() const531 AbstractBasePtr AbstractUndetermined::Clone() const { return std::make_shared<AbstractUndetermined>(); }
532
element() const533 AbstractBasePtr AbstractUndetermined::element() const { return element_; }
534
AbstractTensor(const AbstractBasePtr & element,const BaseShapePtr & shape)535 AbstractTensor::AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape)
536 : AbstractUndetermined(element, shape) {}
537
AbstractTensor(const TypePtr & element_type,const ShapeVector & shape)538 AbstractTensor::AbstractTensor(const TypePtr &element_type, const ShapeVector &shape)
539 : AbstractUndetermined(element_type, shape) {}
540
AbstractTensor(const tensor::TensorPtr & tensor)541 AbstractTensor::AbstractTensor(const tensor::TensorPtr &tensor)
542 : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {}
543
AbstractTensor(const TypePtr & element_type,const BaseShapePtr & shape)544 AbstractTensor::AbstractTensor(const TypePtr &element_type, const BaseShapePtr &shape)
545 : AbstractUndetermined(element_type, shape) {}
546
hash() const547 std::size_t AbstractTensor::hash() const {
548 // We have to exclude value pointer from hash, because CSE (Common Subexpression Elimination)
549 // will use this hash to find duplicate ValueNodes that Tensor values are equal.
550 auto hash_sum = hash_combine(tid(), element_->hash());
551 const auto &shape = GetShapeTrack();
552 if (shape != nullptr) {
553 hash_sum = hash_combine(hash_sum, shape->hash());
554 }
555 return hash_sum;
556 }
557
is_adapter() const558 bool AbstractTensor::is_adapter() const { return is_adapter_; }
559
set_is_adapter(bool is_adapter)560 void AbstractTensor::set_is_adapter(bool is_adapter) { is_adapter_ = is_adapter; }
561
AbstractAny()562 AbstractAny::AbstractAny()
563 : AbstractTensor(DefaultDtype(), std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny}))) {}
564
Join(const AbstractBasePtr & other)565 AbstractBasePtr AbstractAny::Join(const AbstractBasePtr &other) {
566 MS_EXCEPTION_IF_NULL(other);
567 return std::make_shared<AbstractAny>();
568 }
569
Broaden() const570 AbstractBasePtr AbstractAny::Broaden() const { return Clone(); }
571
Clone() const572 AbstractBasePtr AbstractAny::Clone() const {
573 auto any_abstract = std::make_shared<AbstractAny>();
574 if (supposed_tensor_dtype()) {
575 MS_EXCEPTION_IF_NULL(element());
576 const auto &dtype = element()->BuildType();
577 MS_EXCEPTION_IF_NULL(any_abstract->element());
578 any_abstract->element()->set_type(dtype);
579 any_abstract->set_supposed_tensor_dtype(true);
580 }
581 return any_abstract;
582 }
583
ToString() const584 std::string AbstractAny::ToString() const { return type_name(); }
585
supposed_tensor_dtype() const586 bool AbstractAny::supposed_tensor_dtype() const { return supposed_tensor_dtype_; }
587
set_supposed_tensor_dtype(bool flag)588 void AbstractAny::set_supposed_tensor_dtype(bool flag) { supposed_tensor_dtype_ = flag; }
589
DefaultDtype()590 TypePtr AbstractAny::DefaultDtype() { return kFloat64; }
591
message() const592 const std::string &AbstractJoinedAny::message() const { return message_; }
593
set_message(const std::string & message)594 void AbstractJoinedAny::set_message(const std::string &message) { message_ = message; }
595
exception() const596 AbstractJoinedAny::ExceptionType AbstractJoinedAny::exception() const { return exception_; }
597
set_exception(ExceptionType exception)598 void AbstractJoinedAny::set_exception(ExceptionType exception) { exception_ = exception; }
599
ThrowException() const600 void AbstractJoinedAny::ThrowException() const {
601 if (exception_ == kTypeError) {
602 MS_EXCEPTION(TypeError) << message_;
603 } else if (exception_ == kValueError) {
604 MS_EXCEPTION(ValueError) << message_;
605 } else {
606 MS_LOG(EXCEPTION) << message_;
607 }
608 }
609
610 namespace {
CollectSequenceNodes(const AnfNodeWeakPtrList & source_sequence_nodes,AnfNodeWeakPtrList * sequence_nodes_ptr)611 void CollectSequenceNodes(const AnfNodeWeakPtrList &source_sequence_nodes, AnfNodeWeakPtrList *sequence_nodes_ptr) {
612 AnfNodeWeakPtrList &sequence_nodes = *sequence_nodes_ptr;
613 auto sequence_nodes_size = source_sequence_nodes.size();
614 for (size_t i = 0; i < sequence_nodes_size; ++i) {
615 // Lock sequence nodes of this.
616 auto &source_weak_node = source_sequence_nodes[i];
617 auto source_sequence_node = source_weak_node.lock();
618 if (source_sequence_node == nullptr) {
619 continue;
620 }
621 // Check and emplace sequence node for this.
622 auto this_iter = std::find_if(
623 sequence_nodes.begin(), sequence_nodes.end(),
624 [&source_sequence_node](const AnfNodeWeakPtr &weak_node) { return source_sequence_node == weak_node.lock(); });
625 if (this_iter == sequence_nodes.end()) {
626 (void)sequence_nodes.emplace_back(AnfNodeWeakPtr(source_sequence_node));
627 }
628 }
629 }
630
SynchronizeSequenceNodesElementsUseFlagsInner(const AnfNodeWeakPtrList & sequence_nodes)631 void SynchronizeSequenceNodesElementsUseFlagsInner(const AnfNodeWeakPtrList &sequence_nodes) {
632 // Choose the candidate sequence node, that we use its flags as unique one.
633 AnfNodePtr candidate_sequence_node = sequence_nodes[0].lock();
634 MS_EXCEPTION_IF_NULL(candidate_sequence_node);
635 size_t candidate_index = 0;
636 for (size_t i = 1; i < sequence_nodes.size(); ++i) {
637 auto current_sequence_node = sequence_nodes[i].lock();
638 MS_EXCEPTION_IF_NULL(current_sequence_node);
639 if (candidate_sequence_node == current_sequence_node) {
640 continue;
641 }
642 auto candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
643 MS_EXCEPTION_IF_NULL(candidate_flags);
644 auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node);
645 MS_EXCEPTION_IF_NULL(current_flags);
646 if (candidate_flags == current_flags) {
647 continue;
648 }
649
650 // Find the sequence node whose flags are most used.
651 auto candidate_count = candidate_flags.use_count();
652 auto current_count = current_flags.use_count();
653 if (candidate_count < current_count) {
654 candidate_sequence_node = current_sequence_node;
655 candidate_index = i;
656 }
657 }
658
659 // Synchronize the elements use flags for all sequence nodes with candidate sequence node.
660 // We set the same 'elements_use_flags' for them after here.
661 auto candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
662 MS_LOG(DEBUG) << "Sequence nodes size: " << sequence_nodes.size() << ", candidate node index: " << candidate_index
663 << ", candidate node: " << candidate_sequence_node->DebugString() << ", flags: " << candidate_flags;
664 for (size_t i = 0; i < sequence_nodes.size(); ++i) {
665 auto current_sequence_node = sequence_nodes[i].lock();
666 MS_EXCEPTION_IF_NULL(current_sequence_node);
667 if (candidate_sequence_node == current_sequence_node) {
668 continue;
669 }
670 auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node);
671 if (candidate_flags == current_flags) {
672 continue;
673 }
674
675 // Merge the use flags, set true if either is true.
676 for (size_t j = 0; j < candidate_flags->size(); ++j) {
677 MS_LOG(DEBUG) << "Check elements_use_flags[" << j << "], this_flag: " << (*candidate_flags)[j]
678 << ", other_flag: " << (*current_flags)[j];
679 (*candidate_flags)[j] = ((*candidate_flags)[j] || (*current_flags)[j]);
680 }
681 // Use the candidate sequence node flags.
682 SetSequenceNodeElementsUseFlags(current_sequence_node, candidate_flags);
683 MS_LOG(DEBUG) << "Reset flags for sequence node[" << i << "]: " << current_sequence_node->DebugString()
684 << ", flags: " << candidate_flags;
685 }
686 }
687
CheckSequenceNodesValid(const AnfNodeWeakPtrList & sequence_nodes)688 void CheckSequenceNodesValid(const AnfNodeWeakPtrList &sequence_nodes) {
689 if (!IS_OUTPUT_ON(MsLogLevel::kDebug)) {
690 return;
691 }
692 if (sequence_nodes.size() <= 1) {
693 return;
694 }
695 AnfNodePtr candidate_sequence_node = sequence_nodes[0].lock();
696 if (candidate_sequence_node == nullptr) {
697 MS_LOG(DEBUG) << "candidate_sequence_node is null.";
698 return;
699 }
700 auto candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
701 if (candidate_flags == nullptr) {
702 MS_LOG(DEBUG) << "The candidate_flags is null, sequence_nodes[0]: " << candidate_sequence_node->DebugString();
703 return;
704 }
705 for (size_t i = 0; i < sequence_nodes.size(); ++i) {
706 auto current_sequence_node = sequence_nodes[i].lock();
707 if (current_sequence_node == nullptr) {
708 MS_LOG(DEBUG) << "current_sequence_node is null.";
709 return;
710 }
711 MS_LOG(DEBUG) << "sequence_nodes[" << i << "]: " << current_sequence_node << "/"
712 << current_sequence_node->DebugString()
713 << ", flags: " << GetSequenceNodeElementsUseFlags(current_sequence_node);
714 }
715 for (size_t i = 1; i < sequence_nodes.size(); ++i) {
716 auto current_sequence_node = sequence_nodes[i].lock();
717 if (current_sequence_node == nullptr) {
718 MS_LOG(DEBUG) << "current_sequence_node is null.";
719 return;
720 }
721 if (candidate_sequence_node == current_sequence_node) {
722 continue;
723 }
724 candidate_flags = GetSequenceNodeElementsUseFlags(candidate_sequence_node);
725 MS_EXCEPTION_IF_NULL(candidate_flags);
726 auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node);
727 MS_EXCEPTION_IF_NULL(current_flags);
728 if (candidate_flags == current_flags) {
729 continue;
730 }
731 MS_LOG(DEBUG) << "Should use same flags pointer, candidate_node: " << candidate_sequence_node->DebugString()
732 << ", current_node: " << current_sequence_node->DebugString();
733
734 if (candidate_flags->size() != current_flags->size()) {
735 MS_LOG(INTERNAL_EXCEPTION) << "Flag not same size";
736 }
737 for (size_t j = 0; j < candidate_flags->size(); ++j) {
738 if ((*candidate_flags)[j] != (*current_flags)[j]) {
739 MS_LOG(INTERNAL_EXCEPTION) << "Not equal elements_use_flags[" << j << "], this_flag: " << (*candidate_flags)[j]
740 << ", other_flag: " << (*current_flags)[j];
741 }
742 }
743 }
744 }
745
SynchronizeSequenceNodesElementsUseFlags(const AnfNodeWeakPtrList & lhs_sequence_nodes,const AnfNodeWeakPtrList & rhs_sequence_nodes)746 AnfNodeWeakPtrList SynchronizeSequenceNodesElementsUseFlags(const AnfNodeWeakPtrList &lhs_sequence_nodes,
747 const AnfNodeWeakPtrList &rhs_sequence_nodes) {
748 // Collect this and other sequence nodes.
749 AnfNodeWeakPtrList sequence_nodes;
750 CollectSequenceNodes(lhs_sequence_nodes, &sequence_nodes);
751 CollectSequenceNodes(rhs_sequence_nodes, &sequence_nodes);
752 if (sequence_nodes.size() <= 1) {
753 MS_LOG(DEBUG) << "Sequence nodes size should exceed 1.";
754 return sequence_nodes;
755 }
756 // Synchronize the elements use flags for all sequence nodes.
757 SynchronizeSequenceNodesElementsUseFlagsInner(sequence_nodes);
758 CheckSequenceNodesValid(sequence_nodes);
759 return sequence_nodes;
760 }
761
AbstractCanJoin(const AbstractBasePtr & abs1,const AbstractBasePtr & abs2)762 bool AbstractCanJoin(const AbstractBasePtr &abs1, const AbstractBasePtr &abs2) {
763 try {
764 MS_LOG_TRY_CATCH_SCOPE;
765 (void)abs1->Join(abs2);
766 } catch (std::exception &) {
767 return false;
768 }
769 return true;
770 }
771
CheckElementAbstractSame(const AbstractBasePtr & first_element,const AbstractBasePtr & cur_element,const size_t i,bool raise_exception)772 bool CheckElementAbstractSame(const AbstractBasePtr &first_element, const AbstractBasePtr &cur_element, const size_t i,
773 bool raise_exception) {
774 MS_EXCEPTION_IF_NULL(first_element);
775 MS_EXCEPTION_IF_NULL(cur_element);
776 if (first_element->isa<abstract::AbstractAny>() || cur_element->isa<abstract::AbstractAny>()) {
777 return true;
778 }
779 auto first_element_type_id = first_element->BuildType()->generic_type_id();
780 auto cur_element_type_id = cur_element->BuildType()->generic_type_id();
781 if (first_element_type_id != cur_element_type_id) {
782 if (!raise_exception) {
783 return false;
784 }
785 MS_EXCEPTION(ValueError) << "In graph mode, the element type of dynamic length array must be the same."
786 << "The element type do not match, can not convert to dynamic length sequence. "
787 << "The 0th element type is: " << TypeIdToString(first_element_type_id) << ". The " << i
788 << "th element type is: " << TypeIdToString(cur_element_type_id);
789 }
790 auto first_element_shape = first_element->GetShape();
791 MS_EXCEPTION_IF_NULL(first_element_shape);
792 auto cur_element_shape = cur_element->GetShape();
793 MS_EXCEPTION_IF_NULL(cur_element_shape);
794 if (*first_element_shape != *cur_element_shape) {
795 return false;
796 }
797 if (!AbstractCanJoin(first_element, cur_element)) {
798 if (!raise_exception) {
799 return false;
800 }
801 MS_EXCEPTION(TypeError) << "In graph mode, the element shape of dynamic length array must be the same."
802 << "The element do not match, can not convert to dynamic length sequence. "
803 << "The 0th element is: " << first_element->ToString() << ". The " << i
804 << "th element shape is: " << cur_element->ToString();
805 }
806 return true;
807 }
808 } // namespace
809
AbstractSequence(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & sequence_nodes)810 AbstractSequence::AbstractSequence(AbstractBasePtrList &&elements,
811 const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes)
812 : elements_(std::move(elements)), sequence_nodes_(sequence_nodes) {
813 if (sequence_nodes != nullptr) {
814 CheckSequenceNodesValid(*sequence_nodes);
815 }
816 }
817
AbstractSequence(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & sequence_nodes)818 AbstractSequence::AbstractSequence(const AbstractBasePtrList &elements,
819 const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes)
820 : elements_(elements), sequence_nodes_(sequence_nodes) {
821 if (sequence_nodes != nullptr) {
822 CheckSequenceNodesValid(*sequence_nodes);
823 }
824 }
825
operator [](const std::size_t & dim) const826 const AbstractBasePtr AbstractSequence::operator[](const std::size_t &dim) const {
827 if (dynamic_len_) {
828 MS_LOG(EXCEPTION) << "Can not get element from dynamic length sequence " << ToString();
829 }
830 if (dim >= size()) {
831 MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list.";
832 }
833 return elements_[dim];
834 }
835
ToStringInternal() const836 std::string AbstractSequence::ToStringInternal() const {
837 std::ostringstream buffer;
838 if (dynamic_len_) {
839 buffer << "dynamic_len_element_abs: "
840 << (dynamic_len_element_abs_ == nullptr ? "nullptr" : dynamic_len_element_abs_->ToString());
841 return buffer.str();
842 }
843 size_t i = 0;
844 size_t size = elements_.size();
845 auto prefix_space = std::string(space_num_, ' ');
846 for (const auto &element : elements_) {
847 MS_EXCEPTION_IF_NULL(element);
848 if (element->isa<AbstractSequence>()) {
849 constexpr auto kIndent = 4;
850 element->cast<AbstractSequencePtr>()->space_num_ = space_num_ + kIndent;
851 }
852 buffer << "\n" << prefix_space << "element[" << i << "]: " << element->ToString();
853 if (i < size - 1) {
854 buffer << ", ";
855 }
856 i++;
857 }
858 return buffer.str();
859 }
860
ToString() const861 std::string AbstractSequence::ToString() const {
862 std::stringstream ss;
863 ss << "\n";
864 ss << type_name();
865 ss << "{";
866 ss << ToStringInternal();
867 if (!dynamic_len_ && sequence_nodes() != nullptr && !sequence_nodes()->empty()) {
868 ss << ", " << std::string(space_num_, ' ') << "sequence_nodes: {";
869 for (size_t i = 0; i < sequence_nodes()->size(); ++i) {
870 auto sequence_node = (*sequence_nodes())[i].lock();
871 if (sequence_node == nullptr) {
872 ss << "<freed node>";
873 continue;
874 } else {
875 ss << sequence_node->DebugString();
876 }
877 auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
878 if (flags != nullptr) {
879 ss << ", elements_use_flags: {ptr: " << flags << ", value: " << (*flags) << "}";
880 }
881 if (i != sequence_nodes()->size() - 1) {
882 ss << ", ";
883 }
884 }
885 ss << "}";
886 }
887 ss << ", dynamic_len:" << dynamic_len_ << ", is dyn arg:" << dyn_len_arg_;
888 ss << "}";
889 ss << "\n";
890 return ss.str();
891 }
892
ToString(bool verbose) const893 std::string AbstractSequence::ToString(bool verbose) const {
894 if (verbose) {
895 return ToString();
896 }
897 std::ostringstream buffer;
898 size_t i = 0;
899 size_t size = elements_.size();
900 buffer << type_name() << " {";
901 for (const auto &element : elements_) {
902 MS_EXCEPTION_IF_NULL(element);
903 buffer << element->ToString(false);
904 if (i < size - 1) {
905 buffer << ", ";
906 }
907 i++;
908 }
909 buffer << "}";
910 return buffer.str();
911 }
912
SequenceNodesJoin(const AbstractBasePtr & other)913 AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &other) {
914 AnfNodeWeakPtrList sequence_nodes;
915 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
916 if (!enable_eliminate_unused_element || this->sequence_nodes() == nullptr) {
917 return sequence_nodes;
918 }
919 auto other_sequence = dyn_cast_ptr<AbstractSequence>(other);
920 if (other_sequence == nullptr) {
921 return sequence_nodes;
922 }
923 auto this_sequence_nodes_size = (this->sequence_nodes() == nullptr ? 0 : this->sequence_nodes()->size());
924 auto other_sequence_nodes_size =
925 (other_sequence->sequence_nodes() == nullptr ? 0 : other_sequence->sequence_nodes()->size());
926 // The tuple or list output which has sequence_nodes may be joined with a tuple output node like top_k,
927 // we should return the branch which has sequence_nodes.
928 if (this_sequence_nodes_size == 0 && other_sequence_nodes_size == 0) {
929 return sequence_nodes;
930 } else if (this_sequence_nodes_size == 0) {
931 return *(other_sequence->sequence_nodes());
932 } else if (other_sequence_nodes_size == 0) {
933 return *(this->sequence_nodes());
934 }
935 // Collect this and other sequence nodes.
936 if (this->sequence_nodes() != nullptr) {
937 CollectSequenceNodes(*this->sequence_nodes(), &sequence_nodes);
938 }
939 if (other_sequence->sequence_nodes() != nullptr) {
940 CollectSequenceNodes(*other_sequence->sequence_nodes(), &sequence_nodes);
941 }
942 if (sequence_nodes.empty()) {
943 MS_LOG(INFO) << "Sequence nodes size should not be empty.";
944 return sequence_nodes;
945 }
946 // Synchronize the elements use flags for all sequence nodes.
947 SynchronizeSequenceNodesElementsUseFlagsInner(sequence_nodes);
948
949 CheckSequenceNodesValid(sequence_nodes);
950 this->InsertSequenceNodes(sequence_nodes);
951 other_sequence->InsertSequenceNodes(sequence_nodes);
952 return sequence_nodes;
953 }
954
SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr & lhs_sequence,const AbstractSequencePtr & rhs_sequence)955 void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence,
956 const AbstractSequencePtr &rhs_sequence) {
957 if (lhs_sequence->sequence_nodes() == nullptr || rhs_sequence->sequence_nodes() == nullptr) {
958 return;
959 }
960 auto sequence_nodes =
961 SynchronizeSequenceNodesElementsUseFlags(*lhs_sequence->sequence_nodes(), *rhs_sequence->sequence_nodes());
962 lhs_sequence->InsertSequenceNodes(sequence_nodes);
963 rhs_sequence->InsertSequenceNodes(sequence_nodes);
964 if (lhs_sequence->elements().size() != rhs_sequence->elements().size()) {
965 MS_LOG(INTERNAL_EXCEPTION) << "The elements size should be equal, " << lhs_sequence->ToString() << ", "
966 << rhs_sequence->ToString();
967 }
968 for (size_t i = 0; i < lhs_sequence->elements().size(); ++i) {
969 auto lhs_inner_sequence = dyn_cast<AbstractSequence>(lhs_sequence->elements()[i]);
970 if (lhs_inner_sequence == nullptr) {
971 continue;
972 }
973 auto rhs_inner_sequence = dyn_cast<AbstractSequence>(rhs_sequence->elements()[i]);
974 if (rhs_inner_sequence == nullptr) {
975 continue;
976 }
977 SynchronizeSequenceElementsUseFlagsRecursively(lhs_inner_sequence, rhs_inner_sequence);
978 }
979 }
980
InsertSequenceNodes(const AnfNodeWeakPtrList & sequence_nodes)981 void AbstractSequence::InsertSequenceNodes(const AnfNodeWeakPtrList &sequence_nodes) {
982 if (dynamic_len_) {
983 MS_LOG(INTERNAL_EXCEPTION) << "Can not insert sequence nodes for dynamic length sequence " << ToString();
984 }
985 if (sequence_nodes_ == nullptr) {
986 MS_LOG(DEBUG) << "The sequence_nodes is null.";
987 sequence_nodes_ = std::make_shared<AnfNodeWeakPtrList>();
988 }
989 for (auto &weak_node : sequence_nodes) {
990 auto sequence_node = weak_node.lock();
991 InsertSequenceNode(sequence_node);
992 }
993 }
994
InsertSequenceNode(const AnfNodePtr & sequence_node)995 void AbstractSequence::InsertSequenceNode(const AnfNodePtr &sequence_node) {
996 if (dynamic_len_) {
997 MS_LOG(INTERNAL_EXCEPTION) << "Can not insert sequence node for dynamic length sequence " << ToString();
998 }
999 if (sequence_nodes_ == nullptr) {
1000 MS_LOG(DEBUG) << "The sequence_nodes is null.";
1001 sequence_nodes_ = std::make_shared<AnfNodeWeakPtrList>();
1002 }
1003 auto iter =
1004 std::find_if(sequence_nodes_->begin(), sequence_nodes_->end(),
1005 [&sequence_node](const AnfNodeWeakPtr &weak_node) { return sequence_node == weak_node.lock(); });
1006 if (iter == sequence_nodes_->end()) {
1007 (void)sequence_nodes_->emplace_back(sequence_node);
1008 CheckSequenceNodesValid(*sequence_nodes_);
1009 } else {
1010 MS_LOG(DEBUG) << "Fail to insert node \'" << sequence_node->DebugString() << "\' into sequence nodes.";
1011 }
1012 }
1013
UpdateSequenceNode(const AnfNodePtr & old_sequence_node,const AnfNodePtr & new_sequence_node)1014 void AbstractSequence::UpdateSequenceNode(const AnfNodePtr &old_sequence_node, const AnfNodePtr &new_sequence_node) {
1015 if (dynamic_len_) {
1016 MS_LOG(INTERNAL_EXCEPTION) << "Can not update sequence node for dynamic length sequence " << ToString();
1017 }
1018 if (sequence_nodes_ == nullptr) {
1019 MS_LOG(DEBUG) << "The sequence_nodes is null.";
1020 sequence_nodes_ = std::make_shared<AnfNodeWeakPtrList>();
1021 }
1022 auto iter = std::find_if(
1023 sequence_nodes_->begin(), sequence_nodes_->end(),
1024 [&old_sequence_node](const AnfNodeWeakPtr &weak_node) { return old_sequence_node == weak_node.lock(); });
1025 if (iter != sequence_nodes_->end()) {
1026 *iter = new_sequence_node;
1027 CheckSequenceNodesValid(*sequence_nodes_);
1028 return;
1029 }
1030 MS_LOG(INTERNAL_EXCEPTION) << "Not found old node \'" << old_sequence_node->DebugString() << "\' in sequence nodes.";
1031 }
1032
PurifyElements()1033 bool AbstractSequence::PurifyElements() {
1034 if (dynamic_len_ || sequence_nodes_ == nullptr || sequence_nodes_->empty()) {
1035 return true;
1036 }
1037 // Just use any sequence node's elements_use_flags.
1038 AnfNodePtr not_free_node = nullptr;
1039 std::shared_ptr<std::vector<bool>> elements_use_flags_ptr = nullptr;
1040 for (auto &weak_node : *sequence_nodes_) {
1041 auto sequence_node = weak_node.lock();
1042 if (sequence_node == nullptr) {
1043 MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
1044 continue;
1045 }
1046 not_free_node = sequence_node;
1047 auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
1048 if (flags != nullptr) {
1049 elements_use_flags_ptr = flags;
1050 break;
1051 }
1052 }
1053 if (elements_use_flags_ptr == nullptr) {
1054 if (not_free_node == nullptr) {
1055 MS_LOG(INFO) << "Check if all sequence nodes are released, or none elements use flags in them. nodes size: "
1056 << sequence_nodes_->size();
1057 } else {
1058 MS_LOG(INFO) << "Check if none elements use flags in sequence ndoes. one of node: "
1059 << not_free_node->DebugString();
1060 }
1061 return false;
1062 }
1063
1064 // Purify the elements.
1065 auto &elements_use_flags = *elements_use_flags_ptr;
1066 if (elements_use_flags.size() < elements_.size()) {
1067 MS_LOG(INTERNAL_EXCEPTION) << "Elements size should not be greater to elements use flags size. " << ToString();
1068 }
1069 for (size_t i = 0; i < elements_.size(); ++i) {
1070 MS_EXCEPTION_IF_NULL(elements_[i]);
1071 if (!elements_use_flags[i]) {
1072 const auto unuse_node_none = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(0));
1073 if (elements_[i]->isa<AbstractProblem>()) {
1074 unuse_node_none->set_type(std::make_shared<Problem>());
1075 }
1076 elements_[i] = unuse_node_none;
1077 MS_LOG(DEBUG) << "Erase elements[" << i << "] abstract as Zero for " << ToString();
1078 } else {
1079 MS_LOG(DEBUG) << "Keep elements[" << i << "] abstract as " << elements_[i]->ToString();
1080 }
1081 }
1082 return true;
1083 }
1084
1085 // Convert self from a fixed length sequence to dynamic length sequence.
CheckAndConvertToDynamicLenSequence(bool raise_exception)1086 void AbstractSequence::CheckAndConvertToDynamicLenSequence(bool raise_exception) {
1087 // Can not use size() since it will raise error when sequence is already dynamic length.
1088 const size_t input_len = elements_.size();
1089 if (input_len > 1) {
1090 auto first_element = elements()[0];
1091 MS_EXCEPTION_IF_NULL(first_element);
1092 for (size_t i = 1; i < input_len; ++i) {
1093 auto cur_element = elements()[i];
1094 MS_EXCEPTION_IF_NULL(cur_element);
1095 bool ret = CheckElementAbstractSame(first_element, cur_element, i, raise_exception);
1096 if (!ret) {
1097 return;
1098 }
1099 }
1100 set_dynamic_len_element_abs(first_element);
1101 } else if (input_len == 1) {
1102 set_dynamic_len_element_abs(elements()[0]);
1103 }
1104 set_value(kValueAny);
1105 // Set sequence nodes to nulltpr to disable DDE.
1106 sequence_nodes_ = nullptr;
1107 set_dynamic_len(true);
1108 }
1109
1110 // Convert self's cloned abstract from a fixed length sequence to dynamic length sequence, just like tensor broaden.
BroadenToDynamicLenSequence()1111 AbstractSequencePtr AbstractSequence::BroadenToDynamicLenSequence() {
1112 if (dynamic_len()) {
1113 return shared_from_base<AbstractSequence>();
1114 }
1115 if (isa<AbstractSparseTensor>()) {
1116 return shared_from_base<AbstractSequence>();
1117 }
1118 auto clone_sequence = Clone()->cast<AbstractSequencePtr>();
1119 clone_sequence->CheckAndConvertToDynamicLenSequence(false);
1120 if (clone_sequence->dynamic_len()) {
1121 set_dyn_len_arg();
1122 // Set all sequence inputs as used if trans from fixed length to dynamic length.
1123 SetSequenceElementsUseFlagsRecursively(shared_from_base<AbstractSequence>(), true);
1124 return clone_sequence;
1125 }
1126 // Convert to dynamic len failed, return original abstract.
1127 return shared_from_base<AbstractSequence>();
1128 }
1129
ElementsType() const1130 TypePtrList AbstractSequence::ElementsType() const {
1131 TypePtrList element_type_list;
1132 for (const auto &element : elements_) {
1133 MS_EXCEPTION_IF_NULL(element);
1134 TypePtr element_type = element->BuildType();
1135 element_type_list.push_back(element_type);
1136 }
1137 return element_type_list;
1138 }
1139
ElementsShape() const1140 BaseShapePtrList AbstractSequence::ElementsShape() const {
1141 BaseShapePtrList element_shape_list;
1142 for (const auto &element : elements_) {
1143 MS_EXCEPTION_IF_NULL(element);
1144 BaseShapePtr element_shape = element->GetShape();
1145 element_shape_list.push_back(element_shape);
1146 }
1147 return element_shape_list;
1148 }
1149
ElementsClone() const1150 AbstractBasePtrList AbstractSequence::ElementsClone() const {
1151 AbstractBasePtrList element_list;
1152 for (const auto &element : elements_) {
1153 MS_EXCEPTION_IF_NULL(element);
1154 AbstractBasePtr clone = element->Clone();
1155 element_list.push_back(clone);
1156 }
1157 return element_list;
1158 }
1159
ElementsBroaden() const1160 AbstractBasePtrList AbstractSequence::ElementsBroaden() const {
1161 AbstractBasePtrList element_list;
1162 for (const auto &element : elements_) {
1163 MS_EXCEPTION_IF_NULL(element);
1164 AbstractBasePtr broadend = element->Broaden();
1165 element_list.push_back(broadend);
1166 }
1167 return element_list;
1168 }
1169
ElementsPartialBroaden() const1170 AbstractBasePtrList AbstractSequence::ElementsPartialBroaden() const {
1171 AbstractBasePtrList element_list;
1172 for (const auto &element : elements_) {
1173 MS_EXCEPTION_IF_NULL(element);
1174 AbstractBasePtr broadend = element->PartialBroaden();
1175 element_list.push_back(broadend);
1176 }
1177 return element_list;
1178 }
1179
GetValueFromUserData(const AbstractBasePtr & element_abs)1180 std::pair<bool, ValuePtr> GetValueFromUserData(const AbstractBasePtr &element_abs) {
1181 MS_EXCEPTION_IF_NULL(element_abs);
1182 if (abstract::AbstractBase::pyexecute_user_data_catcher()) {
1183 return abstract::AbstractBase::pyexecute_user_data_catcher()(element_abs);
1184 }
1185 return {false, nullptr};
1186 }
1187
1188 template <typename T>
ElementsBuildValue() const1189 ValuePtr AbstractSequence::ElementsBuildValue() const {
1190 std::vector<ValuePtr> element_value_list;
1191 for (const auto &element : elements_) {
1192 MS_EXCEPTION_IF_NULL(element);
1193 auto [has_user_data, element_value] = GetValueFromUserData(element);
1194 if (has_user_data && element_value != nullptr) {
1195 element_value_list.push_back(element_value);
1196 continue;
1197 }
1198 element_value = element->BuildValue();
1199 MS_EXCEPTION_IF_NULL(element_value);
1200 if (element_value->isa<ValueAny>()) {
1201 return kValueAny;
1202 }
1203 element_value_list.push_back(element_value);
1204 }
1205 return std::make_shared<T>(element_value_list);
1206 }
1207 template MS_CORE_API ValuePtr AbstractSequence::ElementsBuildValue<ValueTuple>() const;
1208 template MS_CORE_API ValuePtr AbstractSequence::ElementsBuildValue<ValueList>() const;
1209
1210 template <typename T>
ElementsJoin(const AbstractSequencePtr & other)1211 AbstractBasePtr AbstractSequence::ElementsJoin(const AbstractSequencePtr &other) {
1212 MS_EXCEPTION_IF_NULL(other);
1213 auto joined_list = AbstractJoin(elements_, other->elements_);
1214 bool changes = false;
1215 for (std::size_t i = 0; i < elements_.size(); i++) {
1216 if (elements_[i] != joined_list[i]) {
1217 changes = true;
1218 break;
1219 }
1220 }
1221 if (!changes) {
1222 return shared_from_base<AbstractBase>();
1223 }
1224 return std::make_shared<T>(joined_list);
1225 }
1226 template AbstractBasePtr AbstractSequence::ElementsJoin<AbstractList>(const AbstractSequencePtr &);
1227 template AbstractBasePtr AbstractSequence::ElementsJoin<AbstractTuple>(const AbstractSequencePtr &);
1228
hash() const1229 std::size_t AbstractSequence::hash() const {
1230 if (dynamic_len_) {
1231 size_t hash_val = hash_combine(tid(), static_cast<size_t>(dynamic_len_));
1232 if (dynamic_len_element_abs_ != nullptr) {
1233 return hash_combine(hash_val, static_cast<size_t>(dynamic_len_element_abs_->hash()));
1234 }
1235 return hash_val;
1236 }
1237 return hash_combine(tid(), AbstractBasePtrListHash(elements_));
1238 }
1239
size() const1240 std::size_t AbstractSequence::size() const {
1241 if (dynamic_len_) {
1242 if (dynamic_len_element_abs_ == nullptr) {
1243 return 0;
1244 }
1245 MS_LOG(INTERNAL_EXCEPTION) << "Can not get size for dynamic length sequence " << ToString();
1246 }
1247 return elements_.size();
1248 }
1249
empty() const1250 bool AbstractSequence::empty() const {
1251 if (dynamic_len_) {
1252 if (dynamic_len_element_abs_ == nullptr) {
1253 return true;
1254 }
1255 MS_LOG(INTERNAL_EXCEPTION) << "Can not call function empty() for dynamic length sequence " << ToString();
1256 }
1257 return elements_.empty();
1258 }
1259
set_dynamic_len(bool dynamic_len)1260 void AbstractSequence::set_dynamic_len(bool dynamic_len) {
1261 if (dynamic_len) {
1262 sequence_nodes_ = nullptr;
1263 }
1264 dynamic_len_ = dynamic_len;
1265 }
1266
set_dynamic_len_element_abs(const AbstractBasePtr & dynamic_len_element_abs)1267 void AbstractSequence::set_dynamic_len_element_abs(const AbstractBasePtr &dynamic_len_element_abs) {
1268 if (dynamic_len_element_abs == nullptr) {
1269 return;
1270 }
1271 if (dynamic_len_element_abs->isa<abstract::AbstractDictionary>()) {
1272 MS_EXCEPTION(TypeError) << "DynamicSequence does not support dictionary type as element type now.";
1273 }
1274 // dynamic_len_element_abs should ignore value.
1275 dynamic_len_element_abs_ = AbstractBroaden(dynamic_len_element_abs);
1276 }
1277
operator ==(const AbstractBase & other) const1278 bool AbstractSequence::operator==(const AbstractBase &other) const {
1279 if (this == &other) {
1280 return true;
1281 }
1282 if (tid() != other.tid()) {
1283 return false;
1284 }
1285 const auto &other_sequence = dynamic_cast<const AbstractSequence &>(other);
1286 if (dynamic_len_ != other_sequence.dynamic_len()) {
1287 // Variable length sequence and constant length sequence can not be the same.
1288 return false;
1289 }
1290
1291 if (dynamic_len_) {
1292 // If the abstract of element for two variable sequence is the same, these two sequence is the same.
1293 return IsEqual(dynamic_len_element_abs_, other_sequence.dynamic_len_element_abs());
1294 }
1295
1296 if (elements_.size() != other_sequence.elements_.size()) {
1297 return false;
1298 }
1299 for (size_t i = 0; i < elements_.size(); ++i) {
1300 if (!IsEqual(elements_[i], other_sequence.elements_[i])) {
1301 return false;
1302 }
1303 }
1304 return true;
1305 }
1306
elements() const1307 const AbstractBasePtrList &AbstractSequence::elements() const { return elements_; }
1308
sequence_nodes() const1309 const std::shared_ptr<AnfNodeWeakPtrList> &AbstractSequence::sequence_nodes() const { return sequence_nodes_; }
1310
set_sequence_nodes(const std::shared_ptr<AnfNodeWeakPtrList> & sequence_nodes)1311 void AbstractSequence::set_sequence_nodes(const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes) {
1312 sequence_nodes_ = sequence_nodes;
1313 }
1314
dynamic_len() const1315 bool AbstractSequence::dynamic_len() const { return dynamic_len_; }
1316
dynamic_len_element_abs() const1317 AbstractBasePtr AbstractSequence::dynamic_len_element_abs() const { return dynamic_len_element_abs_; }
1318
set_dyn_len_arg()1319 void AbstractSequence::set_dyn_len_arg() { dyn_len_arg_ = true; }
1320
dyn_len_arg() const1321 bool AbstractSequence::dyn_len_arg() const { return dyn_len_arg_; }
1322
ContainsAllBroadenTensors() const1323 bool AbstractSequence::ContainsAllBroadenTensors() const {
1324 if (dynamic_len_) {
1325 if (dynamic_len_element_abs_ != nullptr && dynamic_len_element_abs_->isa<AbstractTensor>()) {
1326 return true;
1327 }
1328 return false;
1329 }
1330 if (elements_.empty()) {
1331 return false;
1332 }
1333 auto exist_not_broadened_tensor = [](const AbstractBasePtr &abs) {
1334 bool is_broaden_tensor = abs->isa<abstract::AbstractUndetermined>() && abs->IsBroaden();
1335 bool is_broaden_sequence = abs->isa<abstract::AbstractSequence>() &&
1336 abs->cast_ptr<abstract::AbstractSequence>()->ContainsAllBroadenTensors();
1337 return !is_broaden_tensor && !is_broaden_sequence;
1338 };
1339 return !std::any_of(elements_.cbegin(), elements_.cend(), exist_not_broadened_tensor);
1340 }
1341
AbstractTuple(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)1342 AbstractTuple::AbstractTuple(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
1343 : AbstractSequence(std::move(elements), tuple_nodes) {}
1344
AbstractTuple(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)1345 AbstractTuple::AbstractTuple(const AbstractBasePtrList &elements,
1346 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
1347 : AbstractSequence(elements, tuple_nodes) {}
1348
BuildType() const1349 TypePtr AbstractTuple::BuildType() const {
1350 auto ret = std::make_shared<Tuple>(ElementsType());
1351 if (dynamic_len_) {
1352 ret->set_dynamic_len(dynamic_len_);
1353 if (dynamic_len_element_abs_ != nullptr) {
1354 ret->set_dynamic_element_type(dynamic_len_element_abs_->BuildType());
1355 }
1356 }
1357 return ret;
1358 }
1359
BuildShape() const1360 BaseShapePtr AbstractTuple::BuildShape() const {
1361 if (dynamic_len_) {
1362 if (dynamic_len_element_abs_ == nullptr) {
1363 return std::make_shared<DynamicSequenceShape>(nullptr);
1364 }
1365 return std::make_shared<DynamicSequenceShape>(dynamic_len_element_abs_->GetShape());
1366 }
1367 return std::make_shared<TupleShape>(ElementsShape());
1368 }
1369
Clone() const1370 AbstractBasePtr AbstractTuple::Clone() const {
1371 auto ret = std::make_shared<AbstractTuple>(ElementsClone(), sequence_nodes());
1372 ret->dyn_len_arg_ = dyn_len_arg_;
1373 ret->set_dynamic_len(dynamic_len_);
1374 ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1375 ret->SetSymbolicShape(this->GetSymbolicShape());
1376 ret->SetSymbolicValue(this->GetSymbolicValue());
1377 return ret;
1378 }
1379
Broaden() const1380 AbstractBasePtr AbstractTuple::Broaden() const {
1381 auto ret = std::make_shared<AbstractTuple>(ElementsBroaden(), sequence_nodes());
1382 ret->set_dynamic_len(dynamic_len_);
1383 ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1384 return ret;
1385 }
1386
PartialBroaden() const1387 AbstractBasePtr AbstractTuple::PartialBroaden() const {
1388 auto ret = std::make_shared<AbstractTuple>(ElementsPartialBroaden(), sequence_nodes());
1389 ret->set_dynamic_len(dynamic_len_);
1390 ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1391 return ret;
1392 }
1393
RealBuildValue() const1394 ValuePtr AbstractTuple::RealBuildValue() const {
1395 if (dynamic_len_) {
1396 return kValueAny;
1397 }
1398 return ElementsBuildValue<ValueTuple>();
1399 }
1400
set_shape(const BaseShapePtr & shape)1401 void AbstractTuple::set_shape(const BaseShapePtr &shape) {
1402 auto tuple_shape = dyn_cast_ptr<TupleShape>(shape);
1403 MS_EXCEPTION_IF_NULL(tuple_shape);
1404 if (tuple_shape->shape().size() != elements_.size()) {
1405 MS_LOG(INTERNAL_EXCEPTION) << "Size mismatch: " << tuple_shape->shape().size() << " vs " << elements_.size();
1406 }
1407
1408 for (size_t i = 0; i < elements_.size(); ++i) {
1409 MS_EXCEPTION_IF_NULL(elements_[i]);
1410 elements_[i]->set_shape(tuple_shape->shape()[i]);
1411 }
1412 }
1413
operator ==(const AbstractBase & other) const1414 bool AbstractTuple::operator==(const AbstractBase &other) const {
1415 if (this == &other) {
1416 return true;
1417 }
1418 if (!other.isa<AbstractTuple>()) {
1419 return false;
1420 }
1421 return AbstractSequence::operator==(static_cast<const AbstractSequence &>(other));
1422 }
1423
AbstractList(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & list_nodes)1424 AbstractList::AbstractList(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes)
1425 : AbstractSequence(std::move(elements), list_nodes) {}
1426
AbstractList(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & list_nodes)1427 AbstractList::AbstractList(const AbstractBasePtrList &elements, const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes)
1428 : AbstractSequence(elements, list_nodes) {}
1429
operator ==(const AbstractBase & other) const1430 bool AbstractList::operator==(const AbstractBase &other) const {
1431 if (this == &other) {
1432 return true;
1433 }
1434 if (!other.isa<AbstractList>()) {
1435 return false;
1436 }
1437 auto other_extra_info = static_cast<const AbstractList &>(other).extra_info();
1438 if (extra_info_->size() != 0 && other_extra_info->size() != 0 && extra_info_ != other_extra_info) {
1439 return false;
1440 }
1441 return AbstractSequence::operator==(static_cast<const AbstractSequence &>(other));
1442 }
1443
BuildType() const1444 TypePtr AbstractList::BuildType() const {
1445 auto ret = std::make_shared<List>(ElementsType());
1446 if (dynamic_len_) {
1447 ret->set_dynamic_len(dynamic_len_);
1448 if (dynamic_len_element_abs_ != nullptr) {
1449 ret->set_dynamic_element_type(dynamic_len_element_abs_->BuildType());
1450 }
1451 }
1452 return ret;
1453 }
1454
BuildShape() const1455 BaseShapePtr AbstractList::BuildShape() const {
1456 if (dynamic_len_) {
1457 if (dynamic_len_element_abs_ == nullptr) {
1458 return std::make_shared<DynamicSequenceShape>(nullptr);
1459 }
1460 return std::make_shared<DynamicSequenceShape>(dynamic_len_element_abs_->GetShape());
1461 }
1462 return std::make_shared<ListShape>(ElementsShape());
1463 }
1464
Clone() const1465 AbstractBasePtr AbstractList::Clone() const {
1466 auto ret = std::make_shared<AbstractList>(ElementsClone(), sequence_nodes());
1467 ret->dyn_len_arg_ = dyn_len_arg_;
1468 ret->set_dynamic_len(dynamic_len_);
1469 ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1470 ret->set_extra_info(extra_info_);
1471 ret->SetSymbolicShape(this->GetSymbolicShape());
1472 ret->SetSymbolicValue(this->GetSymbolicValue());
1473 return ret;
1474 }
1475
Broaden() const1476 AbstractBasePtr AbstractList::Broaden() const {
1477 auto ret = std::make_shared<AbstractList>(ElementsBroaden(), sequence_nodes());
1478 ret->set_dynamic_len(dynamic_len_);
1479 ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1480 ret->set_extra_info(extra_info_);
1481 return ret;
1482 }
1483
PartialBroaden() const1484 AbstractBasePtr AbstractList::PartialBroaden() const {
1485 auto ret = std::make_shared<AbstractList>(ElementsPartialBroaden(), sequence_nodes());
1486 ret->set_dynamic_len(dynamic_len_);
1487 ret->set_dynamic_len_element_abs(dynamic_len_element_abs_);
1488 ret->set_extra_info(extra_info_);
1489 return ret;
1490 }
1491
RealBuildValue() const1492 ValuePtr AbstractList::RealBuildValue() const {
1493 if (dynamic_len_) {
1494 return kValueAny;
1495 }
1496 return ElementsBuildValue<ValueList>();
1497 }
1498
CheckAndConvertToDynamicLenSequence(bool raise_exception)1499 void AbstractList::CheckAndConvertToDynamicLenSequence(bool raise_exception) {
1500 AbstractSequence::CheckAndConvertToDynamicLenSequence(raise_exception);
1501 ClearExtraInfo();
1502 }
1503
DynamicLenSequenceJoin(const AbstractSequencePtr & other)1504 std::shared_ptr<AbstractSequence> AbstractSequence::DynamicLenSequenceJoin(const AbstractSequencePtr &other) {
1505 auto other_dyn_sequence_abs = other;
1506 if (!dynamic_len() || !other->dynamic_len()) {
1507 MS_LOG(EXCEPTION) << "Can't join fixed length sequence. \nthis:" << ToString() << ", \nother:" << other->ToString();
1508 }
1509 auto element_abs1 = dynamic_len_element_abs_;
1510 auto element_abs2 = other_dyn_sequence_abs->dynamic_len_element_abs();
1511 AbstractBasePtr join_element_abs = nullptr;
1512 // When two element abstracts are not nullptr, join them to get the new element abstract.
1513 // When one or none of the element abstract is nullptr, the result element abstract is another.
1514 if (element_abs1 == nullptr) {
1515 join_element_abs = element_abs2;
1516 } else if (element_abs2 == nullptr) {
1517 join_element_abs = element_abs1;
1518 } else {
1519 join_element_abs = element_abs1->Join(element_abs2);
1520 }
1521 auto ret = Clone()->cast<AbstractSequencePtr>();
1522 ret->set_dynamic_len_element_abs(join_element_abs);
1523 return ret;
1524 }
1525
Join(const AbstractBasePtr & other)1526 AbstractBasePtr AbstractTuple::Join(const AbstractBasePtr &other) {
1527 if (other->isa<AbstractNegligible>()) {
1528 return shared_from_base<AbstractBase>();
1529 }
1530 auto other_sequence = other->cast<AbstractTuplePtr>();
1531 if (other_sequence == nullptr) {
1532 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1533 }
1534 try {
1535 if (dynamic_len_) {
1536 return DynamicLenSequenceJoin(other_sequence->BroadenToDynamicLenSequence());
1537 }
1538 if (other_sequence->dynamic_len()) {
1539 return other_sequence->DynamicLenSequenceJoin(BroadenToDynamicLenSequence());
1540 }
1541 if (other_sequence->size() != size()) {
1542 auto dyn_len_sequence = BroadenToDynamicLenSequence();
1543 return dyn_len_sequence->Join(other_sequence->BroadenToDynamicLenSequence());
1544 }
1545 } catch (std::exception &e) {
1546 MS_LOG(ERROR) << e.what();
1547 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1548 }
1549 auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractTuple>(other_sequence));
1550 MS_EXCEPTION_IF_NULL(res);
1551 res->InsertSequenceNodes(SequenceNodesJoin(other));
1552 return res;
1553 }
1554
Join(const AbstractBasePtr & other)1555 AbstractBasePtr AbstractList::Join(const AbstractBasePtr &other) {
1556 if (other->isa<AbstractNegligible>()) {
1557 return shared_from_base<AbstractBase>();
1558 }
1559 auto other_sequence = other->cast<AbstractListPtr>();
1560 if (other_sequence == nullptr) {
1561 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1562 }
1563 try {
1564 if (dynamic_len_) {
1565 return DynamicLenSequenceJoin(other_sequence->BroadenToDynamicLenSequence());
1566 }
1567 if (other_sequence->dynamic_len()) {
1568 return other_sequence->DynamicLenSequenceJoin(BroadenToDynamicLenSequence());
1569 }
1570 if (other_sequence->size() != size()) {
1571 auto dyn_len_sequence = BroadenToDynamicLenSequence();
1572 return dyn_len_sequence->DynamicLenSequenceJoin(other_sequence->BroadenToDynamicLenSequence());
1573 }
1574 } catch (std::exception &e) {
1575 MS_LOG(ERROR) << e.what();
1576 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1577 }
1578
1579 auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractList>(other_sequence));
1580 MS_EXCEPTION_IF_NULL(res);
1581 res->InsertSequenceNodes(SequenceNodesJoin(other));
1582 return res;
1583 }
1584
RealBuildValue() const1585 ValuePtr AbstractNamedTuple::RealBuildValue() const {
1586 std::vector<ValuePtr> element_value_list;
1587 for (const auto &element : elements_) {
1588 MS_EXCEPTION_IF_NULL(element);
1589 auto element_value = element->BuildValue();
1590 MS_EXCEPTION_IF_NULL(element_value);
1591 element_value_list.push_back(element_value);
1592 }
1593 std::vector<ValuePtr> key_value_list;
1594 for (const auto &key : keys_) {
1595 MS_EXCEPTION_IF_NULL(key);
1596 auto key_value = key->BuildValue();
1597 MS_EXCEPTION_IF_NULL(key_value);
1598 key_value_list.push_back(key_value);
1599 }
1600 return std::make_shared<ValueNamedTuple>(sub_class_name_, key_value_list, element_value_list);
1601 }
1602
BuildType() const1603 TypePtr AbstractSlice::BuildType() const {
1604 MS_EXCEPTION_IF_NULL(start_);
1605 MS_EXCEPTION_IF_NULL(stop_);
1606 MS_EXCEPTION_IF_NULL(step_);
1607 TypePtr start = start_->BuildType();
1608 TypePtr stop = stop_->BuildType();
1609 TypePtr step = step_->BuildType();
1610 return std::make_shared<Slice>(start, stop, step);
1611 }
1612
AbstractDictionary(const std::vector<AbstractElementPair> & key_values)1613 AbstractDictionary::AbstractDictionary(const std::vector<AbstractElementPair> &key_values) : key_values_(key_values) {}
1614
size() const1615 std::size_t AbstractDictionary::size() const { return key_values_.size(); }
1616
elements() const1617 const std::vector<AbstractElementPair> &AbstractDictionary::elements() const { return key_values_; }
1618
AbstractSlice(const AbstractBasePtr & start,const AbstractBasePtr & stop,const AbstractBasePtr & step)1619 AbstractSlice::AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step)
1620 : start_(start), stop_(stop), step_(step) {}
1621
operator ==(const AbstractBase & other) const1622 bool AbstractSlice::operator==(const AbstractBase &other) const {
1623 if (this == &other) {
1624 return true;
1625 }
1626 if (!other.isa<AbstractSlice>()) {
1627 return false;
1628 }
1629 const auto &other_slice = dynamic_cast<const AbstractSlice &>(other);
1630 return IsEqual(start_, other_slice.start_) && IsEqual(stop_, other_slice.stop_) && IsEqual(step_, other_slice.step_);
1631 }
1632
Clone() const1633 AbstractBasePtr AbstractSlice::Clone() const {
1634 MS_EXCEPTION_IF_NULL(start_);
1635 MS_EXCEPTION_IF_NULL(stop_);
1636 MS_EXCEPTION_IF_NULL(step_);
1637 AbstractBasePtr start = start_->Clone();
1638 AbstractBasePtr stop = stop_->Clone();
1639 AbstractBasePtr step = step_->Clone();
1640 return std::make_shared<AbstractSlice>(start, stop, step);
1641 }
1642
Broaden() const1643 AbstractBasePtr AbstractSlice::Broaden() const {
1644 MS_EXCEPTION_IF_NULL(start_);
1645 MS_EXCEPTION_IF_NULL(stop_);
1646 MS_EXCEPTION_IF_NULL(step_);
1647 AbstractBasePtr start = start_->Broaden();
1648 AbstractBasePtr stop = stop_->Broaden();
1649 AbstractBasePtr step = step_->Broaden();
1650 return std::make_shared<AbstractSlice>(start, stop, step);
1651 }
1652
ToString() const1653 std::string AbstractSlice::ToString() const {
1654 std::ostringstream buffer;
1655 buffer << type_name() << "[";
1656 MS_EXCEPTION_IF_NULL(start_);
1657 buffer << start_->ToString() << " : ";
1658 MS_EXCEPTION_IF_NULL(stop_);
1659 buffer << stop_->ToString() << " : ";
1660 MS_EXCEPTION_IF_NULL(step_);
1661 buffer << step_->ToString();
1662 buffer << "]";
1663 return buffer.str();
1664 }
1665
RealBuildValue() const1666 ValuePtr AbstractSlice::RealBuildValue() const {
1667 MS_EXCEPTION_IF_NULL(start_);
1668 MS_EXCEPTION_IF_NULL(stop_);
1669 MS_EXCEPTION_IF_NULL(step_);
1670 ValuePtr start = start_->BuildValue();
1671 ValuePtr stop = stop_->BuildValue();
1672 ValuePtr step = step_->BuildValue();
1673 if (start->isa<ValueAny>() || stop->isa<ValueAny>() || step->isa<ValueAny>()) {
1674 return kValueAny;
1675 }
1676 return std::make_shared<ValueSlice>(start, stop, step);
1677 }
1678
hash() const1679 std::size_t AbstractSlice::hash() const {
1680 MS_EXCEPTION_IF_NULL(start_);
1681 MS_EXCEPTION_IF_NULL(stop_);
1682 MS_EXCEPTION_IF_NULL(step_);
1683 return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
1684 }
1685
start() const1686 AbstractBasePtr AbstractSlice::start() const { return start_; }
1687
stop() const1688 AbstractBasePtr AbstractSlice::stop() const { return stop_; }
1689
step() const1690 AbstractBasePtr AbstractSlice::step() const { return step_; }
1691
shape() const1692 ShapePtr AbstractUndetermined::shape() const {
1693 auto shp = dyn_cast<Shape>(GetShapeTrack());
1694 if (shp == nullptr) {
1695 MS_LOG(INTERNAL_EXCEPTION) << "Tensor should have a shape.";
1696 }
1697 return shp;
1698 }
1699
set_shape(const BaseShapePtr & shape)1700 void AbstractUndetermined::set_shape(const BaseShapePtr &shape) {
1701 MS_EXCEPTION_IF_NULL(shape);
1702 if (shape->isa<NoShape>()) {
1703 MS_LOG(INTERNAL_EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
1704 }
1705 AbstractBase::set_shape(shape);
1706 }
1707
BuildType() const1708 TypePtr AbstractTensor::BuildType() const {
1709 MS_EXCEPTION_IF_NULL(element_);
1710 TypePtr element_type = element_->BuildType();
1711 return std::make_shared<TensorType>(element_type);
1712 }
1713
BuildShape() const1714 BaseShapePtr AbstractTensor::BuildShape() const {
1715 auto shape = GetShapeTrack();
1716 // Guard from using set_shape(nullptr)
1717 if (shape == nullptr) {
1718 return kNoShape;
1719 }
1720 return shape;
1721 }
1722
Join(const AbstractBasePtr & other)1723 AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
1724 MS_EXCEPTION_IF_NULL(other);
1725 auto other_type = other->BuildType();
1726 MS_EXCEPTION_IF_NULL(other_type);
1727 MS_EXCEPTION_IF_NULL(element_);
1728 if (other->isa<AbstractNegligible>()) {
1729 return shared_from_base<AbstractBase>();
1730 }
1731 // AbstractTensor join with AbstractUndetermined
1732 if (other_type->type_id() == kObjectTypeUndeterminedType) {
1733 auto other_undetermined_tensor = dyn_cast_ptr<AbstractUndetermined>(other);
1734 MS_EXCEPTION_IF_NULL(other_undetermined_tensor);
1735 // Check shape
1736 auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape());
1737 if (res_shape == nullptr) {
1738 ShapeJoinLogging(shape(), other_undetermined_tensor->shape(), shared_from_base<AbstractBase>(), other);
1739 }
1740 // Check element
1741 auto element = element_->Join(other_undetermined_tensor->element());
1742 MS_EXCEPTION_IF_NULL(element);
1743 return std::make_shared<AbstractUndetermined>(element, res_shape);
1744 }
1745
1746 // AbstractTensor join with AbstractTensor
1747 auto other_tensor = dyn_cast_ptr<AbstractTensor>(other);
1748 if (other_tensor == nullptr) {
1749 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1750 }
1751 if (*this == *other) {
1752 return shared_from_base<AbstractBase>();
1753 }
1754 // Check shape
1755 auto res_shape = ShapeJoin(this->shape(), other_tensor->shape());
1756 if (res_shape == nullptr) {
1757 ShapeJoinLogging(shape(), other_tensor->shape(), shared_from_base<AbstractBase>(), other);
1758 }
1759 // Check element
1760 auto element = element_->Join(other_tensor->element_);
1761 MS_EXCEPTION_IF_NULL(element);
1762 auto ret = std::make_shared<AbstractTensor>(element, res_shape);
1763 ret->set_is_adapter(is_adapter_);
1764 return ret;
1765 }
1766
equal_to(const AbstractTensor & other) const1767 bool AbstractTensor::equal_to(const AbstractTensor &other) const {
1768 if (this == &other) {
1769 return true;
1770 }
1771 // Check if both Tensor or both AdapterTensor.
1772 if (is_adapter() != other.is_adapter()) {
1773 return false;
1774 }
1775 const auto &v1 = GetValueTrack();
1776 const auto &v2 = other.GetValueTrack();
1777 MS_EXCEPTION_IF_NULL(v1);
1778 MS_EXCEPTION_IF_NULL(v2);
1779 // Check if both point to same specific value.
1780 if (!v1->isa<ValueAny>()) {
1781 return v1 == v2;
1782 }
1783 // Check if both are ValueAny.
1784 if (!v2->isa<ValueAny>()) {
1785 return false;
1786 }
1787 // Check element type.
1788 if (!IsEqual(element_, other.element_)) {
1789 return false;
1790 }
1791 // Check shape.
1792 return IsEqual(shape(), other.shape());
1793 }
1794
operator ==(const AbstractTensor & other) const1795 bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
1796
operator ==(const AbstractBase & other) const1797 bool AbstractTensor::operator==(const AbstractBase &other) const {
1798 if (this == &other) {
1799 return true;
1800 }
1801 if (tid() != other.tid()) {
1802 return false;
1803 }
1804 return equal_to(static_cast<const AbstractTensor &>(other));
1805 }
1806
Clone() const1807 AbstractBasePtr AbstractTensor::Clone() const {
1808 MS_EXCEPTION_IF_NULL(element_);
1809 auto clone = std::make_shared<AbstractTensor>(element_->Clone());
1810 ShapePtr shp = shape();
1811 clone->set_shape(shp->Clone());
1812 clone->set_value(GetValueTrack());
1813 clone->set_is_adapter(is_adapter());
1814 clone->SetSymbolicShape(this->GetSymbolicShape());
1815 clone->SetSymbolicValue(this->GetSymbolicValue());
1816 return clone;
1817 }
1818
Broaden() const1819 AbstractBasePtr AbstractTensor::Broaden() const {
1820 MS_EXCEPTION_IF_NULL(element_);
1821 auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
1822 auto shp = shape();
1823 MS_EXCEPTION_IF_NULL(shp);
1824 broaden->set_shape(shp->Clone());
1825 broaden->set_value(kValueAny);
1826 broaden->set_is_adapter(is_adapter());
1827 return broaden;
1828 }
1829
BroadenWithShape() const1830 AbstractBasePtr AbstractTensor::BroadenWithShape() const {
1831 MS_EXCEPTION_IF_NULL(element_);
1832 auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
1833 auto shp = shape()->Clone();
1834 MS_EXCEPTION_IF_NULL(shp);
1835 shp->Broaden();
1836 broaden->set_shape(shp);
1837 broaden->set_value(kValueAny);
1838 broaden->set_is_adapter(is_adapter());
1839 return broaden;
1840 }
1841
PartialBroaden() const1842 AbstractBasePtr AbstractTensor::PartialBroaden() const { return Broaden(); }
1843
ToString() const1844 std::string AbstractTensor::ToString() const {
1845 std::ostringstream buffer;
1846 BaseShapePtr shape_track = GetShapeTrack();
1847 MS_EXCEPTION_IF_NULL(shape_track);
1848 MS_EXCEPTION_IF_NULL(element_);
1849 auto value_track = GetValueTrack();
1850 MS_EXCEPTION_IF_NULL(value_track);
1851 std::string is_adapter = this->is_adapter() ? "True" : "False";
1852 buffer << type_name() << "("
1853 << "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
1854 << ", is adapter: " << is_adapter << ", value_ptr: " << value_track << ", value: " << value_track->ToString()
1855 << ")";
1856 return buffer.str();
1857 }
1858
BuildType() const1859 TypePtr AbstractAny::BuildType() const {
1860 MS_EXCEPTION_IF_NULL(element_);
1861 TypePtr element_type = element_->BuildType();
1862 return std::make_shared<AnyType>(element_type);
1863 }
1864
Join(const AbstractBasePtr & other)1865 AbstractBasePtr AbstractNegligible::Join(const AbstractBasePtr &other) {
1866 MS_EXCEPTION_IF_NULL(other);
1867 if (other->isa<AbstractScalar>()) {
1868 const auto &value_other = other->GetValueTrack();
1869 if (!value_other->isa<ValueAny>()) {
1870 return std::make_shared<AbstractAny>();
1871 }
1872 }
1873 return other;
1874 }
1875
Clone() const1876 AbstractBasePtr AbstractNegligible::Clone() const { return std::make_shared<AbstractNegligible>(); }
1877
Broaden() const1878 AbstractBasePtr AbstractNegligible::Broaden() const { return Clone(); }
1879
ToString() const1880 std::string AbstractNegligible::ToString() const { return type_name(); }
1881
BuildType() const1882 TypePtr AbstractNegligible::BuildType() const {
1883 MS_EXCEPTION_IF_NULL(element_);
1884 TypePtr element_type = element_->BuildType();
1885 return std::make_shared<NegligibleType>(element_type);
1886 }
1887
BuildType() const1888 TypePtr AbstractDictionary::BuildType() const {
1889 std::vector<std::pair<ValuePtr, TypePtr>> key_values;
1890 for (const auto &item : key_values_) {
1891 MS_EXCEPTION_IF_NULL(item.first);
1892 MS_EXCEPTION_IF_NULL(item.second);
1893 ValuePtr key_type = item.first->BuildValue();
1894 TypePtr value_type = item.second->BuildType();
1895 (void)key_values.emplace_back(key_type, value_type);
1896 }
1897 return std::make_shared<Dictionary>(key_values);
1898 }
1899
operator ==(const AbstractBase & other) const1900 bool AbstractDictionary::operator==(const AbstractBase &other) const {
1901 if (this == &other) {
1902 return true;
1903 }
1904 if (!other.isa<AbstractDictionary>()) {
1905 return false;
1906 }
1907 const auto &other_dict = dynamic_cast<const AbstractDictionary &>(other);
1908 if (key_values_.size() != other_dict.key_values_.size()) {
1909 return false;
1910 }
1911 for (size_t index = 0; index < key_values_.size(); ++index) {
1912 auto &kv1 = key_values_[index];
1913 auto &kv2 = other_dict.key_values_[index];
1914 if (!IsEqual(kv1.first, kv2.first) || !IsEqual(kv1.second, kv2.second)) {
1915 return false;
1916 }
1917 }
1918 return true;
1919 }
1920
Clone() const1921 AbstractBasePtr AbstractDictionary::Clone() const {
1922 std::vector<AbstractElementPair> kv;
1923 (void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
1924 [](const AbstractElementPair &item) {
1925 MS_EXCEPTION_IF_NULL(item.first);
1926 MS_EXCEPTION_IF_NULL(item.second);
1927 return std::make_pair(item.first->Clone(), item.second->Clone());
1928 });
1929 auto ret = std::make_shared<AbstractDictionary>(kv);
1930 ret->set_extra_info(extra_info_);
1931 return ret;
1932 }
1933
Broaden() const1934 AbstractBasePtr AbstractDictionary::Broaden() const {
1935 std::vector<AbstractElementPair> kv;
1936 (void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
1937 [](const AbstractElementPair &item) {
1938 MS_EXCEPTION_IF_NULL(item.second);
1939 return std::make_pair(item.first, item.second->Broaden());
1940 });
1941 auto ret = std::make_shared<AbstractDictionary>(kv);
1942 ret->set_extra_info(extra_info_);
1943 return ret;
1944 }
1945
ToString() const1946 std::string AbstractDictionary::ToString() const {
1947 std::ostringstream buffer;
1948 buffer << type_name() << "{ ";
1949 for (const auto &kv : key_values_) {
1950 MS_EXCEPTION_IF_NULL(kv.first);
1951 MS_EXCEPTION_IF_NULL(kv.second);
1952 buffer << "(" << kv.first->ToString() << ": " << kv.second->ToString() << ") ";
1953 }
1954 buffer << "}";
1955 return buffer.str();
1956 }
1957
hash() const1958 std::size_t AbstractDictionary::hash() const {
1959 std::size_t hash_sum = std::accumulate(key_values_.cbegin(), key_values_.cend(), tid(),
1960 [](std::size_t hash_sum, const AbstractElementPair &item) {
1961 MS_EXCEPTION_IF_NULL(item.first);
1962 MS_EXCEPTION_IF_NULL(item.second);
1963 hash_sum = hash_combine(hash_sum, item.first->hash());
1964 hash_sum = hash_combine(hash_sum, item.second->hash());
1965 return hash_sum;
1966 });
1967 return hash_sum;
1968 }
1969
RealBuildValue() const1970 ValuePtr AbstractDictionary::RealBuildValue() const {
1971 std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
1972 for (const auto &item : key_values_) {
1973 MS_EXCEPTION_IF_NULL(item.first);
1974 MS_EXCEPTION_IF_NULL(item.second);
1975 auto key_element_value = item.first->BuildValue();
1976 auto value_element_value = item.second->BuildValue();
1977 MS_EXCEPTION_IF_NULL(key_element_value);
1978 MS_EXCEPTION_IF_NULL(value_element_value);
1979 if (value_element_value->isa<ValueAny>()) {
1980 return kValueAny;
1981 }
1982 (void)key_values.emplace_back(key_element_value, value_element_value);
1983 }
1984 return std::make_shared<ValueDictionary>(key_values);
1985 }
1986
Join(const AbstractBasePtr & other)1987 AbstractBasePtr AbstractDictionary::Join(const AbstractBasePtr &other) {
1988 if (other->isa<AbstractNegligible>()) {
1989 return shared_from_base<AbstractBase>();
1990 }
1991 auto other_dict = other->cast<AbstractDictionaryPtr>();
1992 if (other_dict == nullptr) {
1993 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
1994 }
1995 if (elements().size() != other_dict->elements().size()) {
1996 MS_LOG(EXCEPTION)
1997 << "Join failed as dict don't have the same size. spec1: " << ::mindspore::ToString(elements())
1998 << ", spec2: " << ::mindspore::ToString(other_dict->elements())
1999 << ".\nFor more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed\n";
2000 }
2001
2002 auto JoinElement = [](const AbstractBasePtr &abs1, const AbstractBasePtr &abs2) -> AbstractBasePtr {
2003 MS_EXCEPTION_IF_NULL(abs1);
2004 auto joined_res = abs1->Join(abs2);
2005 MS_EXCEPTION_IF_NULL(joined_res);
2006 return joined_res;
2007 };
2008
2009 std::vector<AbstractElementPair> joined_key_values;
2010 bool changes = false;
2011 for (std::size_t i = 0; i < elements().size(); i++) {
2012 auto key_value = elements()[i];
2013 auto other_key_value = other_dict->elements()[i];
2014 auto joined_key = JoinElement(elements()[i].first, other_dict->elements()[i].first);
2015 if (joined_key != elements()[i].first) {
2016 changes = true;
2017 }
2018 auto joined_value = JoinElement(elements()[i].second, other_dict->elements()[i].second);
2019 if (joined_value != elements()[i].second) {
2020 changes = true;
2021 }
2022 (void)joined_key_values.emplace_back(joined_key, joined_value);
2023 }
2024 if (!changes) {
2025 return shared_from_base<AbstractBase>();
2026 }
2027 return std::make_shared<AbstractDictionary>(joined_key_values);
2028 }
2029
AbstractJTagged(const AbstractBasePtr & element)2030 AbstractJTagged::AbstractJTagged(const AbstractBasePtr &element) : element_(element) {}
2031
BuildType() const2032 TypePtr AbstractJTagged::BuildType() const {
2033 MS_EXCEPTION_IF_NULL(element_);
2034 TypePtr subtype = element_->BuildType();
2035 return std::make_shared<JTagged>(subtype);
2036 }
2037
Join(const AbstractBasePtr & other)2038 AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
2039 MS_EXCEPTION_IF_NULL(other);
2040 if (other->isa<AbstractNegligible>()) {
2041 return shared_from_base<AbstractBase>();
2042 }
2043 auto other_jtagged = dyn_cast_ptr<AbstractJTagged>(other);
2044 if (other_jtagged == nullptr) {
2045 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
2046 }
2047 MS_EXCEPTION_IF_NULL(element_);
2048 auto joined_elem = element_->Join(other_jtagged->element_);
2049 return std::make_shared<AbstractJTagged>(joined_elem);
2050 }
2051
operator ==(const AbstractBase & other) const2052 bool AbstractJTagged::operator==(const AbstractBase &other) const {
2053 if (this == &other) {
2054 return true;
2055 }
2056 if (!other.isa<AbstractJTagged>()) {
2057 return false;
2058 }
2059 const auto &other_jtagged = dynamic_cast<const AbstractJTagged &>(other);
2060 return IsEqual(element_, other_jtagged.element_);
2061 }
2062
ToString() const2063 std::string AbstractJTagged::ToString() const {
2064 std::ostringstream buffer;
2065 MS_EXCEPTION_IF_NULL(element_);
2066 buffer << type_name() << "("
2067 << "element: " << element_->ToString() << ")";
2068 return buffer.str();
2069 }
2070
Clone() const2071 AbstractBasePtr AbstractJTagged::Clone() const { return std::make_shared<AbstractJTagged>(element_->Clone()); }
2072
Broaden() const2073 AbstractBasePtr AbstractJTagged::Broaden() const { return std::make_shared<AbstractJTagged>(element_->Broaden()); }
2074
element()2075 AbstractBasePtr AbstractJTagged::element() { return element_; }
2076
hash() const2077 std::size_t AbstractJTagged::hash() const { return hash_combine(tid(), element_->hash()); }
2078
AbstractRefTensor(const AbstractTensorPtr & ref_value,const ValuePtr & ref_key_value)2079 AbstractRefTensor::AbstractRefTensor(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value)
2080 : AbstractTensor(*ref_value), ref_key_value_(ref_key_value) {
2081 set_type(std::make_shared<RefType>());
2082 set_is_adapter(ref_value->is_adapter());
2083 MS_EXCEPTION_IF_NULL(ref_key_value);
2084 if (ref_key_value != kValueAny && !ref_key_value->isa<RefKey>()) {
2085 MS_LOG(INTERNAL_EXCEPTION) << "ref_key_value must be kValueAny or RefKey, but got:" << ref_key_value->ToString();
2086 }
2087 }
2088
BuildType() const2089 TypePtr AbstractRefTensor::BuildType() const {
2090 auto type = AbstractTensor::BuildType();
2091 auto subtype = dyn_cast_ptr<TensorType>(type);
2092 MS_EXCEPTION_IF_NULL(subtype);
2093 return std::make_shared<RefType>(subtype);
2094 }
2095
operator ==(const AbstractBase & other) const2096 bool AbstractRefTensor::operator==(const AbstractBase &other) const {
2097 if (this == &other) {
2098 return true;
2099 }
2100 if (!other.isa<AbstractRefTensor>()) {
2101 return false;
2102 }
2103 return AbstractTensor::equal_to(dynamic_cast<const AbstractTensor &>(other));
2104 }
2105
Join(const std::shared_ptr<AbstractRefTensor> & other)2106 AbstractBasePtr AbstractRefTensor::Join(const std::shared_ptr<AbstractRefTensor> &other) {
2107 if (*this == *other) {
2108 return shared_from_base<AbstractRefTensor>();
2109 }
2110 if (other->isa<AbstractNegligible>()) {
2111 return shared_from_base<AbstractBase>();
2112 }
2113 // Firstly, join the ref_key_value.
2114 auto joined_ref_key = ValueJoin(ref_key_value_, other->ref_key_value_);
2115 // Secondly , join the tensor value.
2116 auto joined_tensor = AbstractTensor::Join(other)->cast<AbstractTensorPtr>();
2117 MS_EXCEPTION_IF_NULL(joined_tensor);
2118 return std::make_shared<AbstractRefTensor>(joined_tensor, joined_ref_key);
2119 }
2120
Join(const AbstractBasePtr & other)2121 AbstractBasePtr AbstractRefTensor::Join(const AbstractBasePtr &other) {
2122 MS_EXCEPTION_IF_NULL(other);
2123 // Abstract ref join abstract ref
2124 if (other->isa<AbstractRefTensor>()) {
2125 return AbstractRefTensor::Join(other->cast<AbstractRefPtr>());
2126 }
2127 if (other->isa<AbstractNegligible>()) {
2128 return shared_from_base<AbstractBase>();
2129 }
2130 // Abstract ref join other abstract are same to AbstractTensor::Join.
2131 auto joined_tensor = AbstractTensor::Join(other);
2132 if (!joined_tensor->isa<AbstractTensor>()) {
2133 MS_LOG(INTERNAL_EXCEPTION) << "Expect an AbstractTensor, but got:" << joined_tensor->ToString()
2134 << ", other:" << other->ToString();
2135 }
2136 return joined_tensor;
2137 }
2138
Clone() const2139 AbstractBasePtr AbstractRefTensor::Clone() const {
2140 auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
2141 return std::make_shared<AbstractRefTensor>(abs_tensor, ref_key_value_);
2142 }
2143
Broaden() const2144 AbstractBasePtr AbstractRefTensor::Broaden() const {
2145 // Always broaden for ref
2146 auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
2147 // Broaden the tensor value and keep the ref_key_value.
2148 auto ret = std::make_shared<AbstractRefTensor>(abs_tensor, ref_key_value_);
2149 return ret;
2150 }
2151
ToString() const2152 std::string AbstractRefTensor::ToString() const {
2153 std::ostringstream buffer;
2154 MS_EXCEPTION_IF_NULL(ref_key_value_);
2155 buffer << type_name() << "("
2156 << "key: " << ref_key_value_->ToString() << " ref_value: " << AbstractTensor::ToString();
2157 auto value = GetValueTrack();
2158 if (value != nullptr) {
2159 buffer << ", value: " << value->ToString();
2160 }
2161 buffer << ")";
2162 return buffer.str();
2163 }
2164
PartialBroaden() const2165 AbstractBasePtr AbstractRefTensor::PartialBroaden() const { return Clone(); }
2166
AbstractNone()2167 AbstractNone::AbstractNone() : AbstractBase() { set_type(std::make_shared<TypeNone>()); }
2168
BuildType() const2169 TypePtr AbstractNone::BuildType() const { return std::make_shared<TypeNone>(); }
2170
Clone() const2171 AbstractBasePtr AbstractNone::Clone() const { return std::make_shared<AbstractNone>(); }
2172
operator ==(const AbstractBase & other) const2173 bool AbstractNone::operator==(const AbstractBase &other) const { return other.isa<AbstractNone>(); }
2174
ToString() const2175 std::string AbstractNone::ToString() const {
2176 std::ostringstream buffer;
2177 buffer << type_name() << "(Value: None)";
2178 return buffer.str();
2179 }
2180
Join(const AbstractBasePtr & other)2181 AbstractBasePtr AbstractNone::Join(const AbstractBasePtr &other) {
2182 MS_EXCEPTION_IF_NULL(other);
2183 if (other->isa<AbstractNegligible>()) {
2184 return shared_from_base<AbstractBase>();
2185 }
2186 if (!other->isa<AbstractNone>()) {
2187 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
2188 }
2189 return shared_from_base<AbstractNone>();
2190 }
2191
AbstractNull()2192 AbstractNull::AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
2193
BuildType() const2194 TypePtr AbstractNull::BuildType() const { return std::make_shared<TypeNull>(); }
2195
Clone() const2196 AbstractBasePtr AbstractNull::Clone() const { return std::make_shared<AbstractNull>(); }
2197
RealBuildValue() const2198 ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
2199
operator ==(const AbstractBase & other) const2200 bool AbstractNull::operator==(const AbstractBase &other) const { return other.isa<AbstractNull>(); }
2201
ToString() const2202 std::string AbstractNull::ToString() const {
2203 std::ostringstream buffer;
2204 buffer << type_name() << "(Value: Null)";
2205 return buffer.str();
2206 }
2207
AbstractTimeOut()2208 AbstractTimeOut::AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
2209
BuildType() const2210 TypePtr AbstractTimeOut::BuildType() const { return std::make_shared<TypeNull>(); }
2211
Clone() const2212 AbstractBasePtr AbstractTimeOut::Clone() const { return std::make_shared<AbstractTimeOut>(); }
2213
operator ==(const AbstractBase & other) const2214 bool AbstractTimeOut::operator==(const AbstractBase &other) const { return other.isa<AbstractTimeOut>(); }
2215
ToString() const2216 std::string AbstractTimeOut::ToString() const {
2217 std::ostringstream buffer;
2218 buffer << "AbstractTimeOut "
2219 << "(Value: Null)";
2220 return buffer.str();
2221 }
2222
AbstractEllipsis()2223 AbstractEllipsis::AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); }
2224
BuildType() const2225 TypePtr AbstractEllipsis::BuildType() const { return std::make_shared<TypeEllipsis>(); }
2226
Clone() const2227 AbstractBasePtr AbstractEllipsis::Clone() const { return std::make_shared<AbstractEllipsis>(); }
2228
operator ==(const AbstractBase & other) const2229 bool AbstractEllipsis::operator==(const AbstractBase &other) const { return other.isa<AbstractEllipsis>(); }
2230
ToString() const2231 std::string AbstractEllipsis::ToString() const {
2232 std::ostringstream buffer;
2233 buffer << type_name() << "(Value: Ellipsis)";
2234 return buffer.str();
2235 }
2236
CloneAsTensor() const2237 AbstractBasePtr AbstractRefTensor::CloneAsTensor() const { return AbstractTensor::Clone(); }
2238
ref()2239 AbstractTensorPtr AbstractRefTensor::ref() { return shared_from_base<AbstractTensor>(); }
2240
ref_key_value() const2241 ValuePtr AbstractRefTensor::ref_key_value() const { return ref_key_value_; }
2242
AbstractSparseTensor(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2243 AbstractSparseTensor::AbstractSparseTensor(AbstractBasePtrList &&elements,
2244 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2245 : AbstractTuple(std::move(elements), tuple_nodes) {}
2246
AbstractSparseTensor(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2247 AbstractSparseTensor::AbstractSparseTensor(const AbstractBasePtrList &elements,
2248 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2249 : AbstractTuple(elements, tuple_nodes) {}
2250
BuildShape() const2251 BaseShapePtr AbstractSparseTensor::BuildShape() const {
2252 return std::make_shared<TupleShape>(ElementsShapeTupleRecursive());
2253 }
2254
AbstractRowTensor(const AbstractBasePtr & element,const BaseShapePtr & shape)2255 AbstractRowTensor::AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape)
2256 : AbstractUndetermined(element, shape) {}
2257
AbstractRowTensor(const TypePtr & element_type,const ShapeVector & shape)2258 AbstractRowTensor::AbstractRowTensor(const TypePtr &element_type, const ShapeVector &shape)
2259 : AbstractUndetermined(element_type, shape) {}
2260
indices() const2261 const AbstractTensorPtr AbstractRowTensor::indices() const { return indices_; }
2262
set_indices(const AbstractTensorPtr & indices)2263 void AbstractRowTensor::set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
2264
values() const2265 const AbstractTensorPtr AbstractRowTensor::values() const { return values_; }
2266
set_values(const AbstractTensorPtr & values)2267 void AbstractRowTensor::set_values(const AbstractTensorPtr &values) { values_ = values; }
2268
dense_shape() const2269 const AbstractTuplePtr AbstractRowTensor::dense_shape() const { return dense_shape_; }
2270
set_dense_shape(const AbstractTuplePtr & dense_shape)2271 void AbstractRowTensor::set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
2272
AbstractCOOTensor(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2273 AbstractCOOTensor::AbstractCOOTensor(AbstractBasePtrList &&elements,
2274 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2275 : AbstractSparseTensor(std::move(elements), tuple_nodes) {}
2276
AbstractCOOTensor(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2277 AbstractCOOTensor::AbstractCOOTensor(const AbstractBasePtrList &elements,
2278 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2279 : AbstractSparseTensor(elements, tuple_nodes) {}
2280
AbstractCSRTensor(AbstractBasePtrList && elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2281 AbstractCSRTensor::AbstractCSRTensor(AbstractBasePtrList &&elements,
2282 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2283 : AbstractSparseTensor(std::move(elements), tuple_nodes) {}
2284
AbstractCSRTensor(const AbstractBasePtrList & elements,const std::shared_ptr<AnfNodeWeakPtrList> & tuple_nodes)2285 AbstractCSRTensor::AbstractCSRTensor(const AbstractBasePtrList &elements,
2286 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes)
2287 : AbstractSparseTensor(elements, tuple_nodes) {}
2288
hash() const2289 std::size_t AbstractMonad::hash() const { return hash_combine({tid()}); }
2290
BuildType() const2291 TypePtr AbstractMonad::BuildType() const { return GetTypeTrack(); }
2292
Broaden() const2293 AbstractBasePtr AbstractMonad::Broaden() const { return AbstractBase::Broaden(); }
2294
ToString() const2295 std::string AbstractMonad::ToString() const {
2296 std::ostringstream buffer;
2297 buffer << type_name() << "(" << GetValueTrack()->ToString() << ")";
2298 return buffer.str();
2299 }
2300
AbstractMonad(const ValuePtr & value,const TypePtr & type)2301 AbstractMonad::AbstractMonad(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
2302
AbstractUMonad(const ValuePtr & value)2303 AbstractUMonad::AbstractUMonad(const ValuePtr &value) : AbstractMonad(value, kUMonadType) {}
2304
Clone() const2305 AbstractBasePtr AbstractUMonad::Clone() const { return std::make_shared<AbstractUMonad>(GetValueTrack()); }
2306
AbstractIOMonad(const ValuePtr & value)2307 AbstractIOMonad::AbstractIOMonad(const ValuePtr &value) : AbstractMonad(value, kIOMonadType) {}
2308
Clone() const2309 AbstractBasePtr AbstractIOMonad::Clone() const { return std::make_shared<AbstractIOMonad>(GetValueTrack()); }
2310
map_tensor_type() const2311 MapTensorTypePtr AbstractMapTensor::map_tensor_type() const { return dyn_cast<MapTensorType>(GetTypeTrack()); }
2312
shape() const2313 ShapePtr AbstractMapTensor::shape() const { return dyn_cast<Shape>(GetShapeTrack()); }
2314
value_shape() const2315 const ShapePtr &AbstractMapTensor::value_shape() const { return value_shape_; }
2316
ref_key_value() const2317 const ValuePtr &AbstractMapTensor::ref_key_value() const { return ref_key_value_; }
2318
default_value() const2319 const ValuePtr &AbstractMapTensor::default_value() const { return default_value_; }
2320
permit_filter_value() const2321 const ValuePtr &AbstractMapTensor::permit_filter_value() const { return permit_filter_value_; }
2322
evict_filter_value() const2323 const ValuePtr &AbstractMapTensor::evict_filter_value() const { return evict_filter_value_; }
2324
BuildType() const2325 TypePtr AbstractMapTensor::BuildType() const { return GetTypeTrack(); }
2326
BuildShape() const2327 BaseShapePtr AbstractMapTensor::BuildShape() const { return GetShapeTrack(); }
2328
BuildType() const2329 TypePtr AbstractKeywordArg::BuildType() const {
2330 MS_EXCEPTION_IF_NULL(arg_value_);
2331 TypePtr type = arg_value_->BuildType();
2332 return std::make_shared<Keyword>(arg_name_, type);
2333 }
2334
Clone() const2335 AbstractBasePtr AbstractKeywordArg::Clone() const {
2336 MS_EXCEPTION_IF_NULL(arg_value_);
2337 return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
2338 }
2339
Broaden() const2340 AbstractBasePtr AbstractKeywordArg::Broaden() const {
2341 MS_EXCEPTION_IF_NULL(arg_value_);
2342 return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
2343 }
2344
hash() const2345 std::size_t AbstractKeywordArg::hash() const {
2346 MS_EXCEPTION_IF_NULL(arg_value_);
2347 return hash_combine({tid(), std::hash<std::string>{}(arg_name_), arg_value_->hash()});
2348 }
2349
ToString() const2350 std::string AbstractKeywordArg::ToString() const {
2351 std::ostringstream buffer;
2352 MS_EXCEPTION_IF_NULL(arg_value_);
2353 buffer << type_name() << "(";
2354 buffer << "key: " << arg_name_;
2355 buffer << ", value: " << arg_value_->ToString();
2356 buffer << ")";
2357 return buffer.str();
2358 }
2359
operator ==(const AbstractBase & other) const2360 bool AbstractKeywordArg::operator==(const AbstractBase &other) const {
2361 if (this == &other) {
2362 return true;
2363 }
2364 if (!other.isa<AbstractKeywordArg>()) {
2365 return false;
2366 }
2367 return *this == static_cast<const AbstractKeywordArg &>(other);
2368 }
2369
operator ==(const AbstractKeywordArg & other) const2370 bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const {
2371 if (this == &other) {
2372 return true;
2373 }
2374 return other.arg_name_ == arg_name_ && IsEqual(other.arg_value_, arg_value_);
2375 }
2376
RealBuildValue() const2377 ValuePtr AbstractKeywordArg::RealBuildValue() const {
2378 MS_EXCEPTION_IF_NULL(arg_value_);
2379 ValuePtr value = arg_value_->BuildValue();
2380 MS_EXCEPTION_IF_NULL(value);
2381 if (value->isa<ValueAny>()) {
2382 return kValueAny;
2383 }
2384 return std::make_shared<KeywordArg>(arg_name_, value);
2385 }
2386
AbstractBasePtrListHash(const AbstractBasePtrList & args_abs_list)2387 std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_abs_list) {
2388 // Hash for empty list is zero.
2389 if (args_abs_list.empty()) {
2390 return 0;
2391 }
2392 // Hashing all elements is costly, we only calculate hash from
2393 // the first element and last few elements base on some experiments.
2394 // In some scenarios, this may lead high hash conflicts. Therefore,
2395 // we should use this hash function in hash tables that can tolerate
2396 // high hash conflicts, such as std::unordered_map.
2397 constexpr size_t kMaxLastElements = 4;
2398 const size_t n_args = args_abs_list.size();
2399 // Hash from list size and the first element.
2400 std::size_t hash_value = hash_combine(n_args, args_abs_list[0]->hash());
2401 // Hash from last few elements.
2402 const size_t start = ((n_args > kMaxLastElements) ? (n_args - kMaxLastElements) : 1);
2403 for (size_t i = start; i < n_args; ++i) {
2404 hash_value = hash_combine(hash_value, args_abs_list[i]->hash());
2405 }
2406 return hash_value;
2407 }
2408
AbstractBasePtrListDeepEqual(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs)2409 bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
2410 const std::size_t size = lhs.size();
2411 if (size != rhs.size()) {
2412 return false;
2413 }
2414 for (std::size_t i = 0; i < size; ++i) {
2415 if (!IsEqual(lhs[i], rhs[i])) {
2416 return false;
2417 }
2418 }
2419 return true;
2420 }
2421
2422 // SparseTensor
2423 template <typename T>
GetAbsPtrAt(size_t index) const2424 const T AbstractSparseTensor::GetAbsPtrAt(size_t index) const {
2425 if (index >= size()) {
2426 MS_LOG(EXCEPTION) << "Index should be in range of [0, " << size() << "), but got " << index
2427 << " for abstract: " << type_name();
2428 }
2429 AbstractBasePtr base = elements()[index];
2430 MS_EXCEPTION_IF_NULL(base);
2431 return base->cast<T>();
2432 }
2433
ElementsShapeTupleRecursive() const2434 BaseShapePtrList AbstractSparseTensor::ElementsShapeTupleRecursive() const {
2435 BaseShapePtrList element_shape_list;
2436 for (const auto &element : elements()) {
2437 MS_EXCEPTION_IF_NULL(element);
2438 auto abs_tuple = element->cast_ptr<AbstractTuple>();
2439 if (abs_tuple == nullptr) {
2440 element_shape_list.push_back(element->GetShape());
2441 } else {
2442 for (const auto &scalar : abs_tuple->elements()) {
2443 MS_EXCEPTION_IF_NULL(scalar);
2444 element_shape_list.push_back(scalar->GetShape());
2445 }
2446 }
2447 }
2448 return element_shape_list;
2449 }
2450
GetTensorTypeIdAt(size_t index) const2451 const TypeId AbstractSparseTensor::GetTensorTypeIdAt(size_t index) const {
2452 size_t shape_idx = size() - 1;
2453 if (index >= shape_idx) {
2454 MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape_idx << "), but got " << index << " for "
2455 << ToString();
2456 }
2457 auto abs_tensor = GetAbsPtrAt<abstract::AbstractTensorPtr>(index);
2458 MS_EXCEPTION_IF_NULL(abs_tensor);
2459 return abs_tensor->element()->BuildType()->type_id();
2460 }
2461
GetShapeTypeIdAt(size_t index) const2462 const TypeId AbstractSparseTensor::GetShapeTypeIdAt(size_t index) const {
2463 if (index >= shape()->size()) {
2464 MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape()->size() << "), but got " << index << " for "
2465 << ToString();
2466 }
2467 return shape()->elements()[index]->BuildType()->type_id();
2468 }
2469
BuildType() const2470 TypePtr AbstractSparseTensor::BuildType() const { return std::make_shared<SparseTensorType>(); }
2471
shape() const2472 const AbstractTuplePtr AbstractSparseTensor::shape() const {
2473 auto res = GetAbsPtrAt<abstract::AbstractTuplePtr>(size() - 1);
2474 if (res == nullptr) {
2475 MS_LOG(INTERNAL_EXCEPTION) << "Get shape nullptr in AbstractSparseTensor: " << ToString();
2476 }
2477 return res;
2478 }
2479
2480 // RowTensor
BuildType() const2481 TypePtr AbstractRowTensor::BuildType() const {
2482 MS_EXCEPTION_IF_NULL(element());
2483 TypePtr element_type = element()->BuildType();
2484 return std::make_shared<RowTensorType>(element_type);
2485 }
2486
Clone() const2487 AbstractBasePtr AbstractRowTensor::Clone() const {
2488 MS_EXCEPTION_IF_NULL(element());
2489 auto clone = std::make_shared<AbstractRowTensor>(element()->Clone());
2490 ShapePtr shp = shape();
2491 MS_EXCEPTION_IF_NULL(shp);
2492 clone->set_shape(shp->Clone());
2493 clone->set_value(GetValueTrack());
2494 MS_EXCEPTION_IF_NULL(indices_);
2495 MS_EXCEPTION_IF_NULL(values_);
2496 MS_EXCEPTION_IF_NULL(dense_shape_);
2497 auto indices_clone = indices_->Clone();
2498 auto value_clone = values_->Clone();
2499 auto dense_clone = dense_shape_->Clone();
2500 MS_EXCEPTION_IF_NULL(indices_clone);
2501 MS_EXCEPTION_IF_NULL(value_clone);
2502 MS_EXCEPTION_IF_NULL(dense_clone);
2503 clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
2504 clone->set_values(value_clone->cast<AbstractTensorPtr>());
2505 clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
2506 return clone;
2507 }
2508
MakeAbstract(const BaseShapePtr & shp) const2509 AbstractRowTensorPtr AbstractRowTensor::MakeAbstract(const BaseShapePtr &shp) const {
2510 MS_EXCEPTION_IF_NULL(element());
2511 auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
2512 broaden->set_shape(shp);
2513 broaden->set_value(kValueAny);
2514 MS_EXCEPTION_IF_NULL(indices_);
2515 MS_EXCEPTION_IF_NULL(values_);
2516 MS_EXCEPTION_IF_NULL(dense_shape_);
2517 auto indices_clone = indices_->Clone();
2518 auto value_clone = values_->Clone();
2519 auto dense_clone = dense_shape_->Clone();
2520 MS_EXCEPTION_IF_NULL(indices_clone);
2521 MS_EXCEPTION_IF_NULL(value_clone);
2522 MS_EXCEPTION_IF_NULL(dense_clone);
2523 broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
2524 broaden->set_values(value_clone->cast<AbstractTensorPtr>());
2525 broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
2526 return broaden;
2527 }
2528
Broaden() const2529 AbstractBasePtr AbstractRowTensor::Broaden() const {
2530 auto shp = shape()->Clone();
2531 MS_EXCEPTION_IF_NULL(shp);
2532 return MakeAbstract(shp);
2533 }
2534
BroadenWithShape() const2535 AbstractBasePtr AbstractRowTensor::BroadenWithShape() const {
2536 auto shp = shape()->Clone();
2537 MS_EXCEPTION_IF_NULL(shp);
2538 shp->Broaden();
2539 return MakeAbstract(shp);
2540 }
2541
ToString() const2542 std::string AbstractRowTensor::ToString() const {
2543 std::ostringstream buffer;
2544 BaseShapePtr shape_track = GetShapeTrack();
2545 MS_EXCEPTION_IF_NULL(shape_track);
2546 MS_EXCEPTION_IF_NULL(element());
2547 auto value_track = GetValueTrack();
2548 MS_EXCEPTION_IF_NULL(value_track);
2549 MS_EXCEPTION_IF_NULL(indices_);
2550 MS_EXCEPTION_IF_NULL(values_);
2551 MS_EXCEPTION_IF_NULL(dense_shape_);
2552 buffer << type_name() << "("
2553 << "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
2554 << ", value_ptr: " << value_track << ", value: " << value_track->ToString()
2555 << ", indices: " << indices_->ToString() << ", values: " << values_->ToString()
2556 << ", dense_shape: " << dense_shape_->ToString() << ")";
2557 return buffer.str();
2558 }
2559
2560 // COOTensor
BuildType() const2561 TypePtr AbstractCOOTensor::BuildType() const {
2562 MS_EXCEPTION_IF_NULL(indices());
2563 MS_EXCEPTION_IF_NULL(values());
2564 MS_EXCEPTION_IF_NULL(shape());
2565 TypePtrList elements{indices()->element()->BuildType(), values()->element()->BuildType()};
2566 (void)std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
2567 [](const AbstractBasePtr &p) { return p->BuildType(); });
2568 return std::make_shared<COOTensorType>(elements);
2569 }
2570
Clone() const2571 AbstractBasePtr AbstractCOOTensor::Clone() const {
2572 AbstractBasePtrList element_list;
2573 for (const auto &element : elements()) {
2574 MS_EXCEPTION_IF_NULL(element);
2575 AbstractBasePtr clone = element->Clone();
2576 element_list.push_back(clone);
2577 }
2578 return std::make_shared<abstract::AbstractCOOTensor>(element_list);
2579 }
2580
Broaden() const2581 AbstractBasePtr AbstractCOOTensor::Broaden() const {
2582 return std::make_shared<abstract::AbstractCOOTensor>(ElementsBroaden());
2583 }
2584
PartialBroaden() const2585 AbstractBasePtr AbstractCOOTensor::PartialBroaden() const { return Broaden(); }
2586
ToString() const2587 std::string AbstractCOOTensor::ToString() const {
2588 std::ostringstream buffer;
2589 auto indices_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2590 auto values_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2591 auto shape_ = GetAbsPtrAt<abstract::AbstractTuplePtr>(size() - 1);
2592 MS_EXCEPTION_IF_NULL(indices_);
2593 MS_EXCEPTION_IF_NULL(values_);
2594 MS_EXCEPTION_IF_NULL(shape_);
2595 buffer << type_name() << "("
2596 << "indices: " << indices_->ToString() << ", values" << values_->ToString()
2597 << ", dense_shape: " << shape_->ToString();
2598 return buffer.str();
2599 }
2600
indices() const2601 const AbstractTensorPtr AbstractCOOTensor::indices() const {
2602 auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2603 if (res == nullptr) {
2604 MS_LOG(INTERNAL_EXCEPTION) << "Get indices nullptr in AbstractCOOTensor: " << ToString();
2605 }
2606 return res;
2607 }
2608
values() const2609 const AbstractTensorPtr AbstractCOOTensor::values() const {
2610 auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2611 if (res == nullptr) {
2612 MS_LOG(INTERNAL_EXCEPTION) << "Get values nullptr in AbstractCOOTensor: " << ToString();
2613 }
2614 return res;
2615 }
2616
2617 // CSRTensor
BuildType() const2618 TypePtr AbstractCSRTensor::BuildType() const {
2619 MS_EXCEPTION_IF_NULL(indptr());
2620 MS_EXCEPTION_IF_NULL(indices());
2621 MS_EXCEPTION_IF_NULL(values());
2622 MS_EXCEPTION_IF_NULL(shape());
2623 TypePtrList elements{indptr()->element()->BuildType(), indices()->element()->BuildType(),
2624 values()->element()->BuildType()};
2625 (void)std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
2626 [](const AbstractBasePtr &p) { return p->BuildType(); });
2627 return std::make_shared<CSRTensorType>(elements);
2628 }
2629
Clone() const2630 AbstractBasePtr AbstractCSRTensor::Clone() const {
2631 AbstractBasePtrList element_list;
2632 for (const auto &element : elements()) {
2633 MS_EXCEPTION_IF_NULL(element);
2634 AbstractBasePtr clone = element->Clone();
2635 element_list.push_back(clone);
2636 }
2637 return std::make_shared<abstract::AbstractCSRTensor>(element_list);
2638 }
2639
Broaden() const2640 AbstractBasePtr AbstractCSRTensor::Broaden() const {
2641 return std::make_shared<abstract::AbstractCSRTensor>(ElementsBroaden());
2642 }
2643
PartialBroaden() const2644 AbstractBasePtr AbstractCSRTensor::PartialBroaden() const { return Broaden(); }
2645
ToString() const2646 std::string AbstractCSRTensor::ToString() const {
2647 std::ostringstream buffer;
2648 auto indptr_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndptrIdx);
2649 auto indices_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2650 auto values_ = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2651 auto shape_ = GetAbsPtrAt<abstract::AbstractTuplePtr>(size() - 1);
2652 MS_EXCEPTION_IF_NULL(indptr_);
2653 MS_EXCEPTION_IF_NULL(indices_);
2654 MS_EXCEPTION_IF_NULL(values_);
2655 MS_EXCEPTION_IF_NULL(shape_);
2656 buffer << type_name() << "("
2657 << "indptr: " << indptr_->ToString() << ", indices: " << indices_->ToString() << ", values"
2658 << values_->ToString() << ", dense_shape: " << shape_->ToString();
2659 return buffer.str();
2660 }
2661
indptr() const2662 const AbstractTensorPtr AbstractCSRTensor::indptr() const {
2663 auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndptrIdx);
2664 if (res == nullptr) {
2665 MS_LOG(INTERNAL_EXCEPTION) << "Get indptr nullptr in AbstractCSRTensor: " << ToString();
2666 }
2667 return res;
2668 }
2669
indices() const2670 const AbstractTensorPtr AbstractCSRTensor::indices() const {
2671 auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kIndicesIdx);
2672 if (res == nullptr) {
2673 MS_LOG(INTERNAL_EXCEPTION) << "Get indices nullptr in AbstractCSRTensor: " << ToString();
2674 }
2675 return res;
2676 }
2677
values() const2678 const AbstractTensorPtr AbstractCSRTensor::values() const {
2679 auto res = GetAbsPtrAt<abstract::AbstractTensorPtr>(kValuesIdx);
2680 if (res == nullptr) {
2681 MS_LOG(INTERNAL_EXCEPTION) << "Get values nullptr in AbstractCSRTensor: " << ToString();
2682 }
2683 return res;
2684 }
2685
AbstractMapTensor(const MapTensorPtr & map_tensor)2686 AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor)
2687 : AbstractBase(map_tensor, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
2688 std::make_shared<Shape>(map_tensor->shape())),
2689 ref_key_value_(kValueAny),
2690 default_value_(map_tensor->default_value()),
2691 permit_filter_value_(map_tensor->permit_filter_value()),
2692 evict_filter_value_(map_tensor->evict_filter_value()),
2693 value_shape_(std::make_shared<Shape>(map_tensor->value_shape())) {}
2694
AbstractMapTensor(const MapTensorPtr & map_tensor,const ValuePtr & ref_key_value)2695 AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor, const ValuePtr &ref_key_value)
2696 : AbstractBase(kValueAny, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
2697 std::make_shared<Shape>(map_tensor->shape())),
2698 ref_key_value_(ref_key_value),
2699 default_value_(map_tensor->default_value()),
2700 permit_filter_value_(map_tensor->permit_filter_value()),
2701 evict_filter_value_(map_tensor->evict_filter_value()),
2702 value_shape_(std::make_shared<Shape>(map_tensor->value_shape())) {}
2703
AbstractMapTensor(const AbstractMapTensor & other)2704 AbstractMapTensor::AbstractMapTensor(const AbstractMapTensor &other)
2705 : AbstractBase(other),
2706 ref_key_value_(other.ref_key_value_),
2707 default_value_(other.default_value_),
2708 permit_filter_value_(other.permit_filter_value()),
2709 evict_filter_value_(other.evict_filter_value()),
2710 value_shape_(other.value_shape_) {
2711 set_shape(other.shape());
2712 }
2713
AbstractMapTensor(const TypePtr & type,const ShapePtr & value_shape,const ValuePtr & value,const ValuePtr & ref_key_value,const ValuePtr & default_value,const ValuePtr & permit_filter_value,const ValuePtr & evict_filter_value)2714 AbstractMapTensor::AbstractMapTensor(const TypePtr &type, const ShapePtr &value_shape, const ValuePtr &value,
2715 const ValuePtr &ref_key_value, const ValuePtr &default_value,
2716 const ValuePtr &permit_filter_value, const ValuePtr &evict_filter_value) {
2717 set_value(value);
2718 set_type(type);
2719 ref_key_value_ = ref_key_value;
2720 default_value_ = default_value;
2721 permit_filter_value_ = permit_filter_value;
2722 evict_filter_value_ = evict_filter_value;
2723 ShapeVector shape = {abstract::Shape::kShapeDimAny};
2724 (void)shape.insert(shape.end(), value_shape->shape().begin(), value_shape->shape().end());
2725 set_shape(std::make_shared<mindspore::abstract::Shape>(shape));
2726 }
2727
operator =(const AbstractMapTensor & other)2728 AbstractMapTensor &AbstractMapTensor::operator=(const AbstractMapTensor &other) {
2729 if (this == &other) {
2730 return *this;
2731 }
2732 this->ref_key_value_ = other.ref_key_value();
2733 this->default_value_ = other.default_value();
2734 this->permit_filter_value_ = other.permit_filter_value();
2735 this->evict_filter_value_ = other.evict_filter_value();
2736 this->value_shape_ = other.value_shape_;
2737 this->set_shape(other.shape());
2738 return *this;
2739 }
2740
Clone() const2741 AbstractBasePtr AbstractMapTensor::Clone() const { return std::make_shared<AbstractMapTensor>(*this); }
2742
Join(const AbstractBasePtr & other)2743 AbstractBasePtr AbstractMapTensor::Join(const AbstractBasePtr &other) {
2744 MS_EXCEPTION_IF_NULL(other);
2745 if (other->isa<AbstractNegligible>()) {
2746 return shared_from_base<AbstractBase>();
2747 }
2748 // Same pointer.
2749 if (this == other.get()) {
2750 return shared_from_base<AbstractMapTensor>();
2751 }
2752
2753 // Check class.
2754 auto other_abs = dyn_cast<AbstractMapTensor>(other);
2755 if (other_abs == nullptr) {
2756 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
2757 }
2758
2759 // Join type.
2760 auto joined_type = TypeJoin(GetTypeTrack(), other_abs->GetTypeTrack());
2761 if (joined_type == kTypeAny) {
2762 TypeJoinLogging(GetTypeTrack(), other_abs->GetTypeTrack(), shared_from_base<AbstractBase>(), other);
2763 }
2764
2765 // Join shape
2766 auto joined_shape = ShapeJoin(value_shape(), other_abs->value_shape());
2767 if (joined_shape == nullptr) {
2768 ShapeJoinLogging(value_shape(), other_abs->value_shape(), shared_from_base<AbstractBase>(), other);
2769 }
2770
2771 // Join value.
2772 auto joined_value = (GetValueTrack() == other_abs->GetValueTrack() ? GetValueTrack() : kValueAny);
2773
2774 // Join the ref_key_value.
2775 auto joined_ref_key = ValueJoin(ref_key_value_, other_abs->ref_key_value_);
2776
2777 // Join the default_value.
2778 auto joined_default_value = ValueJoin(default_value_, other_abs->default_value_);
2779 if (joined_default_value->ContainsValueAny()) {
2780 MS_EXCEPTION(ValueError) << "Join default value failed for MapTensor. " << default_value_->ToString()
2781 << " != " << other_abs->default_value_->ToString();
2782 }
2783
2784 // Join the permit_filter_value.
2785 auto joined_permit_filter_value = ValueJoin(permit_filter_value_, other_abs->permit_filter_value_);
2786 if (joined_permit_filter_value->ContainsValueAny()) {
2787 MS_EXCEPTION(ValueError) << "Join default value failed for MapTensor. " << permit_filter_value_->ToString()
2788 << " != " << other_abs->permit_filter_value_->ToString();
2789 }
2790
2791 // Join the evict_filter_value.
2792 auto joined_evict_filter_value = ValueJoin(evict_filter_value_, other_abs->evict_filter_value_);
2793 if (joined_evict_filter_value->ContainsValueAny()) {
2794 MS_EXCEPTION(ValueError) << "Join evict_filter_value failed for MapTensor. " << evict_filter_value_->ToString()
2795 << " != " << other_abs->evict_filter_value_->ToString();
2796 }
2797
2798 return std::make_shared<AbstractMapTensor>(joined_type, joined_shape, joined_value, joined_ref_key,
2799 joined_default_value, joined_permit_filter_value,
2800 joined_evict_filter_value);
2801 }
2802
operator ==(const AbstractBase & other) const2803 bool AbstractMapTensor::operator==(const AbstractBase &other) const {
2804 if (this == &other) {
2805 return true;
2806 }
2807 if (!other.isa<AbstractMapTensor>()) {
2808 return false;
2809 }
2810 const auto &v1 = GetValueTrack();
2811 const auto &v2 = other.GetValueTrack();
2812 MS_EXCEPTION_IF_NULL(v1);
2813 MS_EXCEPTION_IF_NULL(v2);
2814 // Check if both point to same specific value.
2815 if (!v1->isa<ValueAny>()) {
2816 return v1 == v2;
2817 }
2818 // Check if both are ValueAny.
2819 if (!v2->isa<ValueAny>()) {
2820 return false;
2821 }
2822 const auto &other_map_tensor = dynamic_cast<const AbstractMapTensor &>(other);
2823 return common::IsEqual(GetTypeTrack(), other_map_tensor.GetTypeTrack()) &&
2824 common::IsEqual(GetShapeTrack(), other_map_tensor.GetShapeTrack()) &&
2825 common::IsEqual(default_value(), other_map_tensor.default_value());
2826 }
2827
hash() const2828 std::size_t AbstractMapTensor::hash() const {
2829 const auto &type = GetTypeTrack();
2830 const auto &value_shape = GetShapeTrack();
2831 MS_EXCEPTION_IF_NULL(type);
2832 MS_EXCEPTION_IF_NULL(value_shape);
2833 MS_EXCEPTION_IF_NULL(default_value_);
2834 std::size_t hash_value = hash_combine(tid(), type->hash());
2835 hash_value = hash_combine(hash_value, value_shape->hash());
2836 return hash_combine(hash_value, default_value_->hash());
2837 }
2838
ToString() const2839 std::string AbstractMapTensor::ToString() const {
2840 const auto &type = GetTypeTrack();
2841 const auto &value = GetValueTrack();
2842 const auto &value_shape = GetShapeTrack();
2843 MS_EXCEPTION_IF_NULL(type);
2844 MS_EXCEPTION_IF_NULL(value);
2845 MS_EXCEPTION_IF_NULL(value_shape);
2846 return type_name() + "(" + type->ToString() + " " + value_shape->ToString() +
2847 " key: " + (ref_key_value_ == nullptr ? "<null>" : ref_key_value_->ToString()) +
2848 " value: " + value->ToString() + ")";
2849 }
2850
Join(const AbstractBasePtr & other)2851 AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
2852 MS_EXCEPTION_IF_NULL(other);
2853 if (!other->isa<AbstractUMonad>()) {
2854 auto this_type = GetTypeTrack();
2855 auto other_type = other->GetTypeTrack();
2856 MS_EXCEPTION_IF_NULL(this_type);
2857 MS_EXCEPTION_IF_NULL(other);
2858 TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
2859 }
2860 return shared_from_base<AbstractBase>();
2861 }
2862
operator ==(const AbstractBase & other) const2863 bool AbstractUMonad::operator==(const AbstractBase &other) const { return other.isa<AbstractUMonad>(); }
2864
Join(const AbstractBasePtr & other)2865 AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) {
2866 MS_EXCEPTION_IF_NULL(other);
2867 if (!other->isa<AbstractIOMonad>()) {
2868 auto this_type = GetTypeTrack();
2869 auto other_type = other->GetTypeTrack();
2870 MS_EXCEPTION_IF_NULL(this_type);
2871 MS_EXCEPTION_IF_NULL(other);
2872 TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
2873 }
2874 return shared_from_base<AbstractBase>();
2875 }
2876
operator ==(const AbstractBase & other) const2877 bool AbstractIOMonad::operator==(const AbstractBase &other) const { return other.isa<AbstractIOMonad>(); }
2878
GetRefKeyValue(const AbstractBasePtr & abs)2879 ValuePtr GetRefKeyValue(const AbstractBasePtr &abs) {
2880 auto abs_ref = abs->cast_ptr<AbstractRefTensor>();
2881 if (abs_ref != nullptr) {
2882 return abs_ref->ref_key_value();
2883 }
2884 auto abs_map_tensor = abs->cast_ptr<AbstractMapTensor>();
2885 if (abs_map_tensor != nullptr) {
2886 return abs_map_tensor->ref_key_value();
2887 }
2888 return nullptr;
2889 }
2890 } // namespace abstract
2891 } // namespace mindspore
2892