• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_VM_VMIMPL_H_
20 #define MINDSPORE_CCSRC_VM_VMIMPL_H_
21 
22 #include <set>
23 #include <unordered_map>
24 #include <memory>
25 #include <vector>
26 
27 #include "ir/anf.h"
28 #include "ir/manager.h"
29 #include "ir/tensor.h"
30 #include "pybind_api/ir/base_ref_py.h"
31 
32 namespace mindspore {
33 namespace compile {
34 
35 using AnfNodePtrList = std::vector<AnfNodePtr>;
36 using AnfNodePtrToBaseRefMap = std::unordered_map<AnfNodePtr, BaseRef>;
37 using AnfNodePtrToAnfNodePtrMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
38 
39 using FuncGraphPtrToBaseRefMap = std::unordered_map<FuncGraphPtr, BaseRef>;
40 
41 using TensorList = std::vector<tensor::TensorPtr>;
42 
43 class Closure;
44 using ClosurePtr = std::shared_ptr<Closure>;
45 
46 class VMFrame;
47 using VMFramePtr = std::shared_ptr<VMFrame>;
48 using VMFramePtrList = std::vector<VMFramePtr>;
49 
50 class VM;
51 using VMPtr = std::shared_ptr<VM>;
52 
53 class Partial;
54 using PartialPtr = std::shared_ptr<Partial>;
55 
56 using RunFunc = std::function<VectorRef(const VectorRef &args)>;
57 using RunFuncPtr = std::shared_ptr<RunFunc>;
58 
59 using SuccFunc = std::function<AnfNodePtrList(AnfNodePtr node)>;
60 
61 class VMImpl {
62  public:
63   virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0;
64   virtual ~VMImpl() = default;
65 };
66 
67 // An execution frame.
68 // This holds the state for an application of a graph. The nodes list
69 // must contain free variables of graphs encountered before the
70 // graph themselves.
71 // You can index a frame with a node to get its value in the context
72 // of this frame (if it has already been evaluated).
73 // Attributes:
74 //   nodes: list of nodes remaining to execute
75 //   values: Mapping of node to their values in this application
76 //   closure: values for the closure if the current application is a closure
77 class VMFrame {
78  public:
79   VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure);
80   const BaseRef operator[](const AnfNodePtr &node);
todo()81   const AnfNodePtrList &todo() const { return todo_; }
82 
values()83   AnfNodePtrToBaseRefMap &values() { return values_; }
84 
85   virtual ~VMFrame() = default;
86 
87   AnfNodePtrToBaseRefMap values_;
88 
89  private:
90   AnfNodePtrList todo_;
91   AnfNodePtrToBaseRefMap closure_;
92 };
93 
94 // Representation of a closure.
95 class Closure : public Base {
96  public:
97   Closure(const FuncGraphPtr &func_graph, const AnfNodePtrToBaseRefMap &values);
98   BaseRef operator()(const VectorRef &args);
99 
vm()100   const VMPtr &vm() const { return vm_; }
101 
set_vm(const VMPtr & vm)102   void set_vm(const VMPtr &vm) { vm_ = vm; }
103 
func_graph()104   const FuncGraphPtr &func_graph() const { return func_graph_; }
105 
values()106   const AnfNodePtrToBaseRefMap &values() const { return values_; }
107 
108   virtual ~Closure() = default;
109 
110   MS_DECLARE_PARENT(Closure, Base)
111 
112  private:
113   FuncGraphPtr func_graph_;
114   AnfNodePtrToBaseRefMap values_;
115   VMPtr vm_;
116 };
117 
118 // Representation of a partial application.
119 class Partial : public Base {
120  public:
121   Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm);
122   BaseRef operator()(const VectorRef &nodes);
fn()123   const BaseRef &fn() const { return fn_; }
124 
args()125   const VectorRef &args() const { return args_; }
126 
127   virtual ~Partial() = default;
128   MS_DECLARE_PARENT(Partial, Base)
129 
130  private:
131   BaseRef fn_;
132   VectorRef args_;
133   VMPtr vm_;
134 };
135 
136 // Virtual Machine interface.
137 class VM : public std::enable_shared_from_this<VM>, public VMImpl {
138  public:
139   SetRef ComputeFvs(const FuncGraphPtr &func_graph);
140 
141   void AcquireGraph(const FuncGraphPtr &func_graph);
142 
143   VectorRef ExportSequence(const VectorRef &seq);
144 
ExportPrimitive(const PrimitivePtr &)145   BaseRef ExportPrimitive(const PrimitivePtr &) const { return kAnyValue; }
146 
147   ClosurePtr ExportClosure(const ClosurePtr &clos);
148 
149   // Return an object that executes `fg` when called on arguments.
150   ClosurePtr ExportGraph(const FuncGraphPtr &fg);
151 
152   BaseRef ExportObj(const BaseRef &obj) const;
153 
154   BaseRef Export(const BaseRef &value);
155 
156   // Run a graph.
157   // This will evaluate the passed-in graph and return the
158   // resulting value.
159   BaseRef Evaluate(const FuncGraphPtr &func_graph, const VectorRef &args,
160                    const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap());
161 
162   // Return a visitor for the graph.
163   SuccFunc SuccVm(const FuncGraphPtr &func_graph);
164 
165   // Call the `fn` object.
166   // `fn` can be anything that would be valid as the first element of an apply.
167   BaseRef Call(const BaseRef &fn, const VectorRef &args);
168 
169   BaseRef _Call(const BaseRef &graph, const VectorRef &args);
170 
171   ClosurePtr MakeClosure(const FuncGraphPtr &func_graph, const VMFramePtr &frame);
172 
173   BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args);
174 
175   BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame);
176 
177   VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) override;
178 
179  private:
180   FuncGraphManagerPtr manager_;
181   FuncGraphPtrToBaseRefMap vars_;
182 };
183 
184 extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args);
185 
186 }  // namespace compile
187 }  // namespace mindspore
188 
189 #endif  // MINDSPORE_CCSRC_VM_VMIMPL_H_
190