• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "vm/segment_runner.h"
20 
21 #include <algorithm>
22 #include <functional>
23 #include <memory>
24 #include <set>
25 #include <unordered_set>
26 #include <tuple>
27 #include <unordered_map>
28 #include <utility>
29 #include <string>
30 
31 #include "utils/log_adapter.h"
32 #include "utils/utils.h"
33 #include "ir/manager.h"
34 #include "ir/func_graph_cloner.h"
35 #include "frontend/operator/ops.h"
36 
37 namespace mindspore {
38 namespace compile {
39 namespace {
40 // Return the list of nodes whose values are required beyond this segment.
41 // Arguments:
42 //   nodes: list of nodes in the segment
43 //   users: dict mapping each node to its users (globally)
44 //   seen: set of nodes that are part of the segment
GetOutput(const AnfNodePtrList & nodes,const NodeUsersMap & users,const std::unordered_set<AnfNodePtr> & seen)45 AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
46                          const std::unordered_set<AnfNodePtr> &seen) {
47   AnfNodePtrList output;
48   if (users.size() == 0) {
49     return output;
50   }
51   for (auto &node : nodes) {
52     MS_EXCEPTION_IF_NULL(node);
53     if (!node->isa<CNode>()) {
54       continue;
55     }
56     auto iter = users.find(node);
57     if (iter == users.end()) {
58       continue;
59     }
60     auto &node_users = iter->second;
61     const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
62                                             [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
63                                               const bool is_outer_user = (seen.find(u.first) == seen.end());
64                                               return is_outer_user;
65                                             });
66     if (has_outer_user) {
67       output.emplace_back(node);
68     }
69   }
70   return output;
71 }
72 
RefSubGraphNode(const FuncGraphPtr & fg,const AnfNodePtr & node,AnfNodePtrList * const inputs_ptr,AnfNodePtrToAnfNodePtrMap * eqv_ptr)73 AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
74                            AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
75   MS_EXCEPTION_IF_NULL(fg);
76   MS_EXCEPTION_IF_NULL(inputs_ptr);
77   MS_EXCEPTION_IF_NULL(eqv_ptr);
78   MS_EXCEPTION_IF_NULL(node);
79   auto &inputs = *inputs_ptr;
80   auto &eqv = *eqv_ptr;
81   if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
82     eqv[node] = node;
83   } else if (eqv.find(node) == eqv.end()) {
84     inputs.push_back(node);
85     eqv[node] = fg->add_parameter();
86     eqv[node]->set_abstract(node->abstract());
87     eqv[node]->set_kernel_info(node->kernel_info_ptr());
88   }
89   return eqv[node];
90 }
91 }  // namespace
92 
TransformSegmentToAnfGraph(const AnfNodePtrList & lst)93 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
94   if (lst.empty()) {
95     MS_LOG(EXCEPTION) << "Input anf node list is empty";
96   }
97   FuncGraphPtr fg = nullptr;
98   {
99     // limit the lifetime of guard.
100     MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>());
101     MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>()->func_graph());
102     TraceGuard guard(std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
103     fg = std::make_shared<FuncGraph>();
104   }
105   AnfNodePtrList inputs;
106   AnfNodePtrToAnfNodePtrMap eqv;
107   // Merge CNodes into a AnfGraph that represents a linear instruction segment
108   for (auto n : lst) {
109     MS_EXCEPTION_IF_NULL(n);
110     if (!n->isa<CNode>()) {
111       MS_LOG(EXCEPTION) << "Inst is not CNode";
112     }
113     auto &inps = n->cast<CNodePtr>()->inputs();
114     if (inps.empty()) {
115       MS_LOG(EXCEPTION) << "Input is empty";
116     }
117     if (!IsValueNode<Primitive>(inps[0]) &&
118         !(IsValueNode<FuncGraph>(inps[0]) &&
119           inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
120       MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive ValueNode";
121     }
122     auto fn = inps[0];
123     std::vector<AnfNodePtr> args{fn};
124     if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= kDependInputSize &&
125         eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
126       args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
127       const size_t value_start_index = 2;
128       for (size_t i = value_start_index; i < inps.size(); ++i) {
129         args.emplace_back(NewValueNode(MakeValue(0)));
130       }
131     } else {
132       (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
133                            [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
134     }
135     TraceGuard tg(std::make_shared<TraceSegmentTransform>(n->debug_info()));
136     MS_EXCEPTION_IF_NULL(fg);
137     eqv[n] = fg->NewCNode(args);
138     eqv[n]->set_abstract(n->abstract());
139     eqv[n]->set_kernel_info(n->kernel_info_ptr());
140   }
141   std::unordered_set<AnfNodePtr> eqv_keys;
142   (void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()),
143                        [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
144   auto mgr = lst[0]->func_graph()->manager();
145   MS_EXCEPTION_IF_NULL(mgr);
146   auto outputs = GetOutput(lst, mgr->node_users(), eqv_keys);
147   AnfNodePtr fg_output;
148   if (outputs.size() > 1) {
149     std::vector<AnfNodePtr> output_args;
150     output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
151     (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
152                          [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
153     // Set output for AnfGraph
154     fg_output = fg->NewCNode(output_args);
155   } else {
156     if (outputs.empty()) {
157       MS_LOG(EXCEPTION) << "Output is empty.";
158     }
159     fg_output = eqv[outputs[0]];
160   }
161   fg->set_output(fg_output);
162   return std::make_tuple(fg, inputs, outputs);
163 }
164 
165 // Converts the list of nodes to a runnable form.
166 // All the nodes in the list must represent linear flow (no calls, branches, ...)
167 // Returns:
168 //  (fn, inputs, outputs):
169 //  - fn: A callable function
170 //  - inputs: the list of inputs nodes whose values should be
171 //             provided to the function
172 //  - outputs: the list of output nodes corresponding to the
173 //             outputs of the function
174 // Notes:
175 //   This implementation will convert the nodes into a subgraph
176 //   that will run using the MsVM.
177 template <typename T>
Convert(const GraphSegmentPtr & segment,const std::string &)178 LinConvertResult Convert(const GraphSegmentPtr &segment, const std::string &) {
179   MS_EXCEPTION_IF_NULL(segment);
180   LinConvertResult result;
181 
182   FuncGraphPtr fg = nullptr;
183   AnfNodePtrList inputs;
184   AnfNodePtrList outputs;
185 
186   std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
187 
188   // Clone in case g contains subgraphs that have a different manager
189   fg = BasicClone(fg);
190 
191   std::shared_ptr<VMImpl> vm = std::make_shared<T>();
192 
193   result.run =
194     std::make_shared<RunFunc>([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); });
195   result.inputs = inputs;
196   result.outputs = outputs;
197   result.graph_id = UINT32_MAX;
198 
199   return result;
200 }
201 
202 LinkFuncType MsVmConvert = Convert<VM>;
203 
204 std::set<std::string> backend_list = {
205   kMsConvert,
206   kMsVm,
207 };
208 }  // namespace compile
209 }  // namespace mindspore
210