1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ 20 #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ 21 22 #include <memory> 23 #include <string> 24 25 #include "abstract/abstract_value.h" 26 #include "abstract/analysis_context.h" 27 #include "ir/meta_func_graph.h" 28 29 namespace mindspore { 30 namespace abstract { 31 class MS_CORE_API AbstractFuncAtom : public AbstractFunction { 32 public: 33 AbstractFuncAtom() = default; 34 ~AbstractFuncAtom() override = default; MS_DECLARE_PARENT(AbstractFuncAtom,AbstractFunction)35 MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) 36 37 AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); } 38 AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; 39 void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final; 40 bool operator==(const AbstractFunction &other) const override; 41 hash()42 std::size_t hash() const override { return tid(); } 43 }; 44 45 class MS_CORE_API AbstractFuncUnion : public AbstractFunction { 46 public: 47 explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list); 48 AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second); 49 ~AbstractFuncUnion() override = default; 50 MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction) 51 52 std::string ToString() const override; 53 GetUnique()54 AbstractFunctionPtr GetUnique() override { 55 MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; 56 AbstractFunctionPtr result; 57 return result; 58 } 59 bool IsSuperSet(const AbstractFunctionPtr &other); 60 AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; 61 void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final; 62 bool operator==(const AbstractFunction &other) const override; 63 std::size_t hash() const override; Copy()64 AbstractFunctionPtr Copy() const override { 65 MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; 66 AbstractFunctionPtr result; 67 return result; 68 } 69 70 private: 71 AbstractFuncAtomPtrList func_list_; 72 }; 73 74 class MS_CORE_API PrimitiveAbstractClosure : public AbstractFuncAtom { 75 public: 76 // Represents a Primitive. 77 // prim: The primitive 78 // tracking_id: Identifies different uses of the same primitive. 79 explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr) prim_(prim)80 : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {} 81 ~PrimitiveAbstractClosure() override = default; MS_DECLARE_PARENT(PrimitiveAbstractClosure,AbstractFuncAtom)82 MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) 83 84 PrimitivePtr prim() { return prim_; } 85 tracking_id()86 AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } 87 set_tracking_id(AnfNodePtr node)88 void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); } 89 Copy()90 AbstractFunctionPtr Copy() const override { return std::make_shared<PrimitiveAbstractClosure>(prim_, tracking_id()); } 91 92 bool operator==(const AbstractFunction &other) const override; 93 std::size_t hash() const override; 94 ToString()95 std::string ToString() const override { return "Prim: " + prim_->name(); } 96 RealBuildValue()97 ValuePtr RealBuildValue() const override { return prim_; } 98 99 private: 100 PrimitivePtr prim_; 101 // store it as weak_ptr to break reference cycle. 102 // one reference cycle example is Graph::set_output() input0 local variable. 103 AnfNodeWeakPtr tracking_id_; 104 }; 105 using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>; 106 107 class MS_CORE_API FuncGraphAbstractClosure : public AbstractFuncAtom { 108 public: 109 // Represents a Graph in a certain Context. 110 // context: The context, or Context.empty() 111 FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, 112 const AnfNodePtr &tracking_id = nullptr) func_graph_(func_graph)113 : func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) { 114 MS_EXCEPTION_IF_NULL(func_graph); 115 MS_EXCEPTION_IF_NULL(context); 116 } 117 ~FuncGraphAbstractClosure() override = default; MS_DECLARE_PARENT(FuncGraphAbstractClosure,AbstractFuncAtom)118 MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) 119 120 FuncGraphPtr func_graph() { return func_graph_; } 121 context()122 AnalysisContextPtr context() const override { return context_; } 123 tracking_id()124 AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } 125 set_tracking_id(AnfNodePtr node)126 void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); } 127 Copy()128 AbstractFunctionPtr Copy() const override { 129 return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id()); 130 } 131 132 bool operator==(const AbstractFunction &other) const override; 133 std::size_t hash() const override; 134 135 std::string ToString() const override; 136 137 private: 138 FuncGraphPtr func_graph_; 139 AnalysisContextPtr context_; 140 // To discriminate different usage of same graph by using this tracking_id, 141 // so different tracking_id will produce different FuncGraphAbstractClosure, 142 // different FuncGraphEvaluator. 143 // Espcecially useful for recursive func graph call, so it will not mess up 144 // the `context_` in FuncGraphEvaluator. 145 // Notes: Be careful to use nullptr for this variable. 146 // store it as weak_ptr to break reference cycle. 147 AnfNodeWeakPtr tracking_id_; 148 }; 149 using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>; 150 151 class MS_CORE_API MetaFuncGraphAbstractClosure : public AbstractFuncAtom { 152 public: 153 explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, 154 const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope) meta_func_graph_(meta_func_graph)155 : meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {} 156 ~MetaFuncGraphAbstractClosure() override = default; MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure,AbstractFuncAtom)157 MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) 158 159 MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; } 160 context()161 AnalysisContextPtr context() const override { return kDummyAnalysisContext; } 162 GetScope()163 ScopePtr GetScope() { return scope_; } 164 tracking_id()165 AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } 166 Copy()167 AbstractFunctionPtr Copy() const override { 168 return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id()); 169 } 170 bool operator==(const AbstractFunction &other) const override; 171 std::size_t hash() const override; 172 173 std::string ToString() const override; 174 175 private: 176 MetaFuncGraphPtr meta_func_graph_; 177 // refer the comment in FuncGraphAbstractClosure; 178 // store it as weak_ptr to break reference cycle. 179 AnfNodeWeakPtr tracking_id_; 180 ScopePtr scope_; 181 }; 182 using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>; 183 184 class MS_CORE_API PartialAbstractClosure : public AbstractFuncAtom { 185 public: 186 // Represents a partial application. 187 // args_spec_list: The first few arguments of that function 188 PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, 189 const AnfNodePtr &node = nullptr) fn_(fn)190 : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} 191 ~PartialAbstractClosure() override = default; MS_DECLARE_PARENT(PartialAbstractClosure,AbstractFuncAtom)192 MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) 193 194 AbstractFunctionPtr fn() { return fn_; } args()195 const AbstractBasePtrList &args() { return args_spec_list_; } RealBuildValue()196 ValuePtr RealBuildValue() const override { return fn_->BuildValue(); } node()197 AnfNodePtr node() { return node_.lock(); } set_node(const AnfNodePtr & node)198 void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } Copy()199 AbstractFunctionPtr Copy() const override { 200 return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_, node_.lock()); 201 } 202 bool operator==(const AbstractFunction &other) const override; 203 std::size_t hash() const override; 204 205 std::string ToString() const override; 206 207 private: 208 AbstractFuncAtomPtr fn_; 209 AbstractBasePtrList args_spec_list_; 210 // The CNode which this PartialAbstractClosure evaluated from. 211 AnfNodeWeakPtr node_; 212 }; 213 using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>; 214 215 class MS_CORE_API JTransformedAbstractClosure : public AbstractFuncAtom { 216 public: 217 // Represents a Function transformed through the application of J. JTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)218 explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} 219 ~JTransformedAbstractClosure() override = default; MS_DECLARE_PARENT(JTransformedAbstractClosure,AbstractFuncAtom)220 MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) 221 222 AbstractFuncAtomPtr fn() { return fn_; } Copy()223 AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); } 224 bool operator==(const AbstractFunction &other) const override; 225 std::size_t hash() const override; 226 ToString()227 std::string ToString() const override { return "J(" + fn_->ToString() + ")"; } 228 229 private: 230 AbstractFuncAtomPtr fn_; 231 }; 232 233 class MS_CORE_API VirtualAbstractClosure : public AbstractFuncAtom { 234 public: 235 // Represents some function with an explicitly fixed type signature. 236 // args_spec_list: The arguments as abstract value given to the function 237 // output: The output which is abstract value. VirtualAbstractClosure(const AbstractBasePtrList & args_spec_list,const AbstractBasePtr & output_spec)238 VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec) 239 : args_spec_list_(args_spec_list), output_(output_spec) {} VirtualAbstractClosure(const AbstractBasePtr & args_spec,const AbstractBasePtr & output_spec)240 VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec) 241 : args_spec_list_({args_spec}), output_(output_spec) {} 242 ~VirtualAbstractClosure() override = default; MS_DECLARE_PARENT(VirtualAbstractClosure,AbstractFuncAtom)243 MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) 244 245 AbstractBasePtrList args_spec_list() { return args_spec_list_; } 246 output()247 AbstractBasePtr output() { return output_; } Copy()248 AbstractFunctionPtr Copy() const override { 249 return std::make_shared<VirtualAbstractClosure>(args_spec_list_, output_); 250 } 251 bool operator==(const AbstractFunction &other) const override; 252 std::size_t hash() const override; 253 254 std::string ToString() const override; 255 256 private: 257 AbstractBasePtrList args_spec_list_; 258 AbstractBasePtr output_; 259 }; 260 using VirtualAbstractClosurePtr = std::shared_ptr<VirtualAbstractClosure>; 261 262 class MS_CORE_API TypedPrimitiveAbstractClosure : public AbstractFuncAtom { 263 public: 264 // Represents a Primitive with an explicitly fixed type signature. 265 // args_spec_list: The arguments as abstract value given to the Primitive 266 // output: The output which is abstract value. TypedPrimitiveAbstractClosure(const PrimitivePtr prim,const AbstractBasePtrList & args_spec_list,const AbstractBasePtr & output_spec)267 TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list, 268 const AbstractBasePtr &output_spec) 269 : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {} 270 ~TypedPrimitiveAbstractClosure() override = default; MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure,AbstractFuncAtom)271 MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) 272 273 PrimitivePtr prim() { return prim_; } args_spec_list()274 AbstractBasePtrList args_spec_list() { return args_spec_list_; } output()275 AbstractBasePtr output() { return output_; } Copy()276 AbstractFunctionPtr Copy() const override { 277 return std::make_shared<TypedPrimitiveAbstractClosure>(prim_, args_spec_list_, output_); 278 } 279 bool operator==(const AbstractFunction &other) const override; 280 std::size_t hash() const override; 281 282 std::string ToString() const override; 283 284 private: 285 PrimitivePtr prim_; 286 AbstractBasePtrList args_spec_list_; 287 AbstractBasePtr output_; 288 }; 289 290 class PyInterpretAbstractClosure : public AbstractFuncAtom { 291 public: 292 PyInterpretAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, 293 const AnfNodePtr &node = nullptr) fn_(fn)294 : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} 295 ~PyInterpretAbstractClosure() override = default; MS_DECLARE_PARENT(PyInterpretAbstractClosure,AbstractFuncAtom)296 MS_DECLARE_PARENT(PyInterpretAbstractClosure, AbstractFuncAtom) 297 298 AbstractFunctionPtr fn() { return fn_; } args()299 AbstractBasePtrList args() { return args_spec_list_; } RealBuildValue()300 ValuePtr RealBuildValue() const override { return fn_->BuildValue(); } node()301 AnfNodePtr node() { return node_.lock(); } set_node(const AnfNodePtr & node)302 void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } Copy()303 AbstractFunctionPtr Copy() const override { 304 return std::make_shared<PyInterpretAbstractClosure>(fn_, args_spec_list_, node_.lock()); 305 } 306 bool operator==(const AbstractFunction &other) const override; 307 std::size_t hash() const override; 308 309 std::string ToString() const override; 310 311 private: 312 AbstractFuncAtomPtr fn_; 313 AbstractBasePtrList args_spec_list_; 314 AnfNodeWeakPtr node_; 315 }; 316 using PyInterpretAbstractClosurePtr = std::shared_ptr<PyInterpretAbstractClosure>; 317 318 // Represents a function that can't be called. 319 class MS_CORE_API DummyAbstractClosure : public AbstractFuncAtom { 320 public: 321 DummyAbstractClosure() = default; 322 ~DummyAbstractClosure() override = default; MS_DECLARE_PARENT(DummyAbstractClosure,AbstractFuncAtom)323 MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) 324 325 AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); } 326 bool operator==(const AbstractFunction &other) const override; 327 ToString()328 std::string ToString() const override { return "DummyAbstractClosure()"; } 329 }; 330 331 struct MS_CORE_API AbstractFunctionHasher { operatorAbstractFunctionHasher332 std::size_t operator()(const AbstractFunctionPtr &t) const { 333 std::size_t hash = t->hash(); 334 return hash; 335 } 336 }; 337 338 struct MS_CORE_API AbstractFunctionEqual { operatorAbstractFunctionEqual339 bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; } 340 }; 341 } // namespace abstract 342 } // namespace mindspore 343 #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ 344