1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2020 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_UTILS_SYMBOLIC_H_ 20 #define MINDSPORE_CORE_UTILS_SYMBOLIC_H_ 21 22 #include <unordered_map> 23 #include <memory> 24 #include <algorithm> 25 #include <utility> 26 #include <string> 27 28 #include "ir/anf.h" 29 #include "abstract/abstract_value.h" 30 31 namespace mindspore { 32 class SymbolicKeyInstance : public Value { 33 public: SymbolicKeyInstance(const AnfNodePtr & node,const abstract::AbstractBasePtr & abstract)34 SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) 35 : node_(node), abstract_(abstract) {} 36 ~SymbolicKeyInstance() override = default; 37 MS_DECLARE_PARENT(SymbolicKeyInstance, Value); node()38 AnfNodePtr node() const { return node_; } abstract()39 abstract::AbstractBasePtr abstract() const { return abstract_; } 40 bool operator==(const SymbolicKeyInstance &other) const { 41 return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); 42 } 43 hash()44 std::size_t hash() const override { return std::hash<AnfNodePtr>{}(node_); } 45 friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<SymbolicKeyInstance> &inst) { 46 if (inst == nullptr) { 47 os << "[Key][" 48 << "Invalid symbolic key instance" 49 << "]"; 50 } else { 51 os << "[Key][" << inst->node_->type_name() << "]" << inst->node_->ToString(); 52 } 53 return os; 54 } ToString()55 std::string ToString() const override { 56 return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); 57 } 58 bool operator==(const Value &other) const override { 59 if (other.isa<SymbolicKeyInstance>()) { 60 auto other_ = static_cast<const SymbolicKeyInstance &>(other); 61 return *this == other_; 62 } else { 63 return false; 64 } 65 } ToAbstract()66 abstract::AbstractBasePtr ToAbstract() override { 67 return std::make_shared<abstract::AbstractScalar>(shared_from_base<SymbolicKeyInstance>(), 68 std::make_shared<SymbolicKeyType>()); 69 } 70 71 private: 72 AnfNodePtr node_; 73 abstract::AbstractBasePtr abstract_; 74 }; 75 76 using SymbolicKeyInstancePtr = std::shared_ptr<SymbolicKeyInstance>; 77 78 struct SymbolicKeyInstanceHash { operatorSymbolicKeyInstanceHash79 std::size_t operator()(const SymbolicKeyInstancePtr s) const { 80 if (s == nullptr) { 81 return 0; 82 } 83 return s->abstract()->hash(); 84 } 85 }; 86 87 struct SymbolicKeyInstanceEqual { operatorSymbolicKeyInstanceEqual88 bool operator()(const SymbolicKeyInstancePtr lhs, const SymbolicKeyInstancePtr rhs) const { 89 if (lhs == nullptr || rhs == nullptr) { 90 return false; 91 } 92 MS_EXCEPTION_IF_NULL(lhs->node()); 93 MS_EXCEPTION_IF_NULL(rhs->node()); 94 MS_EXCEPTION_IF_NULL(lhs->abstract()); 95 MS_EXCEPTION_IF_NULL(rhs->abstract()); 96 return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract()); 97 } 98 }; 99 100 using EnvInstanceContentsMap = 101 std::unordered_map<SymbolicKeyInstancePtr, Any, SymbolicKeyInstanceHash, SymbolicKeyInstanceEqual>; 102 103 // Environment mapping keys to values. 104 // Keys are SymbolicKeyInstances, which represent nodes in the graph along 105 // with inferred properties. 106 class EnvInstance : public Value { 107 public: 108 friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &env); 109 contents_(contents)110 explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} 111 ~EnvInstance() override = default; 112 MS_DECLARE_PARENT(EnvInstance, Value); ToAbstract()113 abstract::AbstractBasePtr ToAbstract() override { 114 return std::make_shared<abstract::AbstractScalar>(shared_from_base<EnvInstance>(), std::make_shared<EnvType>()); 115 } 116 bool operator==(const EnvInstance &other) const; 117 bool operator==(const Value &other) const override; EnvInstance(const EnvInstance & v)118 EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} 119 EnvInstance(EnvInstance &&v) = default; 120 EnvInstance &operator=(EnvInstance &&src) noexcept { 121 if (&src != this) { 122 contents_ = src.contents_; 123 } 124 return *this; 125 }; 126 127 // Get the sensitivity list for the given key Get(const SymbolicKeyInstancePtr & key,const Any & def)128 const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { 129 auto iterator = contents_.find(key); 130 if (iterator != contents_.end()) { 131 return iterator->second; 132 } 133 return def; 134 } 135 136 // Set a value for the given key. Set(const SymbolicKeyInstancePtr & key,const Any & value)137 EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { 138 EnvInstance rval(contents_); 139 rval.contents_[key] = value; 140 return rval; 141 } 142 143 // Add two EnvInstances. Add(const EnvInstance & other)144 EnvInstance Add(const EnvInstance &other) const { 145 EnvInstance rval(contents_); 146 for (auto iter_other : other.contents_) { 147 auto item_self = contents_.find(iter_other.first); 148 if (item_self != contents_.end()) { 149 MS_LOG(DEBUG) << "Need to use add"; 150 } else { 151 rval.contents_[iter_other.first] = iter_other.second; 152 } 153 } 154 return rval; 155 } 156 Len()157 size_t Len() const { return contents_.size(); } hash()158 std::size_t hash() const override { 159 // deterministic characteristic of member variables. 160 return Len(); 161 } 162 163 private: 164 EnvInstanceContentsMap contents_; 165 }; 166 167 using EnvInstancePtr = std::shared_ptr<EnvInstance>; 168 169 extern std::shared_ptr<EnvInstance> newenv; 170 } // namespace mindspore 171 172 #endif // MINDSPORE_CORE_UTILS_SYMBOLIC_H_ 173