• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_UTILS_SYMBOLIC_H_
20 #define MINDSPORE_CORE_UTILS_SYMBOLIC_H_
21 
22 #include <unordered_map>
23 #include <memory>
24 #include <algorithm>
25 #include <utility>
26 #include <string>
27 
28 #include "ir/anf.h"
29 #include "abstract/abstract_value.h"
30 
31 namespace mindspore {
32 class SymbolicKeyInstance : public Value {
33  public:
SymbolicKeyInstance(const AnfNodePtr & node,const abstract::AbstractBasePtr & abstract)34   SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract)
35       : node_(node), abstract_(abstract) {}
36   ~SymbolicKeyInstance() override = default;
37   MS_DECLARE_PARENT(SymbolicKeyInstance, Value);
node()38   AnfNodePtr node() const { return node_; }
abstract()39   abstract::AbstractBasePtr abstract() const { return abstract_; }
40   bool operator==(const SymbolicKeyInstance &other) const {
41     return (*node_ == *other.node_) && (*abstract_ == *other.abstract_);
42   }
43 
hash()44   std::size_t hash() const override { return std::hash<AnfNodePtr>{}(node_); }
45   friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<SymbolicKeyInstance> &inst) {
46     if (inst == nullptr) {
47       os << "[Key]["
48          << "Invalid symbolic key instance"
49          << "]";
50     } else {
51       os << "[Key][" << inst->node_->type_name() << "]" << inst->node_->ToString();
52     }
53     return os;
54   }
ToString()55   std::string ToString() const override {
56     return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString();
57   }
58   bool operator==(const Value &other) const override {
59     if (other.isa<SymbolicKeyInstance>()) {
60       auto other_ = static_cast<const SymbolicKeyInstance &>(other);
61       return *this == other_;
62     } else {
63       return false;
64     }
65   }
ToAbstract()66   abstract::AbstractBasePtr ToAbstract() override {
67     return std::make_shared<abstract::AbstractScalar>(shared_from_base<SymbolicKeyInstance>(),
68                                                       std::make_shared<SymbolicKeyType>());
69   }
70 
71  private:
72   AnfNodePtr node_;
73   abstract::AbstractBasePtr abstract_;
74 };
75 
76 using SymbolicKeyInstancePtr = std::shared_ptr<SymbolicKeyInstance>;
77 
78 struct SymbolicKeyInstanceHash {
operatorSymbolicKeyInstanceHash79   std::size_t operator()(const SymbolicKeyInstancePtr s) const {
80     if (s == nullptr) {
81       return 0;
82     }
83     return s->abstract()->hash();
84   }
85 };
86 
87 struct SymbolicKeyInstanceEqual {
operatorSymbolicKeyInstanceEqual88   bool operator()(const SymbolicKeyInstancePtr lhs, const SymbolicKeyInstancePtr rhs) const {
89     if (lhs == nullptr || rhs == nullptr) {
90       return false;
91     }
92     MS_EXCEPTION_IF_NULL(lhs->node());
93     MS_EXCEPTION_IF_NULL(rhs->node());
94     MS_EXCEPTION_IF_NULL(lhs->abstract());
95     MS_EXCEPTION_IF_NULL(rhs->abstract());
96     return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract());
97   }
98 };
99 
100 using EnvInstanceContentsMap =
101   std::unordered_map<SymbolicKeyInstancePtr, Any, SymbolicKeyInstanceHash, SymbolicKeyInstanceEqual>;
102 
103 // Environment mapping keys to values.
104 // Keys are SymbolicKeyInstances, which represent nodes in the graph along
105 // with inferred properties.
106 class EnvInstance : public Value {
107  public:
108   friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &env);
109 
contents_(contents)110   explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {}
111   ~EnvInstance() override = default;
112   MS_DECLARE_PARENT(EnvInstance, Value);
ToAbstract()113   abstract::AbstractBasePtr ToAbstract() override {
114     return std::make_shared<abstract::AbstractScalar>(shared_from_base<EnvInstance>(), std::make_shared<EnvType>());
115   }
116   bool operator==(const EnvInstance &other) const;
117   bool operator==(const Value &other) const override;
EnvInstance(const EnvInstance & v)118   EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {}
119   EnvInstance(EnvInstance &&v) = default;
120   EnvInstance &operator=(EnvInstance &&src) noexcept {
121     if (&src != this) {
122       contents_ = src.contents_;
123     }
124     return *this;
125   };
126 
127   // Get the sensitivity list for the given key
Get(const SymbolicKeyInstancePtr & key,const Any & def)128   const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const {
129     auto iterator = contents_.find(key);
130     if (iterator != contents_.end()) {
131       return iterator->second;
132     }
133     return def;
134   }
135 
136   // Set a value for the given key.
Set(const SymbolicKeyInstancePtr & key,const Any & value)137   EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const {
138     EnvInstance rval(contents_);
139     rval.contents_[key] = value;
140     return rval;
141   }
142 
143   // Add two EnvInstances.
Add(const EnvInstance & other)144   EnvInstance Add(const EnvInstance &other) const {
145     EnvInstance rval(contents_);
146     for (auto iter_other : other.contents_) {
147       auto item_self = contents_.find(iter_other.first);
148       if (item_self != contents_.end()) {
149         MS_LOG(DEBUG) << "Need to use add";
150       } else {
151         rval.contents_[iter_other.first] = iter_other.second;
152       }
153     }
154     return rval;
155   }
156 
Len()157   size_t Len() const { return contents_.size(); }
hash()158   std::size_t hash() const override {
159     // deterministic characteristic of member variables.
160     return Len();
161   }
162 
163  private:
164   EnvInstanceContentsMap contents_;
165 };
166 
167 using EnvInstancePtr = std::shared_ptr<EnvInstance>;
168 
169 extern std::shared_ptr<EnvInstance> newenv;
170 }  // namespace mindspore
171 
172 #endif  // MINDSPORE_CORE_UTILS_SYMBOLIC_H_
173