• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/adapter/graph_kernel_splitter_with_py.h"
17 
18 #include <algorithm>
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <utility>
23 #include <map>
24 #include <set>
25 #include <nlohmann/json.hpp>
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "utils/hash_map.h"
28 #include "utils/ms_context.h"
29 #include "include/common/utils/python_adapter.h"
30 #include "kernel/graph_kernel/graph_kernel_json_flags.h"
31 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
32 #include "kernel/framework_utils.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34 #include "backend/common/graph_kernel/graph_kernel_helper.h"
35 #include "backend/common/graph_kernel/graph_kernel_flags.h"
36 
37 namespace mindspore::graphkernel {
38 struct StitchInfo {
39   std::vector<std::string> stitch_ops;
40   std::vector<std::string> stitch_atomic_ops;
41 };
42 
43 class SplitNodesDecoder {
GetStitchInfo(const nlohmann::json & kernel_json) const44   StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const {
45     StitchInfo info;
46     if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
47       nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
48       if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
49         std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
50         info.stitch_ops = stitch_ops;
51       }
52       if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
53         std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
54         info.stitch_atomic_ops = stitch_atomic_ops;
55       }
56     }
57     return info;
58   }
59 
GetRecomputeOps(const nlohmann::json & kernel_json) const60   std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) const {
61     if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) {
62       std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps];
63       return std::set<std::string>(recompute_ops.begin(), recompute_ops.end());
64     }
65     return std::set<std::string>();
66   }
67 
IsRecomputeOp(const nlohmann::json & op_desc,const std::set<std::string> & recompute_ops) const68   bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) const {
69     std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
70     if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
71       return false;
72     }
73     std::string tensor_name = output_descs[0][kJsonKeyTensorName];
74     return recompute_ops.count(tensor_name) > 0;
75   }
76 
NewRecomputeNode(const AnfNodePtr & orig_node,std::map<AnfNodePtr,AnfNodePtr> * node_map) const77   CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) const {
78     auto func_graph = orig_node->func_graph();
79     MS_EXCEPTION_IF_NULL(func_graph);
80     auto cnode = orig_node->cast<CNodePtr>();
81     MS_EXCEPTION_IF_NULL(cnode);
82     TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
83     auto orig_inputs = cnode->inputs();
84     std::vector<AnfNodePtr> inputs;
85     for (auto inp : orig_inputs) {
86       if (node_map->find(inp) == node_map->end()) {
87         inputs.push_back(inp);
88         continue;
89       }
90       inputs.push_back((*node_map)[inp]);
91     }
92     CNodePtr cp_node = func_graph->NewCNode(inputs);
93     func_graph->AddNode(cp_node);
94     ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
95     cp_node->set_scope(scope);
96     cp_node->CloneCNodeInfo(cnode);
97     (*node_map)[orig_node] = cp_node;
98     return cp_node->cast<CNodePtr>();
99   }
100 
SetStitchAttr(const nlohmann::json & op_desc,const StitchInfo & info,const CNodePtr & node) const101   void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const {
102     std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
103     if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
104       return;
105     }
106     std::string tensor_name = output_descs[0][kJsonKeyTensorName];
107     if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
108       AnfUtils::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
109       MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
110     }
111     if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
112         info.stitch_atomic_ops.end()) {
113       AnfUtils::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
114       MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
115     }
116   }
117 
118   // replace original region root op by its copy in this res_graphs
ConnectRecomputeOps(AnfNodePtrList * res_graphs,const AnfNodePtr & orig_region_root,const AnfNodePtr & cp_region_root) const119   void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root,
120                            const AnfNodePtr &cp_region_root) const {
121     for (auto &node : *res_graphs) {
122       auto cnode = node->cast<CNodePtr>();
123       auto inputs = cnode->inputs();
124       for (size_t i = 1; i < inputs.size(); ++i) {
125         if (inputs[i] != orig_region_root) {
126           continue;
127         }
128         cnode->set_input(i, cp_region_root);
129       }
130     }
131   }
132 
133  public:
DecodeSplitNodes(const nlohmann::json & kernel_json,const std::map<std::string,AnfNodePtr> & address_node_map,AnfNodePtrList * res_graphs) const134   bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
135                         AnfNodePtrList *res_graphs) const {
136     MS_EXCEPTION_IF_NULL(res_graphs);
137     MS_LOG(DEBUG) << "start decode, " << kernel_json;
138     // decode cnodes in graph.
139     std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
140     if (op_node_descs.empty()) {
141       MS_LOG(ERROR) << "Error decode, no cnodes for graph: " << kernel_json;
142       return false;
143     }
144     StitchInfo info = GetStitchInfo(kernel_json);
145     auto recompute_ops = GetRecomputeOps(kernel_json);
146     // key_value: original_copied
147     std::map<AnfNodePtr, AnfNodePtr> node_map;
148     // nodes would be copied
149     AnfNodePtrList orig_region_nodes;
150     // nodes would not be copied
151     AnfNodePtrList no_cp_nodes;
152     for (const auto &op_desc : op_node_descs) {
153       if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
154         MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
155         return false;
156       }
157 
158       std::string ptr_address = op_desc[kJsonKeyPtrAddress];
159       if (address_node_map.count(ptr_address) == 0) {
160         MS_LOG(ERROR) << "Decode failed, ptr_address not found in map: " << ptr_address;
161         return false;
162       }
163       auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
164       if (IsRecomputeOp(op_desc, recompute_ops)) {
165         auto cp_node = NewRecomputeNode(node, &node_map);
166         orig_region_nodes.push_back(node);
167         SetStitchAttr(op_desc, info, cp_node);
168         res_graphs->push_back(cp_node);
169         continue;
170       }
171       SetStitchAttr(op_desc, info, node);
172       res_graphs->push_back(node);
173       no_cp_nodes.push_back(node);
174     }
175     for (auto orig_node : orig_region_nodes) {
176       ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]);
177     }
178     MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
179     return true;
180   }
181 };
182 
SplitByJsonStr(const std::map<std::string,AnfNodePtr> & address_node_map,std::string split_graphs_str)183 bool SplitByJsonSchemer::SplitByJsonStr(const std::map<std::string, AnfNodePtr> &address_node_map,
184                                         std::string split_graphs_str) {
185   if (!DecodeJson(split_graphs_str, address_node_map)) {
186     MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str;
187     return false;
188   }
189 
190   if (split_plan_.size() > 1 && GraphKernelFlags::GetInstance().enable_recompute_fusion) {
191     RemoveHangingNodes();
192   }
193   return true;
194 }
195 
RemoveHangingNodes()196 void SplitByJsonSchemer::RemoveHangingNodes() {
197   auto todo = TopoSort(func_graph_->get_return());
198   std::set<AnfNodePtr> new_all_nodes(todo.begin(), todo.end());
199   std::vector<size_t> empty_groups;
200   for (size_t i = 0; i < split_plan_.size(); i++) {
201     for (int j = SizeToInt(split_plan_[i].size()) - 1; j >= 0; j--) {
202       if (new_all_nodes.count(split_plan_[i][j]) == 0) {
203         MS_LOG(INFO) << "Recompute remove hanging node " << split_plan_[i][j]->fullname_with_scope();
204         (void)split_plan_[i].erase(split_plan_[i].begin() + j);
205       }
206     }
207     if (split_plan_[i].empty()) {
208       empty_groups.push_back(i);
209     }
210   }
211   if (!empty_groups.empty()) {
212     MS_LOG(INFO) << "Recompute remove empty groups " << empty_groups;
213     std::reverse(empty_groups.begin(), empty_groups.end());
214     for (auto i : empty_groups) {
215       (void)split_plan_.erase(split_plan_.begin() + i);
216       (void)need_inline_.erase(need_inline_.begin() + i);
217     }
218   }
219 }
220 
DecodeJson(const std::string & json_desc,const std::map<std::string,AnfNodePtr> & address_node_map)221 bool SplitByJsonSchemer::DecodeJson(const std::string &json_desc,
222                                     const std::map<std::string, AnfNodePtr> &address_node_map) {
223   auto kernel_json = nlohmann::json::parse(json_desc);
224   std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
225   std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode];
226   if (graph_modes.size() != graph_descs.size()) {
227     MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size();
228     return false;
229   }
230 
231   // recover json to anfnode.
232   split_plan_.clear();
233   for (const auto &graph_desc : graph_descs) {
234     AnfNodePtrList res_graph;
235     if (!SplitNodesDecoder().DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
236       MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
237       return false;
238     }
239     (void)split_plan_.emplace_back(std::move(res_graph));
240   }
241 
242   // ops to be inlined.
243   need_inline_.clear();
244   (void)std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_),
245                        [](const std::string &mode) { return mode == "basic" ? 1 : 0; });
246   return true;
247 }
248 
Run()249 void SplitByJsonSchemer::Run() {
250   auto mng = func_graph_->manager();
251   if (mng == nullptr) {
252     mng = Manage(func_graph_, true);
253     func_graph_->set_manager(mng);
254   }
255   GetValidKernelNodes();
256   // call CostModel to get a split plan.
257   if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) {
258     split_plan_.clear();
259     need_inline_.clear();
260     return;
261   } else if (split_plan_.size() == 1 && !NeedInline(0)) {
262     // In this case, the CostModel decided to keep the whole graph unchanged.
263     split_plan_.clear();
264     need_inline_.clear();
265     return;
266   } else {
267     MS_LOG(DEBUG) << "CostModel split succeeded. The kernel is split to " << split_plan_.size() << " parts.";
268   }
269   MapNodeGroup();
270   GroupReturnNode();
271   GroupVirtualNodes();
272 }
273 
IsValidKernelNode(const AnfNodePtr & node) const274 bool SplitByJsonSchemer::IsValidKernelNode(const AnfNodePtr &node) const {
275   if (!node->isa<CNode>()) {
276     return false;
277   }
278   if (AnfUtils::IsRealKernel(node)) {
279     return true;
280   }
281   return false;
282 }
283 
GetValidKernelNodes()284 void SplitByJsonSchemer::GetValidKernelNodes() {
285   topo_all_nodes_ = TopoSort(func_graph_->get_return());
286   topo_valid_nodes_.clear();
287   (void)std::copy_if(topo_all_nodes_.begin(), topo_all_nodes_.end(), std::back_inserter(topo_valid_nodes_),
288                      [this](const AnfNodePtr &node) { return IsValidKernelNode(node); });
289 }
290 
MapNodeGroup()291 void SplitByJsonSchemer::MapNodeGroup() {
292   node_group_.clear();
293   for (size_t i = 0; i < split_plan_.size(); ++i) {
294     for (const auto &node : split_plan_[i]) {
295       node_group_[node] = i;
296     }
297   }
298 }
299 
300 // group the return node and last MakeTuple node (if exists).
GroupReturnNode()301 void SplitByJsonSchemer::GroupReturnNode() {
302   AnfNodePtrList outputs;
303   kernel::GetFuncGraphOutputNodes(func_graph_, &outputs);
304   auto ret_node = func_graph_->get_return();
305   auto output = func_graph_->output();
306   MS_EXCEPTION_IF_NULL(output);
307 
308   if (IsValidKernelNode(output)) {
309     auto group_id = node_group_[output];
310     node_group_[ret_node] = group_id;
311     (void)split_plan_[group_id].emplace_back(ret_node);
312     return;
313   }
314   // assign the make_tuple node to a new group.
315   if (common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
316     auto group_id = split_plan_.size();
317     (void)split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
318     (void)need_inline_.emplace_back(1);
319     node_group_[output] = group_id;
320     node_group_[ret_node] = group_id;
321     return;
322   }
323 }
324 
325 // assign virtual node to the same group of its input.
GroupVirtualNodes()326 void SplitByJsonSchemer::GroupVirtualNodes() {
327   for (const auto &node : topo_all_nodes_) {
328     if (node_group_.count(node) != 0) {
329       continue;
330     }
331     auto cnode = node->cast<CNodePtr>();
332     if (cnode == nullptr) {
333       continue;
334     }
335     bool found = false;
336     for (const auto &input : cnode->inputs()) {
337       auto iter = node_group_.find(input);
338       if (iter != node_group_.end()) {
339         auto group_id = iter->second;
340         node_group_[node] = group_id;
341         (void)split_plan_[group_id].emplace_back(node);
342         found = true;
343         break;
344       }
345     }
346     if (!found) {
347       MS_LOG(WARNING) << cnode->fullname_with_scope() << " is ungrouped.";
348     }
349   }
350 }
351 
SplitByCostModel()352 bool CostModelSplitSchemer::SplitByCostModel() {
353   // Use an address map to record the anf node address when converting to json,
354   // it will recover the original node after split.
355   std::map<std::string, AnfNodePtr> address_node_map;
356 
357   // convert anf-ir to json
358   nlohmann::json json_desc;
359   DumpOption dump_option;
360   dump_option.is_before_select_kernel = false;
361   dump_option.save_ptr_address = true;
362   if (!AnfToJsonDesc(topo_valid_nodes_, dump_option, &json_desc, &address_node_map)) {
363     MS_LOG(ERROR) << "Collect json desc failed.";
364     return false;
365   }
366   // set the "node_name" for tracing split result.
367   std::string node_name = json_desc["op"];
368   func_graph_->set_attr(kAttrNodeName, MakeValue(node_name));
369   // call costmodel split function.
370   auto json_desc_str = json_desc.dump();
371   auto flags_str = GraphKernelFlags::GetInstance().DumpAllFlags();
372   MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json: " << json_desc_str
373                 << ". flag: " << flags_str;
374   auto ret = python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str, flags_str);
375   if (py::isinstance<py::none>(ret)) {
376     MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
377                   << json_desc_str << ". flag: " << flags_str;
378     return false;
379   }
380   std::string split_graphs_str = py::cast<std::string>(ret);
381   if (split_graphs_str.empty()) {
382     MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
383                   << json_desc_str << ". flag: " << flags_str;
384     return false;
385   }
386   return SplitByJsonStr(address_node_map, split_graphs_str);
387 }
388 
GetSplitSchema(const std::string & processor)389 std::shared_ptr<SplitSchemer> GraphKernelSplitterWithPy::GetSplitSchema(const std::string &processor) {
390   bool using_py_split_ =
391     (processor == kGPUDevice) && (!is_dynamic_ || !GraphKernelFlags::GetInstance().enable_dynamic_shape_fusion);
392   if (using_py_split_) {
393     MS_LOG(DEBUG) << "use py split model";
394     return std::make_shared<CostModelSplitSchemer>();
395   } else {
396     MS_LOG(DEBUG) << "use c++ split model";
397     return GraphKernelSplitter::GetSplitSchema(processor);
398   }
399 }
400 }  // namespace mindspore::graphkernel
401