• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_VM_VMIMPL_H_
20 #define MINDSPORE_CCSRC_VM_VMIMPL_H_
21 
22 #include <set>
23 #include <memory>
24 #include <vector>
25 
26 #include "utils/hash_map.h"
27 #include "ir/anf.h"
28 #include "ir/manager.h"
29 #include "ir/tensor.h"
30 #include "pybind_api/ir/base_ref_py.h"
31 
32 namespace mindspore {
33 namespace compile {
34 
35 using AnfNodePtrList = std::vector<AnfNodePtr>;
36 using AnfNodePtrToBaseRefMap = mindspore::HashMap<AnfNodePtr, BaseRef>;
37 using AnfNodePtrToAnfNodePtrMap = mindspore::HashMap<AnfNodePtr, AnfNodePtr>;
38 
39 using FuncGraphPtrToBaseRefMap = mindspore::HashMap<FuncGraphPtr, BaseRef>;
40 
41 using TensorList = std::vector<tensor::TensorPtr>;
42 
43 class Closure;
44 using ClosurePtr = std::shared_ptr<Closure>;
45 
46 class VMFrame;
47 using VMFramePtr = std::shared_ptr<VMFrame>;
48 using VMFramePtrList = std::vector<VMFramePtr>;
49 
50 class VM;
51 using VMPtr = std::shared_ptr<VM>;
52 
53 class Partial;
54 using PartialPtr = std::shared_ptr<Partial>;
55 
56 using RunFunc = std::function<VectorRef(const VectorRef &args)>;
57 using RunFuncPtr = std::shared_ptr<RunFunc>;
58 
59 class VMImpl {
60  public:
61   virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0;
62   virtual ~VMImpl() = default;
63 };
64 
65 // An execution frame.
66 // This holds the state for an application of a graph. The nodes list
67 // must contain free variables of graphs encountered before the
68 // graph themselves.
69 // You can index a frame with a node to get its value in the context
70 // of this frame (if it has already been evaluated).
71 // Attributes:
72 //   nodes: list of nodes remaining to execute
73 //   values: Mapping of node to their values in this application
74 //   closure: values for the closure if the current application is a closure
75 class VMFrame {
76  public:
77   VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure);
78   const BaseRef operator[](const AnfNodePtr &node);
todo()79   const AnfNodePtrList &todo() const { return todo_; }
80 
values()81   AnfNodePtrToBaseRefMap &values() { return values_; }
82 
83   virtual ~VMFrame() = default;
84 
85   AnfNodePtrToBaseRefMap values_;
86 
87  private:
88   AnfNodePtrList todo_;
89   AnfNodePtrToBaseRefMap closure_;
90 };
91 
92 // Representation of a closure.
93 class Closure : public Base {
94  public:
95   Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values);
96   BaseRef operator()(const VectorRef &args);
97 
vm()98   const VMPtr &vm() const { return vm_; }
99 
set_vm(const VMPtr & vm)100   void set_vm(const VMPtr &vm) { vm_ = vm; }
101 
func_graph()102   const FuncGraphPtr &func_graph() const { return func_graph_; }
103 
values()104   const AnfNodePtrToBaseRefMap &values() const { return values_; }
105 
106   virtual ~Closure() = default;
107 
108   MS_DECLARE_PARENT(Closure, Base)
109 
110  private:
111   FuncGraphPtr func_graph_;
112   AnfNodePtrToBaseRefMap values_;
113   VMPtr vm_;
114 };
115 
116 // Representation of a partial application.
117 class Partial : public Base {
118  public:
119   Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm);
120   BaseRef operator()(const VectorRef &nodes);
fn()121   const BaseRef &fn() const { return fn_; }
122 
args()123   const VectorRef &args() const { return args_; }
124 
125   virtual ~Partial() = default;
126   MS_DECLARE_PARENT(Partial, Base)
127 
128  private:
129   BaseRef fn_;
130   VectorRef args_;
131   VMPtr vm_;
132 };
133 
134 // Virtual Machine interface.
135 class VM : public std::enable_shared_from_this<VM>, public VMImpl {
136  public:
137   SetRef ComputeFvs(const FuncGraphPtr &graph) const;
138 
139   void AcquireGraph(const FuncGraphPtr &graph);
140 
141   VectorRef ExportSequence(const VectorRef &seq);
142 
ExportPrimitive(const PrimitivePtr &)143   BaseRef ExportPrimitive(const PrimitivePtr &) const { return kValueAny; }
144 
145   ClosurePtr ExportClosure(const ClosurePtr &clos);
146 
147   // Return an object that executes `fg` when called on arguments.
148   ClosurePtr ExportGraph(const FuncGraphPtr &g);
149 
150   BaseRef ExportObj(const BaseRef &obj) const;
151 
152   BaseRef Export(const BaseRef &value);
153 
154   // Run a graph.
155   // This will evaluate the passed-in graph and return the
156   // resulting value.
157   BaseRef Evaluate(const FuncGraphPtr &graph, const VectorRef &args,
158                    const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap());
159 
160   // Return a visitor for the graph.
161   SuccFunc SuccVm(const FuncGraphPtr &graph);
162 
163   // Call the `fn` object.
164   // `fn` can be anything that would be valid as the first element of an apply.
165   BaseRef Call(const BaseRef &fn, const VectorRef &args);
166 
167   BaseRef _Call(const BaseRef &graph, const VectorRef &args);
168 
169   ClosurePtr MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame);
170 
171   BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args);
172 
173   BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame);
174 
175   VectorRef RunGraph(const FuncGraphPtr &g, const VectorRef &args) override;
176 
177  private:
178   FuncGraphManagerPtr manager_;
179   FuncGraphPtrToBaseRefMap vars_;
180 };
181 
182 extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args);
183 
184 }  // namespace compile
185 }  // namespace mindspore
186 
187 #endif  // MINDSPORE_CCSRC_VM_VMIMPL_H_
188