• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_UTILS_SYMBOLIC_H_
20 #define MINDSPORE_CORE_UTILS_SYMBOLIC_H_
21 
22 #include <memory>
23 #include <algorithm>
24 #include <utility>
25 #include <string>
26 
27 #include "utils/hash_map.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "ir/anf.h"
30 #include "ir/func_graph.h"
31 #include "abstract/abstract_value.h"
32 
33 namespace mindspore {
34 class SymbolicKeyInstance : public Value {
35  public:
36   SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract, const int64_t index = -1)
node_(node)37       : node_(node), abstract_(abstract), index_(index) {}
38   ~SymbolicKeyInstance() override = default;
39   MS_DECLARE_PARENT(SymbolicKeyInstance, Value);
node()40   AnfNodePtr node() const { return node_; }
abstract()41   abstract::AbstractBasePtr abstract() const { return abstract_; }
42   bool operator==(const SymbolicKeyInstance &other) const {
43     return (*node_ == *other.node_) && (*abstract_ == *other.abstract_) && (index_ == other.index_);
44   }
45 
hash()46   std::size_t hash() const override {
47     auto hash_value = hash_combine(std::hash<AnfNodePtr>{}(node_), std::hash<int64_t>{}(index_));
48     return hash_value;
49   }
50 
51   friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<SymbolicKeyInstance> &inst) {
52     if (inst == nullptr) {
53       os << "[Key]["
54          << "Invalid symbolic key instance"
55          << "]";
56     } else {
57       os << "[Key][" << inst->node_->type_name() << "]" << inst->node_->ToString();
58     }
59     return os;
60   }
ToString()61   std::string ToString() const override {
62     std::ostringstream oss;
63     if (node_ == nullptr) {
64       return "Invalid node";
65     }
66     oss << "[Key][" << node_->type_name() + "]" << node_->ToString();
67     if (index_ != -1) {
68       oss << "[" << index_ << "]";
69     }
70     return oss.str();
71   }
72 
73   bool operator==(const Value &other) const override {
74     if (other.isa<SymbolicKeyInstance>()) {
75       auto &other_ = static_cast<const SymbolicKeyInstance &>(other);
76       return *this == other_;
77     } else {
78       return false;
79     }
80   }
ToAbstract()81   abstract::AbstractBasePtr ToAbstract() override {
82     return std::make_shared<abstract::AbstractScalar>(shared_from_base<SymbolicKeyInstance>(),
83                                                       std::make_shared<SymbolicKeyType>());
84   }
85 
86  private:
87   AnfNodePtr node_;
88   abstract::AbstractBasePtr abstract_;
89   // If the Value in EnvironGet/EnvironSet of one SymbolicKey is Tuple, that SymbolicKey will be split
90   // to multiple SymbolicKey, this index is used to discriminate those SymbolicKey derived from the same
91   // one.
92   int64_t index_{-1};
93 };
94 
95 using SymbolicKeyInstancePtr = std::shared_ptr<SymbolicKeyInstance>;
96 
97 struct SymbolicKeyInstanceHash {
operatorSymbolicKeyInstanceHash98   std::size_t operator()(const SymbolicKeyInstancePtr &s) const {
99     if (s == nullptr) {
100       return 0;
101     }
102     return s->hash();
103   }
104 };
105 
106 struct SymbolicKeyInstanceEqual {
operatorSymbolicKeyInstanceEqual107   bool operator()(const SymbolicKeyInstancePtr &lhs, const SymbolicKeyInstancePtr &rhs) const {
108     if (lhs == nullptr || rhs == nullptr) {
109       return false;
110     }
111     MS_EXCEPTION_IF_NULL(lhs->node());
112     MS_EXCEPTION_IF_NULL(rhs->node());
113     MS_EXCEPTION_IF_NULL(lhs->abstract());
114     MS_EXCEPTION_IF_NULL(rhs->abstract());
115     return *lhs == *rhs;
116   }
117 };
118 
NewEnviron(const FuncGraphPtr & fg)119 static inline AnfNodePtr NewEnviron(const FuncGraphPtr &fg) {
120   return fg->NewCNode({NewValueNode(prim::kPrimEnvironCreate)});
121 }
122 
IsNewEnvironNode(const AnfNodePtr & node)123 static inline bool IsNewEnvironNode(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimEnvironCreate); }
124 
MakeEnvironAbstract()125 static inline abstract::AbstractBasePtr MakeEnvironAbstract() {
126   return std::make_shared<abstract::AbstractScalar>(kValueAny, std::make_shared<EnvType>());
127 }
128 }  // namespace mindspore
129 
130 #endif  // MINDSPORE_CORE_UTILS_SYMBOLIC_H_
131