1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_VM_VMIMPL_H_ 20 #define MINDSPORE_CCSRC_VM_VMIMPL_H_ 21 22 #include <set> 23 #include <memory> 24 #include <vector> 25 26 #include "utils/hash_map.h" 27 #include "ir/anf.h" 28 #include "ir/manager.h" 29 #include "ir/tensor.h" 30 #include "pybind_api/ir/base_ref_py.h" 31 32 namespace mindspore { 33 namespace compile { 34 35 using AnfNodePtrList = std::vector<AnfNodePtr>; 36 using AnfNodePtrToBaseRefMap = mindspore::HashMap<AnfNodePtr, BaseRef>; 37 using AnfNodePtrToAnfNodePtrMap = mindspore::HashMap<AnfNodePtr, AnfNodePtr>; 38 39 using FuncGraphPtrToBaseRefMap = mindspore::HashMap<FuncGraphPtr, BaseRef>; 40 41 using TensorList = std::vector<tensor::TensorPtr>; 42 43 class Closure; 44 using ClosurePtr = std::shared_ptr<Closure>; 45 46 class VMFrame; 47 using VMFramePtr = std::shared_ptr<VMFrame>; 48 using VMFramePtrList = std::vector<VMFramePtr>; 49 50 class VM; 51 using VMPtr = std::shared_ptr<VM>; 52 53 class Partial; 54 using PartialPtr = std::shared_ptr<Partial>; 55 56 using RunFunc = std::function<VectorRef(const VectorRef &args)>; 57 using RunFuncPtr = std::shared_ptr<RunFunc>; 58 59 class VMImpl { 60 public: 61 virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0; 62 virtual ~VMImpl() = default; 63 }; 64 65 // An execution frame. 66 // This holds the state for an application of a graph. The nodes list 67 // must contain free variables of graphs encountered before the 68 // graph themselves. 69 // You can index a frame with a node to get its value in the context 70 // of this frame (if it has already been evaluated). 71 // Attributes: 72 // nodes: list of nodes remaining to execute 73 // values: Mapping of node to their values in this application 74 // closure: values for the closure if the current application is a closure 75 class VMFrame { 76 public: 77 VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure); 78 const BaseRef operator[](const AnfNodePtr &node); todo()79 const AnfNodePtrList &todo() const { return todo_; } 80 values()81 AnfNodePtrToBaseRefMap &values() { return values_; } 82 83 virtual ~VMFrame() = default; 84 85 AnfNodePtrToBaseRefMap values_; 86 87 private: 88 AnfNodePtrList todo_; 89 AnfNodePtrToBaseRefMap closure_; 90 }; 91 92 // Representation of a closure. 93 class Closure : public Base { 94 public: 95 Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values); 96 BaseRef operator()(const VectorRef &args); 97 vm()98 const VMPtr &vm() const { return vm_; } 99 set_vm(const VMPtr & vm)100 void set_vm(const VMPtr &vm) { vm_ = vm; } 101 func_graph()102 const FuncGraphPtr &func_graph() const { return func_graph_; } 103 values()104 const AnfNodePtrToBaseRefMap &values() const { return values_; } 105 106 virtual ~Closure() = default; 107 108 MS_DECLARE_PARENT(Closure, Base) 109 110 private: 111 FuncGraphPtr func_graph_; 112 AnfNodePtrToBaseRefMap values_; 113 VMPtr vm_; 114 }; 115 116 // Representation of a partial application. 117 class Partial : public Base { 118 public: 119 Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm); 120 BaseRef operator()(const VectorRef &nodes); fn()121 const BaseRef &fn() const { return fn_; } 122 args()123 const VectorRef &args() const { return args_; } 124 125 virtual ~Partial() = default; 126 MS_DECLARE_PARENT(Partial, Base) 127 128 private: 129 BaseRef fn_; 130 VectorRef args_; 131 VMPtr vm_; 132 }; 133 134 // Virtual Machine interface. 135 class VM : public std::enable_shared_from_this<VM>, public VMImpl { 136 public: 137 SetRef ComputeFvs(const FuncGraphPtr &graph) const; 138 139 void AcquireGraph(const FuncGraphPtr &graph); 140 141 VectorRef ExportSequence(const VectorRef &seq); 142 ExportPrimitive(const PrimitivePtr &)143 BaseRef ExportPrimitive(const PrimitivePtr &) const { return kValueAny; } 144 145 ClosurePtr ExportClosure(const ClosurePtr &clos); 146 147 // Return an object that executes `fg` when called on arguments. 148 ClosurePtr ExportGraph(const FuncGraphPtr &g); 149 150 BaseRef ExportObj(const BaseRef &obj) const; 151 152 BaseRef Export(const BaseRef &value); 153 154 // Run a graph. 155 // This will evaluate the passed-in graph and return the 156 // resulting value. 157 BaseRef Evaluate(const FuncGraphPtr &graph, const VectorRef &args, 158 const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap()); 159 160 // Return a visitor for the graph. 161 SuccFunc SuccVm(const FuncGraphPtr &graph); 162 163 // Call the `fn` object. 164 // `fn` can be anything that would be valid as the first element of an apply. 165 BaseRef Call(const BaseRef &fn, const VectorRef &args); 166 167 BaseRef _Call(const BaseRef &graph, const VectorRef &args); 168 169 ClosurePtr MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame); 170 171 BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args); 172 173 BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame); 174 175 VectorRef RunGraph(const FuncGraphPtr &g, const VectorRef &args) override; 176 177 private: 178 FuncGraphManagerPtr manager_; 179 FuncGraphPtrToBaseRefMap vars_; 180 }; 181 182 extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args); 183 184 } // namespace compile 185 } // namespace mindspore 186 187 #endif // MINDSPORE_CCSRC_VM_VMIMPL_H_ 188