1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2020 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_VM_VMIMPL_H_ 20 #define MINDSPORE_CCSRC_VM_VMIMPL_H_ 21 22 #include <set> 23 #include <unordered_map> 24 #include <memory> 25 #include <vector> 26 27 #include "ir/anf.h" 28 #include "ir/manager.h" 29 #include "ir/tensor.h" 30 #include "pybind_api/ir/base_ref_py.h" 31 32 namespace mindspore { 33 namespace compile { 34 35 using AnfNodePtrList = std::vector<AnfNodePtr>; 36 using AnfNodePtrToBaseRefMap = std::unordered_map<AnfNodePtr, BaseRef>; 37 using AnfNodePtrToAnfNodePtrMap = std::unordered_map<AnfNodePtr, AnfNodePtr>; 38 39 using FuncGraphPtrToBaseRefMap = std::unordered_map<FuncGraphPtr, BaseRef>; 40 41 using TensorList = std::vector<tensor::TensorPtr>; 42 43 class Closure; 44 using ClosurePtr = std::shared_ptr<Closure>; 45 46 class VMFrame; 47 using VMFramePtr = std::shared_ptr<VMFrame>; 48 using VMFramePtrList = std::vector<VMFramePtr>; 49 50 class VM; 51 using VMPtr = std::shared_ptr<VM>; 52 53 class Partial; 54 using PartialPtr = std::shared_ptr<Partial>; 55 56 using RunFunc = std::function<VectorRef(const VectorRef &args)>; 57 using RunFuncPtr = std::shared_ptr<RunFunc>; 58 59 using SuccFunc = std::function<AnfNodePtrList(AnfNodePtr node)>; 60 61 class VMImpl { 62 public: 63 virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0; 64 virtual ~VMImpl() = default; 65 }; 66 67 // An execution frame. 68 // This holds the state for an application of a graph. The nodes list 69 // must contain free variables of graphs encountered before the 70 // graph themselves. 71 // You can index a frame with a node to get its value in the context 72 // of this frame (if it has already been evaluated). 73 // Attributes: 74 // nodes: list of nodes remaining to execute 75 // values: Mapping of node to their values in this application 76 // closure: values for the closure if the current application is a closure 77 class VMFrame { 78 public: 79 VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure); 80 const BaseRef operator[](const AnfNodePtr &node); todo()81 const AnfNodePtrList &todo() const { return todo_; } 82 values()83 AnfNodePtrToBaseRefMap &values() { return values_; } 84 85 virtual ~VMFrame() = default; 86 87 AnfNodePtrToBaseRefMap values_; 88 89 private: 90 AnfNodePtrList todo_; 91 AnfNodePtrToBaseRefMap closure_; 92 }; 93 94 // Representation of a closure. 95 class Closure : public Base { 96 public: 97 Closure(const FuncGraphPtr &func_graph, const AnfNodePtrToBaseRefMap &values); 98 BaseRef operator()(const VectorRef &args); 99 vm()100 const VMPtr &vm() const { return vm_; } 101 set_vm(const VMPtr & vm)102 void set_vm(const VMPtr &vm) { vm_ = vm; } 103 func_graph()104 const FuncGraphPtr &func_graph() const { return func_graph_; } 105 values()106 const AnfNodePtrToBaseRefMap &values() const { return values_; } 107 108 virtual ~Closure() = default; 109 110 MS_DECLARE_PARENT(Closure, Base) 111 112 private: 113 FuncGraphPtr func_graph_; 114 AnfNodePtrToBaseRefMap values_; 115 VMPtr vm_; 116 }; 117 118 // Representation of a partial application. 119 class Partial : public Base { 120 public: 121 Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm); 122 BaseRef operator()(const VectorRef &nodes); fn()123 const BaseRef &fn() const { return fn_; } 124 args()125 const VectorRef &args() const { return args_; } 126 127 virtual ~Partial() = default; 128 MS_DECLARE_PARENT(Partial, Base) 129 130 private: 131 BaseRef fn_; 132 VectorRef args_; 133 VMPtr vm_; 134 }; 135 136 // Virtual Machine interface. 137 class VM : public std::enable_shared_from_this<VM>, public VMImpl { 138 public: 139 SetRef ComputeFvs(const FuncGraphPtr &func_graph); 140 141 void AcquireGraph(const FuncGraphPtr &func_graph); 142 143 VectorRef ExportSequence(const VectorRef &seq); 144 ExportPrimitive(const PrimitivePtr &)145 BaseRef ExportPrimitive(const PrimitivePtr &) const { return kAnyValue; } 146 147 ClosurePtr ExportClosure(const ClosurePtr &clos); 148 149 // Return an object that executes `fg` when called on arguments. 150 ClosurePtr ExportGraph(const FuncGraphPtr &fg); 151 152 BaseRef ExportObj(const BaseRef &obj) const; 153 154 BaseRef Export(const BaseRef &value); 155 156 // Run a graph. 157 // This will evaluate the passed-in graph and return the 158 // resulting value. 159 BaseRef Evaluate(const FuncGraphPtr &func_graph, const VectorRef &args, 160 const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap()); 161 162 // Return a visitor for the graph. 163 SuccFunc SuccVm(const FuncGraphPtr &func_graph); 164 165 // Call the `fn` object. 166 // `fn` can be anything that would be valid as the first element of an apply. 167 BaseRef Call(const BaseRef &fn, const VectorRef &args); 168 169 BaseRef _Call(const BaseRef &graph, const VectorRef &args); 170 171 ClosurePtr MakeClosure(const FuncGraphPtr &func_graph, const VMFramePtr &frame); 172 173 BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args); 174 175 BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame); 176 177 VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) override; 178 179 private: 180 FuncGraphManagerPtr manager_; 181 FuncGraphPtrToBaseRefMap vars_; 182 }; 183 184 extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args); 185 186 } // namespace compile 187 } // namespace mindspore 188 189 #endif // MINDSPORE_CCSRC_VM_VMIMPL_H_ 190