1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/adapter/graph_kernel_splitter_with_py.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <utility>
23 #include <map>
24 #include <set>
25 #include <nlohmann/json.hpp>
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "utils/hash_map.h"
28 #include "utils/ms_context.h"
29 #include "include/common/utils/python_adapter.h"
30 #include "kernel/graph_kernel/graph_kernel_json_flags.h"
31 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
32 #include "kernel/framework_utils.h"
33 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
34 #include "backend/common/graph_kernel/graph_kernel_helper.h"
35 #include "backend/common/graph_kernel/graph_kernel_flags.h"
36
37 namespace mindspore::graphkernel {
38 struct StitchInfo {
39 std::vector<std::string> stitch_ops;
40 std::vector<std::string> stitch_atomic_ops;
41 };
42
43 class SplitNodesDecoder {
GetStitchInfo(const nlohmann::json & kernel_json) const44 StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const {
45 StitchInfo info;
46 if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
47 nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
48 if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
49 std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
50 info.stitch_ops = stitch_ops;
51 }
52 if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
53 std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
54 info.stitch_atomic_ops = stitch_atomic_ops;
55 }
56 }
57 return info;
58 }
59
GetRecomputeOps(const nlohmann::json & kernel_json) const60 std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) const {
61 if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) {
62 std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps];
63 return std::set<std::string>(recompute_ops.begin(), recompute_ops.end());
64 }
65 return std::set<std::string>();
66 }
67
IsRecomputeOp(const nlohmann::json & op_desc,const std::set<std::string> & recompute_ops) const68 bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) const {
69 std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
70 if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
71 return false;
72 }
73 std::string tensor_name = output_descs[0][kJsonKeyTensorName];
74 return recompute_ops.count(tensor_name) > 0;
75 }
76
NewRecomputeNode(const AnfNodePtr & orig_node,std::map<AnfNodePtr,AnfNodePtr> * node_map) const77 CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) const {
78 auto func_graph = orig_node->func_graph();
79 MS_EXCEPTION_IF_NULL(func_graph);
80 auto cnode = orig_node->cast<CNodePtr>();
81 MS_EXCEPTION_IF_NULL(cnode);
82 TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
83 auto orig_inputs = cnode->inputs();
84 std::vector<AnfNodePtr> inputs;
85 for (auto inp : orig_inputs) {
86 if (node_map->find(inp) == node_map->end()) {
87 inputs.push_back(inp);
88 continue;
89 }
90 inputs.push_back((*node_map)[inp]);
91 }
92 CNodePtr cp_node = func_graph->NewCNode(inputs);
93 func_graph->AddNode(cp_node);
94 ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
95 cp_node->set_scope(scope);
96 cp_node->CloneCNodeInfo(cnode);
97 (*node_map)[orig_node] = cp_node;
98 return cp_node->cast<CNodePtr>();
99 }
100
SetStitchAttr(const nlohmann::json & op_desc,const StitchInfo & info,const CNodePtr & node) const101 void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const {
102 std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
103 if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
104 return;
105 }
106 std::string tensor_name = output_descs[0][kJsonKeyTensorName];
107 if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
108 AnfUtils::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
109 MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
110 }
111 if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
112 info.stitch_atomic_ops.end()) {
113 AnfUtils::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
114 MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
115 }
116 }
117
118 // replace original region root op by its copy in this res_graphs
ConnectRecomputeOps(AnfNodePtrList * res_graphs,const AnfNodePtr & orig_region_root,const AnfNodePtr & cp_region_root) const119 void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root,
120 const AnfNodePtr &cp_region_root) const {
121 for (auto &node : *res_graphs) {
122 auto cnode = node->cast<CNodePtr>();
123 auto inputs = cnode->inputs();
124 for (size_t i = 1; i < inputs.size(); ++i) {
125 if (inputs[i] != orig_region_root) {
126 continue;
127 }
128 cnode->set_input(i, cp_region_root);
129 }
130 }
131 }
132
133 public:
DecodeSplitNodes(const nlohmann::json & kernel_json,const std::map<std::string,AnfNodePtr> & address_node_map,AnfNodePtrList * res_graphs) const134 bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
135 AnfNodePtrList *res_graphs) const {
136 MS_EXCEPTION_IF_NULL(res_graphs);
137 MS_LOG(DEBUG) << "start decode, " << kernel_json;
138 // decode cnodes in graph.
139 std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
140 if (op_node_descs.empty()) {
141 MS_LOG(ERROR) << "Error decode, no cnodes for graph: " << kernel_json;
142 return false;
143 }
144 StitchInfo info = GetStitchInfo(kernel_json);
145 auto recompute_ops = GetRecomputeOps(kernel_json);
146 // key_value: original_copied
147 std::map<AnfNodePtr, AnfNodePtr> node_map;
148 // nodes would be copied
149 AnfNodePtrList orig_region_nodes;
150 // nodes would not be copied
151 AnfNodePtrList no_cp_nodes;
152 for (const auto &op_desc : op_node_descs) {
153 if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
154 MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
155 return false;
156 }
157
158 std::string ptr_address = op_desc[kJsonKeyPtrAddress];
159 if (address_node_map.count(ptr_address) == 0) {
160 MS_LOG(ERROR) << "Decode failed, ptr_address not found in map: " << ptr_address;
161 return false;
162 }
163 auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
164 if (IsRecomputeOp(op_desc, recompute_ops)) {
165 auto cp_node = NewRecomputeNode(node, &node_map);
166 orig_region_nodes.push_back(node);
167 SetStitchAttr(op_desc, info, cp_node);
168 res_graphs->push_back(cp_node);
169 continue;
170 }
171 SetStitchAttr(op_desc, info, node);
172 res_graphs->push_back(node);
173 no_cp_nodes.push_back(node);
174 }
175 for (auto orig_node : orig_region_nodes) {
176 ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]);
177 }
178 MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
179 return true;
180 }
181 };
182
SplitByJsonStr(const std::map<std::string,AnfNodePtr> & address_node_map,std::string split_graphs_str)183 bool SplitByJsonSchemer::SplitByJsonStr(const std::map<std::string, AnfNodePtr> &address_node_map,
184 std::string split_graphs_str) {
185 if (!DecodeJson(split_graphs_str, address_node_map)) {
186 MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str;
187 return false;
188 }
189
190 if (split_plan_.size() > 1 && GraphKernelFlags::GetInstance().enable_recompute_fusion) {
191 RemoveHangingNodes();
192 }
193 return true;
194 }
195
RemoveHangingNodes()196 void SplitByJsonSchemer::RemoveHangingNodes() {
197 auto todo = TopoSort(func_graph_->get_return());
198 std::set<AnfNodePtr> new_all_nodes(todo.begin(), todo.end());
199 std::vector<size_t> empty_groups;
200 for (size_t i = 0; i < split_plan_.size(); i++) {
201 for (int j = SizeToInt(split_plan_[i].size()) - 1; j >= 0; j--) {
202 if (new_all_nodes.count(split_plan_[i][j]) == 0) {
203 MS_LOG(INFO) << "Recompute remove hanging node " << split_plan_[i][j]->fullname_with_scope();
204 (void)split_plan_[i].erase(split_plan_[i].begin() + j);
205 }
206 }
207 if (split_plan_[i].empty()) {
208 empty_groups.push_back(i);
209 }
210 }
211 if (!empty_groups.empty()) {
212 MS_LOG(INFO) << "Recompute remove empty groups " << empty_groups;
213 std::reverse(empty_groups.begin(), empty_groups.end());
214 for (auto i : empty_groups) {
215 (void)split_plan_.erase(split_plan_.begin() + i);
216 (void)need_inline_.erase(need_inline_.begin() + i);
217 }
218 }
219 }
220
DecodeJson(const std::string & json_desc,const std::map<std::string,AnfNodePtr> & address_node_map)221 bool SplitByJsonSchemer::DecodeJson(const std::string &json_desc,
222 const std::map<std::string, AnfNodePtr> &address_node_map) {
223 auto kernel_json = nlohmann::json::parse(json_desc);
224 std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
225 std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode];
226 if (graph_modes.size() != graph_descs.size()) {
227 MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size();
228 return false;
229 }
230
231 // recover json to anfnode.
232 split_plan_.clear();
233 for (const auto &graph_desc : graph_descs) {
234 AnfNodePtrList res_graph;
235 if (!SplitNodesDecoder().DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
236 MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
237 return false;
238 }
239 (void)split_plan_.emplace_back(std::move(res_graph));
240 }
241
242 // ops to be inlined.
243 need_inline_.clear();
244 (void)std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_),
245 [](const std::string &mode) { return mode == "basic" ? 1 : 0; });
246 return true;
247 }
248
Run()249 void SplitByJsonSchemer::Run() {
250 auto mng = func_graph_->manager();
251 if (mng == nullptr) {
252 mng = Manage(func_graph_, true);
253 func_graph_->set_manager(mng);
254 }
255 GetValidKernelNodes();
256 // call CostModel to get a split plan.
257 if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) {
258 split_plan_.clear();
259 need_inline_.clear();
260 return;
261 } else if (split_plan_.size() == 1 && !NeedInline(0)) {
262 // In this case, the CostModel decided to keep the whole graph unchanged.
263 split_plan_.clear();
264 need_inline_.clear();
265 return;
266 } else {
267 MS_LOG(DEBUG) << "CostModel split succeeded. The kernel is split to " << split_plan_.size() << " parts.";
268 }
269 MapNodeGroup();
270 GroupReturnNode();
271 GroupVirtualNodes();
272 }
273
IsValidKernelNode(const AnfNodePtr & node) const274 bool SplitByJsonSchemer::IsValidKernelNode(const AnfNodePtr &node) const {
275 if (!node->isa<CNode>()) {
276 return false;
277 }
278 if (AnfUtils::IsRealKernel(node)) {
279 return true;
280 }
281 return false;
282 }
283
GetValidKernelNodes()284 void SplitByJsonSchemer::GetValidKernelNodes() {
285 topo_all_nodes_ = TopoSort(func_graph_->get_return());
286 topo_valid_nodes_.clear();
287 (void)std::copy_if(topo_all_nodes_.begin(), topo_all_nodes_.end(), std::back_inserter(topo_valid_nodes_),
288 [this](const AnfNodePtr &node) { return IsValidKernelNode(node); });
289 }
290
MapNodeGroup()291 void SplitByJsonSchemer::MapNodeGroup() {
292 node_group_.clear();
293 for (size_t i = 0; i < split_plan_.size(); ++i) {
294 for (const auto &node : split_plan_[i]) {
295 node_group_[node] = i;
296 }
297 }
298 }
299
300 // group the return node and last MakeTuple node (if exists).
GroupReturnNode()301 void SplitByJsonSchemer::GroupReturnNode() {
302 AnfNodePtrList outputs;
303 kernel::GetFuncGraphOutputNodes(func_graph_, &outputs);
304 auto ret_node = func_graph_->get_return();
305 auto output = func_graph_->output();
306 MS_EXCEPTION_IF_NULL(output);
307
308 if (IsValidKernelNode(output)) {
309 auto group_id = node_group_[output];
310 node_group_[ret_node] = group_id;
311 (void)split_plan_[group_id].emplace_back(ret_node);
312 return;
313 }
314 // assign the make_tuple node to a new group.
315 if (common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
316 auto group_id = split_plan_.size();
317 (void)split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
318 (void)need_inline_.emplace_back(1);
319 node_group_[output] = group_id;
320 node_group_[ret_node] = group_id;
321 return;
322 }
323 }
324
325 // assign virtual node to the same group of its input.
GroupVirtualNodes()326 void SplitByJsonSchemer::GroupVirtualNodes() {
327 for (const auto &node : topo_all_nodes_) {
328 if (node_group_.count(node) != 0) {
329 continue;
330 }
331 auto cnode = node->cast<CNodePtr>();
332 if (cnode == nullptr) {
333 continue;
334 }
335 bool found = false;
336 for (const auto &input : cnode->inputs()) {
337 auto iter = node_group_.find(input);
338 if (iter != node_group_.end()) {
339 auto group_id = iter->second;
340 node_group_[node] = group_id;
341 (void)split_plan_[group_id].emplace_back(node);
342 found = true;
343 break;
344 }
345 }
346 if (!found) {
347 MS_LOG(WARNING) << cnode->fullname_with_scope() << " is ungrouped.";
348 }
349 }
350 }
351
SplitByCostModel()352 bool CostModelSplitSchemer::SplitByCostModel() {
353 // Use an address map to record the anf node address when converting to json,
354 // it will recover the original node after split.
355 std::map<std::string, AnfNodePtr> address_node_map;
356
357 // convert anf-ir to json
358 nlohmann::json json_desc;
359 DumpOption dump_option;
360 dump_option.is_before_select_kernel = false;
361 dump_option.save_ptr_address = true;
362 if (!AnfToJsonDesc(topo_valid_nodes_, dump_option, &json_desc, &address_node_map)) {
363 MS_LOG(ERROR) << "Collect json desc failed.";
364 return false;
365 }
366 // set the "node_name" for tracing split result.
367 std::string node_name = json_desc["op"];
368 func_graph_->set_attr(kAttrNodeName, MakeValue(node_name));
369 // call costmodel split function.
370 auto json_desc_str = json_desc.dump();
371 auto flags_str = GraphKernelFlags::GetInstance().DumpAllFlags();
372 MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json: " << json_desc_str
373 << ". flag: " << flags_str;
374 auto ret = python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str, flags_str);
375 if (py::isinstance<py::none>(ret)) {
376 MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
377 << json_desc_str << ". flag: " << flags_str;
378 return false;
379 }
380 std::string split_graphs_str = py::cast<std::string>(ret);
381 if (split_graphs_str.empty()) {
382 MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
383 << json_desc_str << ". flag: " << flags_str;
384 return false;
385 }
386 return SplitByJsonStr(address_node_map, split_graphs_str);
387 }
388
GetSplitSchema(const std::string & processor)389 std::shared_ptr<SplitSchemer> GraphKernelSplitterWithPy::GetSplitSchema(const std::string &processor) {
390 bool using_py_split_ =
391 (processor == kGPUDevice) && (!is_dynamic_ || !GraphKernelFlags::GetInstance().enable_dynamic_shape_fusion);
392 if (using_py_split_) {
393 MS_LOG(DEBUG) << "use py split model";
394 return std::make_shared<CostModelSplitSchemer>();
395 } else {
396 MS_LOG(DEBUG) << "use c++ split model";
397 return GraphKernelSplitter::GetSplitSchema(processor);
398 }
399 }
400 } // namespace mindspore::graphkernel
401