1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ 20 #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ 21 22 #include <utility> 23 #include <vector> 24 #include <string> 25 #include <unordered_map> 26 #include <memory> 27 28 #include "utils/log_adapter.h" 29 #include "utils/hashing.h" 30 #include "utils/any.h" 31 #include "utils/flags.h" 32 #include "base/base.h" 33 #include "ir/dtype.h" 34 #include "ir/value.h" 35 #include "ir/tensor.h" 36 #include "abstract/dshape.h" 37 #include "utils/shape_utils.h" 38 39 namespace mindspore { 40 namespace abstract { 41 class AbstractBase; 42 using AbstractBasePtrList = std::vector<AbstractBasePtr>; 43 44 // The base class for abstract value. The abstract value is used in evaluating 45 // to express the type, shape, and value of the real value. 46 class MS_CORE_API AbstractBase : public Base { 47 public: 48 using TraceNodeProvider = std::function<void(AnfNodePtr *node)>; 49 50 explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, 51 const BaseShapePtr &shape = kNoShape) value_(value)52 : value_(value), type_(type), shape_(shape) {} 53 ~AbstractBase() override = default; MS_DECLARE_PARENT(AbstractBase,Base)54 MS_DECLARE_PARENT(AbstractBase, Base) 55 56 std::size_t hash() const override { return tid(); } 57 std::string ToString() const override; 58 59 virtual bool operator==(const AbstractBase &other) const; set_value(const ValuePtr & value)60 void set_value(const ValuePtr &value) { value_ = value; } set_type(const TypePtr & type)61 void set_type(const TypePtr &type) { type_ = type; } set_shape(const BaseShapePtr & shape)62 virtual void set_shape(const BaseShapePtr &shape) { shape_ = shape; } set_value_desc(const std::string & desc)63 void set_value_desc(const std::string &desc) { value_desc_ = desc; } value_desc()64 const std::string &value_desc() const { return value_desc_; } GetValueTrack()65 ValuePtr GetValueTrack() const { return value_; } GetTypeTrack()66 TypePtr GetTypeTrack() const { return type_; } GetShapeTrack()67 BaseShapePtr GetShapeTrack() const { return shape_; } 68 69 // Try build a real value from an abstract value. If the value cannot be built, 70 // a default value (AnyValue) is returned. 71 ValuePtr BuildValue() const; 72 73 virtual TypePtr BuildType() const = 0; BuildShape()74 virtual BaseShapePtr BuildShape() const { return kNoShape; } 75 virtual AbstractBasePtr Clone() const = 0; 76 set_trace_node_provider(TraceNodeProvider trace_node_provider)77 static void set_trace_node_provider(TraceNodeProvider trace_node_provider) { 78 trace_node_provider_ = trace_node_provider; 79 } 80 81 inline static TraceNodeProvider trace_node_provider_ = nullptr; 82 virtual AbstractBasePtr Broaden() const; Join(const AbstractBasePtr &)83 virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); } 84 85 friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) { 86 os << a->ToString(); 87 return os; 88 } 89 // Broaden abstract with constraints, broaden only if cond_func is true.Now this interface is only used in pynative 90 // mode when abstract type is AbstractTensor(not include the derived abstract type). 91 virtual AbstractBasePtr PartialBroaden() const; 92 93 protected: 94 // default implementation, it can be overwritten by subclass; RealBuildValue()95 virtual ValuePtr RealBuildValue() const { return kAnyValue; } 96 97 private: 98 ValuePtr value_; 99 TypePtr type_; 100 BaseShapePtr shape_; 101 std::string value_desc_; // store initial value description for error report 102 }; 103 104 class MS_CORE_API AbstractScalar : public AbstractBase { 105 public: AbstractScalar()106 AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {} AbstractScalar(const ValuePtr & value,const TypePtr & type)107 AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {} AbstractScalar(const ValuePtr & value)108 explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {} AbstractScalar(int value)109 explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {} AbstractScalar(int64_t value)110 explicit AbstractScalar(int64_t value) : AbstractBase(MakeValue(value), kInt64) {} AbstractScalar(float value)111 explicit AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {} AbstractScalar(double value)112 explicit AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {} AbstractScalar(bool value)113 explicit AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {} AbstractScalar(const std::string & value)114 explicit AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {} AbstractScalar(const TypePtr & type)115 explicit AbstractScalar(const TypePtr &type) : AbstractBase(kAnyValue, type) {} 116 ~AbstractScalar() override = default; MS_DECLARE_PARENT(AbstractScalar,AbstractBase)117 MS_DECLARE_PARENT(AbstractScalar, AbstractBase) 118 119 std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); } 120 BuildType()121 TypePtr BuildType() const override { return GetTypeTrack(); } Clone()122 AbstractBasePtr Clone() const override { 123 return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone()); 124 } 125 AbstractBasePtr Broaden() const override; 126 AbstractBasePtr Join(const AbstractBasePtr &other) override; 127 }; 128 using AbstractScalarPtr = std::shared_ptr<AbstractScalar>; 129 130 class MS_CORE_API AbstractType : public AbstractBase { 131 public: AbstractType(const TypePtr & type)132 explicit AbstractType(const TypePtr &type) : AbstractBase(type, kTypeType) { 133 if (type == nullptr) { 134 MS_LOG(EXCEPTION) << "type is nullptr"; 135 } 136 } 137 ~AbstractType() override = default; 138 MS_DECLARE_PARENT(AbstractType, AbstractBase) 139 140 std::string ToString() const override; 141 bool operator==(const AbstractBase &other) const override; 142 BuildType()143 TypePtr BuildType() const override { return std::make_shared<TypeType>(); } 144 AbstractBasePtr Clone() const override; Broaden()145 AbstractBasePtr Broaden() const override { return Clone(); } 146 }; 147 using AbstractTypePtr = std::shared_ptr<AbstractType>; 148 149 class MS_CORE_API AbstractError : public AbstractBase { 150 public: AbstractError(const StringImmPtr & err,const AnfNodePtr & node)151 AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) { 152 if (err == nullptr || node == nullptr) { 153 MS_LOG(EXCEPTION) << "err or node is nullptr"; 154 } 155 } 156 ~AbstractError() override = default; MS_DECLARE_PARENT(AbstractError,AbstractBase)157 MS_DECLARE_PARENT(AbstractError, AbstractBase) 158 159 TypePtr BuildType() const override { return std::make_shared<Problem>(); } Broaden()160 AbstractBasePtr Broaden() const override { return Clone(); } 161 Clone()162 AbstractBasePtr Clone() const override { 163 return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_); 164 } 165 166 std::string ToString() const override; 167 168 private: 169 // Origin node been specialized to AbstractError, for debug purpose only. 170 const AnfNodePtr node_; 171 }; 172 173 class MS_CORE_API AbstractScript : public AbstractBase { 174 public: AbstractScript()175 AbstractScript() : AbstractBase(kAnyValue, kAnyType) {} AbstractScript(const ValuePtr & value,const TypePtr & type)176 AbstractScript(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {} AbstractScript(const ValuePtr & value)177 explicit AbstractScript(const ValuePtr &value) : AbstractBase(value, kString) {} 178 // explicit AbstractScript(const std::string &value) : AbstractBase(MakeValue(value), kString) {} 179 ~AbstractScript() override = default; MS_DECLARE_PARENT(AbstractScript,AbstractBase)180 MS_DECLARE_PARENT(AbstractScript, AbstractBase) 181 182 std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); } 183 BuildType()184 TypePtr BuildType() const override { return GetTypeTrack(); } Clone()185 AbstractBasePtr Clone() const override { 186 return std::make_shared<AbstractScript>(GetValueTrack(), GetTypeTrack()->Clone()); 187 } Broaden()188 AbstractBasePtr Broaden() const override { return Clone(); } 189 }; 190 using AbstractScriptPtr = std::shared_ptr<AbstractScript>; 191 192 class Evaluator; 193 using EvaluatorPtr = std::shared_ptr<Evaluator>; 194 class AnalysisEngine; 195 using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>; 196 197 class AbstractFunction; 198 using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>; 199 class AbstractFuncAtom; 200 using AbstractFuncAtomPtr = std::shared_ptr<AbstractFuncAtom>; 201 using AbstractFuncAtomPtrList = std::vector<AbstractFuncAtomPtr>; 202 203 class MS_CORE_API AbstractFunction : public AbstractBase { 204 public: 205 AbstractFunction() = default; 206 ~AbstractFunction() override = default; 207 MS_DECLARE_PARENT(AbstractFunction, AbstractBase) 208 209 // If there is exactly one possible function, return it. Otherwise, raise an Exception. 210 // Caller should ensure the uniqueness. 211 virtual AbstractFunctionPtr GetUnique() = 0; 212 BuildType()213 TypePtr BuildType() const override { return std::make_shared<Function>(); } Clone()214 AbstractBasePtr Clone() const override { return Copy(); } 215 // For Function, no need to broaden. Broaden()216 AbstractBasePtr Broaden() const override { 217 return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>(); 218 } 219 virtual AbstractFunctionPtr Copy() const = 0; 220 221 AbstractBasePtr Join(const AbstractBasePtr &other) final; 222 virtual AbstractFunctionPtr Join(const AbstractFunctionPtr &other) = 0; 223 224 virtual void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const = 0; 225 bool operator==(const AbstractBase &other) const final; 226 virtual bool operator==(const AbstractFunction &other) const = 0; 227 228 static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); 229 tracking_id()230 virtual AnfNodePtr tracking_id() const { return nullptr; } set_tracking_id(AnfNodePtr)231 virtual void set_tracking_id(AnfNodePtr) {} context()232 virtual AnalysisContextPtr context() const { return nullptr; } 233 }; 234 using AbstractFunctionPtrList = std::vector<AbstractFunctionPtr>; 235 236 // Represents a key-value pair used in function's parameters. 237 class MS_CORE_API AbstractKeywordArg : public AbstractBase { 238 public: AbstractKeywordArg(const std::string & key,const AbstractBasePtr & argument)239 AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument) : arg_name_(key), arg_value_(argument) {} 240 ~AbstractKeywordArg() override = default; 241 MS_DECLARE_PARENT(AbstractKeywordArg, AbstractBase) 242 243 TypePtr BuildType() const override; 244 AbstractBasePtr Clone() const override; 245 AbstractBasePtr Broaden() const override; 246 std::size_t hash() const override; 247 248 bool operator==(const AbstractKeywordArg &other) const; 249 bool operator==(const AbstractBase &other) const override; get_key()250 std::string get_key() const { return arg_name_; } get_arg()251 AbstractBasePtr get_arg() const { return arg_value_; } 252 253 std::string ToString() const override; 254 255 protected: 256 ValuePtr RealBuildValue() const override; 257 258 private: 259 std::string arg_name_; 260 AbstractBasePtr arg_value_; 261 }; 262 using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>; 263 264 class MS_CORE_API AbstractUndetermined : public AbstractBase { 265 public: 266 // shape and type are all unknown AbstractUndetermined()267 AbstractUndetermined() : AbstractBase(kAnyValue) {} 268 // only element_ and value, shape track are valid member, type track are unknown. 269 explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) AbstractBase(kAnyValue)270 : AbstractBase(kAnyValue), element_(element) { 271 if (element == nullptr) { 272 MS_LOG(EXCEPTION) << "element is nullptr"; 273 } 274 if (element->isa<AbstractUndetermined>()) { 275 MS_LOG(EXCEPTION) << "element type error"; 276 } 277 MS_EXCEPTION_IF_NULL(shape); 278 if (shape->isa<NoShape>()) { 279 MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape."; 280 } 281 AbstractBase::set_shape(shape); 282 } AbstractUndetermined(const TypePtr & element_type,const ShapeVector & shape)283 AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape) 284 : AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) { 285 if (element_type == nullptr) { 286 MS_LOG(EXCEPTION) << "element_type is nullptr"; 287 } 288 AbstractBase::set_shape(std::make_shared<Shape>(shape)); 289 } 290 explicit AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>()) AbstractBase(kAnyValue)291 : AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) { 292 if (element_type == nullptr) { 293 MS_LOG(EXCEPTION) << "element_type is nullptr"; 294 } 295 MS_EXCEPTION_IF_NULL(shape); 296 if (shape->isa<NoShape>()) { 297 MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape."; 298 } 299 AbstractBase::set_shape(shape); 300 } 301 ~AbstractUndetermined() override = default; MS_DECLARE_PARENT(AbstractUndetermined,AbstractBase)302 MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase) 303 TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); } Clone()304 AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); } element()305 AbstractBasePtr element() const { return element_; } 306 ShapePtr shape() const; 307 void set_shape(const BaseShapePtr &shape) override; 308 309 protected: 310 AbstractBasePtr element_; 311 }; 312 313 class MS_CORE_API AbstractTensor : public AbstractUndetermined { 314 public: 315 // only element_ and value, shape track are valid member, type track are unknown. 316 explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) AbstractUndetermined(element,shape)317 : AbstractUndetermined(element, shape) {} AbstractTensor(const TypePtr & element_type,const ShapeVector & shape)318 AbstractTensor(const TypePtr &element_type, const ShapeVector &shape) : AbstractUndetermined(element_type, shape) {} AbstractTensor(const tensor::TensorPtr & tensor)319 explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {} 320 explicit AbstractTensor(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>()) AbstractUndetermined(element_type,shape)321 : AbstractUndetermined(element_type, shape) {} 322 ~AbstractTensor() override = default; MS_DECLARE_PARENT(AbstractTensor,AbstractUndetermined)323 MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined) 324 325 void set_value_range(const ValuePtr &min_value, const ValuePtr &max_value) { 326 min_value_ = min_value; 327 max_value_ = max_value; 328 } get_min_value()329 const ValuePtr &get_min_value() const { return min_value_; } get_max_value()330 const ValuePtr &get_max_value() const { return max_value_; } 331 332 TypePtr BuildType() const override; 333 BaseShapePtr BuildShape() const override; 334 AbstractBasePtr Clone() const override; 335 AbstractBasePtr Broaden() const override; 336 AbstractBasePtr BroadenWithShape() const; 337 AbstractBasePtr Join(const AbstractBasePtr &other) override; 338 bool operator==(const AbstractTensor &other) const; 339 bool operator==(const AbstractBase &other) const override; 340 std::string ToString() const override; hash()341 std::size_t hash() const override { 342 auto value = GetValueTrack(); 343 auto hash_sum = hash_combine(tid(), element_->hash()); 344 if (value != nullptr) { 345 auto tensor = value->cast<tensor::TensorPtr>(); 346 if (tensor != nullptr) { 347 hash_sum = hash_combine(hash_sum, LongToSize(tensor->DataSize())); 348 } 349 } 350 return hash_sum; 351 } 352 AbstractBasePtr PartialBroaden() const override; 353 354 protected: 355 bool equal_to(const AbstractTensor &other) const; 356 ValuePtr min_value_ = nullptr; 357 ValuePtr max_value_ = nullptr; 358 }; 359 using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; 360 using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; 361 362 class MS_CORE_API AbstractSequeue : public AbstractBase { 363 public: AbstractSequeue(const AbstractBasePtrList & elements)364 explicit AbstractSequeue(const AbstractBasePtrList &elements) : elements_(elements) {} 365 ~AbstractSequeue() override = default; 366 MS_DECLARE_PARENT(AbstractSequeue, AbstractBase) 367 368 TypePtrList ElementsType() const; 369 BaseShapePtrList ElementsShape() const; 370 AbstractBasePtrList ElementsClone() const; 371 AbstractBasePtrList ElementsBroaden() const; 372 AbstractBasePtrList ElementsPartialBroaden() const; 373 374 template <typename T> 375 ValuePtr ElementsBuildValue() const; 376 377 template <typename T> 378 AbstractBasePtr ElementsJoin(const AbstractBasePtr &other); 379 size()380 std::size_t size() const { return elements_.size(); } elements()381 const AbstractBasePtrList &elements() const { return elements_; } 382 383 std::size_t hash() const override; 384 std::string ToString() const override; 385 const AbstractBasePtr operator[](const std::size_t &dim) const; 386 virtual bool operator==(const AbstractSequeue &other) const; 387 388 protected: 389 AbstractBasePtrList elements_; 390 }; 391 using AbstractSequeuePtr = std::shared_ptr<AbstractSequeue>; 392 393 class MS_CORE_API AbstractTuple : public AbstractSequeue { 394 public: AbstractTuple(const AbstractBasePtrList & elements)395 explicit AbstractTuple(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} 396 397 ~AbstractTuple() override = default; MS_DECLARE_PARENT(AbstractTuple,AbstractSequeue)398 MS_DECLARE_PARENT(AbstractTuple, AbstractSequeue) 399 400 TypePtr BuildType() const override { return std::make_shared<Tuple>(ElementsType()); } 401 BuildShape()402 BaseShapePtr BuildShape() const override { return std::make_shared<TupleShape>(ElementsShape()); } 403 Clone()404 AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); } 405 Broaden()406 AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); } 407 PartialBroaden()408 AbstractBasePtr PartialBroaden() const override { return std::make_shared<AbstractTuple>(ElementsPartialBroaden()); } 409 Join(const AbstractBasePtr & other)410 AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); } 411 ToString()412 std::string ToString() const override { return type_name() + "(" + AbstractSequeue::ToString() + ")"; } 413 414 bool operator==(const AbstractTuple &other) const; 415 bool operator==(const AbstractBase &other) const override; 416 417 protected: RealBuildValue()418 ValuePtr RealBuildValue() const override { return ElementsBuildValue<ValueTuple>(); } 419 }; 420 using AbstractTuplePtr = std::shared_ptr<AbstractTuple>; 421 422 class MS_CORE_API AbstractList : public AbstractSequeue { 423 public: AbstractList(const AbstractBasePtrList & elements)424 explicit AbstractList(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} 425 426 ~AbstractList() override = default; MS_DECLARE_PARENT(AbstractList,AbstractSequeue)427 MS_DECLARE_PARENT(AbstractList, AbstractSequeue) 428 429 TypePtr BuildType() const override { return std::make_shared<List>(ElementsType()); } 430 BuildShape()431 BaseShapePtr BuildShape() const override { return std::make_shared<ListShape>(ElementsShape()); } 432 Clone()433 AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); } 434 Broaden()435 AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); } 436 PartialBroaden()437 AbstractBasePtr PartialBroaden() const override { return std::make_shared<AbstractList>(ElementsPartialBroaden()); } 438 Join(const AbstractBasePtr & other)439 AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); } 440 ToString()441 std::string ToString() const override { return type_name() + "[" + AbstractSequeue::ToString() + "]"; } 442 443 bool operator==(const AbstractList &other) const; 444 445 bool operator==(const AbstractBase &other) const override; 446 447 protected: RealBuildValue()448 ValuePtr RealBuildValue() const override { return ElementsBuildValue<ValueList>(); } 449 }; 450 using AbstractListPtr = std::shared_ptr<AbstractList>; 451 452 class MS_CORE_API AbstractClass : public AbstractBase { 453 public: AbstractClass(const Named & tag,const std::vector<AbstractAttribute> & attributes,const std::unordered_map<std::string,ValuePtr> & methods)454 AbstractClass(const Named &tag, const std::vector<AbstractAttribute> &attributes, 455 const std::unordered_map<std::string, ValuePtr> &methods) 456 : attributes_(attributes), tag_(tag), methods_(methods) {} 457 458 ~AbstractClass() override = default; 459 MS_DECLARE_PARENT(AbstractClass, AbstractBase) 460 461 TypePtr BuildType() const override; 462 bool operator==(const AbstractClass &other) const; 463 bool operator==(const AbstractBase &other) const override; attributes()464 const std::vector<AbstractAttribute> &attributes() const { return attributes_; } methods()465 std::unordered_map<std::string, ValuePtr> methods() { return methods_; } 466 AbstractBasePtr GetAttribute(const std::string &name); 467 ValuePtr GetMethod(const std::string &name); 468 AbstractBasePtr Clone() const override; 469 AbstractBasePtr Broaden() const override; 470 std::string ToString() const override; tag()471 Named tag() const { return tag_; } 472 std::size_t hash() const override; 473 474 protected: 475 ValuePtr RealBuildValue() const override; 476 477 private: 478 std::vector<AbstractAttribute> attributes_; 479 Named tag_; 480 std::unordered_map<std::string, ValuePtr> methods_; 481 }; 482 using AbstractClassPtr = std::shared_ptr<AbstractClass>; 483 484 class MS_CORE_API AbstractDictionary : public AbstractBase { 485 public: AbstractDictionary(const std::vector<AbstractAttribute> & key_values)486 explicit AbstractDictionary(const std::vector<AbstractAttribute> &key_values) : key_values_(key_values) {} 487 ~AbstractDictionary() override = default; 488 MS_DECLARE_PARENT(AbstractDictionary, AbstractBase) 489 490 TypePtr BuildType() const override; 491 bool operator==(const AbstractDictionary &other) const; 492 bool operator==(const AbstractBase &other) const override; 493 AbstractBasePtr Clone() const override; 494 AbstractBasePtr Broaden() const override; 495 std::string ToString() const override; 496 std::size_t hash() const override; size()497 std::size_t size() const { return key_values_.size(); } elements()498 const std::vector<AbstractAttribute> &elements() const { return key_values_; } 499 500 std::vector<AbstractAttribute> key_values_; 501 502 protected: 503 ValuePtr RealBuildValue() const override; 504 }; 505 using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>; 506 507 class MS_CORE_API AbstractSlice : public AbstractBase { 508 public: AbstractSlice(const AbstractBasePtr & start,const AbstractBasePtr & stop,const AbstractBasePtr & step)509 AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step) 510 : start_(start), stop_(stop), step_(step) {} 511 ~AbstractSlice() override = default; 512 MS_DECLARE_PARENT(AbstractSlice, AbstractBase) 513 514 TypePtr BuildType() const override; 515 bool operator==(const AbstractSlice &other) const; 516 bool operator==(const AbstractBase &other) const override; 517 AbstractBasePtr Clone() const override; 518 AbstractBasePtr Broaden() const override; 519 std::string ToString() const override; 520 std::size_t hash() const override; start()521 AbstractBasePtr start() const { return start_; } stop()522 AbstractBasePtr stop() const { return stop_; } step()523 AbstractBasePtr step() const { return step_; } 524 525 protected: 526 ValuePtr RealBuildValue() const override; 527 528 private: 529 AbstractBasePtr start_; 530 AbstractBasePtr stop_; 531 AbstractBasePtr step_; 532 }; 533 using AbstractSlicePtr = std::shared_ptr<AbstractSlice>; 534 535 class MS_CORE_API AbstractJTagged : public AbstractBase { 536 public: AbstractJTagged(const AbstractBasePtr & element)537 explicit AbstractJTagged(const AbstractBasePtr &element) : element_(element) {} 538 539 ~AbstractJTagged() override = default; 540 MS_DECLARE_PARENT(AbstractJTagged, AbstractBase) 541 542 TypePtr BuildType() const override; Clone()543 AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); } Broaden()544 AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); } 545 AbstractBasePtr Join(const AbstractBasePtr &other) override; 546 547 bool operator==(const AbstractJTagged &other) const; 548 bool operator==(const AbstractBase &other) const override; 549 std::string ToString() const override; element()550 AbstractBasePtr element() { return element_; } hash()551 std::size_t hash() const override { return hash_combine(tid(), element_->hash()); } 552 553 private: 554 AbstractBasePtr element_; 555 }; 556 using AbstractJTaggedPtr = std::shared_ptr<AbstractJTagged>; 557 558 class MS_CORE_API AbstractNone : public AbstractBase { 559 public: AbstractNone()560 AbstractNone() : AbstractBase() { set_type(std::make_shared<TypeNone>()); } 561 ~AbstractNone() override = default; MS_DECLARE_PARENT(AbstractNone,AbstractBase)562 MS_DECLARE_PARENT(AbstractNone, AbstractBase) 563 564 TypePtr BuildType() const override { return std::make_shared<TypeNone>(); } 565 bool operator==(const AbstractNone &other) const; 566 bool operator==(const AbstractBase &other) const override; Clone()567 AbstractBasePtr Clone() const override { return std::make_shared<AbstractNone>(); } 568 std::string ToString() const override; 569 570 protected: 571 ValuePtr RealBuildValue() const override; 572 }; 573 using AbstractNonePtr = std::shared_ptr<AbstractNone>; 574 575 // the un assigned state value for variable, which means the variable is not assigned 576 class MS_CORE_API AbstractNull : public AbstractBase { 577 public: AbstractNull()578 AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); } 579 ~AbstractNull() override = default; MS_DECLARE_PARENT(AbstractNull,AbstractBase)580 MS_DECLARE_PARENT(AbstractNull, AbstractBase) 581 582 TypePtr BuildType() const override { return std::make_shared<TypeNull>(); } 583 bool operator==(const AbstractNull &other) const; 584 bool operator==(const AbstractBase &other) const override; Clone()585 AbstractBasePtr Clone() const override { return std::make_shared<AbstractNull>(); } 586 std::string ToString() const override; 587 }; 588 using AbstractNullPtr = std::shared_ptr<AbstractNull>; 589 590 // the timeout state value for variable, which means the variable is not assigned because it is timeout 591 class MS_CORE_API AbstractTimeOut : public AbstractBase { 592 public: AbstractTimeOut()593 AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); } 594 ~AbstractTimeOut() override = default; MS_DECLARE_PARENT(AbstractTimeOut,AbstractBase)595 MS_DECLARE_PARENT(AbstractTimeOut, AbstractBase) 596 597 TypePtr BuildType() const override { return std::make_shared<TypeNull>(); } 598 bool operator==(const AbstractTimeOut &other) const; 599 bool operator==(const AbstractBase &other) const override; Clone()600 AbstractBasePtr Clone() const override { return std::make_shared<AbstractTimeOut>(); } 601 std::string ToString() const override; 602 }; 603 using AbstractTimeOutPtr = std::shared_ptr<AbstractTimeOut>; 604 605 class MS_CORE_API AbstractEllipsis : public AbstractBase { 606 public: AbstractEllipsis()607 AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); } 608 ~AbstractEllipsis() override = default; MS_DECLARE_PARENT(AbstractEllipsis,AbstractBase)609 MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) 610 611 TypePtr BuildType() const override { return std::make_shared<TypeEllipsis>(); } 612 bool operator==(const AbstractEllipsis &other) const; 613 bool operator==(const AbstractBase &other) const override; Clone()614 AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); } 615 std::string ToString() const override; 616 }; 617 using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>; 618 619 class MS_CORE_API AbstractRefKey : public AbstractBase { 620 public: AbstractRefKey()621 AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); } 622 ~AbstractRefKey() override = default; MS_DECLARE_PARENT(AbstractRefKey,AbstractBase)623 MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) 624 625 TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); } 626 bool operator==(const AbstractRefKey &other) const; 627 bool operator==(const AbstractBase &other) const override; Clone()628 AbstractBasePtr Clone() const override { 629 auto cloned = std::make_shared<AbstractRefKey>(); 630 cloned->set_value(GetValueTrack()); 631 return cloned; 632 } set_value(const ValuePtr & value)633 inline void set_value(const ValuePtr &value) { 634 AbstractBase::set_value(value); 635 if (value != nullptr) { 636 ref_key_value_ = value->cast<RefKeyPtr>(); 637 } 638 } ref_key_value()639 RefKeyPtr ref_key_value() const { return ref_key_value_; } 640 AbstractBasePtr Join(const AbstractBasePtr &other) override; 641 std::string ToString() const override; 642 643 private: 644 // cache for ref_key after build value, when value is null, return nullptr. 645 RefKeyPtr ref_key_value_{nullptr}; 646 }; 647 using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; 648 649 class MS_CORE_API AbstractRef : public AbstractTensor { 650 public: 651 AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value); 652 ~AbstractRef() override = default; 653 MS_DECLARE_PARENT(AbstractRef, AbstractTensor) 654 655 TypePtr BuildType() const override; 656 bool operator==(const AbstractRef &other) const; 657 bool operator==(const AbstractBase &other) const override; Clone()658 AbstractBasePtr Clone() const override { 659 auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>(); 660 if (abs_tensor == nullptr) { 661 return nullptr; 662 } 663 return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor); 664 } CloneAsTensor()665 AbstractBasePtr CloneAsTensor() const { return AbstractTensor::Clone(); } 666 std::string ToString() const override; ref()667 inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); } ref_key()668 inline AbstractBasePtr ref_key() const { return ref_key_; } ref_key_value()669 inline RefKeyPtr ref_key_value() const { return ref_key_value_; } Broaden()670 AbstractBasePtr Broaden() const override { 671 // always broaden for ref 672 auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>(); 673 if (abs_tensor == nullptr) { 674 return nullptr; 675 } 676 return std::make_shared<AbstractRef>(ref_key_->Broaden(), abs_tensor); 677 } 678 AbstractBasePtr Join(const AbstractBasePtr &other) override; hash()679 std::size_t hash() const override { 680 return AbstractTensor::hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ 681 } 682 AbstractBasePtr PartialBroaden() const override; 683 684 private: 685 AbstractBasePtr ref_key_; 686 // cache for ref_key after build value, when value is null, return nullptr. 687 RefKeyPtr ref_key_value_; 688 }; 689 using AbstractRefPtr = std::shared_ptr<AbstractRef>; 690 691 struct MS_CORE_API AbstractBasePtrListHasher { 692 std::size_t operator()(const AbstractBasePtrList &args_spec_list) const; 693 }; 694 695 struct MS_CORE_API AbstractBasePtrListEqual { 696 bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const; 697 }; 698 699 MS_CORE_API std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); 700 MS_CORE_API bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); 701 702 // RowTensor 703 class MS_CORE_API AbstractRowTensor : public AbstractUndetermined { 704 public: 705 explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) AbstractUndetermined(element,shape)706 : AbstractUndetermined(element, shape) {} AbstractRowTensor(const TypePtr & element_type,const ShapeVector & shape)707 AbstractRowTensor(const TypePtr &element_type, const ShapeVector &shape) 708 : AbstractUndetermined(element_type, shape) {} 709 ~AbstractRowTensor() override = default; MS_DECLARE_PARENT(AbstractRowTensor,AbstractUndetermined)710 MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined) 711 712 const AbstractTensorPtr indices() const { return indices_; } set_indices(const AbstractTensorPtr & indices)713 void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } values()714 const AbstractTensorPtr values() const { return values_; } set_values(const AbstractTensorPtr & values)715 void set_values(const AbstractTensorPtr &values) { values_ = values; } dense_shape()716 const AbstractTuplePtr dense_shape() const { return dense_shape_; } set_dense_shape(const AbstractTuplePtr & dense_shape)717 void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } 718 TypePtr BuildType() const override; 719 AbstractBasePtr Clone() const override; 720 AbstractBasePtr Broaden() const override; 721 AbstractBasePtr BroadenWithShape() const; 722 723 std::string ToString() const override; 724 725 private: 726 AbstractTensorPtr indices_; 727 AbstractTensorPtr values_; 728 AbstractTuplePtr dense_shape_; 729 }; 730 731 // SparseTensor 732 class MS_CORE_API AbstractSparseTensor : public AbstractUndetermined { 733 public: 734 explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) AbstractUndetermined(element,shape)735 : AbstractUndetermined(element, shape) {} AbstractSparseTensor(const TypePtr & element_type,const ShapeVector & shape)736 AbstractSparseTensor(const TypePtr &element_type, const ShapeVector &shape) 737 : AbstractUndetermined(element_type, shape) {} 738 ~AbstractSparseTensor() override = default; MS_DECLARE_PARENT(AbstractSparseTensor,AbstractUndetermined)739 MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined) 740 741 const AbstractTensorPtr indices() const { return indices_; } set_indices(const AbstractTensorPtr & indices)742 void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } values()743 const AbstractTensorPtr values() const { return values_; } set_values(const AbstractTensorPtr & values)744 void set_values(const AbstractTensorPtr &values) { values_ = values; } dense_shape()745 const AbstractTuplePtr dense_shape() const { return dense_shape_; } set_dense_shape(const AbstractTuplePtr & dense_shape)746 void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } 747 TypePtr BuildType() const override; 748 AbstractBasePtr Clone() const override; 749 AbstractBasePtr Broaden() const override; 750 AbstractBasePtr BroadenWithShape() const; 751 752 std::string ToString() const override; 753 754 private: 755 AbstractTensorPtr indices_; 756 AbstractTensorPtr values_; 757 AbstractTuplePtr dense_shape_; 758 }; 759 760 class AbstractMonad : public AbstractBase { 761 public: 762 ~AbstractMonad() override = default; MS_DECLARE_PARENT(AbstractMonad,AbstractBase)763 MS_DECLARE_PARENT(AbstractMonad, AbstractBase) 764 765 std::size_t hash() const override { return hash_combine({tid()}); } BuildType()766 TypePtr BuildType() const override { return GetTypeTrack(); } Broaden()767 AbstractBasePtr Broaden() const override { return AbstractBase::Broaden(); } 768 AbstractBasePtr Join(const AbstractBasePtr &other) override = 0; ToString()769 std::string ToString() const override { 770 std::ostringstream buffer; 771 buffer << type_name() << "(" << GetValueTrack()->ToString() << ")"; 772 return buffer.str(); 773 } 774 775 protected: AbstractMonad(const ValuePtr & value,const TypePtr & type)776 AbstractMonad(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {} 777 }; 778 using AbstractMonadPtr = std::shared_ptr<AbstractMonad>; 779 780 class AbstractUMonad : public AbstractMonad { 781 public: AbstractMonad(value,kUMonadType)782 explicit AbstractUMonad(const ValuePtr &value = kUMonad) : AbstractMonad(value, kUMonadType) {} 783 ~AbstractUMonad() override = default; MS_DECLARE_PARENT(AbstractUMonad,AbstractMonad)784 MS_DECLARE_PARENT(AbstractUMonad, AbstractMonad) 785 786 AbstractBasePtr Clone() const override { return std::make_shared<AbstractUMonad>(GetValueTrack()); } 787 AbstractBasePtr Join(const AbstractBasePtr &other) override; 788 bool operator==(const AbstractUMonad &other) const; 789 bool operator==(const AbstractBase &other) const override; 790 }; 791 using AbstractUMonadPtr = std::shared_ptr<AbstractUMonad>; 792 793 class AbstractIOMonad : public AbstractMonad { 794 public: AbstractMonad(value,kIOMonadType)795 explicit AbstractIOMonad(const ValuePtr &value = kIOMonad) : AbstractMonad(value, kIOMonadType) {} 796 ~AbstractIOMonad() override = default; MS_DECLARE_PARENT(AbstractIOMonad,AbstractMonad)797 MS_DECLARE_PARENT(AbstractIOMonad, AbstractMonad) 798 799 AbstractBasePtr Clone() const override { return std::make_shared<AbstractIOMonad>(GetValueTrack()); } 800 AbstractBasePtr Join(const AbstractBasePtr &other) override; 801 bool operator==(const AbstractIOMonad &other) const; 802 bool operator==(const AbstractBase &other) const override; 803 }; 804 using AbstractIOMonadPtr = std::shared_ptr<AbstractIOMonad>; 805 806 AnfNodePtr GetTraceNode(const AbstractBasePtr &abs); 807 std::string ExtractLoggingInfo(const std::string &info); 808 } // namespace abstract 809 } // namespace mindspore 810 #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ 811