• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.h"
17 #include <vector>
18 #include <set>
19 #include <string>
20 #include "utils/utils.h"
21 #include "backend/session/anf_runtime_algorithm.h"
22 #include "frontend/optimizer/opt.h"
23 #include "backend/optimizer/ascend/ascend_helper.h"
24 #include "utils/trace_base.h"
25 namespace mindspore {
26 namespace opt {
27 namespace {
28 // insert tensormove for some cnode even if not a Ref cnode
29 const std::set<std::string> kNeedInsertTensorMoveOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName,
30                                                           kLambUpdateWithLROpName, kGetNextOpName};
31 
IsParameterOrValueNode(const AnfNodePtr & node)32 bool IsParameterOrValueNode(const AnfNodePtr &node) {
33   MS_EXCEPTION_IF_NULL(node);
34   auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
35   auto real_node = kernel_with_index.first;
36   MS_EXCEPTION_IF_NULL(real_node);
37   if (real_node->isa<Parameter>()) {
38     return true;
39   }
40   return real_node->isa<ValueNode>();
41 }
42 
43 // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
IsNodeOutPutUsedByOtherRealKernel(const FuncGraphPtr & graph,const AnfNodePtr & input)44 bool IsNodeOutPutUsedByOtherRealKernel(const FuncGraphPtr &graph, const AnfNodePtr &input) {
45   auto manager = graph->manager();
46   MS_EXCEPTION_IF_NULL(manager);
47   auto &node_users = manager->node_users();
48   auto iter = node_users.find(input);
49   if (iter == node_users.end()) {
50     MS_LOG(EXCEPTION) << "node has no output in manager, trace: " << trace::DumpSourceLines(input);
51   }
52   auto user_items = iter->second;
53   if (user_items.size() == 1) {
54     MS_LOG(INFO) << "This node only used once, no need to insert tensormove node.";
55     return false;
56   }
57   for (const auto &node_pair : user_items) {
58     auto node = node_pair.first;
59     MS_EXCEPTION_IF_NULL(node);
60     if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
61       MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope();
62       return true;
63     }
64   }
65   MS_LOG(INFO) << "This node used by other node, but the node is not real kernel, no need to insert tensormove node.";
66   return false;
67 }
68 }  // namespace
69 
NeedInsertTensorMove(const FuncGraphPtr & graph,const AnfNodePtr & input,const CNodePtr & cur_node) const70 bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, const AnfNodePtr &input,
71                                                      const CNodePtr &cur_node) const {
72   MS_EXCEPTION_IF_NULL(graph);
73   MS_EXCEPTION_IF_NULL(input);
74   MS_EXCEPTION_IF_NULL(cur_node);
75   if (IsPrimitiveCNode(cur_node, prim::kPrimReceive)) {
76     return false;
77   }
78   // visited nop node if exist.
79   auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input, 0, false);
80   auto real_input = kernel_with_index.first;
81   // when input is a parameter or is a value node
82   if (IsParameterOrValueNode(real_input)) {
83     return true;
84   }
85   // when input is a Ref cnode
86   if (kernel_query_->IsTbeRef(real_input)) {
87     return true;
88   }
89   // when input is some special cnodes: kLambNextMVOpName, kLambNextMVWithDecayOpName, kLambUpdateWithLROpName,
90   // kGetNextOpName
91   if (kNeedInsertTensorMoveOpSet.find(AnfAlgo::GetCNodeName(real_input)) != kNeedInsertTensorMoveOpSet.end()) {
92     return true;
93   }
94   // example1: NodeA --> Allreduce
95   //           NodeA --> other RealNode(!Allreude)
96   // example2: NodeA --> NopNode --> Allreduce
97   //           NodeA --> other RealNode(!Allreude)
98   // example3: NodeA --> NopNode --> Allreduce
99   //                             --> other RealNode(!Allreude)
100   // when input is used by others
101   if (IsNodeOutPutUsedByOtherRealKernel(graph, input)) {
102     return true;
103   }
104   if (opt::IsNopNode(real_input)) {
105     auto cnode = real_input->cast<CNodePtr>();
106     MS_EXCEPTION_IF_NULL(cnode);
107     return NeedInsertTensorMove(graph, cnode->input(1), cur_node);
108   }
109   return false;
110 }
111 
InsertTensorMove(const FuncGraphPtr & graph,const CNodePtr & hccl_node) const112 void InsertTensorMoveForHcclOp::InsertTensorMove(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
113   MS_EXCEPTION_IF_NULL(graph);
114   MS_EXCEPTION_IF_NULL(hccl_node);
115   bool need_tensor_move_async = false;
116   std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
117   for (size_t i = 1; i < hccl_node->size(); ++i) {
118     auto input = hccl_node->input(i);
119     if (NeedInsertTensorMove(graph, input, hccl_node)) {
120       auto tensor_move = CreateTensorMoveOp(graph, input);
121       if (tensor_move == nullptr) {
122         MS_LOG(EXCEPTION) << "Create tensor_move op failed.";
123       }
124       if (input->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(input)) {
125         AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), tensor_move);
126       }
127       new_inputs.push_back(tensor_move);
128       need_tensor_move_async = true;
129     } else {
130       new_inputs.push_back(input);
131     }
132   }
133 
134   if (need_tensor_move_async) {
135     CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
136     new_hccl_node->set_inputs(new_inputs);
137     auto manager = graph->manager();
138     MS_EXCEPTION_IF_NULL(manager);
139     MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
140     auto kernel_graph = graph->cast<KernelGraphPtr>();
141     if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(hccl_node)) {
142       kernel_graph->ReplaceInternalOutput(hccl_node, new_hccl_node);
143     }
144     (void)manager->Replace(hccl_node, new_hccl_node);
145     MS_LOG(DEBUG) << "end replace";
146   }
147 }
148 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const149 const AnfNodePtr InsertTensorMoveForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
150                                                     const EquivPtr &) const {
151   if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
152     return nullptr;
153   }
154   if (!AnfAlgo::IsCommunicationOp(node)) {
155     return nullptr;
156   }
157   InsertTensorMove(func_graph, node->cast<CNodePtr>());
158   return nullptr;
159 }
160 }  // namespace opt
161 }  // namespace mindspore
162