1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.h"
17 #include <vector>
18 #include <set>
19 #include <string>
20 #include "utils/utils.h"
21 #include "backend/session/anf_runtime_algorithm.h"
22 #include "frontend/optimizer/opt.h"
23 #include "backend/optimizer/ascend/ascend_helper.h"
24 #include "utils/trace_base.h"
25 namespace mindspore {
26 namespace opt {
27 namespace {
28 // insert tensormove for some cnode even if not a Ref cnode
29 const std::set<std::string> kNeedInsertTensorMoveOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName,
30 kLambUpdateWithLROpName, kGetNextOpName};
31
IsParameterOrValueNode(const AnfNodePtr & node)32 bool IsParameterOrValueNode(const AnfNodePtr &node) {
33 MS_EXCEPTION_IF_NULL(node);
34 auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
35 auto real_node = kernel_with_index.first;
36 MS_EXCEPTION_IF_NULL(real_node);
37 if (real_node->isa<Parameter>()) {
38 return true;
39 }
40 return real_node->isa<ValueNode>();
41 }
42
43 // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
IsNodeOutPutUsedByOtherRealKernel(const FuncGraphPtr & graph,const AnfNodePtr & input)44 bool IsNodeOutPutUsedByOtherRealKernel(const FuncGraphPtr &graph, const AnfNodePtr &input) {
45 auto manager = graph->manager();
46 MS_EXCEPTION_IF_NULL(manager);
47 auto &node_users = manager->node_users();
48 auto iter = node_users.find(input);
49 if (iter == node_users.end()) {
50 MS_LOG(EXCEPTION) << "node has no output in manager, trace: " << trace::DumpSourceLines(input);
51 }
52 auto user_items = iter->second;
53 if (user_items.size() == 1) {
54 MS_LOG(INFO) << "This node only used once, no need to insert tensormove node.";
55 return false;
56 }
57 for (const auto &node_pair : user_items) {
58 auto node = node_pair.first;
59 MS_EXCEPTION_IF_NULL(node);
60 if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
61 MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope();
62 return true;
63 }
64 }
65 MS_LOG(INFO) << "This node used by other node, but the node is not real kernel, no need to insert tensormove node.";
66 return false;
67 }
68 } // namespace
69
NeedInsertTensorMove(const FuncGraphPtr & graph,const AnfNodePtr & input,const CNodePtr & cur_node) const70 bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, const AnfNodePtr &input,
71 const CNodePtr &cur_node) const {
72 MS_EXCEPTION_IF_NULL(graph);
73 MS_EXCEPTION_IF_NULL(input);
74 MS_EXCEPTION_IF_NULL(cur_node);
75 if (IsPrimitiveCNode(cur_node, prim::kPrimReceive)) {
76 return false;
77 }
78 // visited nop node if exist.
79 auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input, 0, false);
80 auto real_input = kernel_with_index.first;
81 // when input is a parameter or is a value node
82 if (IsParameterOrValueNode(real_input)) {
83 return true;
84 }
85 // when input is a Ref cnode
86 if (kernel_query_->IsTbeRef(real_input)) {
87 return true;
88 }
89 // when input is some special cnodes: kLambNextMVOpName, kLambNextMVWithDecayOpName, kLambUpdateWithLROpName,
90 // kGetNextOpName
91 if (kNeedInsertTensorMoveOpSet.find(AnfAlgo::GetCNodeName(real_input)) != kNeedInsertTensorMoveOpSet.end()) {
92 return true;
93 }
94 // example1: NodeA --> Allreduce
95 // NodeA --> other RealNode(!Allreude)
96 // example2: NodeA --> NopNode --> Allreduce
97 // NodeA --> other RealNode(!Allreude)
98 // example3: NodeA --> NopNode --> Allreduce
99 // --> other RealNode(!Allreude)
100 // when input is used by others
101 if (IsNodeOutPutUsedByOtherRealKernel(graph, input)) {
102 return true;
103 }
104 if (opt::IsNopNode(real_input)) {
105 auto cnode = real_input->cast<CNodePtr>();
106 MS_EXCEPTION_IF_NULL(cnode);
107 return NeedInsertTensorMove(graph, cnode->input(1), cur_node);
108 }
109 return false;
110 }
111
InsertTensorMove(const FuncGraphPtr & graph,const CNodePtr & hccl_node) const112 void InsertTensorMoveForHcclOp::InsertTensorMove(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
113 MS_EXCEPTION_IF_NULL(graph);
114 MS_EXCEPTION_IF_NULL(hccl_node);
115 bool need_tensor_move_async = false;
116 std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
117 for (size_t i = 1; i < hccl_node->size(); ++i) {
118 auto input = hccl_node->input(i);
119 if (NeedInsertTensorMove(graph, input, hccl_node)) {
120 auto tensor_move = CreateTensorMoveOp(graph, input);
121 if (tensor_move == nullptr) {
122 MS_LOG(EXCEPTION) << "Create tensor_move op failed.";
123 }
124 if (input->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(input)) {
125 AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), tensor_move);
126 }
127 new_inputs.push_back(tensor_move);
128 need_tensor_move_async = true;
129 } else {
130 new_inputs.push_back(input);
131 }
132 }
133
134 if (need_tensor_move_async) {
135 CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
136 new_hccl_node->set_inputs(new_inputs);
137 auto manager = graph->manager();
138 MS_EXCEPTION_IF_NULL(manager);
139 MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
140 auto kernel_graph = graph->cast<KernelGraphPtr>();
141 if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(hccl_node)) {
142 kernel_graph->ReplaceInternalOutput(hccl_node, new_hccl_node);
143 }
144 (void)manager->Replace(hccl_node, new_hccl_node);
145 MS_LOG(DEBUG) << "end replace";
146 }
147 }
148
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const149 const AnfNodePtr InsertTensorMoveForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
150 const EquivPtr &) const {
151 if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
152 return nullptr;
153 }
154 if (!AnfAlgo::IsCommunicationOp(node)) {
155 return nullptr;
156 }
157 InsertTensorMove(func_graph, node->cast<CNodePtr>());
158 return nullptr;
159 }
160 } // namespace opt
161 } // namespace mindspore
162