1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/abstract_value.h"
20
21 #include <regex>
22 #include <algorithm>
23
24 #include "utils/symbolic.h"
25 #include "abstract/utils.h"
26 #include "utils/ms_context.h"
27 #include "utils/trace_base.h"
28
29 namespace mindspore {
30 namespace abstract {
GetTraceNode(const AbstractBasePtr & abs)31 AnfNodePtr GetTraceNode(const AbstractBasePtr &abs) {
32 AnfNodePtr node = nullptr;
33 if (mindspore::abstract::AbstractBase::trace_node_provider_ != nullptr) {
34 mindspore::abstract::AbstractBase::trace_node_provider_(&node);
35 }
36 return node;
37 }
38
AbstractTypeJoinLogging(const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)39 inline void AbstractTypeJoinLogging(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) {
40 std::ostringstream oss;
41 oss << "Type Join Failed: abstract type " << abstract1->type_name() << " cannot not join with "
42 << abstract2->type_name() << ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
43 << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
44 auto node = GetTraceNode(abstract1);
45 if (node != nullptr) {
46 oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
47 }
48 MS_EXCEPTION(TypeError) << oss.str();
49 }
50
TypeJoinLogging(const TypePtr & type1,const TypePtr & type2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)51 inline void TypeJoinLogging(const TypePtr &type1, const TypePtr &type2, const AbstractBasePtr &abstract1,
52 const AbstractBasePtr &abstract2) {
53 std::ostringstream oss;
54 oss << "Type Join Failed: dtype1 = " << type1->ToString() << ", dtype2 = " << type2->ToString()
55 << ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
56 << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
57 auto node = GetTraceNode(abstract1);
58 if (node != nullptr) {
59 oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
60 }
61 MS_EXCEPTION(TypeError) << oss.str();
62 }
63
ShapeJoinLogging(const BaseShapePtr & shape1,const BaseShapePtr & shape2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)64 inline void ShapeJoinLogging(const BaseShapePtr &shape1, const BaseShapePtr &shape2, const AbstractBasePtr &abstract1,
65 const AbstractBasePtr &abstract2) {
66 std::ostringstream oss;
67 oss << "Shape Join Failed: shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString()
68 << ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
69 << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
70 auto node = GetTraceNode(abstract1);
71 if (node != nullptr) {
72 oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
73 }
74 MS_EXCEPTION(ValueError) << oss.str();
75 }
76
ExtractLoggingInfo(const std::string & info)77 std::string ExtractLoggingInfo(const std::string &info) {
78 // Extract log information based on the keyword "Type Join Failed" or "Shape Join Failed"
79 std::regex e("(Type Join Failed|Shape Join Failed).*?\\.");
80 std::smatch result;
81 bool found = std::regex_search(info, result, e);
82 if (found) {
83 return result.str();
84 }
85 return "";
86 }
87
operator ==(const AbstractBase & other) const88 bool AbstractBase::operator==(const AbstractBase &other) const {
89 if (tid() != other.tid()) {
90 return false;
91 }
92 auto type = BuildType();
93 auto other_type = BuildType();
94 MS_EXCEPTION_IF_NULL(other_type);
95 MS_EXCEPTION_IF_NULL(type);
96 if (type->type_id() == kObjectTypeUndeterminedType && other_type->type_id() == kObjectTypeUndeterminedType) {
97 return true;
98 }
99 if (value_ == nullptr || other.value_ == nullptr) {
100 MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
101 << this->ToString() << ", other: " << other.ToString();
102 }
103
104 bool value_equal = false;
105 if (value_ == other.value_) {
106 value_equal = true;
107 } else if (*value_ == *other.value_) {
108 value_equal = true;
109 }
110 bool type_equal = false;
111 MS_EXCEPTION_IF_NULL(type_);
112 MS_EXCEPTION_IF_NULL(other.type_);
113 if (type_ == other.type_) {
114 type_equal = true;
115 } else if (*type_ == *other.type_) {
116 type_equal = true;
117 }
118 bool shape_equal = false;
119 MS_EXCEPTION_IF_NULL(shape_);
120 MS_EXCEPTION_IF_NULL(other.shape_);
121 if (shape_ == other.shape_) {
122 shape_equal = true;
123 } else if (*shape_ == *other.shape_) {
124 shape_equal = true;
125 }
126 return value_equal && type_equal && shape_equal;
127 }
128
BuildValue() const129 ValuePtr AbstractBase::BuildValue() const {
130 if (value_ == nullptr) {
131 return RealBuildValue();
132 }
133 return value_;
134 }
135
Broaden() const136 AbstractBasePtr AbstractBase::Broaden() const {
137 AbstractBasePtr clone = Clone();
138 MS_EXCEPTION_IF_NULL(clone);
139 clone->set_value(kAnyValue);
140 return clone;
141 }
142
PartialBroaden() const143 AbstractBasePtr AbstractBase::PartialBroaden() const { return Clone(); }
144
ToString() const145 std::string AbstractBase::ToString() const {
146 std::ostringstream buffer;
147 std::string value = std::string("value is null");
148 if (value_ != nullptr) {
149 value = value_->ToString();
150 }
151 MS_EXCEPTION_IF_NULL(type_);
152 MS_EXCEPTION_IF_NULL(shape_);
153 buffer << type_name() << "("
154 << "Type: " << type_->ToString() << ", Value: " << value << ", Shape: " << shape_->ToString() << ")";
155 return buffer.str();
156 }
157
Broaden() const158 AbstractBasePtr AbstractScalar::Broaden() const {
159 auto context = MsContext::GetInstance();
160 MS_EXCEPTION_IF_NULL(context);
161 if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
162 return AbstractBase::Broaden();
163 }
164 auto type_id = GetTypeTrack()->type_id();
165 if (type_id == kObjectTypeEnvType) {
166 return AbstractBase::Broaden();
167 }
168 return Clone();
169 }
170
Join(const AbstractBasePtr & other)171 AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
172 MS_EXCEPTION_IF_NULL(other);
173 if (*this == *other) {
174 return shared_from_base<AbstractBase>();
175 }
176 auto type_self = GetTypeTrack();
177 auto type_other = other->GetTypeTrack();
178 TypePtr res_type = TypeJoin(type_self, type_other);
179 if (res_type == kAnyType) {
180 TypeJoinLogging(type_self, type_other, shared_from_base<AbstractBase>(), other);
181 }
182 auto value_self = GetValueTrack();
183 auto value_other = other->GetValueTrack();
184 ValuePtr res_value = ValueJoin(value_self, value_other);
185 if (res_value == value_self) {
186 return shared_from_base<AbstractBase>();
187 }
188 return std::make_shared<AbstractScalar>(res_value, res_type);
189 }
190
Clone() const191 AbstractBasePtr AbstractType::Clone() const {
192 ValuePtr value_self = GetValueTrack();
193 if (value_self == nullptr || !value_self->isa<Type>()) {
194 return nullptr;
195 }
196 TypePtr type_self = value_self->cast<TypePtr>();
197 return std::make_shared<AbstractType>(type_self->Clone());
198 }
199
operator ==(const AbstractBase & other) const200 bool AbstractType::operator==(const AbstractBase &other) const {
201 if (tid() != other.tid()) {
202 return false;
203 }
204 // Have to compare TypePtr with value;
205 ValuePtr value_self = GetValueTrack();
206 ValuePtr value_other = other.GetValueTrack();
207 if (value_self == nullptr || value_other == nullptr) {
208 MS_LOG(EXCEPTION) << "AbstractType value should not be nullptr. this: " << this->ToString()
209 << ", other: " << other.ToString();
210 }
211 if (!value_self->isa<Type>() || !value_other->isa<Type>()) {
212 return false;
213 }
214 TypePtr type_self = value_self->cast<TypePtr>();
215 TypePtr type_other = value_other->cast<TypePtr>();
216 bool value_equal = *type_self == *type_other;
217 return value_equal;
218 }
219
ToString() const220 std::string AbstractType::ToString() const {
221 std::ostringstream buffer;
222 ValuePtr value_self = GetValueTrack();
223 if (value_self == nullptr) {
224 buffer << "AbstractType value: nullptr";
225 return buffer.str();
226 }
227 if (!value_self->isa<Type>()) {
228 buffer << type_name() << "(Value: nullptr)";
229 return buffer.str();
230 }
231 TypePtr type_self = value_self->cast<TypePtr>();
232 buffer << type_name() << "("
233 << "Value: " << type_self->ToString() << ")";
234 return buffer.str();
235 }
236
ToString() const237 std::string AbstractError::ToString() const {
238 std::ostringstream buffer;
239 auto value_track = GetValueTrack();
240 MS_EXCEPTION_IF_NULL(value_track);
241 MS_EXCEPTION_IF_NULL(node_);
242 buffer << type_name() << "("
243 << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")";
244 return buffer.str();
245 }
246
Join(const AbstractBasePtr & other)247 AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) {
248 MS_EXCEPTION_IF_NULL(other);
249 auto other_func = dyn_cast<AbstractFunction>(other);
250 if (other_func == nullptr) {
251 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
252 }
253 return Join(other_func);
254 }
255
operator ==(const AbstractBase & other) const256 bool AbstractFunction::operator==(const AbstractBase &other) const {
257 if (!other.isa<AbstractFunction>()) {
258 return false;
259 }
260 const auto &other_func = static_cast<const AbstractFunction &>(other);
261 bool value_equal = (*this == other_func);
262 return value_equal;
263 }
264
operator [](const std::size_t & dim) const265 const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const {
266 if (dim >= size()) {
267 MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list.";
268 }
269 return elements_[dim];
270 }
271
ToString() const272 std::string AbstractSequeue::ToString() const {
273 std::ostringstream buffer;
274 size_t i = 0;
275 size_t size = elements_.size();
276 for (const auto &ele : elements_) {
277 MS_EXCEPTION_IF_NULL(ele);
278 buffer << "element[" << i << "]: " << ele->ToString();
279 if (i < size - 1) {
280 buffer << ", ";
281 }
282 i++;
283 }
284 return buffer.str();
285 }
286
ElementsType() const287 TypePtrList AbstractSequeue::ElementsType() const {
288 TypePtrList element_type_list;
289 for (const auto &ele : elements_) {
290 MS_EXCEPTION_IF_NULL(ele);
291 TypePtr element_type = ele->BuildType();
292 element_type_list.push_back(element_type);
293 }
294 return element_type_list;
295 }
296
ElementsShape() const297 BaseShapePtrList AbstractSequeue::ElementsShape() const {
298 BaseShapePtrList element_shape_list;
299 for (const auto &ele : elements_) {
300 MS_EXCEPTION_IF_NULL(ele);
301 BaseShapePtr element_shape = ele->BuildShape();
302 element_shape_list.push_back(element_shape);
303 }
304 return element_shape_list;
305 }
306
ElementsClone() const307 AbstractBasePtrList AbstractSequeue::ElementsClone() const {
308 AbstractBasePtrList ele_list;
309 for (const auto &ele : elements_) {
310 MS_EXCEPTION_IF_NULL(ele);
311 AbstractBasePtr clone = ele->Clone();
312 ele_list.push_back(clone);
313 }
314 return ele_list;
315 }
316
ElementsBroaden() const317 AbstractBasePtrList AbstractSequeue::ElementsBroaden() const {
318 AbstractBasePtrList ele_list;
319 for (const auto &ele : elements_) {
320 MS_EXCEPTION_IF_NULL(ele);
321 AbstractBasePtr broadend = ele->Broaden();
322 ele_list.push_back(broadend);
323 }
324 return ele_list;
325 }
326
ElementsPartialBroaden() const327 AbstractBasePtrList AbstractSequeue::ElementsPartialBroaden() const {
328 AbstractBasePtrList ele_list;
329 for (const auto &ele : elements_) {
330 MS_EXCEPTION_IF_NULL(ele);
331 AbstractBasePtr broadend = ele->PartialBroaden();
332 ele_list.push_back(broadend);
333 }
334 return ele_list;
335 }
336
337 template <typename T>
ElementsBuildValue() const338 ValuePtr AbstractSequeue::ElementsBuildValue() const {
339 std::vector<ValuePtr> element_value_list;
340 for (const auto &ele : elements_) {
341 MS_EXCEPTION_IF_NULL(ele);
342 ValuePtr element_value = ele->BuildValue();
343 MS_EXCEPTION_IF_NULL(element_value);
344 if (element_value->isa<AnyValue>()) {
345 return kAnyValue;
346 }
347 element_value_list.push_back(element_value);
348 }
349 return std::make_shared<T>(element_value_list);
350 }
351 template ValuePtr AbstractSequeue::ElementsBuildValue<ValueTuple>() const;
352 template ValuePtr AbstractSequeue::ElementsBuildValue<ValueList>() const;
353
354 template <typename T>
ElementsJoin(const AbstractBasePtr & other)355 AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) {
356 MS_EXCEPTION_IF_NULL(other);
357 auto other_sequeue = dyn_cast<T>(other);
358 if (other_sequeue == nullptr) {
359 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
360 }
361 auto joined_list = AbstractJoin(elements_, other_sequeue->elements_);
362 bool changes = false;
363 for (std::size_t i = 0; i < elements_.size(); i++) {
364 if (elements_[i] != joined_list[i]) {
365 changes = true;
366 break;
367 }
368 }
369 if (!changes) {
370 return shared_from_base<AbstractBase>();
371 }
372 return std::make_shared<T>(joined_list);
373 }
374 template AbstractBasePtr AbstractSequeue::ElementsJoin<AbstractList>(const AbstractBasePtr &);
375 template AbstractBasePtr AbstractSequeue::ElementsJoin<AbstractTuple>(const AbstractBasePtr &);
376
hash() const377 std::size_t AbstractSequeue::hash() const {
378 std::size_t hash_sum = hash_combine(tid(), std::hash<size_t>{}(elements_.size()));
379 // Hashing all elements is costly, so only take at most 4 elements into account based on
380 // some experiments.
381 constexpr size_t max_elements_cnt = 4;
382 for (size_t i = 0; (i < elements_.size()) && (i < max_elements_cnt); i++) {
383 hash_sum = hash_combine(hash_sum, elements_[i]->hash());
384 }
385 return hash_sum;
386 }
387
operator ==(const AbstractSequeue & other) const388 bool AbstractSequeue::operator==(const AbstractSequeue &other) const {
389 if (&other == this) {
390 return true;
391 }
392
393 if (elements_.size() != other.elements_.size()) {
394 return false;
395 }
396 for (size_t i = 0; i < elements_.size(); i++) {
397 MS_EXCEPTION_IF_NULL(elements_[i]);
398 MS_EXCEPTION_IF_NULL(other.elements_[i]);
399 if (!(*(elements_[i]) == *(other.elements_[i]))) {
400 return false;
401 }
402 }
403 return true;
404 }
405
operator ==(const AbstractTuple & other) const406 bool AbstractTuple::operator==(const AbstractTuple &other) const { return AbstractSequeue::operator==(other); }
407
operator ==(const AbstractBase & other) const408 bool AbstractTuple::operator==(const AbstractBase &other) const {
409 if (&other == this) {
410 return true;
411 }
412
413 if (other.isa<AbstractTuple>()) {
414 auto other_tuple = static_cast<const AbstractTuple *>(&other);
415 return *this == *other_tuple;
416 }
417
418 return false;
419 }
420
operator ==(const AbstractList & other) const421 bool AbstractList::operator==(const AbstractList &other) const { return AbstractSequeue::operator==(other); }
422
operator ==(const AbstractBase & other) const423 bool AbstractList::operator==(const AbstractBase &other) const {
424 if (&other == this) {
425 return true;
426 }
427
428 if (other.isa<AbstractList>()) {
429 auto other_list = static_cast<const AbstractList *>(&other);
430 return *this == *other_list;
431 }
432 return false;
433 }
434
BuildType() const435 TypePtr AbstractSlice::BuildType() const {
436 MS_EXCEPTION_IF_NULL(start_);
437 MS_EXCEPTION_IF_NULL(stop_);
438 MS_EXCEPTION_IF_NULL(step_);
439 TypePtr start = start_->BuildType();
440 TypePtr stop = stop_->BuildType();
441 TypePtr step = step_->BuildType();
442 return std::make_shared<Slice>(start, stop, step);
443 }
444
operator ==(const AbstractSlice & other) const445 bool AbstractSlice::operator==(const AbstractSlice &other) const {
446 if (&other == this) {
447 return true;
448 }
449 return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_);
450 }
451
operator ==(const AbstractBase & other) const452 bool AbstractSlice::operator==(const AbstractBase &other) const {
453 if (&other == this) {
454 return true;
455 }
456 if (!other.isa<AbstractSlice>()) {
457 return false;
458 }
459 auto other_slice = static_cast<const AbstractSlice *>(&other);
460 return *this == *other_slice;
461 }
462
Clone() const463 AbstractBasePtr AbstractSlice::Clone() const {
464 MS_EXCEPTION_IF_NULL(start_);
465 MS_EXCEPTION_IF_NULL(stop_);
466 MS_EXCEPTION_IF_NULL(step_);
467 AbstractBasePtr start = start_->Clone();
468 AbstractBasePtr stop = stop_->Clone();
469 AbstractBasePtr step = step_->Clone();
470 return std::make_shared<AbstractSlice>(start, stop, step);
471 }
472
Broaden() const473 AbstractBasePtr AbstractSlice::Broaden() const {
474 MS_EXCEPTION_IF_NULL(start_);
475 MS_EXCEPTION_IF_NULL(stop_);
476 MS_EXCEPTION_IF_NULL(step_);
477 AbstractBasePtr start = start_->Broaden();
478 AbstractBasePtr stop = stop_->Broaden();
479 AbstractBasePtr step = step_->Broaden();
480 return std::make_shared<AbstractSlice>(start, stop, step);
481 }
482
ToString() const483 std::string AbstractSlice::ToString() const {
484 std::ostringstream buffer;
485 buffer << type_name() << "[";
486 MS_EXCEPTION_IF_NULL(start_);
487 buffer << start_->ToString() << " : ";
488 MS_EXCEPTION_IF_NULL(stop_);
489 buffer << stop_->ToString() << " : ";
490 MS_EXCEPTION_IF_NULL(step_);
491 buffer << step_->ToString();
492 buffer << "]";
493 return buffer.str();
494 }
495
RealBuildValue() const496 ValuePtr AbstractSlice::RealBuildValue() const {
497 MS_EXCEPTION_IF_NULL(start_);
498 MS_EXCEPTION_IF_NULL(stop_);
499 MS_EXCEPTION_IF_NULL(step_);
500 ValuePtr start = start_->BuildValue();
501 ValuePtr stop = stop_->BuildValue();
502 ValuePtr step = step_->BuildValue();
503 if (start->isa<AnyValue>() || stop->isa<AnyValue>() || step->isa<AnyValue>()) {
504 return kAnyValue;
505 }
506 return std::make_shared<ValueSlice>(start, stop, step);
507 }
508
hash() const509 std::size_t AbstractSlice::hash() const {
510 MS_EXCEPTION_IF_NULL(start_);
511 MS_EXCEPTION_IF_NULL(stop_);
512 MS_EXCEPTION_IF_NULL(step_);
513 return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
514 }
515
shape() const516 ShapePtr AbstractUndetermined::shape() const {
517 auto shp = dyn_cast<Shape>(GetShapeTrack());
518 if (shp == nullptr) {
519 MS_LOG(EXCEPTION) << "Tensor should have a shape.";
520 }
521 return shp;
522 }
523
set_shape(const BaseShapePtr & shape)524 void AbstractUndetermined::set_shape(const BaseShapePtr &shape) {
525 MS_EXCEPTION_IF_NULL(shape);
526 if (shape->isa<NoShape>()) {
527 MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
528 }
529 AbstractBase::set_shape(shape);
530 }
531
BuildType() const532 TypePtr AbstractTensor::BuildType() const {
533 MS_EXCEPTION_IF_NULL(element_);
534 TypePtr element_type = element_->BuildType();
535 return std::make_shared<TensorType>(element_type);
536 }
537
BuildShape() const538 BaseShapePtr AbstractTensor::BuildShape() const {
539 auto shape = GetShapeTrack();
540 // Guard from using set_shape(nullptr)
541 if (shape == nullptr) {
542 return kNoShape;
543 }
544 return shape;
545 }
546
Join(const AbstractBasePtr & other)547 AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
548 MS_EXCEPTION_IF_NULL(other);
549 auto type = other->BuildType();
550 MS_EXCEPTION_IF_NULL(type);
551 MS_EXCEPTION_IF_NULL(element_);
552
553 // AbstractTensor join with AbstractUndetermined
554 if (type->type_id() == kObjectTypeUndeterminedType) {
555 auto other_undetermined_tensor = dyn_cast<AbstractUndetermined>(other);
556 MS_EXCEPTION_IF_NULL(other_undetermined_tensor);
557 // check shape
558 auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape());
559 if (res_shape == nullptr) {
560 ShapeJoinLogging(shape(), other_undetermined_tensor->shape(), shared_from_base<AbstractBase>(), other);
561 }
562 // check element
563 auto element = element_->Join(other_undetermined_tensor->element());
564 MS_EXCEPTION_IF_NULL(element);
565 return std::make_shared<AbstractUndetermined>(element, res_shape);
566 }
567
568 // AbstractTensor join with AbstractTensor
569 auto other_tensor = dyn_cast<AbstractTensor>(other);
570 if (other_tensor == nullptr) {
571 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
572 }
573 if (*this == *other) {
574 return shared_from_base<AbstractBase>();
575 }
576 // check shape
577 auto res_shape = ShapeJoin(this->shape(), other_tensor->shape());
578 if (res_shape == nullptr) {
579 ShapeJoinLogging(shape(), other_tensor->shape(), shared_from_base<AbstractBase>(), other);
580 }
581 // check element
582 auto element = element_->Join(other_tensor->element_);
583 MS_EXCEPTION_IF_NULL(element);
584 return std::make_shared<AbstractTensor>(element, res_shape);
585 }
586
equal_to(const AbstractTensor & other) const587 bool AbstractTensor::equal_to(const AbstractTensor &other) const {
588 if (&other == this) {
589 return true;
590 }
591
592 auto v1 = GetValueTrack();
593 auto v2 = other.GetValueTrack();
594 if (v1 == nullptr || v2 == nullptr) {
595 MS_LOG(EXCEPTION) << "The value of AbstractTensor is nullptr";
596 }
597
598 bool is_value_equal = (v1 == v2);
599 if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
600 is_value_equal = true;
601 }
602 MS_EXCEPTION_IF_NULL(element_);
603 MS_EXCEPTION_IF_NULL(other.element_);
604 return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
605 }
606
operator ==(const AbstractTensor & other) const607 bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
608
operator ==(const AbstractBase & other) const609 bool AbstractTensor::operator==(const AbstractBase &other) const {
610 if (&other == this) {
611 return true;
612 }
613
614 if (other.tid() == tid()) {
615 auto other_tensor = static_cast<const AbstractTensor *>(&other);
616 return *this == *other_tensor;
617 } else {
618 return false;
619 }
620 }
621
Clone() const622 AbstractBasePtr AbstractTensor::Clone() const {
623 MS_EXCEPTION_IF_NULL(element_);
624 auto clone = std::make_shared<AbstractTensor>(element_->Clone());
625 ShapePtr shp = shape();
626 clone->set_shape(shp->Clone());
627 clone->set_value(GetValueTrack());
628 clone->set_value_range(get_min_value(), get_max_value());
629 return clone;
630 }
631
Broaden() const632 AbstractBasePtr AbstractTensor::Broaden() const {
633 MS_EXCEPTION_IF_NULL(element_);
634 auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
635 auto shp = shape();
636 MS_EXCEPTION_IF_NULL(shp);
637 broaden->set_shape(shp->Clone());
638 broaden->set_value(kAnyValue);
639 return broaden;
640 }
641
BroadenWithShape() const642 AbstractBasePtr AbstractTensor::BroadenWithShape() const {
643 MS_EXCEPTION_IF_NULL(element_);
644 auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
645 auto shp = shape()->Clone();
646 MS_EXCEPTION_IF_NULL(shp);
647 shp->Broaden();
648 broaden->set_shape(shp);
649 broaden->set_value(kAnyValue);
650 return broaden;
651 }
652
PartialBroaden() const653 AbstractBasePtr AbstractTensor::PartialBroaden() const { return Broaden(); }
654
ToString() const655 std::string AbstractTensor::ToString() const {
656 std::ostringstream buffer;
657 BaseShapePtr shape_track = GetShapeTrack();
658 MS_EXCEPTION_IF_NULL(shape_track);
659 MS_EXCEPTION_IF_NULL(element_);
660 auto value_track = GetValueTrack();
661 MS_EXCEPTION_IF_NULL(value_track);
662 buffer << type_name() << "("
663 << "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
664 << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
665 return buffer.str();
666 }
667
BuildType() const668 TypePtr AbstractDictionary::BuildType() const {
669 std::vector<std::pair<std::string, TypePtr>> key_values;
670 for (const auto &item : key_values_) {
671 MS_EXCEPTION_IF_NULL(item.second);
672 TypePtr type = item.second->BuildType();
673 key_values.emplace_back(item.first, type);
674 }
675 return std::make_shared<Dictionary>(key_values);
676 }
677
operator ==(const AbstractDictionary & other) const678 bool AbstractDictionary::operator==(const AbstractDictionary &other) const {
679 if (key_values_.size() != other.key_values_.size()) {
680 return false;
681 }
682
683 for (size_t index = 0; index < key_values_.size(); index++) {
684 if (key_values_[index].first != other.key_values_[index].first) {
685 return false;
686 }
687 MS_EXCEPTION_IF_NULL(key_values_[index].second);
688 MS_EXCEPTION_IF_NULL(other.key_values_[index].second);
689 if (!(*key_values_[index].second == *other.key_values_[index].second)) {
690 return false;
691 }
692 }
693 return true;
694 }
695
operator ==(const AbstractBase & other) const696 bool AbstractDictionary::operator==(const AbstractBase &other) const {
697 if (&other == this) {
698 return true;
699 }
700 if (other.isa<AbstractDictionary>()) {
701 auto other_class = static_cast<const AbstractDictionary *>(&other);
702 return *this == *other_class;
703 }
704 return false;
705 }
706
Clone() const707 AbstractBasePtr AbstractDictionary::Clone() const {
708 std::vector<AbstractAttribute> kv;
709 (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
710 [](const AbstractAttribute &item) {
711 MS_EXCEPTION_IF_NULL(item.second);
712 return std::make_pair(item.first, item.second->Clone());
713 });
714 return std::make_shared<AbstractDictionary>(kv);
715 }
716
Broaden() const717 AbstractBasePtr AbstractDictionary::Broaden() const {
718 std::vector<AbstractAttribute> kv;
719 (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
720 [](const AbstractAttribute &item) {
721 MS_EXCEPTION_IF_NULL(item.second);
722 return std::make_pair(item.first, item.second->Broaden());
723 });
724 return std::make_shared<AbstractDictionary>(kv);
725 }
726
ToString() const727 std::string AbstractDictionary::ToString() const {
728 std::ostringstream buffer;
729 buffer << type_name() << "{ ";
730 for (const auto &kv : key_values_) {
731 MS_EXCEPTION_IF_NULL(kv.second);
732 buffer << "(" << kv.first << ": " << kv.second->ToString() << ") ";
733 }
734 buffer << "}";
735 return buffer.str();
736 }
737
hash() const738 std::size_t AbstractDictionary::hash() const {
739 std::size_t hash_sum = std::accumulate(key_values_.begin(), key_values_.end(), tid(),
740 [](std::size_t hash_sum, const AbstractAttribute &item) {
741 hash_sum = hash_combine(hash_sum, std::hash<std::string>()(item.first));
742 MS_EXCEPTION_IF_NULL(item.second);
743 hash_sum = hash_combine(hash_sum, item.second->hash());
744 return hash_sum;
745 });
746 return hash_sum;
747 }
748
RealBuildValue() const749 ValuePtr AbstractDictionary::RealBuildValue() const {
750 std::vector<std::pair<std::string, ValuePtr>> key_values;
751 for (const auto &item : key_values_) {
752 MS_EXCEPTION_IF_NULL(item.second);
753 auto element_value = item.second->BuildValue();
754 MS_EXCEPTION_IF_NULL(element_value);
755 if (element_value->isa<AnyValue>()) {
756 return kAnyValue;
757 }
758 key_values.emplace_back(item.first, element_value);
759 }
760 return std::make_shared<ValueDictionary>(key_values);
761 }
762
BuildType() const763 TypePtr AbstractClass::BuildType() const {
764 ClassAttrVector attributes_type;
765 for (const auto &attr : attributes_) {
766 MS_EXCEPTION_IF_NULL(attr.second);
767 TypePtr type = attr.second->BuildType();
768 std::pair<std::string, TypePtr> elem(attr.first, type);
769 attributes_type.push_back(elem);
770 }
771
772 return std::make_shared<Class>(tag_, attributes_type, methods_);
773 }
774
operator ==(const AbstractClass & other) const775 bool AbstractClass::operator==(const AbstractClass &other) const {
776 if (!(tag_ == other.tag_)) {
777 return false;
778 }
779 if (attributes_.size() != other.attributes_.size()) {
780 return false;
781 }
782 for (size_t i = 0; i < attributes_.size(); i++) {
783 MS_EXCEPTION_IF_NULL(attributes_[i].second);
784 MS_EXCEPTION_IF_NULL(other.attributes_[i].second);
785 if (!(*attributes_[i].second == *other.attributes_[i].second)) {
786 MS_LOG(DEBUG) << "attr " << attributes_[i].first << " not equal, arg1:" << attributes_[i].second->ToString()
787 << " arg2:" << other.attributes_[i].second->ToString();
788 return false;
789 }
790 }
791 // method compare;
792 if (methods_.size() != other.methods_.size()) {
793 return false;
794 }
795 for (const auto &iter : methods_) {
796 auto iter_other = other.methods_.find(iter.first);
797 if (iter_other == other.methods_.end()) {
798 return false;
799 }
800 if (!(*iter.second == *iter_other->second)) {
801 return false;
802 }
803 }
804 return true;
805 }
806
operator ==(const AbstractBase & other) const807 bool AbstractClass::operator==(const AbstractBase &other) const {
808 if (other.isa<AbstractClass>()) {
809 auto other_class = static_cast<const AbstractClass *>(&other);
810 return *this == *other_class;
811 }
812 return false;
813 }
814
GetAttribute(const std::string & name)815 AbstractBasePtr AbstractClass::GetAttribute(const std::string &name) {
816 auto it = std::find_if(attributes_.begin(), attributes_.end(),
817 [name](const AbstractAttribute &pair) -> bool { return pair.first == name; });
818 if (it != attributes_.end()) {
819 return it->second;
820 }
821 return nullptr;
822 }
823
GetMethod(const std::string & name)824 ValuePtr AbstractClass::GetMethod(const std::string &name) {
825 auto method_pair = methods_.find(name);
826 if (method_pair != methods_.end()) {
827 return method_pair->second;
828 }
829 return kAnyValue;
830 }
831
Clone() const832 AbstractBasePtr AbstractClass::Clone() const {
833 std::vector<AbstractAttribute> attributes_clone;
834 for (const auto &attr : attributes_) {
835 MS_EXCEPTION_IF_NULL(attr.second);
836 AbstractBasePtr clone = attr.second->Clone();
837 AbstractAttribute elem(attr.first, clone);
838 attributes_clone.push_back(elem);
839 }
840 return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
841 }
842
Broaden() const843 AbstractBasePtr AbstractClass::Broaden() const {
844 std::vector<AbstractAttribute> attributes_clone;
845 for (const auto &attr : attributes_) {
846 MS_EXCEPTION_IF_NULL(attr.second);
847 AbstractBasePtr clone = attr.second->Broaden();
848 AbstractAttribute elem(attr.first, clone);
849 attributes_clone.push_back(elem);
850 }
851 return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
852 }
853
ToString() const854 std::string AbstractClass::ToString() const {
855 std::ostringstream buffer;
856 buffer << type_name() << "(tag: " << tag_ << ") attrs:(";
857 bool append_comma = false;
858 for (const auto &attr : attributes_) {
859 if (append_comma) {
860 buffer << ", ";
861 } else {
862 append_comma = true;
863 }
864 MS_EXCEPTION_IF_NULL(attr.second);
865 buffer << attr.first << ":" << attr.second->ToString();
866 }
867 buffer << ") method:(";
868 append_comma = false;
869 for (const auto &iter : methods_) {
870 if (append_comma) {
871 buffer << ", ";
872 } else {
873 append_comma = true;
874 }
875 MS_EXCEPTION_IF_NULL(iter.second);
876 buffer << iter.first << ":" << iter.second->ToString();
877 }
878 buffer << ")";
879 return buffer.str();
880 }
881
hash() const882 std::size_t AbstractClass::hash() const {
883 std::size_t hash_sum = std::accumulate(attributes_.begin(), attributes_.end(), hash_combine(tid(), tag_.hash()),
884 [](std::size_t hash_sum, const AbstractAttribute &item) {
885 MS_EXCEPTION_IF_NULL(item.second);
886 return hash_combine(hash_sum, item.second->hash());
887 });
888
889 return hash_sum;
890 }
891
RealBuildValue() const892 ValuePtr AbstractClass::RealBuildValue() const {
893 auto type = BuildType();
894 MS_EXCEPTION_IF_NULL(type);
895 auto cls = type->cast<ClassPtr>();
896 std::unordered_map<std::string, ValuePtr> attributes_value_map;
897 for (const auto &attr : attributes_) {
898 MS_EXCEPTION_IF_NULL(attr.second);
899 ValuePtr value = attr.second->BuildValue();
900 MS_EXCEPTION_IF_NULL(value);
901 if (value->isa<AnyValue>()) {
902 return kAnyValue;
903 }
904 attributes_value_map[attr.first] = value;
905 }
906 cls->set_value(attributes_value_map);
907 return cls;
908 }
909
BuildType() const910 TypePtr AbstractJTagged::BuildType() const {
911 MS_EXCEPTION_IF_NULL(element_);
912 TypePtr subtype = element_->BuildType();
913 return std::make_shared<JTagged>(subtype);
914 }
915
Join(const AbstractBasePtr & other)916 AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
917 MS_EXCEPTION_IF_NULL(other);
918 auto other_jtagged = dyn_cast<AbstractJTagged>(other);
919 if (other_jtagged == nullptr) {
920 AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
921 }
922 MS_EXCEPTION_IF_NULL(element_);
923 auto joined_elem = element_->Join(other_jtagged->element_);
924 return std::make_shared<AbstractJTagged>(joined_elem);
925 }
926
operator ==(const AbstractJTagged & other) const927 bool AbstractJTagged::operator==(const AbstractJTagged &other) const {
928 MS_EXCEPTION_IF_NULL(element_);
929 MS_EXCEPTION_IF_NULL(other.element_);
930 return (*element_ == *other.element_);
931 }
932
operator ==(const AbstractBase & other) const933 bool AbstractJTagged::operator==(const AbstractBase &other) const {
934 if (other.isa<AbstractJTagged>()) {
935 auto other_jtagged = static_cast<const AbstractJTagged *>(&other);
936 return *this == *other_jtagged;
937 }
938 return false;
939 }
940
ToString() const941 std::string AbstractJTagged::ToString() const {
942 std::ostringstream buffer;
943 MS_EXCEPTION_IF_NULL(element_);
944 buffer << type_name() << "("
945 << "element: " << element_->ToString() << ")";
946 return buffer.str();
947 }
948
AbstractRef(const AbstractBasePtr & ref_key,const AbstractTensorPtr & ref_value)949 AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value)
950 : AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) {
951 set_type(std::make_shared<RefType>());
952 if (ref_key && ref_key->isa<AbstractRefKey>()) {
953 ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
954 }
955 }
956
BuildType() const957 TypePtr AbstractRef::BuildType() const {
958 auto type = AbstractTensor::BuildType();
959 MS_EXCEPTION_IF_NULL(type);
960 auto subtype = type->cast<TensorTypePtr>();
961 return std::make_shared<RefType>(subtype);
962 }
963
operator ==(const AbstractRef & other) const964 bool AbstractRef::operator==(const AbstractRef &other) const {
965 return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_);
966 }
967
operator ==(const AbstractBase & other) const968 bool AbstractRef::operator==(const AbstractBase &other) const {
969 if (other.isa<AbstractRef>()) {
970 auto other_conf = static_cast<const AbstractRef *>(&other);
971 return *this == *other_conf;
972 }
973 return false;
974 }
975
Join(const AbstractBasePtr & other)976 AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
977 MS_EXCEPTION_IF_NULL(other);
978 if (*this == *other) {
979 auto ret = shared_from_base<AbstractBase>();
980 return ret;
981 }
982 auto value_self = GetValueTrack();
983 MS_EXCEPTION_IF_NULL(value_self);
984 ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
985 if (res_value == value_self) {
986 auto ret = shared_from_base<AbstractBase>();
987 return ret;
988 }
989 auto ret = std::make_shared<AbstractRefKey>();
990 ret->set_value(res_value);
991 return ret;
992 }
993
Join(const AbstractBasePtr & other)994 AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
995 MS_EXCEPTION_IF_NULL(other);
996 auto other_ref = other->cast<AbstractRefPtr>();
997 if (other_ref == nullptr) {
998 auto join_abs = AbstractTensor::Join(other);
999 MS_EXCEPTION_IF_NULL(join_abs);
1000 return join_abs->cast<AbstractTensorPtr>();
1001 }
1002 MS_EXCEPTION_IF_NULL(ref_key_);
1003 MS_EXCEPTION_IF_NULL(other_ref->ref_key_);
1004 if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
1005 return shared_from_base<AbstractBase>();
1006 }
1007 auto ref_key = ref_key_->Join(other_ref->ref_key_);
1008 auto joined_abs_tensor = other_ref->ref();
1009 MS_EXCEPTION_IF_NULL(joined_abs_tensor);
1010 auto ref = AbstractTensor::Join(joined_abs_tensor);
1011 MS_EXCEPTION_IF_NULL(ref);
1012 auto ref_tensor = ref->cast<AbstractTensorPtr>();
1013 MS_EXCEPTION_IF_NULL(ref_tensor);
1014 return std::make_shared<AbstractRef>(ref_key, ref_tensor);
1015 }
1016
ToString() const1017 std::string AbstractRef::ToString() const {
1018 std::ostringstream buffer;
1019 MS_EXCEPTION_IF_NULL(ref_key_);
1020 buffer << type_name() << "("
1021 << "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString();
1022 auto value = GetValueTrack();
1023 if (value != nullptr) {
1024 buffer << ", value: " << value->ToString();
1025 }
1026 buffer << ")";
1027 return buffer.str();
1028 }
1029
PartialBroaden() const1030 AbstractBasePtr AbstractRef::PartialBroaden() const { return Clone(); }
1031
operator ==(const AbstractNone &) const1032 bool AbstractNone::operator==(const AbstractNone &) const { return true; }
1033
operator ==(const AbstractBase & other) const1034 bool AbstractNone::operator==(const AbstractBase &other) const {
1035 if (other.isa<AbstractNone>()) {
1036 auto other_none = static_cast<const AbstractNone *>(&other);
1037 return *this == *other_none;
1038 }
1039 return false;
1040 }
1041
ToString() const1042 std::string AbstractNone::ToString() const {
1043 std::ostringstream buffer;
1044 buffer << type_name() << "(Value: None)";
1045 return buffer.str();
1046 }
1047
RealBuildValue() const1048 ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
1049
operator ==(const AbstractRefKey & other) const1050 bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
1051 ValuePtr value_self = GetValueTrack();
1052 ValuePtr value_other = other.GetValueTrack();
1053 if (value_self != nullptr && value_other != nullptr) {
1054 if (value_self->isa<AnyValue>() && value_other->isa<AnyValue>()) {
1055 return true;
1056 }
1057 if (!value_self->isa<RefKey>() || !value_other->isa<RefKey>()) {
1058 return false;
1059 }
1060 RefKeyPtr type_self = value_self->cast<RefKeyPtr>();
1061 RefKeyPtr type_other = value_other->cast<RefKeyPtr>();
1062 return *type_self == *type_other;
1063 } else if (value_self != nullptr || value_other != nullptr) {
1064 return false;
1065 }
1066 return true;
1067 }
1068
operator ==(const AbstractBase & other) const1069 bool AbstractRefKey::operator==(const AbstractBase &other) const {
1070 if (other.isa<AbstractRefKey>()) {
1071 auto other_confkey = static_cast<const AbstractRefKey *>(&other);
1072 return *this == *other_confkey;
1073 } else {
1074 return false;
1075 }
1076 }
1077
ToString() const1078 std::string AbstractRefKey::ToString() const {
1079 std::ostringstream buffer;
1080 buffer << type_name();
1081 auto value = GetValueTrack();
1082 if (value != nullptr) {
1083 buffer << "(value: " << value->ToString() << ")";
1084 }
1085 return buffer.str();
1086 }
1087
operator ==(const AbstractNull &) const1088 bool AbstractNull::operator==(const AbstractNull &) const { return true; }
1089
operator ==(const AbstractBase & other) const1090 bool AbstractNull::operator==(const AbstractBase &other) const {
1091 if (&other == this) {
1092 return true;
1093 }
1094 if (other.isa<AbstractNull>()) {
1095 auto other_none = static_cast<const AbstractNull *>(&other);
1096 return *this == *other_none;
1097 } else {
1098 return false;
1099 }
1100 }
1101
ToString() const1102 std::string AbstractNull::ToString() const {
1103 std::ostringstream buffer;
1104 buffer << type_name() << "(Value: Null)";
1105 return buffer.str();
1106 }
1107
operator ==(const AbstractTimeOut &) const1108 bool AbstractTimeOut::operator==(const AbstractTimeOut &) const { return true; }
1109
operator ==(const AbstractBase & other) const1110 bool AbstractTimeOut::operator==(const AbstractBase &other) const {
1111 if (&other == this) {
1112 return true;
1113 }
1114 if (other.isa<AbstractTimeOut>()) {
1115 auto other_none = static_cast<const AbstractTimeOut *>(&other);
1116 return *this == *other_none;
1117 } else {
1118 return false;
1119 }
1120 }
1121
ToString() const1122 std::string AbstractTimeOut::ToString() const {
1123 std::ostringstream buffer;
1124 buffer << "AbstractTimeOut "
1125 << "(Value: Null)";
1126 return buffer.str();
1127 }
1128
operator ==(const AbstractEllipsis &) const1129 bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
1130
operator ==(const AbstractBase & other) const1131 bool AbstractEllipsis::operator==(const AbstractBase &other) const {
1132 if (&other == this) {
1133 return true;
1134 }
1135 if (other.isa<AbstractEllipsis>()) {
1136 auto other_none = static_cast<const AbstractEllipsis *>(&other);
1137 return *this == *other_none;
1138 } else {
1139 return false;
1140 }
1141 }
1142
ToString() const1143 std::string AbstractEllipsis::ToString() const {
1144 std::ostringstream buffer;
1145 buffer << type_name() << "(Value: Ellipsis)";
1146 return buffer.str();
1147 }
1148
BuildType() const1149 TypePtr AbstractKeywordArg::BuildType() const {
1150 MS_EXCEPTION_IF_NULL(arg_value_);
1151 TypePtr type = arg_value_->BuildType();
1152 return std::make_shared<Keyword>(arg_name_, type);
1153 }
1154
Clone() const1155 AbstractBasePtr AbstractKeywordArg::Clone() const {
1156 MS_EXCEPTION_IF_NULL(arg_value_);
1157 return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
1158 }
1159
Broaden() const1160 AbstractBasePtr AbstractKeywordArg::Broaden() const {
1161 MS_EXCEPTION_IF_NULL(arg_value_);
1162 return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
1163 }
1164
hash() const1165 std::size_t AbstractKeywordArg::hash() const {
1166 MS_EXCEPTION_IF_NULL(arg_value_);
1167 return hash_combine({tid(), std::hash<std::string>{}(arg_name_), arg_value_->hash()});
1168 }
1169
ToString() const1170 std::string AbstractKeywordArg::ToString() const {
1171 std::ostringstream buffer;
1172 MS_EXCEPTION_IF_NULL(arg_value_);
1173 buffer << type_name() << "(";
1174 buffer << "key : " << arg_name_;
1175 buffer << "value : " << arg_value_->ToString();
1176 buffer << ")";
1177 return buffer.str();
1178 }
1179
operator ==(const AbstractBase & other) const1180 bool AbstractKeywordArg::operator==(const AbstractBase &other) const {
1181 if (&other == this) {
1182 return true;
1183 }
1184
1185 if (other.isa<AbstractKeywordArg>()) {
1186 auto other_tuple = static_cast<const AbstractKeywordArg *>(&other);
1187 return *this == *other_tuple;
1188 }
1189 return false;
1190 }
1191
operator ==(const AbstractKeywordArg & other) const1192 bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const {
1193 if (&other == this) {
1194 return true;
1195 }
1196 MS_EXCEPTION_IF_NULL(arg_value_);
1197 MS_EXCEPTION_IF_NULL(other.arg_value_);
1198 return other.arg_name_ == arg_name_ && *other.arg_value_ == *arg_value_;
1199 }
1200
RealBuildValue() const1201 ValuePtr AbstractKeywordArg::RealBuildValue() const {
1202 MS_EXCEPTION_IF_NULL(arg_value_);
1203 ValuePtr value = arg_value_->BuildValue();
1204 MS_EXCEPTION_IF_NULL(value);
1205 if (value->isa<AnyValue>()) {
1206 return kAnyValue;
1207 }
1208 return std::make_shared<KeywordArg>(arg_name_, value);
1209 }
1210
AbstractBasePtrListHash(const AbstractBasePtrList & args_spec_list)1211 std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) {
1212 std::size_t hash_value = 0;
1213 // Hashing all elements is costly, so only take at most 4 elements into account based on
1214 // some experiments.
1215 constexpr auto kMaxElementsNum = 4;
1216 for (size_t i = 0; (i < args_spec_list.size()) && (i < kMaxElementsNum); i++) {
1217 MS_EXCEPTION_IF_NULL(args_spec_list[i]);
1218 hash_value = hash_combine(hash_value, args_spec_list[i]->hash());
1219 }
1220 return hash_value;
1221 }
1222
AbstractBasePtrListDeepEqual(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs)1223 bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
1224 if (lhs.size() != rhs.size()) {
1225 return false;
1226 }
1227 std::size_t size = lhs.size();
1228 for (std::size_t i = 0; i < size; i++) {
1229 MS_EXCEPTION_IF_NULL(lhs[i]);
1230 MS_EXCEPTION_IF_NULL(rhs[i]);
1231 if (lhs[i] == rhs[i]) {
1232 continue;
1233 }
1234 if (!(*lhs[i] == *rhs[i])) {
1235 return false;
1236 }
1237 }
1238 return true;
1239 }
1240
operator ()(const AbstractBasePtrList & args_spec_list) const1241 std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const {
1242 return AbstractBasePtrListHash(args_spec_list);
1243 }
1244
operator ()(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs) const1245 bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
1246 return AbstractBasePtrListDeepEqual(lhs, rhs);
1247 }
1248
1249 // RowTensor
BuildType() const1250 TypePtr AbstractRowTensor::BuildType() const {
1251 MS_EXCEPTION_IF_NULL(element());
1252 TypePtr element_type = element()->BuildType();
1253 return std::make_shared<RowTensorType>(element_type);
1254 }
1255
Clone() const1256 AbstractBasePtr AbstractRowTensor::Clone() const {
1257 MS_EXCEPTION_IF_NULL(element());
1258 auto clone = std::make_shared<AbstractRowTensor>(element()->Clone());
1259 ShapePtr shp = shape();
1260 MS_EXCEPTION_IF_NULL(shp);
1261 clone->set_shape(shp->Clone());
1262 clone->set_value(GetValueTrack());
1263 MS_EXCEPTION_IF_NULL(indices_);
1264 MS_EXCEPTION_IF_NULL(values_);
1265 MS_EXCEPTION_IF_NULL(dense_shape_);
1266 auto indices_clone = indices_->Clone();
1267 auto value_clone = values_->Clone();
1268 auto dense_clone = dense_shape_->Clone();
1269 MS_EXCEPTION_IF_NULL(indices_clone);
1270 MS_EXCEPTION_IF_NULL(value_clone);
1271 MS_EXCEPTION_IF_NULL(dense_clone);
1272 clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
1273 clone->set_values(value_clone->cast<AbstractTensorPtr>());
1274 clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1275 return clone;
1276 }
1277
Broaden() const1278 AbstractBasePtr AbstractRowTensor::Broaden() const {
1279 MS_EXCEPTION_IF_NULL(element());
1280 auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
1281 auto shp = shape();
1282 MS_EXCEPTION_IF_NULL(shp);
1283 broaden->set_shape(shp->Clone());
1284 broaden->set_value(kAnyValue);
1285 MS_EXCEPTION_IF_NULL(indices_);
1286 MS_EXCEPTION_IF_NULL(values_);
1287 MS_EXCEPTION_IF_NULL(dense_shape_);
1288 auto indices_clone = indices_->Clone();
1289 auto value_clone = values_->Clone();
1290 auto dense_clone = dense_shape_->Clone();
1291 MS_EXCEPTION_IF_NULL(indices_clone);
1292 MS_EXCEPTION_IF_NULL(value_clone);
1293 MS_EXCEPTION_IF_NULL(dense_clone);
1294 broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1295 broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1296 broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1297 return broaden;
1298 }
1299
BroadenWithShape() const1300 AbstractBasePtr AbstractRowTensor::BroadenWithShape() const {
1301 MS_EXCEPTION_IF_NULL(element());
1302 auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
1303 auto shp = shape()->Clone();
1304 MS_EXCEPTION_IF_NULL(shp);
1305 shp->Broaden();
1306 broaden->set_shape(shp);
1307 broaden->set_value(kAnyValue);
1308 MS_EXCEPTION_IF_NULL(indices_);
1309 MS_EXCEPTION_IF_NULL(values_);
1310 MS_EXCEPTION_IF_NULL(dense_shape_);
1311 auto indices_clone = indices_->Clone();
1312 auto value_clone = values_->Clone();
1313 auto dense_clone = dense_shape_->Clone();
1314 MS_EXCEPTION_IF_NULL(indices_clone);
1315 MS_EXCEPTION_IF_NULL(value_clone);
1316 MS_EXCEPTION_IF_NULL(dense_clone);
1317 broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1318 broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1319 broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1320 return broaden;
1321 }
1322
ToString() const1323 std::string AbstractRowTensor::ToString() const {
1324 std::ostringstream buffer;
1325 BaseShapePtr shape_track = GetShapeTrack();
1326 MS_EXCEPTION_IF_NULL(shape_track);
1327 MS_EXCEPTION_IF_NULL(element());
1328 auto value_track = GetValueTrack();
1329 MS_EXCEPTION_IF_NULL(value_track);
1330 MS_EXCEPTION_IF_NULL(indices_);
1331 MS_EXCEPTION_IF_NULL(values_);
1332 MS_EXCEPTION_IF_NULL(dense_shape_);
1333 buffer << type_name() << "("
1334 << "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
1335 << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
1336 << ", indices: " << indices_->ToString() << ", values" << values_->ToString()
1337 << ", dense_shape: " << dense_shape_->ToString();
1338 return buffer.str();
1339 }
1340
1341 // SparseTensor
BuildType() const1342 TypePtr AbstractSparseTensor::BuildType() const {
1343 MS_EXCEPTION_IF_NULL(element());
1344 TypePtr element_type = element()->BuildType();
1345 return std::make_shared<SparseTensorType>(element_type);
1346 }
1347
Clone() const1348 AbstractBasePtr AbstractSparseTensor::Clone() const {
1349 MS_EXCEPTION_IF_NULL(element());
1350 auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
1351 ShapePtr shp = shape();
1352 MS_EXCEPTION_IF_NULL(shp);
1353 clone->set_shape(shp->Clone());
1354 clone->set_value(GetValueTrack());
1355 MS_EXCEPTION_IF_NULL(indices_);
1356 MS_EXCEPTION_IF_NULL(values_);
1357 MS_EXCEPTION_IF_NULL(dense_shape_);
1358 auto indices_clone = indices_->Clone();
1359 auto value_clone = values_->Clone();
1360 auto dense_clone = dense_shape_->Clone();
1361 MS_EXCEPTION_IF_NULL(indices_clone);
1362 MS_EXCEPTION_IF_NULL(value_clone);
1363 MS_EXCEPTION_IF_NULL(dense_clone);
1364 clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
1365 clone->set_values(value_clone->cast<AbstractTensorPtr>());
1366 clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1367 return clone;
1368 }
1369
Broaden() const1370 AbstractBasePtr AbstractSparseTensor::Broaden() const {
1371 MS_EXCEPTION_IF_NULL(element());
1372 auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
1373 auto shp = shape();
1374 MS_EXCEPTION_IF_NULL(shp);
1375 MS_EXCEPTION_IF_NULL(indices_);
1376 MS_EXCEPTION_IF_NULL(values_);
1377 MS_EXCEPTION_IF_NULL(dense_shape_);
1378 auto indices_clone = indices_->Clone();
1379 auto value_clone = values_->Clone();
1380 auto dense_clone = dense_shape_->Clone();
1381 MS_EXCEPTION_IF_NULL(indices_clone);
1382 MS_EXCEPTION_IF_NULL(value_clone);
1383 MS_EXCEPTION_IF_NULL(dense_clone);
1384 broaden->set_shape(shp->Clone());
1385 broaden->set_value(kAnyValue);
1386 broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1387 broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1388 broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1389 return broaden;
1390 }
1391
BroadenWithShape() const1392 AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
1393 MS_EXCEPTION_IF_NULL(element());
1394 auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
1395 auto this_shape = shape();
1396 MS_EXCEPTION_IF_NULL(this_shape);
1397 auto shp = this_shape->Clone();
1398 MS_EXCEPTION_IF_NULL(shp);
1399 shp->Broaden();
1400 broaden->set_shape(shp);
1401 broaden->set_value(kAnyValue);
1402 MS_EXCEPTION_IF_NULL(indices_);
1403 MS_EXCEPTION_IF_NULL(values_);
1404 MS_EXCEPTION_IF_NULL(dense_shape_);
1405 auto indices_clone = indices_->Clone();
1406 auto value_clone = values_->Clone();
1407 auto dense_clone = dense_shape_->Clone();
1408 MS_EXCEPTION_IF_NULL(indices_clone);
1409 MS_EXCEPTION_IF_NULL(value_clone);
1410 MS_EXCEPTION_IF_NULL(dense_clone);
1411 broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1412 broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1413 broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1414 return broaden;
1415 }
1416
ToString() const1417 std::string AbstractSparseTensor::ToString() const {
1418 std::ostringstream buffer;
1419 BaseShapePtr shape_track = GetShapeTrack();
1420 MS_EXCEPTION_IF_NULL(shape_track);
1421 MS_EXCEPTION_IF_NULL(element());
1422 auto value_track = GetValueTrack();
1423 MS_EXCEPTION_IF_NULL(value_track);
1424 MS_EXCEPTION_IF_NULL(indices_);
1425 MS_EXCEPTION_IF_NULL(values_);
1426 MS_EXCEPTION_IF_NULL(dense_shape_);
1427 buffer << type_name() << "("
1428 << "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
1429 << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
1430 << ", indices: " << indices_->ToString() << ", values" << values_->ToString()
1431 << ", dense_shape: " << dense_shape_->ToString();
1432 return buffer.str();
1433 }
1434
Join(const AbstractBasePtr & other)1435 AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
1436 MS_EXCEPTION_IF_NULL(other);
1437 if (!other->isa<AbstractUMonad>()) {
1438 auto this_type = GetTypeTrack();
1439 auto other_type = other->GetTypeTrack();
1440 MS_EXCEPTION_IF_NULL(this_type);
1441 MS_EXCEPTION_IF_NULL(other);
1442 TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
1443 }
1444 return shared_from_base<AbstractBase>();
1445 }
1446
operator ==(const AbstractUMonad &) const1447 bool AbstractUMonad::operator==(const AbstractUMonad &) const { return true; }
1448
operator ==(const AbstractBase & other) const1449 bool AbstractUMonad::operator==(const AbstractBase &other) const {
1450 if (&other == this) {
1451 return true;
1452 }
1453 return other.isa<AbstractUMonad>();
1454 }
1455
Join(const AbstractBasePtr & other)1456 AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) {
1457 MS_EXCEPTION_IF_NULL(other);
1458 if (!other->isa<AbstractIOMonad>()) {
1459 auto this_type = GetTypeTrack();
1460 auto other_type = other->GetTypeTrack();
1461 MS_EXCEPTION_IF_NULL(this_type);
1462 MS_EXCEPTION_IF_NULL(other);
1463 TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
1464 }
1465 return shared_from_base<AbstractBase>();
1466 }
1467
operator ==(const AbstractIOMonad &) const1468 bool AbstractIOMonad::operator==(const AbstractIOMonad &) const { return true; }
1469
operator ==(const AbstractBase & other) const1470 bool AbstractIOMonad::operator==(const AbstractBase &other) const {
1471 if (&other == this) {
1472 return true;
1473 }
1474 return other.isa<AbstractIOMonad>();
1475 }
1476 } // namespace abstract
1477 } // namespace mindspore
1478