• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/abstract_value.h"
20 
21 #include <regex>
22 #include <algorithm>
23 
24 #include "utils/symbolic.h"
25 #include "abstract/utils.h"
26 #include "utils/ms_context.h"
27 #include "utils/trace_base.h"
28 
29 namespace mindspore {
30 namespace abstract {
GetTraceNode(const AbstractBasePtr & abs)31 AnfNodePtr GetTraceNode(const AbstractBasePtr &abs) {
32   AnfNodePtr node = nullptr;
33   if (mindspore::abstract::AbstractBase::trace_node_provider_ != nullptr) {
34     mindspore::abstract::AbstractBase::trace_node_provider_(&node);
35   }
36   return node;
37 }
38 
AbstractTypeJoinLogging(const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)39 inline void AbstractTypeJoinLogging(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) {
40   std::ostringstream oss;
41   oss << "Type Join Failed: abstract type " << abstract1->type_name() << " cannot not join with "
42       << abstract2->type_name() << ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
43       << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
44   auto node = GetTraceNode(abstract1);
45   if (node != nullptr) {
46     oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
47   }
48   MS_EXCEPTION(TypeError) << oss.str();
49 }
50 
TypeJoinLogging(const TypePtr & type1,const TypePtr & type2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)51 inline void TypeJoinLogging(const TypePtr &type1, const TypePtr &type2, const AbstractBasePtr &abstract1,
52                             const AbstractBasePtr &abstract2) {
53   std::ostringstream oss;
54   oss << "Type Join Failed: dtype1 = " << type1->ToString() << ", dtype2 = " << type2->ToString()
55       << ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
56       << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
57   auto node = GetTraceNode(abstract1);
58   if (node != nullptr) {
59     oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
60   }
61   MS_EXCEPTION(TypeError) << oss.str();
62 }
63 
ShapeJoinLogging(const BaseShapePtr & shape1,const BaseShapePtr & shape2,const AbstractBasePtr & abstract1,const AbstractBasePtr & abstract2)64 inline void ShapeJoinLogging(const BaseShapePtr &shape1, const BaseShapePtr &shape2, const AbstractBasePtr &abstract1,
65                              const AbstractBasePtr &abstract2) {
66   std::ostringstream oss;
67   oss << "Shape Join Failed: shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString()
68       << ". For more details, please refer to the FAQ at https://www.mindspore.cn. "
69       << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString();
70   auto node = GetTraceNode(abstract1);
71   if (node != nullptr) {
72     oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node);
73   }
74   MS_EXCEPTION(ValueError) << oss.str();
75 }
76 
ExtractLoggingInfo(const std::string & info)77 std::string ExtractLoggingInfo(const std::string &info) {
78   // Extract log information based on the keyword "Type Join Failed" or "Shape Join Failed"
79   std::regex e("(Type Join Failed|Shape Join Failed).*?\\.");
80   std::smatch result;
81   bool found = std::regex_search(info, result, e);
82   if (found) {
83     return result.str();
84   }
85   return "";
86 }
87 
operator ==(const AbstractBase & other) const88 bool AbstractBase::operator==(const AbstractBase &other) const {
89   if (tid() != other.tid()) {
90     return false;
91   }
92   auto type = BuildType();
93   auto other_type = BuildType();
94   MS_EXCEPTION_IF_NULL(other_type);
95   MS_EXCEPTION_IF_NULL(type);
96   if (type->type_id() == kObjectTypeUndeterminedType && other_type->type_id() == kObjectTypeUndeterminedType) {
97     return true;
98   }
99   if (value_ == nullptr || other.value_ == nullptr) {
100     MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
101                       << this->ToString() << ", other: " << other.ToString();
102   }
103 
104   bool value_equal = false;
105   if (value_ == other.value_) {
106     value_equal = true;
107   } else if (*value_ == *other.value_) {
108     value_equal = true;
109   }
110   bool type_equal = false;
111   MS_EXCEPTION_IF_NULL(type_);
112   MS_EXCEPTION_IF_NULL(other.type_);
113   if (type_ == other.type_) {
114     type_equal = true;
115   } else if (*type_ == *other.type_) {
116     type_equal = true;
117   }
118   bool shape_equal = false;
119   MS_EXCEPTION_IF_NULL(shape_);
120   MS_EXCEPTION_IF_NULL(other.shape_);
121   if (shape_ == other.shape_) {
122     shape_equal = true;
123   } else if (*shape_ == *other.shape_) {
124     shape_equal = true;
125   }
126   return value_equal && type_equal && shape_equal;
127 }
128 
BuildValue() const129 ValuePtr AbstractBase::BuildValue() const {
130   if (value_ == nullptr) {
131     return RealBuildValue();
132   }
133   return value_;
134 }
135 
Broaden() const136 AbstractBasePtr AbstractBase::Broaden() const {
137   AbstractBasePtr clone = Clone();
138   MS_EXCEPTION_IF_NULL(clone);
139   clone->set_value(kAnyValue);
140   return clone;
141 }
142 
PartialBroaden() const143 AbstractBasePtr AbstractBase::PartialBroaden() const { return Clone(); }
144 
ToString() const145 std::string AbstractBase::ToString() const {
146   std::ostringstream buffer;
147   std::string value = std::string("value is null");
148   if (value_ != nullptr) {
149     value = value_->ToString();
150   }
151   MS_EXCEPTION_IF_NULL(type_);
152   MS_EXCEPTION_IF_NULL(shape_);
153   buffer << type_name() << "("
154          << "Type: " << type_->ToString() << ", Value: " << value << ", Shape: " << shape_->ToString() << ")";
155   return buffer.str();
156 }
157 
Broaden() const158 AbstractBasePtr AbstractScalar::Broaden() const {
159   auto context = MsContext::GetInstance();
160   MS_EXCEPTION_IF_NULL(context);
161   if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
162     return AbstractBase::Broaden();
163   }
164   auto type_id = GetTypeTrack()->type_id();
165   if (type_id == kObjectTypeEnvType) {
166     return AbstractBase::Broaden();
167   }
168   return Clone();
169 }
170 
Join(const AbstractBasePtr & other)171 AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
172   MS_EXCEPTION_IF_NULL(other);
173   if (*this == *other) {
174     return shared_from_base<AbstractBase>();
175   }
176   auto type_self = GetTypeTrack();
177   auto type_other = other->GetTypeTrack();
178   TypePtr res_type = TypeJoin(type_self, type_other);
179   if (res_type == kAnyType) {
180     TypeJoinLogging(type_self, type_other, shared_from_base<AbstractBase>(), other);
181   }
182   auto value_self = GetValueTrack();
183   auto value_other = other->GetValueTrack();
184   ValuePtr res_value = ValueJoin(value_self, value_other);
185   if (res_value == value_self) {
186     return shared_from_base<AbstractBase>();
187   }
188   return std::make_shared<AbstractScalar>(res_value, res_type);
189 }
190 
Clone() const191 AbstractBasePtr AbstractType::Clone() const {
192   ValuePtr value_self = GetValueTrack();
193   if (value_self == nullptr || !value_self->isa<Type>()) {
194     return nullptr;
195   }
196   TypePtr type_self = value_self->cast<TypePtr>();
197   return std::make_shared<AbstractType>(type_self->Clone());
198 }
199 
operator ==(const AbstractBase & other) const200 bool AbstractType::operator==(const AbstractBase &other) const {
201   if (tid() != other.tid()) {
202     return false;
203   }
204   // Have to compare TypePtr with value;
205   ValuePtr value_self = GetValueTrack();
206   ValuePtr value_other = other.GetValueTrack();
207   if (value_self == nullptr || value_other == nullptr) {
208     MS_LOG(EXCEPTION) << "AbstractType value should not be nullptr. this: " << this->ToString()
209                       << ", other: " << other.ToString();
210   }
211   if (!value_self->isa<Type>() || !value_other->isa<Type>()) {
212     return false;
213   }
214   TypePtr type_self = value_self->cast<TypePtr>();
215   TypePtr type_other = value_other->cast<TypePtr>();
216   bool value_equal = *type_self == *type_other;
217   return value_equal;
218 }
219 
ToString() const220 std::string AbstractType::ToString() const {
221   std::ostringstream buffer;
222   ValuePtr value_self = GetValueTrack();
223   if (value_self == nullptr) {
224     buffer << "AbstractType value: nullptr";
225     return buffer.str();
226   }
227   if (!value_self->isa<Type>()) {
228     buffer << type_name() << "(Value: nullptr)";
229     return buffer.str();
230   }
231   TypePtr type_self = value_self->cast<TypePtr>();
232   buffer << type_name() << "("
233          << "Value: " << type_self->ToString() << ")";
234   return buffer.str();
235 }
236 
ToString() const237 std::string AbstractError::ToString() const {
238   std::ostringstream buffer;
239   auto value_track = GetValueTrack();
240   MS_EXCEPTION_IF_NULL(value_track);
241   MS_EXCEPTION_IF_NULL(node_);
242   buffer << type_name() << "("
243          << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")";
244   return buffer.str();
245 }
246 
Join(const AbstractBasePtr & other)247 AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) {
248   MS_EXCEPTION_IF_NULL(other);
249   auto other_func = dyn_cast<AbstractFunction>(other);
250   if (other_func == nullptr) {
251     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
252   }
253   return Join(other_func);
254 }
255 
operator ==(const AbstractBase & other) const256 bool AbstractFunction::operator==(const AbstractBase &other) const {
257   if (!other.isa<AbstractFunction>()) {
258     return false;
259   }
260   const auto &other_func = static_cast<const AbstractFunction &>(other);
261   bool value_equal = (*this == other_func);
262   return value_equal;
263 }
264 
operator [](const std::size_t & dim) const265 const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const {
266   if (dim >= size()) {
267     MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list.";
268   }
269   return elements_[dim];
270 }
271 
ToString() const272 std::string AbstractSequeue::ToString() const {
273   std::ostringstream buffer;
274   size_t i = 0;
275   size_t size = elements_.size();
276   for (const auto &ele : elements_) {
277     MS_EXCEPTION_IF_NULL(ele);
278     buffer << "element[" << i << "]: " << ele->ToString();
279     if (i < size - 1) {
280       buffer << ", ";
281     }
282     i++;
283   }
284   return buffer.str();
285 }
286 
ElementsType() const287 TypePtrList AbstractSequeue::ElementsType() const {
288   TypePtrList element_type_list;
289   for (const auto &ele : elements_) {
290     MS_EXCEPTION_IF_NULL(ele);
291     TypePtr element_type = ele->BuildType();
292     element_type_list.push_back(element_type);
293   }
294   return element_type_list;
295 }
296 
ElementsShape() const297 BaseShapePtrList AbstractSequeue::ElementsShape() const {
298   BaseShapePtrList element_shape_list;
299   for (const auto &ele : elements_) {
300     MS_EXCEPTION_IF_NULL(ele);
301     BaseShapePtr element_shape = ele->BuildShape();
302     element_shape_list.push_back(element_shape);
303   }
304   return element_shape_list;
305 }
306 
ElementsClone() const307 AbstractBasePtrList AbstractSequeue::ElementsClone() const {
308   AbstractBasePtrList ele_list;
309   for (const auto &ele : elements_) {
310     MS_EXCEPTION_IF_NULL(ele);
311     AbstractBasePtr clone = ele->Clone();
312     ele_list.push_back(clone);
313   }
314   return ele_list;
315 }
316 
ElementsBroaden() const317 AbstractBasePtrList AbstractSequeue::ElementsBroaden() const {
318   AbstractBasePtrList ele_list;
319   for (const auto &ele : elements_) {
320     MS_EXCEPTION_IF_NULL(ele);
321     AbstractBasePtr broadend = ele->Broaden();
322     ele_list.push_back(broadend);
323   }
324   return ele_list;
325 }
326 
ElementsPartialBroaden() const327 AbstractBasePtrList AbstractSequeue::ElementsPartialBroaden() const {
328   AbstractBasePtrList ele_list;
329   for (const auto &ele : elements_) {
330     MS_EXCEPTION_IF_NULL(ele);
331     AbstractBasePtr broadend = ele->PartialBroaden();
332     ele_list.push_back(broadend);
333   }
334   return ele_list;
335 }
336 
337 template <typename T>
ElementsBuildValue() const338 ValuePtr AbstractSequeue::ElementsBuildValue() const {
339   std::vector<ValuePtr> element_value_list;
340   for (const auto &ele : elements_) {
341     MS_EXCEPTION_IF_NULL(ele);
342     ValuePtr element_value = ele->BuildValue();
343     MS_EXCEPTION_IF_NULL(element_value);
344     if (element_value->isa<AnyValue>()) {
345       return kAnyValue;
346     }
347     element_value_list.push_back(element_value);
348   }
349   return std::make_shared<T>(element_value_list);
350 }
351 template ValuePtr AbstractSequeue::ElementsBuildValue<ValueTuple>() const;
352 template ValuePtr AbstractSequeue::ElementsBuildValue<ValueList>() const;
353 
354 template <typename T>
ElementsJoin(const AbstractBasePtr & other)355 AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) {
356   MS_EXCEPTION_IF_NULL(other);
357   auto other_sequeue = dyn_cast<T>(other);
358   if (other_sequeue == nullptr) {
359     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
360   }
361   auto joined_list = AbstractJoin(elements_, other_sequeue->elements_);
362   bool changes = false;
363   for (std::size_t i = 0; i < elements_.size(); i++) {
364     if (elements_[i] != joined_list[i]) {
365       changes = true;
366       break;
367     }
368   }
369   if (!changes) {
370     return shared_from_base<AbstractBase>();
371   }
372   return std::make_shared<T>(joined_list);
373 }
374 template AbstractBasePtr AbstractSequeue::ElementsJoin<AbstractList>(const AbstractBasePtr &);
375 template AbstractBasePtr AbstractSequeue::ElementsJoin<AbstractTuple>(const AbstractBasePtr &);
376 
hash() const377 std::size_t AbstractSequeue::hash() const {
378   std::size_t hash_sum = hash_combine(tid(), std::hash<size_t>{}(elements_.size()));
379   // Hashing all elements is costly, so only take at most 4 elements into account based on
380   // some experiments.
381   constexpr size_t max_elements_cnt = 4;
382   for (size_t i = 0; (i < elements_.size()) && (i < max_elements_cnt); i++) {
383     hash_sum = hash_combine(hash_sum, elements_[i]->hash());
384   }
385   return hash_sum;
386 }
387 
operator ==(const AbstractSequeue & other) const388 bool AbstractSequeue::operator==(const AbstractSequeue &other) const {
389   if (&other == this) {
390     return true;
391   }
392 
393   if (elements_.size() != other.elements_.size()) {
394     return false;
395   }
396   for (size_t i = 0; i < elements_.size(); i++) {
397     MS_EXCEPTION_IF_NULL(elements_[i]);
398     MS_EXCEPTION_IF_NULL(other.elements_[i]);
399     if (!(*(elements_[i]) == *(other.elements_[i]))) {
400       return false;
401     }
402   }
403   return true;
404 }
405 
operator ==(const AbstractTuple & other) const406 bool AbstractTuple::operator==(const AbstractTuple &other) const { return AbstractSequeue::operator==(other); }
407 
operator ==(const AbstractBase & other) const408 bool AbstractTuple::operator==(const AbstractBase &other) const {
409   if (&other == this) {
410     return true;
411   }
412 
413   if (other.isa<AbstractTuple>()) {
414     auto other_tuple = static_cast<const AbstractTuple *>(&other);
415     return *this == *other_tuple;
416   }
417 
418   return false;
419 }
420 
operator ==(const AbstractList & other) const421 bool AbstractList::operator==(const AbstractList &other) const { return AbstractSequeue::operator==(other); }
422 
operator ==(const AbstractBase & other) const423 bool AbstractList::operator==(const AbstractBase &other) const {
424   if (&other == this) {
425     return true;
426   }
427 
428   if (other.isa<AbstractList>()) {
429     auto other_list = static_cast<const AbstractList *>(&other);
430     return *this == *other_list;
431   }
432   return false;
433 }
434 
BuildType() const435 TypePtr AbstractSlice::BuildType() const {
436   MS_EXCEPTION_IF_NULL(start_);
437   MS_EXCEPTION_IF_NULL(stop_);
438   MS_EXCEPTION_IF_NULL(step_);
439   TypePtr start = start_->BuildType();
440   TypePtr stop = stop_->BuildType();
441   TypePtr step = step_->BuildType();
442   return std::make_shared<Slice>(start, stop, step);
443 }
444 
operator ==(const AbstractSlice & other) const445 bool AbstractSlice::operator==(const AbstractSlice &other) const {
446   if (&other == this) {
447     return true;
448   }
449   return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_);
450 }
451 
operator ==(const AbstractBase & other) const452 bool AbstractSlice::operator==(const AbstractBase &other) const {
453   if (&other == this) {
454     return true;
455   }
456   if (!other.isa<AbstractSlice>()) {
457     return false;
458   }
459   auto other_slice = static_cast<const AbstractSlice *>(&other);
460   return *this == *other_slice;
461 }
462 
Clone() const463 AbstractBasePtr AbstractSlice::Clone() const {
464   MS_EXCEPTION_IF_NULL(start_);
465   MS_EXCEPTION_IF_NULL(stop_);
466   MS_EXCEPTION_IF_NULL(step_);
467   AbstractBasePtr start = start_->Clone();
468   AbstractBasePtr stop = stop_->Clone();
469   AbstractBasePtr step = step_->Clone();
470   return std::make_shared<AbstractSlice>(start, stop, step);
471 }
472 
Broaden() const473 AbstractBasePtr AbstractSlice::Broaden() const {
474   MS_EXCEPTION_IF_NULL(start_);
475   MS_EXCEPTION_IF_NULL(stop_);
476   MS_EXCEPTION_IF_NULL(step_);
477   AbstractBasePtr start = start_->Broaden();
478   AbstractBasePtr stop = stop_->Broaden();
479   AbstractBasePtr step = step_->Broaden();
480   return std::make_shared<AbstractSlice>(start, stop, step);
481 }
482 
ToString() const483 std::string AbstractSlice::ToString() const {
484   std::ostringstream buffer;
485   buffer << type_name() << "[";
486   MS_EXCEPTION_IF_NULL(start_);
487   buffer << start_->ToString() << " : ";
488   MS_EXCEPTION_IF_NULL(stop_);
489   buffer << stop_->ToString() << " : ";
490   MS_EXCEPTION_IF_NULL(step_);
491   buffer << step_->ToString();
492   buffer << "]";
493   return buffer.str();
494 }
495 
RealBuildValue() const496 ValuePtr AbstractSlice::RealBuildValue() const {
497   MS_EXCEPTION_IF_NULL(start_);
498   MS_EXCEPTION_IF_NULL(stop_);
499   MS_EXCEPTION_IF_NULL(step_);
500   ValuePtr start = start_->BuildValue();
501   ValuePtr stop = stop_->BuildValue();
502   ValuePtr step = step_->BuildValue();
503   if (start->isa<AnyValue>() || stop->isa<AnyValue>() || step->isa<AnyValue>()) {
504     return kAnyValue;
505   }
506   return std::make_shared<ValueSlice>(start, stop, step);
507 }
508 
hash() const509 std::size_t AbstractSlice::hash() const {
510   MS_EXCEPTION_IF_NULL(start_);
511   MS_EXCEPTION_IF_NULL(stop_);
512   MS_EXCEPTION_IF_NULL(step_);
513   return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
514 }
515 
shape() const516 ShapePtr AbstractUndetermined::shape() const {
517   auto shp = dyn_cast<Shape>(GetShapeTrack());
518   if (shp == nullptr) {
519     MS_LOG(EXCEPTION) << "Tensor should have a shape.";
520   }
521   return shp;
522 }
523 
set_shape(const BaseShapePtr & shape)524 void AbstractUndetermined::set_shape(const BaseShapePtr &shape) {
525   MS_EXCEPTION_IF_NULL(shape);
526   if (shape->isa<NoShape>()) {
527     MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
528   }
529   AbstractBase::set_shape(shape);
530 }
531 
BuildType() const532 TypePtr AbstractTensor::BuildType() const {
533   MS_EXCEPTION_IF_NULL(element_);
534   TypePtr element_type = element_->BuildType();
535   return std::make_shared<TensorType>(element_type);
536 }
537 
BuildShape() const538 BaseShapePtr AbstractTensor::BuildShape() const {
539   auto shape = GetShapeTrack();
540   // Guard from using set_shape(nullptr)
541   if (shape == nullptr) {
542     return kNoShape;
543   }
544   return shape;
545 }
546 
Join(const AbstractBasePtr & other)547 AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
548   MS_EXCEPTION_IF_NULL(other);
549   auto type = other->BuildType();
550   MS_EXCEPTION_IF_NULL(type);
551   MS_EXCEPTION_IF_NULL(element_);
552 
553   // AbstractTensor join with AbstractUndetermined
554   if (type->type_id() == kObjectTypeUndeterminedType) {
555     auto other_undetermined_tensor = dyn_cast<AbstractUndetermined>(other);
556     MS_EXCEPTION_IF_NULL(other_undetermined_tensor);
557     // check shape
558     auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape());
559     if (res_shape == nullptr) {
560       ShapeJoinLogging(shape(), other_undetermined_tensor->shape(), shared_from_base<AbstractBase>(), other);
561     }
562     // check element
563     auto element = element_->Join(other_undetermined_tensor->element());
564     MS_EXCEPTION_IF_NULL(element);
565     return std::make_shared<AbstractUndetermined>(element, res_shape);
566   }
567 
568   // AbstractTensor join with AbstractTensor
569   auto other_tensor = dyn_cast<AbstractTensor>(other);
570   if (other_tensor == nullptr) {
571     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
572   }
573   if (*this == *other) {
574     return shared_from_base<AbstractBase>();
575   }
576   // check shape
577   auto res_shape = ShapeJoin(this->shape(), other_tensor->shape());
578   if (res_shape == nullptr) {
579     ShapeJoinLogging(shape(), other_tensor->shape(), shared_from_base<AbstractBase>(), other);
580   }
581   // check element
582   auto element = element_->Join(other_tensor->element_);
583   MS_EXCEPTION_IF_NULL(element);
584   return std::make_shared<AbstractTensor>(element, res_shape);
585 }
586 
equal_to(const AbstractTensor & other) const587 bool AbstractTensor::equal_to(const AbstractTensor &other) const {
588   if (&other == this) {
589     return true;
590   }
591 
592   auto v1 = GetValueTrack();
593   auto v2 = other.GetValueTrack();
594   if (v1 == nullptr || v2 == nullptr) {
595     MS_LOG(EXCEPTION) << "The value of AbstractTensor is nullptr";
596   }
597 
598   bool is_value_equal = (v1 == v2);
599   if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
600     is_value_equal = true;
601   }
602   MS_EXCEPTION_IF_NULL(element_);
603   MS_EXCEPTION_IF_NULL(other.element_);
604   return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
605 }
606 
operator ==(const AbstractTensor & other) const607 bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
608 
operator ==(const AbstractBase & other) const609 bool AbstractTensor::operator==(const AbstractBase &other) const {
610   if (&other == this) {
611     return true;
612   }
613 
614   if (other.tid() == tid()) {
615     auto other_tensor = static_cast<const AbstractTensor *>(&other);
616     return *this == *other_tensor;
617   } else {
618     return false;
619   }
620 }
621 
Clone() const622 AbstractBasePtr AbstractTensor::Clone() const {
623   MS_EXCEPTION_IF_NULL(element_);
624   auto clone = std::make_shared<AbstractTensor>(element_->Clone());
625   ShapePtr shp = shape();
626   clone->set_shape(shp->Clone());
627   clone->set_value(GetValueTrack());
628   clone->set_value_range(get_min_value(), get_max_value());
629   return clone;
630 }
631 
Broaden() const632 AbstractBasePtr AbstractTensor::Broaden() const {
633   MS_EXCEPTION_IF_NULL(element_);
634   auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
635   auto shp = shape();
636   MS_EXCEPTION_IF_NULL(shp);
637   broaden->set_shape(shp->Clone());
638   broaden->set_value(kAnyValue);
639   return broaden;
640 }
641 
BroadenWithShape() const642 AbstractBasePtr AbstractTensor::BroadenWithShape() const {
643   MS_EXCEPTION_IF_NULL(element_);
644   auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
645   auto shp = shape()->Clone();
646   MS_EXCEPTION_IF_NULL(shp);
647   shp->Broaden();
648   broaden->set_shape(shp);
649   broaden->set_value(kAnyValue);
650   return broaden;
651 }
652 
PartialBroaden() const653 AbstractBasePtr AbstractTensor::PartialBroaden() const { return Broaden(); }
654 
ToString() const655 std::string AbstractTensor::ToString() const {
656   std::ostringstream buffer;
657   BaseShapePtr shape_track = GetShapeTrack();
658   MS_EXCEPTION_IF_NULL(shape_track);
659   MS_EXCEPTION_IF_NULL(element_);
660   auto value_track = GetValueTrack();
661   MS_EXCEPTION_IF_NULL(value_track);
662   buffer << type_name() << "("
663          << "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
664          << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
665   return buffer.str();
666 }
667 
BuildType() const668 TypePtr AbstractDictionary::BuildType() const {
669   std::vector<std::pair<std::string, TypePtr>> key_values;
670   for (const auto &item : key_values_) {
671     MS_EXCEPTION_IF_NULL(item.second);
672     TypePtr type = item.second->BuildType();
673     key_values.emplace_back(item.first, type);
674   }
675   return std::make_shared<Dictionary>(key_values);
676 }
677 
operator ==(const AbstractDictionary & other) const678 bool AbstractDictionary::operator==(const AbstractDictionary &other) const {
679   if (key_values_.size() != other.key_values_.size()) {
680     return false;
681   }
682 
683   for (size_t index = 0; index < key_values_.size(); index++) {
684     if (key_values_[index].first != other.key_values_[index].first) {
685       return false;
686     }
687     MS_EXCEPTION_IF_NULL(key_values_[index].second);
688     MS_EXCEPTION_IF_NULL(other.key_values_[index].second);
689     if (!(*key_values_[index].second == *other.key_values_[index].second)) {
690       return false;
691     }
692   }
693   return true;
694 }
695 
operator ==(const AbstractBase & other) const696 bool AbstractDictionary::operator==(const AbstractBase &other) const {
697   if (&other == this) {
698     return true;
699   }
700   if (other.isa<AbstractDictionary>()) {
701     auto other_class = static_cast<const AbstractDictionary *>(&other);
702     return *this == *other_class;
703   }
704   return false;
705 }
706 
Clone() const707 AbstractBasePtr AbstractDictionary::Clone() const {
708   std::vector<AbstractAttribute> kv;
709   (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
710                        [](const AbstractAttribute &item) {
711                          MS_EXCEPTION_IF_NULL(item.second);
712                          return std::make_pair(item.first, item.second->Clone());
713                        });
714   return std::make_shared<AbstractDictionary>(kv);
715 }
716 
Broaden() const717 AbstractBasePtr AbstractDictionary::Broaden() const {
718   std::vector<AbstractAttribute> kv;
719   (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
720                        [](const AbstractAttribute &item) {
721                          MS_EXCEPTION_IF_NULL(item.second);
722                          return std::make_pair(item.first, item.second->Broaden());
723                        });
724   return std::make_shared<AbstractDictionary>(kv);
725 }
726 
ToString() const727 std::string AbstractDictionary::ToString() const {
728   std::ostringstream buffer;
729   buffer << type_name() << "{ ";
730   for (const auto &kv : key_values_) {
731     MS_EXCEPTION_IF_NULL(kv.second);
732     buffer << "(" << kv.first << ": " << kv.second->ToString() << ") ";
733   }
734   buffer << "}";
735   return buffer.str();
736 }
737 
hash() const738 std::size_t AbstractDictionary::hash() const {
739   std::size_t hash_sum = std::accumulate(key_values_.begin(), key_values_.end(), tid(),
740                                          [](std::size_t hash_sum, const AbstractAttribute &item) {
741                                            hash_sum = hash_combine(hash_sum, std::hash<std::string>()(item.first));
742                                            MS_EXCEPTION_IF_NULL(item.second);
743                                            hash_sum = hash_combine(hash_sum, item.second->hash());
744                                            return hash_sum;
745                                          });
746   return hash_sum;
747 }
748 
RealBuildValue() const749 ValuePtr AbstractDictionary::RealBuildValue() const {
750   std::vector<std::pair<std::string, ValuePtr>> key_values;
751   for (const auto &item : key_values_) {
752     MS_EXCEPTION_IF_NULL(item.second);
753     auto element_value = item.second->BuildValue();
754     MS_EXCEPTION_IF_NULL(element_value);
755     if (element_value->isa<AnyValue>()) {
756       return kAnyValue;
757     }
758     key_values.emplace_back(item.first, element_value);
759   }
760   return std::make_shared<ValueDictionary>(key_values);
761 }
762 
BuildType() const763 TypePtr AbstractClass::BuildType() const {
764   ClassAttrVector attributes_type;
765   for (const auto &attr : attributes_) {
766     MS_EXCEPTION_IF_NULL(attr.second);
767     TypePtr type = attr.second->BuildType();
768     std::pair<std::string, TypePtr> elem(attr.first, type);
769     attributes_type.push_back(elem);
770   }
771 
772   return std::make_shared<Class>(tag_, attributes_type, methods_);
773 }
774 
operator ==(const AbstractClass & other) const775 bool AbstractClass::operator==(const AbstractClass &other) const {
776   if (!(tag_ == other.tag_)) {
777     return false;
778   }
779   if (attributes_.size() != other.attributes_.size()) {
780     return false;
781   }
782   for (size_t i = 0; i < attributes_.size(); i++) {
783     MS_EXCEPTION_IF_NULL(attributes_[i].second);
784     MS_EXCEPTION_IF_NULL(other.attributes_[i].second);
785     if (!(*attributes_[i].second == *other.attributes_[i].second)) {
786       MS_LOG(DEBUG) << "attr " << attributes_[i].first << " not equal, arg1:" << attributes_[i].second->ToString()
787                     << " arg2:" << other.attributes_[i].second->ToString();
788       return false;
789     }
790   }
791   // method compare;
792   if (methods_.size() != other.methods_.size()) {
793     return false;
794   }
795   for (const auto &iter : methods_) {
796     auto iter_other = other.methods_.find(iter.first);
797     if (iter_other == other.methods_.end()) {
798       return false;
799     }
800     if (!(*iter.second == *iter_other->second)) {
801       return false;
802     }
803   }
804   return true;
805 }
806 
operator ==(const AbstractBase & other) const807 bool AbstractClass::operator==(const AbstractBase &other) const {
808   if (other.isa<AbstractClass>()) {
809     auto other_class = static_cast<const AbstractClass *>(&other);
810     return *this == *other_class;
811   }
812   return false;
813 }
814 
GetAttribute(const std::string & name)815 AbstractBasePtr AbstractClass::GetAttribute(const std::string &name) {
816   auto it = std::find_if(attributes_.begin(), attributes_.end(),
817                          [name](const AbstractAttribute &pair) -> bool { return pair.first == name; });
818   if (it != attributes_.end()) {
819     return it->second;
820   }
821   return nullptr;
822 }
823 
GetMethod(const std::string & name)824 ValuePtr AbstractClass::GetMethod(const std::string &name) {
825   auto method_pair = methods_.find(name);
826   if (method_pair != methods_.end()) {
827     return method_pair->second;
828   }
829   return kAnyValue;
830 }
831 
Clone() const832 AbstractBasePtr AbstractClass::Clone() const {
833   std::vector<AbstractAttribute> attributes_clone;
834   for (const auto &attr : attributes_) {
835     MS_EXCEPTION_IF_NULL(attr.second);
836     AbstractBasePtr clone = attr.second->Clone();
837     AbstractAttribute elem(attr.first, clone);
838     attributes_clone.push_back(elem);
839   }
840   return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
841 }
842 
Broaden() const843 AbstractBasePtr AbstractClass::Broaden() const {
844   std::vector<AbstractAttribute> attributes_clone;
845   for (const auto &attr : attributes_) {
846     MS_EXCEPTION_IF_NULL(attr.second);
847     AbstractBasePtr clone = attr.second->Broaden();
848     AbstractAttribute elem(attr.first, clone);
849     attributes_clone.push_back(elem);
850   }
851   return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
852 }
853 
ToString() const854 std::string AbstractClass::ToString() const {
855   std::ostringstream buffer;
856   buffer << type_name() << "(tag: " << tag_ << ") attrs:(";
857   bool append_comma = false;
858   for (const auto &attr : attributes_) {
859     if (append_comma) {
860       buffer << ", ";
861     } else {
862       append_comma = true;
863     }
864     MS_EXCEPTION_IF_NULL(attr.second);
865     buffer << attr.first << ":" << attr.second->ToString();
866   }
867   buffer << ") method:(";
868   append_comma = false;
869   for (const auto &iter : methods_) {
870     if (append_comma) {
871       buffer << ", ";
872     } else {
873       append_comma = true;
874     }
875     MS_EXCEPTION_IF_NULL(iter.second);
876     buffer << iter.first << ":" << iter.second->ToString();
877   }
878   buffer << ")";
879   return buffer.str();
880 }
881 
hash() const882 std::size_t AbstractClass::hash() const {
883   std::size_t hash_sum = std::accumulate(attributes_.begin(), attributes_.end(), hash_combine(tid(), tag_.hash()),
884                                          [](std::size_t hash_sum, const AbstractAttribute &item) {
885                                            MS_EXCEPTION_IF_NULL(item.second);
886                                            return hash_combine(hash_sum, item.second->hash());
887                                          });
888 
889   return hash_sum;
890 }
891 
RealBuildValue() const892 ValuePtr AbstractClass::RealBuildValue() const {
893   auto type = BuildType();
894   MS_EXCEPTION_IF_NULL(type);
895   auto cls = type->cast<ClassPtr>();
896   std::unordered_map<std::string, ValuePtr> attributes_value_map;
897   for (const auto &attr : attributes_) {
898     MS_EXCEPTION_IF_NULL(attr.second);
899     ValuePtr value = attr.second->BuildValue();
900     MS_EXCEPTION_IF_NULL(value);
901     if (value->isa<AnyValue>()) {
902       return kAnyValue;
903     }
904     attributes_value_map[attr.first] = value;
905   }
906   cls->set_value(attributes_value_map);
907   return cls;
908 }
909 
BuildType() const910 TypePtr AbstractJTagged::BuildType() const {
911   MS_EXCEPTION_IF_NULL(element_);
912   TypePtr subtype = element_->BuildType();
913   return std::make_shared<JTagged>(subtype);
914 }
915 
Join(const AbstractBasePtr & other)916 AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
917   MS_EXCEPTION_IF_NULL(other);
918   auto other_jtagged = dyn_cast<AbstractJTagged>(other);
919   if (other_jtagged == nullptr) {
920     AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
921   }
922   MS_EXCEPTION_IF_NULL(element_);
923   auto joined_elem = element_->Join(other_jtagged->element_);
924   return std::make_shared<AbstractJTagged>(joined_elem);
925 }
926 
operator ==(const AbstractJTagged & other) const927 bool AbstractJTagged::operator==(const AbstractJTagged &other) const {
928   MS_EXCEPTION_IF_NULL(element_);
929   MS_EXCEPTION_IF_NULL(other.element_);
930   return (*element_ == *other.element_);
931 }
932 
operator ==(const AbstractBase & other) const933 bool AbstractJTagged::operator==(const AbstractBase &other) const {
934   if (other.isa<AbstractJTagged>()) {
935     auto other_jtagged = static_cast<const AbstractJTagged *>(&other);
936     return *this == *other_jtagged;
937   }
938   return false;
939 }
940 
ToString() const941 std::string AbstractJTagged::ToString() const {
942   std::ostringstream buffer;
943   MS_EXCEPTION_IF_NULL(element_);
944   buffer << type_name() << "("
945          << "element: " << element_->ToString() << ")";
946   return buffer.str();
947 }
948 
AbstractRef(const AbstractBasePtr & ref_key,const AbstractTensorPtr & ref_value)949 AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value)
950     : AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) {
951   set_type(std::make_shared<RefType>());
952   if (ref_key && ref_key->isa<AbstractRefKey>()) {
953     ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
954   }
955 }
956 
BuildType() const957 TypePtr AbstractRef::BuildType() const {
958   auto type = AbstractTensor::BuildType();
959   MS_EXCEPTION_IF_NULL(type);
960   auto subtype = type->cast<TensorTypePtr>();
961   return std::make_shared<RefType>(subtype);
962 }
963 
operator ==(const AbstractRef & other) const964 bool AbstractRef::operator==(const AbstractRef &other) const {
965   return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_);
966 }
967 
operator ==(const AbstractBase & other) const968 bool AbstractRef::operator==(const AbstractBase &other) const {
969   if (other.isa<AbstractRef>()) {
970     auto other_conf = static_cast<const AbstractRef *>(&other);
971     return *this == *other_conf;
972   }
973   return false;
974 }
975 
Join(const AbstractBasePtr & other)976 AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
977   MS_EXCEPTION_IF_NULL(other);
978   if (*this == *other) {
979     auto ret = shared_from_base<AbstractBase>();
980     return ret;
981   }
982   auto value_self = GetValueTrack();
983   MS_EXCEPTION_IF_NULL(value_self);
984   ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
985   if (res_value == value_self) {
986     auto ret = shared_from_base<AbstractBase>();
987     return ret;
988   }
989   auto ret = std::make_shared<AbstractRefKey>();
990   ret->set_value(res_value);
991   return ret;
992 }
993 
Join(const AbstractBasePtr & other)994 AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
995   MS_EXCEPTION_IF_NULL(other);
996   auto other_ref = other->cast<AbstractRefPtr>();
997   if (other_ref == nullptr) {
998     auto join_abs = AbstractTensor::Join(other);
999     MS_EXCEPTION_IF_NULL(join_abs);
1000     return join_abs->cast<AbstractTensorPtr>();
1001   }
1002   MS_EXCEPTION_IF_NULL(ref_key_);
1003   MS_EXCEPTION_IF_NULL(other_ref->ref_key_);
1004   if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
1005     return shared_from_base<AbstractBase>();
1006   }
1007   auto ref_key = ref_key_->Join(other_ref->ref_key_);
1008   auto joined_abs_tensor = other_ref->ref();
1009   MS_EXCEPTION_IF_NULL(joined_abs_tensor);
1010   auto ref = AbstractTensor::Join(joined_abs_tensor);
1011   MS_EXCEPTION_IF_NULL(ref);
1012   auto ref_tensor = ref->cast<AbstractTensorPtr>();
1013   MS_EXCEPTION_IF_NULL(ref_tensor);
1014   return std::make_shared<AbstractRef>(ref_key, ref_tensor);
1015 }
1016 
ToString() const1017 std::string AbstractRef::ToString() const {
1018   std::ostringstream buffer;
1019   MS_EXCEPTION_IF_NULL(ref_key_);
1020   buffer << type_name() << "("
1021          << "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString();
1022   auto value = GetValueTrack();
1023   if (value != nullptr) {
1024     buffer << ", value: " << value->ToString();
1025   }
1026   buffer << ")";
1027   return buffer.str();
1028 }
1029 
PartialBroaden() const1030 AbstractBasePtr AbstractRef::PartialBroaden() const { return Clone(); }
1031 
operator ==(const AbstractNone &) const1032 bool AbstractNone::operator==(const AbstractNone &) const { return true; }
1033 
operator ==(const AbstractBase & other) const1034 bool AbstractNone::operator==(const AbstractBase &other) const {
1035   if (other.isa<AbstractNone>()) {
1036     auto other_none = static_cast<const AbstractNone *>(&other);
1037     return *this == *other_none;
1038   }
1039   return false;
1040 }
1041 
ToString() const1042 std::string AbstractNone::ToString() const {
1043   std::ostringstream buffer;
1044   buffer << type_name() << "(Value: None)";
1045   return buffer.str();
1046 }
1047 
RealBuildValue() const1048 ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
1049 
operator ==(const AbstractRefKey & other) const1050 bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
1051   ValuePtr value_self = GetValueTrack();
1052   ValuePtr value_other = other.GetValueTrack();
1053   if (value_self != nullptr && value_other != nullptr) {
1054     if (value_self->isa<AnyValue>() && value_other->isa<AnyValue>()) {
1055       return true;
1056     }
1057     if (!value_self->isa<RefKey>() || !value_other->isa<RefKey>()) {
1058       return false;
1059     }
1060     RefKeyPtr type_self = value_self->cast<RefKeyPtr>();
1061     RefKeyPtr type_other = value_other->cast<RefKeyPtr>();
1062     return *type_self == *type_other;
1063   } else if (value_self != nullptr || value_other != nullptr) {
1064     return false;
1065   }
1066   return true;
1067 }
1068 
operator ==(const AbstractBase & other) const1069 bool AbstractRefKey::operator==(const AbstractBase &other) const {
1070   if (other.isa<AbstractRefKey>()) {
1071     auto other_confkey = static_cast<const AbstractRefKey *>(&other);
1072     return *this == *other_confkey;
1073   } else {
1074     return false;
1075   }
1076 }
1077 
ToString() const1078 std::string AbstractRefKey::ToString() const {
1079   std::ostringstream buffer;
1080   buffer << type_name();
1081   auto value = GetValueTrack();
1082   if (value != nullptr) {
1083     buffer << "(value: " << value->ToString() << ")";
1084   }
1085   return buffer.str();
1086 }
1087 
operator ==(const AbstractNull &) const1088 bool AbstractNull::operator==(const AbstractNull &) const { return true; }
1089 
operator ==(const AbstractBase & other) const1090 bool AbstractNull::operator==(const AbstractBase &other) const {
1091   if (&other == this) {
1092     return true;
1093   }
1094   if (other.isa<AbstractNull>()) {
1095     auto other_none = static_cast<const AbstractNull *>(&other);
1096     return *this == *other_none;
1097   } else {
1098     return false;
1099   }
1100 }
1101 
ToString() const1102 std::string AbstractNull::ToString() const {
1103   std::ostringstream buffer;
1104   buffer << type_name() << "(Value: Null)";
1105   return buffer.str();
1106 }
1107 
operator ==(const AbstractTimeOut &) const1108 bool AbstractTimeOut::operator==(const AbstractTimeOut &) const { return true; }
1109 
operator ==(const AbstractBase & other) const1110 bool AbstractTimeOut::operator==(const AbstractBase &other) const {
1111   if (&other == this) {
1112     return true;
1113   }
1114   if (other.isa<AbstractTimeOut>()) {
1115     auto other_none = static_cast<const AbstractTimeOut *>(&other);
1116     return *this == *other_none;
1117   } else {
1118     return false;
1119   }
1120 }
1121 
ToString() const1122 std::string AbstractTimeOut::ToString() const {
1123   std::ostringstream buffer;
1124   buffer << "AbstractTimeOut "
1125          << "(Value: Null)";
1126   return buffer.str();
1127 }
1128 
operator ==(const AbstractEllipsis &) const1129 bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
1130 
operator ==(const AbstractBase & other) const1131 bool AbstractEllipsis::operator==(const AbstractBase &other) const {
1132   if (&other == this) {
1133     return true;
1134   }
1135   if (other.isa<AbstractEllipsis>()) {
1136     auto other_none = static_cast<const AbstractEllipsis *>(&other);
1137     return *this == *other_none;
1138   } else {
1139     return false;
1140   }
1141 }
1142 
ToString() const1143 std::string AbstractEllipsis::ToString() const {
1144   std::ostringstream buffer;
1145   buffer << type_name() << "(Value: Ellipsis)";
1146   return buffer.str();
1147 }
1148 
BuildType() const1149 TypePtr AbstractKeywordArg::BuildType() const {
1150   MS_EXCEPTION_IF_NULL(arg_value_);
1151   TypePtr type = arg_value_->BuildType();
1152   return std::make_shared<Keyword>(arg_name_, type);
1153 }
1154 
Clone() const1155 AbstractBasePtr AbstractKeywordArg::Clone() const {
1156   MS_EXCEPTION_IF_NULL(arg_value_);
1157   return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
1158 }
1159 
Broaden() const1160 AbstractBasePtr AbstractKeywordArg::Broaden() const {
1161   MS_EXCEPTION_IF_NULL(arg_value_);
1162   return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
1163 }
1164 
hash() const1165 std::size_t AbstractKeywordArg::hash() const {
1166   MS_EXCEPTION_IF_NULL(arg_value_);
1167   return hash_combine({tid(), std::hash<std::string>{}(arg_name_), arg_value_->hash()});
1168 }
1169 
ToString() const1170 std::string AbstractKeywordArg::ToString() const {
1171   std::ostringstream buffer;
1172   MS_EXCEPTION_IF_NULL(arg_value_);
1173   buffer << type_name() << "(";
1174   buffer << "key : " << arg_name_;
1175   buffer << "value : " << arg_value_->ToString();
1176   buffer << ")";
1177   return buffer.str();
1178 }
1179 
operator ==(const AbstractBase & other) const1180 bool AbstractKeywordArg::operator==(const AbstractBase &other) const {
1181   if (&other == this) {
1182     return true;
1183   }
1184 
1185   if (other.isa<AbstractKeywordArg>()) {
1186     auto other_tuple = static_cast<const AbstractKeywordArg *>(&other);
1187     return *this == *other_tuple;
1188   }
1189   return false;
1190 }
1191 
operator ==(const AbstractKeywordArg & other) const1192 bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const {
1193   if (&other == this) {
1194     return true;
1195   }
1196   MS_EXCEPTION_IF_NULL(arg_value_);
1197   MS_EXCEPTION_IF_NULL(other.arg_value_);
1198   return other.arg_name_ == arg_name_ && *other.arg_value_ == *arg_value_;
1199 }
1200 
RealBuildValue() const1201 ValuePtr AbstractKeywordArg::RealBuildValue() const {
1202   MS_EXCEPTION_IF_NULL(arg_value_);
1203   ValuePtr value = arg_value_->BuildValue();
1204   MS_EXCEPTION_IF_NULL(value);
1205   if (value->isa<AnyValue>()) {
1206     return kAnyValue;
1207   }
1208   return std::make_shared<KeywordArg>(arg_name_, value);
1209 }
1210 
AbstractBasePtrListHash(const AbstractBasePtrList & args_spec_list)1211 std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) {
1212   std::size_t hash_value = 0;
1213   // Hashing all elements is costly, so only take at most 4 elements into account based on
1214   // some experiments.
1215   constexpr auto kMaxElementsNum = 4;
1216   for (size_t i = 0; (i < args_spec_list.size()) && (i < kMaxElementsNum); i++) {
1217     MS_EXCEPTION_IF_NULL(args_spec_list[i]);
1218     hash_value = hash_combine(hash_value, args_spec_list[i]->hash());
1219   }
1220   return hash_value;
1221 }
1222 
AbstractBasePtrListDeepEqual(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs)1223 bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
1224   if (lhs.size() != rhs.size()) {
1225     return false;
1226   }
1227   std::size_t size = lhs.size();
1228   for (std::size_t i = 0; i < size; i++) {
1229     MS_EXCEPTION_IF_NULL(lhs[i]);
1230     MS_EXCEPTION_IF_NULL(rhs[i]);
1231     if (lhs[i] == rhs[i]) {
1232       continue;
1233     }
1234     if (!(*lhs[i] == *rhs[i])) {
1235       return false;
1236     }
1237   }
1238   return true;
1239 }
1240 
operator ()(const AbstractBasePtrList & args_spec_list) const1241 std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const {
1242   return AbstractBasePtrListHash(args_spec_list);
1243 }
1244 
operator ()(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs) const1245 bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
1246   return AbstractBasePtrListDeepEqual(lhs, rhs);
1247 }
1248 
1249 // RowTensor
BuildType() const1250 TypePtr AbstractRowTensor::BuildType() const {
1251   MS_EXCEPTION_IF_NULL(element());
1252   TypePtr element_type = element()->BuildType();
1253   return std::make_shared<RowTensorType>(element_type);
1254 }
1255 
Clone() const1256 AbstractBasePtr AbstractRowTensor::Clone() const {
1257   MS_EXCEPTION_IF_NULL(element());
1258   auto clone = std::make_shared<AbstractRowTensor>(element()->Clone());
1259   ShapePtr shp = shape();
1260   MS_EXCEPTION_IF_NULL(shp);
1261   clone->set_shape(shp->Clone());
1262   clone->set_value(GetValueTrack());
1263   MS_EXCEPTION_IF_NULL(indices_);
1264   MS_EXCEPTION_IF_NULL(values_);
1265   MS_EXCEPTION_IF_NULL(dense_shape_);
1266   auto indices_clone = indices_->Clone();
1267   auto value_clone = values_->Clone();
1268   auto dense_clone = dense_shape_->Clone();
1269   MS_EXCEPTION_IF_NULL(indices_clone);
1270   MS_EXCEPTION_IF_NULL(value_clone);
1271   MS_EXCEPTION_IF_NULL(dense_clone);
1272   clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
1273   clone->set_values(value_clone->cast<AbstractTensorPtr>());
1274   clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1275   return clone;
1276 }
1277 
Broaden() const1278 AbstractBasePtr AbstractRowTensor::Broaden() const {
1279   MS_EXCEPTION_IF_NULL(element());
1280   auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
1281   auto shp = shape();
1282   MS_EXCEPTION_IF_NULL(shp);
1283   broaden->set_shape(shp->Clone());
1284   broaden->set_value(kAnyValue);
1285   MS_EXCEPTION_IF_NULL(indices_);
1286   MS_EXCEPTION_IF_NULL(values_);
1287   MS_EXCEPTION_IF_NULL(dense_shape_);
1288   auto indices_clone = indices_->Clone();
1289   auto value_clone = values_->Clone();
1290   auto dense_clone = dense_shape_->Clone();
1291   MS_EXCEPTION_IF_NULL(indices_clone);
1292   MS_EXCEPTION_IF_NULL(value_clone);
1293   MS_EXCEPTION_IF_NULL(dense_clone);
1294   broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1295   broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1296   broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1297   return broaden;
1298 }
1299 
BroadenWithShape() const1300 AbstractBasePtr AbstractRowTensor::BroadenWithShape() const {
1301   MS_EXCEPTION_IF_NULL(element());
1302   auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
1303   auto shp = shape()->Clone();
1304   MS_EXCEPTION_IF_NULL(shp);
1305   shp->Broaden();
1306   broaden->set_shape(shp);
1307   broaden->set_value(kAnyValue);
1308   MS_EXCEPTION_IF_NULL(indices_);
1309   MS_EXCEPTION_IF_NULL(values_);
1310   MS_EXCEPTION_IF_NULL(dense_shape_);
1311   auto indices_clone = indices_->Clone();
1312   auto value_clone = values_->Clone();
1313   auto dense_clone = dense_shape_->Clone();
1314   MS_EXCEPTION_IF_NULL(indices_clone);
1315   MS_EXCEPTION_IF_NULL(value_clone);
1316   MS_EXCEPTION_IF_NULL(dense_clone);
1317   broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1318   broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1319   broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1320   return broaden;
1321 }
1322 
ToString() const1323 std::string AbstractRowTensor::ToString() const {
1324   std::ostringstream buffer;
1325   BaseShapePtr shape_track = GetShapeTrack();
1326   MS_EXCEPTION_IF_NULL(shape_track);
1327   MS_EXCEPTION_IF_NULL(element());
1328   auto value_track = GetValueTrack();
1329   MS_EXCEPTION_IF_NULL(value_track);
1330   MS_EXCEPTION_IF_NULL(indices_);
1331   MS_EXCEPTION_IF_NULL(values_);
1332   MS_EXCEPTION_IF_NULL(dense_shape_);
1333   buffer << type_name() << "("
1334          << "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
1335          << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
1336          << ", indices: " << indices_->ToString() << ", values" << values_->ToString()
1337          << ", dense_shape: " << dense_shape_->ToString();
1338   return buffer.str();
1339 }
1340 
1341 // SparseTensor
BuildType() const1342 TypePtr AbstractSparseTensor::BuildType() const {
1343   MS_EXCEPTION_IF_NULL(element());
1344   TypePtr element_type = element()->BuildType();
1345   return std::make_shared<SparseTensorType>(element_type);
1346 }
1347 
Clone() const1348 AbstractBasePtr AbstractSparseTensor::Clone() const {
1349   MS_EXCEPTION_IF_NULL(element());
1350   auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
1351   ShapePtr shp = shape();
1352   MS_EXCEPTION_IF_NULL(shp);
1353   clone->set_shape(shp->Clone());
1354   clone->set_value(GetValueTrack());
1355   MS_EXCEPTION_IF_NULL(indices_);
1356   MS_EXCEPTION_IF_NULL(values_);
1357   MS_EXCEPTION_IF_NULL(dense_shape_);
1358   auto indices_clone = indices_->Clone();
1359   auto value_clone = values_->Clone();
1360   auto dense_clone = dense_shape_->Clone();
1361   MS_EXCEPTION_IF_NULL(indices_clone);
1362   MS_EXCEPTION_IF_NULL(value_clone);
1363   MS_EXCEPTION_IF_NULL(dense_clone);
1364   clone->set_indices(indices_clone->cast<AbstractTensorPtr>());
1365   clone->set_values(value_clone->cast<AbstractTensorPtr>());
1366   clone->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1367   return clone;
1368 }
1369 
Broaden() const1370 AbstractBasePtr AbstractSparseTensor::Broaden() const {
1371   MS_EXCEPTION_IF_NULL(element());
1372   auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
1373   auto shp = shape();
1374   MS_EXCEPTION_IF_NULL(shp);
1375   MS_EXCEPTION_IF_NULL(indices_);
1376   MS_EXCEPTION_IF_NULL(values_);
1377   MS_EXCEPTION_IF_NULL(dense_shape_);
1378   auto indices_clone = indices_->Clone();
1379   auto value_clone = values_->Clone();
1380   auto dense_clone = dense_shape_->Clone();
1381   MS_EXCEPTION_IF_NULL(indices_clone);
1382   MS_EXCEPTION_IF_NULL(value_clone);
1383   MS_EXCEPTION_IF_NULL(dense_clone);
1384   broaden->set_shape(shp->Clone());
1385   broaden->set_value(kAnyValue);
1386   broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1387   broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1388   broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1389   return broaden;
1390 }
1391 
BroadenWithShape() const1392 AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
1393   MS_EXCEPTION_IF_NULL(element());
1394   auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
1395   auto this_shape = shape();
1396   MS_EXCEPTION_IF_NULL(this_shape);
1397   auto shp = this_shape->Clone();
1398   MS_EXCEPTION_IF_NULL(shp);
1399   shp->Broaden();
1400   broaden->set_shape(shp);
1401   broaden->set_value(kAnyValue);
1402   MS_EXCEPTION_IF_NULL(indices_);
1403   MS_EXCEPTION_IF_NULL(values_);
1404   MS_EXCEPTION_IF_NULL(dense_shape_);
1405   auto indices_clone = indices_->Clone();
1406   auto value_clone = values_->Clone();
1407   auto dense_clone = dense_shape_->Clone();
1408   MS_EXCEPTION_IF_NULL(indices_clone);
1409   MS_EXCEPTION_IF_NULL(value_clone);
1410   MS_EXCEPTION_IF_NULL(dense_clone);
1411   broaden->set_indices(indices_clone->cast<AbstractTensorPtr>());
1412   broaden->set_values(value_clone->cast<AbstractTensorPtr>());
1413   broaden->set_dense_shape(dense_clone->cast<AbstractTuplePtr>());
1414   return broaden;
1415 }
1416 
ToString() const1417 std::string AbstractSparseTensor::ToString() const {
1418   std::ostringstream buffer;
1419   BaseShapePtr shape_track = GetShapeTrack();
1420   MS_EXCEPTION_IF_NULL(shape_track);
1421   MS_EXCEPTION_IF_NULL(element());
1422   auto value_track = GetValueTrack();
1423   MS_EXCEPTION_IF_NULL(value_track);
1424   MS_EXCEPTION_IF_NULL(indices_);
1425   MS_EXCEPTION_IF_NULL(values_);
1426   MS_EXCEPTION_IF_NULL(dense_shape_);
1427   buffer << type_name() << "("
1428          << "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
1429          << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
1430          << ", indices: " << indices_->ToString() << ", values" << values_->ToString()
1431          << ", dense_shape: " << dense_shape_->ToString();
1432   return buffer.str();
1433 }
1434 
Join(const AbstractBasePtr & other)1435 AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
1436   MS_EXCEPTION_IF_NULL(other);
1437   if (!other->isa<AbstractUMonad>()) {
1438     auto this_type = GetTypeTrack();
1439     auto other_type = other->GetTypeTrack();
1440     MS_EXCEPTION_IF_NULL(this_type);
1441     MS_EXCEPTION_IF_NULL(other);
1442     TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
1443   }
1444   return shared_from_base<AbstractBase>();
1445 }
1446 
operator ==(const AbstractUMonad &) const1447 bool AbstractUMonad::operator==(const AbstractUMonad &) const { return true; }
1448 
operator ==(const AbstractBase & other) const1449 bool AbstractUMonad::operator==(const AbstractBase &other) const {
1450   if (&other == this) {
1451     return true;
1452   }
1453   return other.isa<AbstractUMonad>();
1454 }
1455 
Join(const AbstractBasePtr & other)1456 AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) {
1457   MS_EXCEPTION_IF_NULL(other);
1458   if (!other->isa<AbstractIOMonad>()) {
1459     auto this_type = GetTypeTrack();
1460     auto other_type = other->GetTypeTrack();
1461     MS_EXCEPTION_IF_NULL(this_type);
1462     MS_EXCEPTION_IF_NULL(other);
1463     TypeJoinLogging(this_type, other_type, shared_from_base<AbstractBase>(), other);
1464   }
1465   return shared_from_base<AbstractBase>();
1466 }
1467 
operator ==(const AbstractIOMonad &) const1468 bool AbstractIOMonad::operator==(const AbstractIOMonad &) const { return true; }
1469 
operator ==(const AbstractBase & other) const1470 bool AbstractIOMonad::operator==(const AbstractBase &other) const {
1471   if (&other == this) {
1472     return true;
1473   }
1474   return other.isa<AbstractIOMonad>();
1475 }
1476 }  // namespace abstract
1477 }  // namespace mindspore
1478