1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2023 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ 20 #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ 21 22 #include <cstdint> 23 #include <utility> 24 #include <vector> 25 #include <string> 26 #include <memory> 27 #include "utils/log_adapter.h" 28 #include "utils/hashing.h" 29 #include "utils/any.h" 30 #include "utils/hash_map.h" 31 #include "base/base.h" 32 #include "base/user_data.h" 33 #include "ir/dtype.h" 34 #include "ir/value.h" 35 #include "ir/tensor.h" 36 #include "ir/map_tensor.h" 37 #include "abstract/dshape.h" 38 #include "abstract/utils.h" 39 #include "utils/shape_utils.h" 40 #include "mindspore/core/symbolic_shape/symbol.h" 41 42 namespace mindspore { 43 namespace abstract { 44 class AbstractBase; 45 using AbstractBasePtrList = std::vector<AbstractBasePtr>; 46 /// \brief The base class for abstract value of an anf node. 47 /// 48 /// The abstract value is used in evaluator to express 49 /// the type, shape and value of an anf node. 50 class MS_CORE_API AbstractBase : public Base { 51 public: 52 using TraceNodeProvider = std::function<void(AnfNodePtr *node)>; 53 54 /// \brief Constructor of AbstractBase. 55 /// 56 /// \param[in] value The real value (if any) of an anf node. Default: nullptr. 57 /// \param[in] type The type of an anf node. Default: kTypeAny. 58 /// \param[in] shape The dimension of an anf node. Default: kNoShape. 59 explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kTypeAny, 60 const BaseShapePtr &shape = kNoShape); 61 62 /// \brief Copy constructor 63 /// \param[in] abstract_base an abstract 64 AbstractBase(const AbstractBase &other); 65 66 /// \brief Overloads operator '=' for Named. 67 /// 68 /// \param[in] other An an abstract. 69 /// \return An abstract set with the same type, value and shape as abstract_base. 70 AbstractBase &operator=(const AbstractBase &other); 71 72 /// \brief Destructor of AbstractBase. 73 ~AbstractBase() override = default; 74 MS_DECLARE_PARENT(AbstractBase, Base) 75 76 /// \brief Get the hash number of the abstract. 77 /// 78 /// \return The hash of the object. 79 std::size_t hash() const override; 80 81 /// \brief Get the formatted text to describe the abstract. 82 /// 83 /// \return A string. 84 std::string ToString() const override; 85 86 /// \brief Get the formatted text to describe the abstract. 87 /// 88 /// \return A string. 89 virtual std::string ToString(bool verbose) const; 90 91 /// \brief Overwrite the operator '==' to compare other abstract. 92 /// 93 /// \param[in] other The other abstract to be joined. 94 /// 95 /// \return A boolean, which indicates whether the other abstract is same. 96 virtual bool operator==(const AbstractBase &other) const; 97 98 /// \brief Set the value for the AbstractBase. 99 /// 100 /// \param[in] value The value of an anf node. 101 void set_value(const ValuePtr &value); 102 103 /// \brief Set the type for the AbstractBase. 104 /// 105 /// \param[in] type The type of an anf node. 106 void set_type(const TypePtr &type); 107 108 /// \brief Set the shape for the AbstractBase. 109 /// 110 /// \param[in] shape The shape of an anf node. 111 virtual void set_shape(const BaseShapePtr &shape); 112 113 /// \brief Set the value description for the AbstractBase. 114 /// 115 /// \param[in] desc The description of value. 116 void set_value_desc(const std::string &desc); 117 118 /// \brief Get the value description. 119 /// 120 /// \return A string of the value description. 121 const std::string &value_desc() const; 122 123 /// \brief Get the abstract value, which is tracked. 124 /// 125 /// \return A pointer to the Value. 126 const ValuePtr &GetValueTrack() const; 127 128 /// \brief Get the abstract type, which is tracked. 129 /// 130 /// \return A pointer to the Type. 131 const TypePtr &GetTypeTrack() const; 132 133 /// \brief Get the abstract shape, which is tracked. 134 /// 135 /// \return A pointer to the BaseShape. 136 const BaseShapePtr &GetShapeTrack() const; 137 138 /// \brief Try to build a real value from an abstract value. 139 /// 140 /// \note This is a deprecated function, please do not call it, use GetValue instead. 141 /// \note If the value cannot be built, a default value (ValueAny) is returned. 142 /// 143 /// \return A pointer to the Value. 144 ValuePtr BuildValue() const; 145 146 /// \brief Build the type of the abstract. 147 /// 148 /// \note This is a deprecated function, please do not call it, use GetType instead. 149 /// \note Use this function to get the actual type, while track type is not enough accurate. 150 /// 151 /// \return A pointer to the Type. BuildType()152 virtual TypePtr BuildType() const { MS_LOG(EXCEPTION) << "The method 'BuildType()' doesn't implement"; } 153 154 /// \brief Build the shape of the abstract. 155 /// 156 /// \note This is a deprecated function, please do not call it, use GetShape instead. 157 /// \note Use this function to get the actual shape, while track shape is not enough accurate. 158 /// 159 /// \return A pointer to the BaseShape. 160 virtual BaseShapePtr BuildShape() const; 161 162 /// \brief Get or build the shape of AbstractBase. 163 /// 164 /// \return The base shape got or built. 165 virtual BaseShapePtr GetShape() const; 166 167 /// \brief Get or build the object type of the AbstractBase. 168 /// 169 /// \return The object type. 170 virtual TypePtr GetType() const; 171 172 /// \brief Get or build the value of the AbstractBase. 173 /// 174 /// \return The value of the AbstractBase if exists, else return kValueAny. 175 virtual ValuePtr GetValue() const; 176 177 /// \brief Set the symbolic shape of the abstract. SetSymbolicShape(const ListSymbolPtr & s)178 void SetSymbolicShape(const ListSymbolPtr &s) { symbolic_shape_ = s; } 179 180 /// \brief Get the symbolic shape of the abstract. 181 /// 182 /// \return The symbolic shape if exists, else return nullptr. GetSymbolicShape()183 const ListSymbolPtr &GetSymbolicShape() const { return symbolic_shape_; } 184 185 /// \brief Set the symbolic shape of the abstract. SetSymbolicValue(const SymbolPtr & s)186 void SetSymbolicValue(const SymbolPtr &s) { symbolic_value_ = s; } 187 188 /// \brief Get the symbolic value of the abstract. 189 /// 190 /// \return The symbolic value if exists, else return nullptr. GetSymbolicValue()191 const SymbolPtr &GetSymbolicValue() const { return symbolic_value_; } 192 193 /// \brief Clone an abstract from the abstract. 194 /// 195 /// \return A pointer to the cloned abstract. Clone()196 virtual AbstractBasePtr Clone() const { MS_LOG(EXCEPTION) << "The method 'Clone()' doesn't implement"; } 197 198 /// \brief Set the function, which prints the debug info. 199 /// 200 /// \param[in] trace_node_provider The function. 201 static void set_trace_node_provider(const TraceNodeProvider &trace_node_provider); 202 203 static TraceNodeProvider trace_node_provider_; 204 205 /// \brief Broaden the abstract. It will upgrade the abstract to a higher level. 206 /// 207 /// \return A pointer to the broadened abstract. 208 virtual AbstractBasePtr Broaden() const; 209 210 /// \brief Combine two abstracts. If two abstracts are different, it will broaden the abstract value. 211 /// 212 /// \param[in] other The other abstract to be joined. 213 /// 214 /// \return A pointer to the combined abstract. 215 virtual AbstractBasePtr Join(const AbstractBasePtr &other); 216 bool IsBroaden() const; 217 218 /// \brief Write the abstract's string to the std::ostream. 219 /// 220 /// \param[in] os A std::ostream. 221 /// \param[in] a An abstract. 222 /// 223 /// \return A std::ostream. 224 #ifndef _MSC_VER 225 friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) { 226 os << a->ToString(); 227 return os; 228 } 229 #endif 230 /// \brief Broaden abstract with constraints. 231 /// 232 /// \return A pointer to the broadened abstract. 233 virtual AbstractBasePtr PartialBroaden() const; 234 235 /// \brief Process the abstract with InterpretedObject. 236 using InterpretBoolChecker = std::pair<bool, bool> (*)(const AbstractBasePtr &cond); 237 static inline InterpretBoolChecker interpret_bool_checker_ = nullptr; 238 static void set_interpret_bool_checker(InterpretBoolChecker checker); 239 static InterpretBoolChecker interpret_bool_checker(); 240 241 /// \brief Process the user date of abstract with PyExecute node. 242 using PyExecuteUserDataCatcher = std::pair<bool, ValuePtr> (*)(const AbstractBasePtr &element_abs); 243 static inline PyExecuteUserDataCatcher pyexecute_user_data_catcher_ = nullptr; 244 static void set_pyexecute_user_data_catcher(PyExecuteUserDataCatcher catcher); 245 static inline PyExecuteUserDataCatcher pyexecute_user_data_catcher(); 246 247 /// \brief Store for mindir input and output names. 248 std::string name() const; 249 void set_name(const std::string &name); 250 251 /// \brief Cover *this abstract for inplace primitive. If inplace_abstract() is not null, use it as real abstract. 252 AbstractBasePtr inplace_abstract() const; 253 void set_inplace_abstract(const AbstractBasePtr &inplace_abstract); 254 255 protected: 256 /// \brief Build a value when value is not set. 257 /// 258 /// \return A pointer to the Value. 259 virtual ValuePtr RealBuildValue() const; 260 261 ValuePtr value_; 262 TypePtr type_; 263 BaseShapePtr shape_; 264 std::string value_desc_; // Store initial value description for error report. 265 std::string name_; // Store for mindir input and output names. 266 ListSymbolPtr symbolic_shape_{nullptr}; 267 SymbolPtr symbolic_value_{nullptr}; 268 269 private: 270 AbstractBasePtr inplace_abstract_{nullptr}; // Cover *this abstract for inplace primitive. 271 }; 272 273 /// \brief Class AbstractScalar describes a scalar's type and value. 274 class MS_CORE_API AbstractScalar final : public AbstractBase { 275 public: 276 /// \brief Constructor of AbstractScalar. 277 AbstractScalar(); 278 279 /// \brief Constructor of AbstractScalar. 280 /// 281 /// \param[in] value The real value of an anf node. 282 /// \param[in] type The type of an anf node. 283 AbstractScalar(const ValuePtr &value, const TypePtr &type); 284 285 /// \brief Constructor of AbstractScalar. 286 /// 287 /// \param[in] value The real value of an anf node. 288 explicit AbstractScalar(const ValuePtr &value); 289 290 /// \brief Constructor of AbstractScalar, inited with an int number. 291 /// 292 /// \param[in] value An int number. 293 explicit AbstractScalar(int value); 294 295 /// \brief Constructor of AbstractScalar, inited with an int64 number. 296 /// 297 /// \param[in] value An int64 number. 298 explicit AbstractScalar(int64_t value); 299 300 /// \brief Constructor of AbstractScalar, inited with a float number. 301 /// 302 /// \param[in] value A float number. 303 explicit AbstractScalar(float value); 304 305 /// \brief Constructor of AbstractScalar, inited with a double number. 306 /// 307 /// \param[in] value A double number. 308 explicit AbstractScalar(double value); 309 310 /// \brief Constructor of AbstractScalar, inited with a bool. 311 /// 312 /// \param[in] value A boolean variable. 313 explicit AbstractScalar(bool value); 314 315 /// \brief Constructor of AbstractScalar, inited with a string. 316 /// 317 /// \param[in] value A string. 318 explicit AbstractScalar(const std::string &value); 319 320 /// \brief Constructor of AbstractScalar, inited with a type. 321 /// 322 /// \param[in] type The type. 323 explicit AbstractScalar(const TypePtr &type); 324 325 /// \brief Destructor of AbstractScalar. 326 ~AbstractScalar() override = default; 327 MS_DECLARE_PARENT(AbstractScalar, AbstractBase) 328 329 /// \brief Set the flag 'is_variable_' for scalar. 330 /// 331 /// \param[in] is_variable Boolean value for flag 'is_variable_'. 332 void set_is_variable(bool is_variable); 333 334 std::size_t hash() const override; 335 336 TypePtr BuildType() const override; 337 338 AbstractBasePtr Clone() const override; 339 340 AbstractBasePtr Broaden() const override; 341 342 AbstractBasePtr Join(const AbstractBasePtr &other) override; 343 344 private: 345 bool is_variable_{false}; 346 }; 347 using AbstractScalarPtr = std::shared_ptr<AbstractScalar>; 348 349 /// \brief Class AbstractType describes the abstract value from a Typeof node. 350 class MS_CORE_API AbstractType final : public AbstractBase { 351 public: 352 /// \brief Constructor of AbstractType. 353 /// 354 /// \param[in] type The type of an anf node. AbstractType(const TypePtr & type)355 explicit AbstractType(const TypePtr &type) : AbstractBase(type, kTypeType) { 356 if (type == nullptr) { 357 MS_LOG(EXCEPTION) << "type is nullptr"; 358 } 359 } 360 361 /// \brief Destructor of AbstractType. 362 ~AbstractType() override = default; 363 MS_DECLARE_PARENT(AbstractType, AbstractBase) 364 365 std::string ToString() const override; 366 367 bool operator==(const AbstractBase &other) const override; 368 369 TypePtr BuildType() const override; 370 371 AbstractBasePtr Clone() const override; 372 373 AbstractBasePtr Broaden() const override; 374 }; 375 using AbstractTypePtr = std::shared_ptr<AbstractType>; 376 377 /// \brief Class AbstractClass describes the abstract value from a class. 378 class MS_CORE_API AbstractClass final : public AbstractBase { 379 public: 380 /// \brief Constructor of AbstractClass. 381 /// 382 /// \param[in] value A class value. AbstractClass(const ValuePtr & value)383 explicit AbstractClass(const ValuePtr &value) 384 : AbstractBase(value, kClassType), 385 hash_(hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()})) {} 386 387 /// \brief Destructor of AbstractClass. 388 ~AbstractClass() override = default; 389 MS_DECLARE_PARENT(AbstractClass, AbstractBase) 390 391 std::string ToString() const override; 392 hash()393 std::size_t hash() const override { return hash_; } 394 395 bool operator==(const AbstractBase &other) const override; 396 BuildType()397 TypePtr BuildType() const override { return std::make_shared<MsClassType>(); } 398 399 AbstractBasePtr Clone() const override; 400 Broaden()401 AbstractBasePtr Broaden() const override { return Clone(); } 402 403 AbstractBasePtr Join(const AbstractBasePtr &other) override; 404 405 private: 406 std::size_t hash_; 407 }; 408 using AbstractClassPtr = std::shared_ptr<AbstractClass>; 409 410 /// \brief Class AbstractProblem describes the abstract value from an error. 411 class MS_CORE_API AbstractProblem final : public AbstractBase { 412 public: 413 /// \brief Constructor of AbstractProblem. 414 /// 415 /// \param[in] err the error string. 416 /// \param[in] node the binding anf node. 417 AbstractProblem(const ValueProblemPtr &err, const AnfNodePtr &node); 418 419 /// \brief Destructor of AbstractProblem. 420 ~AbstractProblem() override = default; 421 MS_DECLARE_PARENT(AbstractProblem, AbstractBase) 422 423 TypePtr BuildType() const override; 424 425 AbstractBasePtr Broaden() const override; 426 427 AbstractBasePtr Clone() const override; 428 429 std::string ToString() const override; 430 431 private: 432 // Origin node been specialized to AbstractProblem, for debug purpose only. 433 const AnfNodePtr node_; 434 }; 435 436 /// \brief Class AbstractScript describes the script node's type, shape and value. 437 class MS_CORE_API AbstractScript final : public AbstractBase { 438 public: 439 /// \brief Constructor of AbstractScript. 440 AbstractScript(); 441 442 /// \brief Constructor of AbstractScript. 443 /// 444 /// \param[in] value The real value of an anf node. 445 /// \param[in] type The type of an anf node. 446 AbstractScript(const ValuePtr &value, const TypePtr &type); 447 448 /// \brief Constructor of AbstractScript. 449 /// 450 /// \param[in] value The real value to be set. 451 explicit AbstractScript(const ValuePtr &value); 452 453 /// \brief Destructor of AbstractScript. 454 ~AbstractScript() override = default; 455 MS_DECLARE_PARENT(AbstractScript, AbstractBase) 456 457 std::size_t hash() const override; 458 459 TypePtr BuildType() const override; 460 461 AbstractBasePtr Clone() const override; 462 463 AbstractBasePtr Broaden() const override; 464 }; 465 using AbstractScriptPtr = std::shared_ptr<AbstractScript>; 466 467 class Evaluator; 468 using EvaluatorPtr = std::shared_ptr<Evaluator>; 469 class AnalysisEngine; 470 using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>; 471 472 class AbstractFunction; 473 using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>; 474 class AbstractFuncAtom; 475 using AbstractFuncAtomPtr = std::shared_ptr<AbstractFuncAtom>; 476 using AbstractFuncAtomPtrList = std::vector<AbstractFuncAtomPtr>; 477 478 /// \brief The base class for the abstract value of the function node. 479 class MS_CORE_API AbstractFunction : public AbstractBase { 480 public: 481 /// \brief Constructor of AbstractFunction. 482 AbstractFunction() = default; 483 /// \brief Destructor of AbstractFunction. 484 ~AbstractFunction() override = default; 485 MS_DECLARE_PARENT(AbstractFunction, AbstractBase) 486 487 /// \brief Get the unique AbstractFunction. 488 /// 489 /// If there is exactly one possible function, return it. Otherwise, raise an Exception. 490 /// Caller should ensure the uniqueness. 491 /// 492 /// \return A pointer to AbstractFunction. 493 virtual AbstractFunctionPtr GetUnique() = 0; 494 495 TypePtr BuildType() const override; 496 497 AbstractBasePtr Clone() const override; 498 499 AbstractBasePtr Broaden() const override; 500 501 /// \brief Copy an AbstractFunction. 502 /// 503 /// \return A pointer to the copied abstract. 504 virtual AbstractFunctionPtr Copy() const = 0; 505 506 /// \brief Combine two abstracts. If two abstracts are different, it will broaden the abstract value. 507 /// 508 /// \param[in] other The other abstract to be joined. 509 /// 510 /// \return A pointer to the combined abstract. 511 AbstractBasePtr Join(const AbstractBasePtr &other) final; 512 513 /// \brief Combine two abstracts. If two abstracts are different, it will broaden the abstract value. 514 /// 515 /// \param[in] other The other abstract to be joined. 516 /// 517 /// \return A pointer to the combined abstract. 518 virtual AbstractFunctionPtr Join(const AbstractFunctionPtr &other) = 0; 519 520 /// \brief Handle something with the outer visit function. 521 virtual void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const = 0; 522 523 /// \brief Overwrite the operator '==' to compare other abstract. 524 /// 525 /// \param[in] other The other abstract to be joined. 526 /// 527 /// \return A boolean, which indicates whether the other abstract is same. 528 bool operator==(const AbstractBase &other) const final; 529 530 /// \brief Overwrite the operator '==' to compare other AbstractFunction. 531 /// 532 /// \param[in] other The other instance of AbstractFunction. 533 /// 534 /// \return A boolean, which indicates whether the other AbstractFunction is same. 535 virtual bool operator==(const AbstractFunction &other) const = 0; 536 537 /// \brief Make a AbstractFuncUnion from a list of AbstractFuncAtom. 538 /// 539 /// \param[in] func_list A list of AbstractFuncAtomPtrList. 540 /// \return A point to the AbstractFunction. 541 static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); 542 543 /// \brief Get the tracking id as the memory address of the anf node. 544 /// 545 /// \return The memory address of to the anf node. 546 virtual std::uintptr_t tracking_id() const; 547 548 /// \brief Copy an AbstractFunction without copying tracking id. 549 /// 550 /// \return A pointer to the copied abstract. 551 virtual AbstractFunctionPtr CopyWithoutTrackingId() const; 552 553 /// \brief Get the context which manages the abstract. 554 /// 555 /// \return A point to the context. 556 virtual AnalysisContextPtr context() const; 557 558 static std::uintptr_t ToTrackingId(const AnfNodePtr &node); 559 }; 560 561 using AbstractFunctionPtrList = std::vector<AbstractFunctionPtr>; 562 563 /// \brief Class AbstractKeywordArg describes an abstract value from a key-value node. 564 /// 565 /// Represents a key-value pair used in function's parameters. 566 class MS_CORE_API AbstractKeywordArg final : public AbstractBase { 567 public: 568 /// \brief Constructor of AbstractKeywordArg. 569 /// 570 /// \param[in] key The key name of the key-value pair. 571 /// \param[in] argument The key value of the key-value pair. 572 AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument); 573 574 /// \brief Destructor of AbstractKeywordArg. 575 ~AbstractKeywordArg() override = default; 576 MS_DECLARE_PARENT(AbstractKeywordArg, AbstractBase) 577 578 TypePtr BuildType() const override; 579 580 AbstractBasePtr Clone() const override; 581 582 AbstractBasePtr Broaden() const override; 583 584 std::size_t hash() const override; 585 586 /// \brief Overwrite the operator '==' to compare other key-value abstract. 587 /// 588 /// \param[in] other The other abstract to be joined. 589 /// 590 /// \return A boolean, which indicates whether the other abstract is same. 591 bool operator==(const AbstractKeywordArg &other) const; 592 593 bool operator==(const AbstractBase &other) const override; 594 595 /// \brief Get the key name of the key-value pair. 596 /// 597 /// \return A string. 598 std::string get_key() const; 599 600 /// \brief Get the key value of the key-value pair. 601 /// 602 /// \return A point to the abstract. 603 AbstractBasePtr get_arg() const; 604 605 std::string ToString() const override; 606 607 protected: 608 ValuePtr RealBuildValue() const override; 609 610 private: 611 std::string arg_name_; 612 AbstractBasePtr arg_value_; 613 }; 614 using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>; 615 616 /// \brief Class AbstractUndetermined describes the abstract if anf node has unknown shape, type or value. 617 class MS_CORE_API AbstractUndetermined : public AbstractBase { 618 public: 619 /// \brief Constructor of AbstractUndetermined. 620 /// 621 /// Shape and type are all unknown. 622 AbstractUndetermined(); 623 624 /// \brief Constructor of AbstractUndetermined. 625 /// 626 /// Only element, value and shape track are valid member, type track are unknown. 627 /// 628 /// \param[in] element The abstract which is undetermined. 629 /// \param[in] shape The dimension of value. 630 explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()); 631 632 /// \brief Constructor of AbstractUndetermined. 633 /// 634 /// \param[in] element_type A type of the undetermined abstract. 635 /// \param[in] shape A vector of shape. 636 AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape); 637 638 /// \brief Constructor of AbstractUndetermined. 639 /// 640 /// \param[in] element_type A type of the undetermined abstract. 641 /// \param[in] shape A shape of the undetermined abstract. 642 explicit AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>()); 643 644 /// \brief Destructor of AbstractUndetermined. 645 ~AbstractUndetermined() override = default; 646 MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase) 647 648 TypePtr BuildType() const override; 649 650 AbstractBasePtr Clone() const override; 651 652 /// \brief Get the element, which is the tracked undetermined abstract. 653 /// 654 /// \return A pointer to the bind abstract, which is undetermined. 655 AbstractBasePtr element() const; 656 657 /// \brief Get the shape of the undetermined abstract. 658 /// 659 /// \return A pointer to the shape. 660 ShapePtr shape() const; 661 662 void set_shape(const BaseShapePtr &shape) override; 663 664 protected: 665 AbstractBasePtr element_; 666 }; 667 668 /// \brief Class AbstractTensor describes a tensor's type, shape and value. 669 class MS_CORE_API AbstractTensor : public AbstractUndetermined { 670 public: 671 /// \brief Constructor of AbstractTensor. 672 /// 673 /// \param[in] element The abstract to be wrapper as a abstract tensor. 674 /// \param[in] shape The dimension of abstract tensor. 675 explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()); 676 677 /// \brief Constructor of AbstractTensor. 678 /// 679 /// \param[in] element_type The type of abstract tensor. 680 /// \param[in] shape A vector of the tensor's shape. 681 AbstractTensor(const TypePtr &element_type, const ShapeVector &shape); 682 683 /// \brief Constructor of AbstractTensor. 684 /// 685 /// \param[in] tensor The tensor to be abstracted. 686 explicit AbstractTensor(const tensor::TensorPtr &tensor); 687 688 /// \brief Constructor of AbstractTensor. 689 /// 690 /// \param[in] element_type The type of a tensor. 691 /// \param[in] shape The dimension of a tensor. 692 explicit AbstractTensor(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>()); 693 694 /// \brief Destructor of AbstractTensor. 695 ~AbstractTensor() override = default; 696 MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined) 697 698 TypePtr BuildType() const override; 699 700 BaseShapePtr BuildShape() const override; 701 702 AbstractBasePtr Clone() const override; 703 704 AbstractBasePtr Broaden() const override; 705 706 /// \brief Broaden the abstract. It will upgrade the abstract to a higher level. 707 /// 708 /// \note The shape will be remained. 709 /// 710 /// \return A pointer to the broadened abstract. 711 AbstractBasePtr BroadenWithShape() const; 712 713 AbstractBasePtr Join(const AbstractBasePtr &other) override; 714 715 /// \brief Overwrite the operator '==' to compare other abstract tensor. 716 /// 717 /// \param[in] other The other instance of AbstractTensor. 718 /// 719 /// \return A boolean, which indicates whether the other abstract is same. 720 virtual bool operator==(const AbstractTensor &other) const; 721 722 bool operator==(const AbstractBase &other) const override; 723 724 std::string ToString() const override; 725 726 std::size_t hash() const override; 727 728 AbstractBasePtr PartialBroaden() const override; 729 730 bool is_adapter() const; 731 void set_is_adapter(bool is_adapter); 732 733 protected: 734 bool equal_to(const AbstractTensor &other) const; 735 bool is_adapter_ = false; 736 }; 737 using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; 738 using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; 739 740 /// \brief Class AbstractAny describes a type, whose shape and value is unknown. 741 /// 742 /// AbstractAny is even not a Tensor type, but any type. 743 class MS_CORE_API AbstractAny : public AbstractTensor { 744 public: 745 /// \brief Constructor of AbstractAny. 746 /// 747 /// \param[in] element The abstract to be wrapper as a abstract tensor. 748 /// \param[in] shape The dimension of abstract tensor. 749 AbstractAny(); 750 751 /// \brief Destructor of AbstractAny. 752 ~AbstractAny() override = default; 753 MS_DECLARE_PARENT(AbstractAny, AbstractTensor) 754 755 AbstractBasePtr Join(const AbstractBasePtr &other) override; 756 757 AbstractBasePtr Broaden() const override; 758 759 AbstractBasePtr Clone() const override; 760 761 TypePtr BuildType() const override; 762 763 std::string ToString() const override; 764 765 bool supposed_tensor_dtype() const; 766 767 void set_supposed_tensor_dtype(bool flag); 768 769 static TypePtr DefaultDtype(); 770 771 private: 772 bool supposed_tensor_dtype_{false}; 773 }; 774 using AbstractAnyPtr = std::shared_ptr<AbstractAny>; 775 using AbstractAnyPtrList = std::vector<AbstractAnyPtr>; 776 777 /// \brief Class AbstractNegligible describes a type, whose shape and value is unknown, 778 /// and should choose other branch in control flow. 779 /// 780 /// AbstractNegligible is even not a Tensor type, but any type. 781 class MS_CORE_API AbstractNegligible : public AbstractAny { 782 public: 783 /// \brief Constructor of AbstractNegligible. 784 /// 785 /// \param[in] element The abstract to be wrapper as a abstract tensor. 786 /// \param[in] shape The dimension of abstract tensor. AbstractNegligible()787 AbstractNegligible() : AbstractAny() {} 788 789 /// \brief Destructor of AbstractNegligible. 790 ~AbstractNegligible() override = default; 791 MS_DECLARE_PARENT(AbstractNegligible, AbstractAny) 792 793 AbstractBasePtr Join(const AbstractBasePtr &other) override; 794 795 AbstractBasePtr Broaden() const override; 796 797 AbstractBasePtr Clone() const override; 798 799 TypePtr BuildType() const override; 800 801 std::string ToString() const override; 802 }; 803 using AbstractNegligiblePtr = std::shared_ptr<AbstractNegligible>; 804 using AbstractNegligiblePtrList = std::vector<AbstractNegligiblePtr>; 805 806 /// \brief Class AbstractJoinedAny describes a type, whose shape and value is unknown. 807 /// 808 /// AbstractJoinedAny is even not a Tensor type, but any type. 809 class MS_CORE_API AbstractJoinedAny : public AbstractAny { 810 public: 811 /// \brief Constructor of AbstractJoinedAny. AbstractJoinedAny()812 AbstractJoinedAny() : AbstractAny() {} 813 814 /// \brief Destructor of AbstractJoinedAny. 815 ~AbstractJoinedAny() override = default; 816 MS_DECLARE_PARENT(AbstractJoinedAny, AbstractAny) 817 818 enum ExceptionType { 819 kDefault, 820 kTypeError, 821 kValueError, 822 }; 823 824 const std::string &message() const; 825 void set_message(const std::string &message); 826 ExceptionType exception() const; 827 void set_exception(ExceptionType exception); 828 829 void ThrowException() const; 830 831 private: 832 std::string message_; 833 ExceptionType exception_{kDefault}; 834 }; 835 using AbstractJoinedAnyPtr = std::shared_ptr<AbstractJoinedAny>; 836 using AbstractJoinedAnyPtrList = std::vector<AbstractJoinedAnyPtr>; 837 838 /// \brief Class AbstractSequence describes the abstract value of a tuple or list. 839 class MS_CORE_API AbstractSequence : public AbstractBase { 840 public: 841 /// \brief Constructor of AbstractSequence. 842 /// 843 /// \param[in] elements A list of abstracts. 844 /// \param[in] sequence_nodes The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes. 845 explicit AbstractSequence(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes); 846 847 /// \brief Constructor of AbstractSequence. 848 /// 849 /// \param[in] elements A list of abstracts. 850 /// \param[in] sequence_nodes The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes. 851 explicit AbstractSequence(const AbstractBasePtrList &elements, 852 const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes); 853 854 /// \brief Destructor of AbstractSequence. 855 ~AbstractSequence() override = default; 856 MS_DECLARE_PARENT(AbstractSequence, AbstractBase) 857 858 /// \brief Get the all of types. 859 /// 860 /// \return A vector of types. 861 TypePtrList ElementsType() const; 862 863 /// \brief Get the all of shapes. 864 /// 865 /// \return A vector of shapes. 866 BaseShapePtrList ElementsShape() const; 867 868 /// \brief Clone all of the abstracts. 869 /// 870 /// \return A vector of the cloned abstracts. 871 AbstractBasePtrList ElementsClone() const; 872 873 /// \brief Broaden the list of abstracts. 874 /// 875 /// \return A vector of the broadened abstracts. 876 AbstractBasePtrList ElementsBroaden() const; 877 878 /// \brief Broaden abstract with constraints, only when cond_func is true. 879 /// 880 /// \return A pointer to the broadened abstract. 881 AbstractBasePtrList ElementsPartialBroaden() const; 882 883 /// \brief Get real value by specific template. 884 /// 885 /// \tparam T the class type of value. 886 /// \return A point to value. 887 template <typename T> 888 ValuePtr ElementsBuildValue() const; 889 890 /// \brief Combine other abstract to the sequence of abstracts. 891 /// 892 /// \tparam T param other's class type. 893 /// \param[in] other The other abstract to be joined. 894 /// \return A pointer to the combined abstract. 895 template <typename T> 896 AbstractBasePtr ElementsJoin(const std::shared_ptr<AbstractSequence> &other); 897 898 /// \brief Combine other sequence nodes with this one. 899 /// 900 /// \param[in] other The other abstract to be joined. 901 /// \return A sequence nodes list combined. 902 AnfNodeWeakPtrList SequenceNodesJoin(const AbstractBasePtr &other); 903 904 /// \brief Get the size of the stored elements. 905 /// 906 /// \return A size_t. 907 std::size_t size() const; 908 909 /// \brief Get the size of the stored elements. 910 /// 911 /// \return A size_t. 912 bool empty() const; 913 914 /// \brief Get the stored elements. 915 /// 916 /// \return A vector of elements. 917 const AbstractBasePtrList &elements() const; 918 919 /// \brief Purify the elements list, and clean unused elements. 920 /// 921 /// \return A boolean, which indicates whether success. 922 bool PurifyElements(); 923 924 /// \brief Get the sequence nodes where these 'AbstractSequence' evaluated from. 925 /// 926 /// \return The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes. 927 const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes() const; 928 929 /// \brief Set the sequence nodes where these 'AbstractSequence' evaluated from. 930 /// 931 /// \param[in] sequence_nodes The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes. 932 void set_sequence_nodes(const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes); 933 934 /// \brief Insert a node into the sequence nodes. 935 /// 936 /// \param[in] sequence_node The node to intert into sequence nodes. 937 void InsertSequenceNode(const AnfNodePtr &sequence_node); 938 939 /// \brief Insert nodes into the sequence nodes. 940 /// 941 /// \param[in] sequence_nodes The nodes to intert into sequence nodes. 942 void InsertSequenceNodes(const AnfNodeWeakPtrList &sequence_nodes); 943 944 /// \brief Update the sequence nodes. 945 /// 946 /// \param[in] old_sequence_node The old node in sequence nodes. 947 /// \param[in] new_sequence_node The new node to replace old node in sequence nodes. 948 void UpdateSequenceNode(const AnfNodePtr &old_sequence_node, const AnfNodePtr &new_sequence_node); 949 950 /// \brief Check whether all elements of the tuple are tensors. 951 /// 952 /// \return Whether all elements of the tuple are tensors. 953 bool ContainsAllBroadenTensors() const; 954 955 std::size_t hash() const override; 956 957 std::string ToStringInternal() const; 958 std::string ToString() const override; 959 std::string ToString(bool verbose) const override; 960 961 /// \brief Overwrite the operator '[]' to get an element. 962 /// 963 /// \param[in] dim The index. 964 /// \return A pointer to the abstract. 965 const AbstractBasePtr operator[](const std::size_t &dim) const; 966 967 /// \brief Overwrite the operator '==' to compare other abstract sequence. 968 /// 969 /// \param[in] other The other instance of AbstractSequence. 970 /// 971 /// \return A boolean, which indicates whether the other abstract is same. 972 bool operator==(const AbstractBase &other) const override; 973 974 /// \brief Indicate whether the sequence is dynamic length. 975 /// 976 /// \return Boolean value indicates whether the sequence is dynamic length. 977 bool dynamic_len() const; 978 979 /// \brief Set the sequence to be dynamic length or not. 980 /// 981 /// \param[in] dynamic_len Boolean value to decide whether the sequence is dynamic length. 982 void set_dynamic_len(bool dynamic_len); 983 984 /// \brief Return the abstract of element for variable len sequence. 985 /// 986 /// \return Abstract for element for variable len sequence. 987 AbstractBasePtr dynamic_len_element_abs() const; 988 989 /// \brief Set the abstract of element for variable len sequence. 990 /// 991 /// \param[in] dynamic_len_element_abs Abstract of element for variable len sequence. 992 void set_dynamic_len_element_abs(const AbstractBasePtr &dynamic_len_element_abs); 993 994 /// \brief Check and convert the sequence to dynamic length sequence. 995 virtual void CheckAndConvertToDynamicLenSequence(bool raise_exception = true); 996 997 std::shared_ptr<AbstractSequence> BroadenToDynamicLenSequence(); 998 999 std::shared_ptr<AbstractSequence> DynamicLenSequenceJoin(const std::shared_ptr<AbstractSequence> &other); 1000 1001 void set_dyn_len_arg(); 1002 bool dyn_len_arg() const; 1003 1004 protected: 1005 AbstractBasePtrList elements_; 1006 // Since there're not too many nodes, we just use vector here. 1007 std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes_; 1008 // Dynamic len sequence related. 1009 bool dynamic_len_ = false; 1010 size_t space_num_{0}; 1011 AbstractBasePtr dynamic_len_element_abs_ = nullptr; 1012 bool dyn_len_arg_ = false; 1013 }; 1014 using AbstractSequencePtr = std::shared_ptr<AbstractSequence>; 1015 1016 class MS_CORE_API ExtraInfoHolder { 1017 public: 1018 ~ExtraInfoHolder() = default; 1019 1020 /// \brief Set data to ExtraInfoHolder. 1021 /// 1022 /// \param[in] key The key for data in ExtraInfoHolder. 1023 /// \param[in] data The data to store in ExtraInfoHolder. 1024 template <typename T> SetData(const std::string & key,const std::shared_ptr<T> & data)1025 void SetData(const std::string &key, const std::shared_ptr<T> &data) { 1026 extra_info_->set<T>(key, data); 1027 } 1028 1029 /// \brief Get data from ExtraInfoHolder using key. 1030 /// 1031 /// \param[in] key The key for data in ExtraInfoHolder. 1032 /// \return The corresponding data. 1033 template <typename T> GetData(const std::string & key)1034 std::shared_ptr<T> GetData(const std::string &key) const { 1035 return extra_info_->get<T>(key); 1036 } 1037 1038 /// \brief Check whether ExtraInfoHolder has specific data. 1039 /// 1040 /// \param[in] key The key for data in ExtraInfoHolder. 1041 /// \return True if it exists, otherwise false. HasData(const std::string & key)1042 bool HasData(const std::string &key) const { return extra_info_->has(key); } 1043 1044 /// \brief Get corresponding extra info user data. 1045 /// 1046 /// \return The corresponding extra info user data. extra_info()1047 UserDataPtr extra_info() const { return extra_info_; } 1048 1049 /// \brief Set corresponding extra info user data. 1050 /// 1051 /// \param[in] extra_info The corresponding extra info user data. set_extra_info(const UserDataPtr & extra_info)1052 void set_extra_info(const UserDataPtr &extra_info) { extra_info_ = extra_info; } 1053 1054 /// \brief Clear corresponding extra info user data. ClearExtraInfo()1055 void ClearExtraInfo() { extra_info_ = std::make_shared<UserData>(); } 1056 1057 protected: 1058 UserDataPtr extra_info_ = std::make_shared<UserData>(); 1059 }; 1060 1061 /// \brief Class AbstractTuple describes a tuple. 1062 class MS_CORE_API AbstractTuple : public AbstractSequence, public ExtraInfoHolder { 1063 public: 1064 /// \brief Constructor of AbstractTuple. 1065 /// 1066 /// \param[in] elements A list of abstracts. 1067 /// \param[in] tuple_nodes The nodes of tuple, usually are MakeTuple CNodes or tuple ValueNodes. 1068 explicit AbstractTuple(AbstractBasePtrList &&elements, 1069 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1070 1071 /// \brief Constructor of AbstractTuple. 1072 /// 1073 /// \param[in] elements A list of abstracts. 1074 /// \param[in] tuple_nodes The nodes of tuple, usually are MakeTuple CNodes or tuple ValueNodes. 1075 explicit AbstractTuple(const AbstractBasePtrList &elements, 1076 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1077 1078 /// \brief Destructor of AbstractTuple. 1079 ~AbstractTuple() override = default; 1080 MS_DECLARE_PARENT(AbstractTuple, AbstractSequence) 1081 1082 /// \brief Set the shape for the AbstractTuple, only use for dynamic shape. 1083 /// 1084 /// \param[in] shape The shape that will be set to the AbstractTuple. 1085 void set_shape(const BaseShapePtr &shape) override; 1086 1087 TypePtr BuildType() const override; 1088 1089 BaseShapePtr BuildShape() const override; 1090 1091 AbstractBasePtr Clone() const override; 1092 1093 AbstractBasePtr Broaden() const override; 1094 1095 AbstractBasePtr PartialBroaden() const override; 1096 1097 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1098 1099 /// \brief Overwrite the operator '==' to compare other abstract tuple. 1100 /// 1101 /// \param[in] other The other instance of AbstractTuple. 1102 /// 1103 /// \return A boolean, which indicates whether the other abstract is same. 1104 bool operator==(const AbstractBase &other) const override; 1105 1106 protected: 1107 ValuePtr RealBuildValue() const override; 1108 }; 1109 using AbstractTuplePtr = std::shared_ptr<AbstractTuple>; 1110 1111 /// \brief Class AbstractList describes a list. 1112 class MS_CORE_API AbstractList final : public AbstractSequence, public ExtraInfoHolder { 1113 public: 1114 /// \brief Constructor of AbstractList. 1115 /// 1116 /// \param[in] elements A list of abstracts. 1117 /// \param[in] list_nodes The nodes of list, usually are MakeList CNodes or list ValueNodes. 1118 explicit AbstractList(AbstractBasePtrList &&elements, 1119 const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes = nullptr); 1120 1121 /// \brief Constructor of AbstractList. 1122 /// 1123 /// \param[in] elements A list of abstracts. 1124 /// \param[in] list_nodes The nodes of list, usually are MakeList CNodes or list ValueNodes. 1125 explicit AbstractList(const AbstractBasePtrList &elements, 1126 const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes = nullptr); 1127 1128 /// \brief Destructor of AbstractList. 1129 ~AbstractList() override = default; 1130 MS_DECLARE_PARENT(AbstractList, AbstractSequence) 1131 1132 TypePtr BuildType() const override; 1133 1134 BaseShapePtr BuildShape() const override; 1135 1136 AbstractBasePtr Clone() const override; 1137 1138 AbstractBasePtr Broaden() const override; 1139 1140 AbstractBasePtr PartialBroaden() const override; 1141 1142 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1143 1144 /// \brief Overwrite the operator '==' to compare other abstract list. 1145 /// 1146 /// \param[in] other The other instance of AbstractList. 1147 /// 1148 /// \return A boolean, which indicates whether the other abstract is same. 1149 bool operator==(const AbstractBase &other) const override; 1150 1151 /// \brief Check and convert the list to dynamic length list. 1152 void CheckAndConvertToDynamicLenSequence(bool raise_exception = true) override; 1153 1154 protected: 1155 ValuePtr RealBuildValue() const override; 1156 }; 1157 using AbstractListPtr = std::shared_ptr<AbstractList>; 1158 1159 /// \brief Class AbstractNamedTuple describes a namedtuple node's abstract value. 1160 class MS_CORE_API AbstractNamedTuple final : public AbstractTuple { 1161 public: 1162 /// \brief Constructor of AbstractNamedTuple. 1163 /// 1164 /// \param[in] name The name of a namedtuple. 1165 /// \param[in] values A List of data in namedtuple. 1166 /// \param[in] keys A list of label in namedtuple. AbstractNamedTuple(const std::string & sub_class_name,const AbstractBasePtrList & keys,const AbstractBasePtrList & values)1167 AbstractNamedTuple(const std::string &sub_class_name, const AbstractBasePtrList &keys, 1168 const AbstractBasePtrList &values) 1169 : AbstractTuple(values), sub_class_name_{sub_class_name}, keys_(keys) {} 1170 1171 /// \brief Destructor of AbstractNamedTuple. 1172 ~AbstractNamedTuple() override = default; MS_DECLARE_PARENT(AbstractNamedTuple,AbstractTuple)1173 MS_DECLARE_PARENT(AbstractNamedTuple, AbstractTuple) 1174 /// \brief Get the stored label. 1175 /// 1176 /// \return A vector of label. 1177 const AbstractBasePtrList &key() const { return keys_; } 1178 /// \brief Get the name of namedtuple object. 1179 /// 1180 /// \return A string of namedtuple's type name. sub_class_name()1181 const std::string &sub_class_name() const { return sub_class_name_; } 1182 1183 protected: 1184 ValuePtr RealBuildValue() const override; 1185 1186 private: 1187 std::string sub_class_name_; 1188 AbstractBasePtrList keys_; 1189 }; 1190 using AbstractNamedTuplePtr = std::shared_ptr<AbstractNamedTuple>; 1191 1192 /// \brief Class AbstractDictionary describes a dictionary node's abstract value. 1193 class MS_CORE_API AbstractDictionary final : public AbstractBase, public ExtraInfoHolder { 1194 public: 1195 /// \brief Constructor of AbstractDictionary. 1196 /// 1197 /// \param[in] key_values The vector of AbstractElementPair. 1198 explicit AbstractDictionary(const std::vector<AbstractElementPair> &key_values); 1199 1200 /// \brief Destructor of AbstractDictionary. 1201 ~AbstractDictionary() override = default; 1202 MS_DECLARE_PARENT(AbstractDictionary, AbstractBase) 1203 1204 TypePtr BuildType() const override; 1205 1206 bool operator==(const AbstractBase &other) const override; 1207 1208 AbstractBasePtr Clone() const override; 1209 1210 AbstractBasePtr Broaden() const override; 1211 1212 std::string ToString() const override; 1213 1214 std::size_t hash() const override; 1215 1216 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1217 1218 /// \brief Get the size of key values. 1219 /// 1220 /// \return A size_t. 1221 std::size_t size() const; 1222 1223 /// \brief Get the key values. 1224 /// 1225 /// \return A vector of AbstractElementPair. 1226 const std::vector<AbstractElementPair> &elements() const; 1227 1228 protected: 1229 ValuePtr RealBuildValue() const override; 1230 std::vector<AbstractElementPair> key_values_; 1231 }; 1232 using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>; 1233 1234 /// \brief Class AbstractSlice describes a slice node's abstract value. 1235 class MS_CORE_API AbstractSlice final : public AbstractBase { 1236 public: 1237 /// \brief Constructor of AbstractSlice. 1238 /// 1239 /// \param[in] start The start index of slice. 1240 /// \param[in] stop The stop index of slice. 1241 /// \param[in] step The step size of slice. 1242 AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step); 1243 1244 /// \brief Destructor of AbstractSlice. 1245 ~AbstractSlice() override = default; 1246 MS_DECLARE_PARENT(AbstractSlice, AbstractBase) 1247 1248 TypePtr BuildType() const override; 1249 1250 /// \brief Overwrite the operator '==' to compare other abstract lice. 1251 /// 1252 /// \param[in] other The other instance of AbstractSlice. 1253 /// 1254 /// \return A boolean, which indicates whether the other abstract is same. 1255 bool operator==(const AbstractBase &other) const override; 1256 1257 AbstractBasePtr Clone() const override; 1258 1259 AbstractBasePtr Broaden() const override; 1260 1261 std::string ToString() const override; 1262 1263 std::size_t hash() const override; 1264 1265 /// \brief Get the start index of slice. 1266 /// 1267 /// \return A point to the abstract of start index. 1268 AbstractBasePtr start() const; 1269 1270 /// \brief Get the stop index of slice. 1271 /// 1272 /// \return A point to the abstract of stop index. 1273 AbstractBasePtr stop() const; 1274 1275 /// \brief Get the step size of slice. 1276 /// 1277 /// \return A point to the abstract of step number. 1278 AbstractBasePtr step() const; 1279 1280 protected: 1281 ValuePtr RealBuildValue() const override; 1282 1283 private: 1284 AbstractBasePtr start_; 1285 AbstractBasePtr stop_; 1286 AbstractBasePtr step_; 1287 }; 1288 using AbstractSlicePtr = std::shared_ptr<AbstractSlice>; 1289 1290 /// \brief Class AbstractJTagged describes a J node's abstract value. 1291 class MS_CORE_API AbstractJTagged final : public AbstractBase { 1292 public: 1293 /// \brief Constructor of AbstractJTagged. 1294 /// 1295 /// \param[in] element The value to be processed. 1296 explicit AbstractJTagged(const AbstractBasePtr &element); 1297 1298 /// \brief Destructor of AbstractJTagged. 1299 ~AbstractJTagged() override = default; 1300 MS_DECLARE_PARENT(AbstractJTagged, AbstractBase) 1301 1302 TypePtr BuildType() const override; 1303 1304 AbstractBasePtr Clone() const override; 1305 1306 AbstractBasePtr Broaden() const override; 1307 1308 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1309 1310 /// \brief Overwrite the operator '==' to compare other AbstractJTagged. 1311 /// 1312 /// \param[in] other The other abstract to be joined. 1313 /// 1314 /// \return A boolean, which indicates whether the other abstract is same. 1315 bool operator==(const AbstractBase &other) const override; 1316 1317 std::string ToString() const override; 1318 1319 /// \brief Get the element. 1320 /// 1321 /// \return A pointer to a abstract, which is the element_. 1322 AbstractBasePtr element(); 1323 1324 std::size_t hash() const override; 1325 1326 private: 1327 AbstractBasePtr element_; 1328 }; 1329 using AbstractJTaggedPtr = std::shared_ptr<AbstractJTagged>; 1330 1331 /// \brief Class AbstractNone describes a None node's abstract value. 1332 class MS_CORE_API AbstractNone final : public AbstractBase { 1333 public: 1334 /// \brief Constructor of AbstractNone. 1335 AbstractNone(); 1336 1337 /// \brief Destructor of AbstractNone. 1338 ~AbstractNone() override = default; 1339 MS_DECLARE_PARENT(AbstractNone, AbstractBase) 1340 1341 TypePtr BuildType() const override; 1342 1343 /// \brief Overwrite the operator '==' to compare other AbstractNone. 1344 /// 1345 /// \param[in] other The other instance of AbstractNone. 1346 /// 1347 /// \return A boolean, which indicates whether the other abstract is same. 1348 bool operator==(const AbstractBase &other) const override; 1349 1350 AbstractBasePtr Clone() const override; 1351 1352 std::string ToString() const override; 1353 1354 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1355 1356 protected: 1357 ValuePtr RealBuildValue() const override; 1358 }; 1359 using AbstractNonePtr = std::shared_ptr<AbstractNone>; 1360 1361 /// \brief Class AbstractNull describes a Null node's abstract value. 1362 /// 1363 /// The unassigned state value for variable, 1364 /// which means the variable is not assigned. 1365 class MS_CORE_API AbstractNull final : public AbstractBase { 1366 public: 1367 /// \brief Constructor of AbstractNull. 1368 AbstractNull(); 1369 1370 /// \brief Destructor of AbstractNull. 1371 ~AbstractNull() override = default; 1372 MS_DECLARE_PARENT(AbstractNull, AbstractBase) 1373 1374 TypePtr BuildType() const override; 1375 1376 /// \brief Overwrite the operator '==' to compare other AbstractNull. 1377 /// 1378 /// \param[in] other The other instance of AbstractNull. 1379 /// 1380 /// \return A boolean, which indicates whether the other abstract is same. 1381 bool operator==(const AbstractBase &other) const override; 1382 1383 AbstractBasePtr Clone() const override; 1384 1385 std::string ToString() const override; 1386 }; 1387 using AbstractNullPtr = std::shared_ptr<AbstractNull>; 1388 1389 /// \brief Class AbstractTimeOut describes a TimeOut node's abstract value. 1390 /// 1391 /// The timeout state value for variable, which means 1392 /// the variable is not assigned because it is timeout. 1393 class MS_CORE_API AbstractTimeOut final : public AbstractBase { 1394 public: 1395 /// \brief Constructor of AbstractTimeOut. 1396 AbstractTimeOut(); 1397 1398 /// \brief Destructor of AbstractTimeOut. 1399 ~AbstractTimeOut() override = default; 1400 MS_DECLARE_PARENT(AbstractTimeOut, AbstractBase) 1401 1402 TypePtr BuildType() const override; 1403 1404 /// \brief Overwrite the operator '==' to compare other AbstractTimeOut. 1405 /// 1406 /// \param[in] other The other instance of AbstractTimeOut. 1407 /// 1408 /// \return A boolean, which indicates whether the other abstract is same. 1409 bool operator==(const AbstractBase &other) const override; 1410 1411 AbstractBasePtr Clone() const override; 1412 1413 std::string ToString() const override; 1414 }; 1415 using AbstractTimeOutPtr = std::shared_ptr<AbstractTimeOut>; 1416 1417 /// \brief Class AbstractEllipsis describes a Ellipsis node's abstract value. 1418 class MS_CORE_API AbstractEllipsis final : public AbstractBase { 1419 public: 1420 /// \brief Constructor of AbstractEllipsis. 1421 AbstractEllipsis(); 1422 1423 /// \brief Destructor of AbstractEllipsis. 1424 ~AbstractEllipsis() override = default; 1425 MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) 1426 1427 TypePtr BuildType() const override; 1428 1429 /// \brief Overwrite the operator '==' to compare other AbstractEllipsis. 1430 /// 1431 /// \param[in] other The other instance of AbstractTimeOut. 1432 /// 1433 /// \return A boolean, which indicates whether the other abstract is same. 1434 bool operator==(const AbstractBase &other) const override; 1435 1436 AbstractBasePtr Clone() const override; 1437 1438 std::string ToString() const override; 1439 }; 1440 using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>; 1441 1442 /// \brief Class AbstractRefTensor describes a RefTensor's abstract value. 1443 class MS_CORE_API AbstractRefTensor final : public AbstractTensor { 1444 public: 1445 /// \brief Constructor of AbstractRef. 1446 /// 1447 /// \param[in] ref_value The tensor. 1448 /// \param[in] ref_key_value The ref key of tensor. 1449 AbstractRefTensor(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value); 1450 1451 /// \brief Destructor of AbstractEllipsis. 1452 ~AbstractRefTensor() override = default; 1453 MS_DECLARE_PARENT(AbstractRefTensor, AbstractTensor) 1454 1455 TypePtr BuildType() const override; 1456 1457 /// \brief Overwrite the operator '==' to compare other AbstractRef. 1458 /// 1459 /// \param[in] other The other instance of AbstractTimeOut. 1460 /// 1461 /// \return A boolean, which indicates whether the other abstract is same. 1462 bool operator==(const AbstractBase &other) const override; 1463 1464 AbstractBasePtr Clone() const override; 1465 1466 /// \brief Use parent's AbstractTensor::Clone() to clone an abstract. 1467 /// 1468 /// \return A pointer to the cloned abstract. 1469 AbstractBasePtr CloneAsTensor() const; 1470 1471 std::string ToString() const override; 1472 1473 /// \brief Get the abstract tensor, which is referenced. 1474 /// 1475 /// \return A pointer to the abstract tensor. 1476 AbstractTensorPtr ref(); 1477 1478 /// \brief Get the ref key value, ref key is string actually. 1479 /// 1480 /// \return A point to the RefKey. 1481 ValuePtr ref_key_value() const; 1482 1483 AbstractBasePtr Broaden() const override; 1484 1485 virtual AbstractBasePtr Join(const std::shared_ptr<AbstractRefTensor> &other); 1486 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1487 1488 AbstractBasePtr PartialBroaden() const override; 1489 1490 private: 1491 // ref_key_value is the reference key of AbstractRef, the value can be a string value or kValueAny 1492 ValuePtr ref_key_value_; 1493 }; 1494 using AbstractRefPtr = std::shared_ptr<AbstractRefTensor>; 1495 1496 /// \brief Compute the hash of a list of abstracts. 1497 /// 1498 /// \param[in] args_abs_list A list of abstracts. 1499 /// \return A hash number. 1500 MS_CORE_API std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_abs_list); 1501 1502 /// \brief Determine whether a list of abstracts is equal to another. 1503 /// 1504 /// \param[in] lhs The first list of abstracts. 1505 /// \param[in] rhs The second list of abstracts. 1506 /// \return A boolean. 1507 MS_CORE_API bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); 1508 1509 /// \brief Struct AbstractBasePtrListHasher provides a function to compute the hash of a list of abstracts. 1510 struct AbstractBasePtrListHasher { operatorAbstractBasePtrListHasher1511 std::size_t operator()(const AbstractBasePtrList &args_abs_list) const { 1512 return AbstractBasePtrListHash(args_abs_list); 1513 } 1514 }; 1515 1516 /// \brief Struct AbstractBasePtrListEqual provides a function to determine whether a list of abstracts is equal to 1517 /// another. 1518 struct AbstractBasePtrListEqual { operatorAbstractBasePtrListEqual1519 bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { 1520 return AbstractBasePtrListDeepEqual(lhs, rhs); 1521 } 1522 }; 1523 1524 class MS_CORE_API AbstractSparseTensor : public AbstractTuple { 1525 public: 1526 explicit AbstractSparseTensor(AbstractBasePtrList &&elements, 1527 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1528 1529 explicit AbstractSparseTensor(const AbstractBasePtrList &elements, 1530 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1531 1532 ~AbstractSparseTensor() override = default; 1533 MS_DECLARE_PARENT(AbstractSparseTensor, AbstractTuple) 1534 1535 template <typename T> 1536 const T GetAbsPtrAt(size_t index) const; 1537 /// \brief If any element is a tuple, get every element shape in it. 1538 BaseShapePtrList ElementsShapeTupleRecursive() const; 1539 TypePtr BuildType() const override; 1540 BaseShapePtr BuildShape() const override; 1541 1542 /// \brief Return the TypeId of a Tensor element in SparseTensor. 1543 /// 1544 /// \param[in] index The index of element to choose. 1545 /// \return A TypeId. 1546 const TypeId GetTensorTypeIdAt(size_t index) const; 1547 1548 /// \brief Return the TypeId of a shape element in SparseTensor. Note that each element in shape will be transformed 1549 /// to Tensor(scalar) in the backend. 1550 /// \param[in] index The index of element to choose. 1551 /// \return A TypeId. 1552 const TypeId GetShapeTypeIdAt(size_t index) const; 1553 1554 const AbstractTuplePtr shape() const; 1555 }; 1556 using AbstractSparseTensorPtr = std::shared_ptr<AbstractSparseTensor>; 1557 1558 /// \brief Class AbstractRowTensor describes a RowTensor's abstract value. 1559 class MS_CORE_API AbstractRowTensor final : public AbstractUndetermined { 1560 public: 1561 /// \brief Constructor of AbstractRowTensor. 1562 /// 1563 /// \param[in] element The abstract which is wrapped to a RowTensor's abstract value. 1564 /// \param[in] shape A dimension of the abstract. 1565 explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()); 1566 1567 /// \brief Constructor of AbstractRowTensor. 1568 /// 1569 /// \param[in] element_type The type of RowTensor. 1570 /// \param[in] shape A dimension of RowTensor. 1571 AbstractRowTensor(const TypePtr &element_type, const ShapeVector &shape); 1572 1573 /// \brief Destructor of AbstractRowTensor. 1574 ~AbstractRowTensor() override = default; 1575 MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined) 1576 1577 /// \brief Get the indices of RowTensor. 1578 /// 1579 /// \return A pointer to the abstract tensor. 1580 const AbstractTensorPtr indices() const; 1581 1582 /// \brief Set the indices for abstract. 1583 /// 1584 /// \param[in] indices The indices. 1585 void set_indices(const AbstractTensorPtr &indices); 1586 1587 /// \brief Get the values. 1588 /// 1589 /// \return A pointer to the abstract tensor. 1590 const AbstractTensorPtr values() const; 1591 1592 /// \brief Set the values. 1593 /// 1594 /// \param[in] values The values of tensor. 1595 void set_values(const AbstractTensorPtr &values); 1596 1597 /// \brief Get the dense shape. 1598 /// 1599 /// \return A pointer to the tuple of abstracts. 1600 const AbstractTuplePtr dense_shape() const; 1601 1602 /// \brief Set the dense shape. 1603 /// 1604 /// \param[in] dense_shape The dense shape of RowTensor. 1605 void set_dense_shape(const AbstractTuplePtr &dense_shape); 1606 1607 TypePtr BuildType() const override; 1608 1609 AbstractBasePtr Clone() const override; 1610 1611 AbstractBasePtr Broaden() const override; 1612 1613 /// \brief Broaden the abstract with the shape not changing. 1614 /// 1615 /// \return A pointer to the broadened abstract. 1616 AbstractBasePtr BroadenWithShape() const; 1617 1618 std::string ToString() const override; 1619 1620 private: 1621 std::shared_ptr<AbstractRowTensor> MakeAbstract(const BaseShapePtr &shp) const; 1622 AbstractTensorPtr indices_; 1623 AbstractTensorPtr values_; 1624 AbstractTuplePtr dense_shape_; 1625 }; 1626 using AbstractRowTensorPtr = std::shared_ptr<AbstractRowTensor>; 1627 1628 // COOTensor is a Tuple with fixed number of elements and specific meaning of each position. 1629 class MS_CORE_API AbstractCOOTensor : public AbstractSparseTensor { 1630 public: 1631 explicit AbstractCOOTensor(AbstractBasePtrList &&elements, 1632 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1633 1634 explicit AbstractCOOTensor(const AbstractBasePtrList &elements, 1635 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1636 1637 ~AbstractCOOTensor() override = default; 1638 MS_DECLARE_PARENT(AbstractCOOTensor, AbstractSparseTensor) 1639 1640 const AbstractTensorPtr indices() const; 1641 const AbstractTensorPtr values() const; 1642 1643 TypePtr BuildType() const override; 1644 AbstractBasePtr Clone() const override; 1645 AbstractBasePtr Broaden() const override; 1646 AbstractBasePtr PartialBroaden() const override; 1647 std::string ToString() const override; 1648 1649 static constexpr size_t kIndicesIdx = 0; 1650 static constexpr size_t kValuesIdx = 1; 1651 }; 1652 using AbstractCOOTensorPtr = std::shared_ptr<AbstractCOOTensor>; 1653 1654 // CSRTensor is a Tuple with fixed number of elements and specific meaning of each position. 1655 class MS_CORE_API AbstractCSRTensor : public AbstractSparseTensor { 1656 public: 1657 explicit AbstractCSRTensor(AbstractBasePtrList &&elements, 1658 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1659 1660 explicit AbstractCSRTensor(const AbstractBasePtrList &elements, 1661 const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr); 1662 1663 ~AbstractCSRTensor() override = default; 1664 MS_DECLARE_PARENT(AbstractCSRTensor, AbstractSparseTensor) 1665 1666 const AbstractTensorPtr indptr() const; 1667 const AbstractTensorPtr indices() const; 1668 const AbstractTensorPtr values() const; 1669 1670 TypePtr BuildType() const override; 1671 AbstractBasePtr Clone() const override; 1672 AbstractBasePtr Broaden() const override; 1673 AbstractBasePtr PartialBroaden() const override; 1674 std::string ToString() const override; 1675 1676 static constexpr size_t kIndptrIdx = 0; 1677 static constexpr size_t kIndicesIdx = 1; 1678 static constexpr size_t kValuesIdx = 2; 1679 }; 1680 using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>; 1681 1682 class MS_CORE_API AbstractMonad : public AbstractBase { 1683 public: 1684 ~AbstractMonad() override = default; 1685 MS_DECLARE_PARENT(AbstractMonad, AbstractBase) 1686 1687 std::size_t hash() const override; 1688 TypePtr BuildType() const override; 1689 AbstractBasePtr Broaden() const override; 1690 AbstractBasePtr Join(const AbstractBasePtr &other) override = 0; 1691 std::string ToString() const override; 1692 1693 protected: 1694 AbstractMonad(const ValuePtr &value, const TypePtr &type); 1695 }; 1696 using AbstractMonadPtr = std::shared_ptr<AbstractMonad>; 1697 1698 class MS_CORE_API AbstractUMonad final : public AbstractMonad { 1699 public: 1700 explicit AbstractUMonad(const ValuePtr &value = kUMonad); 1701 ~AbstractUMonad() override = default; 1702 MS_DECLARE_PARENT(AbstractUMonad, AbstractMonad) 1703 1704 AbstractBasePtr Clone() const override; 1705 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1706 bool operator==(const AbstractBase &other) const override; 1707 }; 1708 using AbstractUMonadPtr = std::shared_ptr<AbstractUMonad>; 1709 1710 class MS_CORE_API AbstractIOMonad final : public AbstractMonad { 1711 public: 1712 explicit AbstractIOMonad(const ValuePtr &value = kIOMonad); 1713 ~AbstractIOMonad() override = default; 1714 MS_DECLARE_PARENT(AbstractIOMonad, AbstractMonad) 1715 1716 AbstractBasePtr Clone() const override; 1717 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1718 bool operator==(const AbstractBase &other) const override; 1719 }; 1720 using AbstractIOMonadPtr = std::shared_ptr<AbstractIOMonad>; 1721 using tensor::MapTensorPtr; 1722 /// \brief Class AbstractMapTensor describes a MapTensor's abstract value. 1723 class MS_CORE_API AbstractMapTensor final : public AbstractBase { 1724 public: 1725 explicit AbstractMapTensor(const MapTensorPtr &map_tensor); 1726 AbstractMapTensor(const MapTensorPtr &map_tensor, const ValuePtr &ref_key_value); 1727 AbstractMapTensor(const AbstractMapTensor &other); 1728 AbstractMapTensor(const TypePtr &type, const ShapePtr &value_shape, const ValuePtr &value, 1729 const ValuePtr &ref_key_value, const ValuePtr &default_value, const ValuePtr &permit_filter_value, 1730 const ValuePtr &evict_filter_value); 1731 ~AbstractMapTensor() override = default; 1732 1733 MS_DECLARE_PARENT(AbstractMapTensor, AbstractBase) 1734 1735 AbstractMapTensor &operator=(const AbstractMapTensor &other); 1736 1737 MapTensorTypePtr map_tensor_type() const; 1738 ShapePtr shape() const; 1739 const ShapePtr &value_shape() const; 1740 const ValuePtr &ref_key_value() const; 1741 const ValuePtr &default_value() const; 1742 const ValuePtr &permit_filter_value() const; 1743 const ValuePtr &evict_filter_value() const; 1744 TypePtr BuildType() const override; 1745 BaseShapePtr BuildShape() const override; 1746 1747 AbstractBasePtr Clone() const override; 1748 AbstractBasePtr Join(const AbstractBasePtr &other) override; 1749 bool operator==(const AbstractBase &other) const override; 1750 std::size_t hash() const override; 1751 std::string ToString() const override; 1752 1753 private: 1754 // The reference key value, can be a string value or kValueAny. 1755 ValuePtr ref_key_value_; 1756 // The default value, a scalar or string with initializer name. 1757 ValuePtr default_value_; 1758 // Permission threshold. 1759 ValuePtr permit_filter_value_; 1760 // Remove threshold. 1761 ValuePtr evict_filter_value_; 1762 // The value shape. 1763 ShapePtr value_shape_; 1764 }; 1765 using AbstractMapTensorPtr = std::shared_ptr<AbstractMapTensor>; 1766 1767 // Define attribute value map 1768 using AttrValueMap = mindspore::HashMap<std::string, ValuePtr>; 1769 using AttrValueMapPtr = std::shared_ptr<AttrValueMap>; 1770 1771 // The class to save evaluated result: abstract value and modified attribute 1772 class EvalResult : public Base { 1773 public: EvalResult(const AbstractBasePtr & abs,const AttrValueMapPtr & attr)1774 EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr) 1775 : abstract_(abs), attribute_(attr), has_side_effect_node_(false) {} 1776 ~EvalResult() override = default; 1777 MS_DECLARE_PARENT(EvalResult, Base); abstract()1778 const AbstractBasePtr &abstract() const { return abstract_; } attribute()1779 const AttrValueMapPtr &attribute() const { return attribute_; } has_side_effect_node()1780 bool has_side_effect_node() const { return has_side_effect_node_; } set_has_side_effect_node(bool has_side_effect_node)1781 void set_has_side_effect_node(bool has_side_effect_node) { has_side_effect_node_ = has_side_effect_node; } 1782 1783 private: 1784 AbstractBasePtr abstract_; 1785 // Attribute related to PrimEvaluator; 1786 AttrValueMapPtr attribute_; 1787 1788 bool has_side_effect_node_; 1789 }; 1790 using EvalResultPtr = std::shared_ptr<EvalResult>; 1791 1792 // Superclass for AnfNodeConfig and VirtualConfig. 1793 class Config : public Base { 1794 public: 1795 Config() = default; 1796 ~Config() override = default; 1797 MS_DECLARE_PARENT(Config, Base); 1798 virtual EvalResultPtr ObtainEvalResult() = 0; 1799 }; 1800 1801 // Config will be stored in AnalysisCache 1802 using ConfigPtr = std::shared_ptr<Config>; 1803 using ConfigPtrList = std::vector<ConfigPtr>; 1804 1805 MS_CORE_API std::string ExtractLoggingInfo(const std::string &info); 1806 MS_CORE_API void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence, 1807 const AbstractSequencePtr &rhs_sequence); 1808 MS_CORE_API ValuePtr GetRefKeyValue(const AbstractBasePtr &abs); 1809 } // namespace abstract 1810 } // namespace mindspore 1811 #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ 1812