1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/graph_kernel/add_atomic_clean.h"
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <set>
23 #include <stack>
24 #include <string>
25 #include <vector>
26 #include "base/core_ops.h"
27 #include "ir/tensor.h"
28 #include "utils/utils.h"
29 #include "utils/log_adapter.h"
30 #include "backend/kernel_compiler/kernel.h"
31 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
32 #include "backend/session/anf_runtime_algorithm.h"
33 #include "backend/session/kernel_graph.h"
34 #include "debug/anf_ir_dump.h"
35 #include "backend/kernel_compiler/common_utils.h"
36
37 namespace mindspore {
38 namespace opt {
39 namespace {
GetUniqReduceAxes(const AnfNodePtr & node,bool is_ascend=false)40 std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) {
41 if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
42 MS_LOG(EXCEPTION) << "Only process for reduce sum!";
43 }
44
45 auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex);
46 ShapeVector src_shape_vec;
47 if (is_ascend) {
48 src_shape_vec = GetDeviceShape(input);
49 } else {
50 src_shape_vec = GetShape(input);
51 }
52 auto axis_vec = GetReduceAxis(node);
53 if (axis_vec.empty()) {
54 for (size_t i = 0; i < src_shape_vec.size(); ++i) {
55 (void)axis_vec.emplace_back(i);
56 }
57 } else {
58 (void)std::transform(axis_vec.begin(), axis_vec.end(), axis_vec.begin(), [&src_shape_vec](int64_t axis) -> int64_t {
59 return axis < 0 ? axis + SizeToLong(src_shape_vec.size()) : axis;
60 });
61 }
62
63 std::set<int64_t> axis_set(axis_vec.begin(), axis_vec.end());
64 return axis_set;
65 }
66
HaveReduceInPredecessors(const AnfNodePtr & node)67 bool HaveReduceInPredecessors(const AnfNodePtr &node) {
68 std::stack<AnfNodePtr> st;
69 st.push(node);
70 while (!st.empty()) {
71 auto n = st.top();
72 st.pop();
73
74 if (n != node) {
75 if (!n->isa<CNode>()) {
76 continue;
77 }
78 if (IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
79 return true;
80 }
81 }
82
83 auto n_inputs = n->cast<CNodePtr>()->inputs();
84 (void)std::for_each(n_inputs.cbegin() + 1, n_inputs.cend(), [&st](const AnfNodePtr &n) -> void { st.push(n); });
85 }
86
87 return false;
88 }
89
CalNewIndex(int64_t old_index,int64_t reduce_index)90 inline int64_t CalNewIndex(int64_t old_index, int64_t reduce_index) {
91 return old_index - (old_index > reduce_index ? 1 : 0);
92 }
93 } // namespace
Init()94 std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
95 auto processor = kernel::GetProcessorFromContext();
96 if (processor == kernel::Processor::AICORE) {
97 return std::make_shared<AtomicAddCheckerAscend>();
98 } else if (processor == kernel::Processor::CUDA) {
99 return std::make_shared<AtomicAddCheckerGPU>();
100 }
101 return nullptr;
102 }
103
FindCandidate(const AnfNodePtr & anf_node)104 bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
105 auto node = anf_node->cast<CNodePtr>();
106 MS_EXCEPTION_IF_NULL(node);
107 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
108 auto mng_sub = sub_graph->manager();
109 if (mng_sub == nullptr) {
110 mng_sub = Manage(sub_graph, false);
111 sub_graph->set_manager(mng_sub);
112 }
113
114 // Rule: Only one ReduceSum inside sub-graph.
115 auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
116 if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
117 size_t target_cnt = 0;
118 const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
119 for (size_t i = 1; i < inputs.size(); ++i) {
120 if (IsPrimitiveCNode(inputs[i], target_type_)) {
121 atomic_add_info_.atomic_add_node = inputs[i]->cast<CNodePtr>();
122 atomic_add_info_.reduce_real_output_index = i - 1;
123 target_cnt++;
124 }
125 }
126
127 if (target_cnt != 1) {
128 return false;
129 }
130 atomic_add_info_.real_output_num = inputs.size() - 1;
131 } else if (IsPrimitiveCNode(real_return_node, target_type_)) {
132 atomic_add_info_.atomic_add_node = real_return_node->cast<CNodePtr>();
133 atomic_add_info_.real_output_num = 1;
134 } else {
135 return false;
136 }
137
138 // Rule: ReduceSum should not fuse any other ops in out direction, which means it should be in output list.
139 return (mng_sub->node_users()[atomic_add_info_.atomic_add_node].size() <= 1);
140 }
141
CanActivateAtomicAdd(const AnfNodePtr & anf_node)142 bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
143 // Rules to activate atomic add:
144 // 1. Find only one ReduceSum inside sub-graph, and it should not fuse any other ops in out direction,
145 // which mean it should be in output list.
146 // 2. The reduce axis and reduce number should meet condition:
147 // (GPU) all-reduce or reduce-x when fuse number is greater than or equal to 1024, or reduce-y.
148 // (Ascend) The first valid axis of the input data is the reduce axis or the non-reduce axis
149 // cannot make full use of multi-core.
150 // 3. No other ReduceSum as output ReduceSum's predecessors (reduce compile limitation).
151
152 // Rule 1.
153 if (!FindCandidate(anf_node)) {
154 return false;
155 }
156
157 // Rule 2.
158 if (!SuitableForAtomicAdd(atomic_add_info_.atomic_add_node)) {
159 return false;
160 }
161
162 // Rule 3.
163 return !HaveReduceInPredecessors(atomic_add_info_.atomic_add_node);
164 }
165
Check(const AnfNodePtr & node)166 bool AtomicAddChecker::Check(const AnfNodePtr &node) {
167 return (AnfAlgo::IsGraphKernel(node) && CanActivateAtomicAdd(node));
168 }
169
SuitableForAtomicAdd(const AnfNodePtr & node)170 bool AtomicAddCheckerGPU::SuitableForAtomicAdd(const AnfNodePtr &node) {
171 auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex);
172 auto src_shape_vec = GetShape(input);
173 std::set<int64_t> axis_set = GetUniqReduceAxes(node);
174
175 // For reduce whose last dim is reduced (including all-reduce),
176 // it is suitable for atomic add only the reduce num is greater than or equal to 1024.
177 if (axis_set.count(src_shape_vec.size() - 1) != 0) {
178 size_t reduce_size = std::accumulate(
179 axis_set.begin(), axis_set.end(), LongToSize(1),
180 [&src_shape_vec](size_t size, int64_t axis) { return size * LongToSize(src_shape_vec[LongToSize(axis)]); });
181 return reduce_size >= 1024;
182 }
183
184 // For reduce whose last dim is not reduced, always true.
185 return true;
186 }
187
SuitableForAtomicAdd(const AnfNodePtr & node)188 bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) {
189 auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex);
190
191 // Atomic addition is enabled only when the data type is fp32
192 auto type = AnfAlgo::GetOutputDeviceDataType(input, 0);
193 if (type != kNumberTypeFloat32) {
194 return false;
195 }
196
197 // If the first valid axis of the input data is the reduce axis, enable atomic addition
198 auto src_shape_vec = GetDeviceShape(input);
199 std::set<int64_t> reduce_axis_set = GetUniqReduceAxes(node, true);
200 auto start_with_reduce = false;
201 for (size_t i = 0; i < src_shape_vec.size(); ++i) {
202 auto dim = src_shape_vec[i];
203 if (dim != 1) {
204 if (reduce_axis_set.count(i)) {
205 start_with_reduce = true;
206 }
207 break;
208 }
209 }
210 if (start_with_reduce) {
211 return true;
212 }
213
214 // If the non-reduce axis cannot make full use of multi-core, enable atomic addition
215 constexpr auto processor_core_num = 32LL;
216 auto start_non_reduce_dim = 1LL;
217 for (size_t i = 0; i < src_shape_vec.size(); ++i) {
218 auto dim = src_shape_vec[i];
219 if (reduce_axis_set.count(i)) {
220 break;
221 }
222 start_non_reduce_dim = start_non_reduce_dim * dim;
223 }
224 if (start_non_reduce_dim < processor_core_num) {
225 return true;
226 }
227
228 return false;
229 }
230
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const AnfNodePtr & new_input,bool bypass)231 void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
232 bool bypass) {
233 // Change kernel build info.
234 auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
235 MS_EXCEPTION_IF_NULL(kernel_info);
236 const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
237 auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
238 auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
239 auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
240 auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
241 auto origin_processor = origin_kernel_build_info->processor();
242
243 std::vector<std::string> &new_inputs_format = origin_inputs_format;
244 std::vector<TypeId> &new_inputs_type = origin_inputs_type;
245 std::vector<std::string> new_outputs_format;
246 std::vector<TypeId> new_outputs_type;
247 for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
248 if (bypass && real_output_num_ > 1 && i == reduce_real_output_index_) {
249 continue;
250 }
251 new_outputs_format.push_back(origin_outputs_format[i]);
252 new_outputs_type.push_back(origin_outputs_type[i]);
253 }
254
255 auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
256 new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
257 new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
258
259 kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
260 new_info_builder.SetInputsFormat(new_inputs_format);
261 new_info_builder.SetInputsDeviceType(new_inputs_type);
262 new_info_builder.SetOutputsFormat(new_outputs_format);
263 new_info_builder.SetOutputsDeviceType(new_outputs_type);
264 new_info_builder.SetProcessor(origin_processor);
265 new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
266 new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
267 auto new_selected_info = new_info_builder.Build();
268 AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
269 }
270
CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr & sub_graph,const AnfNodePtr & new_parameter)271 void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
272 const AnfNodePtr &new_parameter) {
273 // add inplaceassign
274 AnfNodePtr out_node;
275 bool fake_out = false;
276 size_t replace_index = 0;
277 auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex);
278 if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) {
279 const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
280 for (size_t i = 1; i < outs.size(); ++i) {
281 if (i != reduce_real_output_index_ + 1) {
282 out_node = outs[i];
283 replace_index = i;
284 break;
285 }
286 }
287 } else {
288 out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
289 fake_out = true;
290 }
291
292 auto inplace_assign_node =
293 CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
294 {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
295 SetNodeAttrSafely("fake_output", MakeValue(fake_out), inplace_assign_node);
296
297 CNodePtr new_out_node;
298 if (real_output_num_ > 2) {
299 std::vector<AnfNodePtr> output_args = {NewValueNode(prim::kPrimMakeTuple)};
300 const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
301 for (size_t i = 1; i < outs.size(); ++i) {
302 if (i == reduce_real_output_index_ + 1) {
303 continue;
304 } else if (i == replace_index) {
305 output_args.push_back(inplace_assign_node);
306 } else {
307 output_args.push_back(outs[i]);
308 }
309 }
310 // Set output for AnfGraph
311 new_out_node = sub_graph->NewCNode(output_args);
312 } else {
313 new_out_node = inplace_assign_node;
314 }
315 sub_graph->set_output(new_out_node);
316 }
317
CorrectAbstract(const AnfNodePtr & composite_node) const318 void AtomicCleanInsertter::CorrectAbstract(const AnfNodePtr &composite_node) const {
319 // If there is only one output(ReduceSum), it should be a fake output with the same abstract with origin output.
320 if (real_output_num_ <= 1) {
321 return;
322 }
323
324 // Change abstract.
325 auto origin_out_spec = composite_node->abstract()->cast<abstract::AbstractTuplePtr>();
326 MS_EXCEPTION_IF_NULL(origin_out_spec);
327 const auto &origin_out_specs = origin_out_spec->elements();
328 AbstractBasePtrList new_out_specs;
329 for (size_t i = 0; i < origin_out_specs.size(); ++i) {
330 if (i != reduce_real_output_index_) {
331 new_out_specs.push_back(origin_out_specs[i]);
332 }
333 }
334 composite_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_out_specs));
335 }
336
ProcessOriginCNode(const AnfNodePtr & composite_node,const AnfNodePtr & new_input)337 void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
338 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
339 auto mng_sub = sub_graph->manager();
340 if (mng_sub == nullptr) {
341 mng_sub = Manage(sub_graph, false);
342 sub_graph->set_manager(mng_sub);
343 }
344
345 // Add atomic attribute to reducesum node.
346 SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_node_);
347
348 // add input
349 auto inputs = composite_node->cast<CNodePtr>()->inputs();
350 inputs.push_back(new_input);
351 composite_node->cast<CNodePtr>()->set_inputs(inputs);
352
353 // add parameter
354 auto parameter = sub_graph->add_parameter();
355 parameter->set_abstract(new_input->abstract());
356 parameter->set_kernel_info(new_input->kernel_info_ptr());
357
358 CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
359
360 CorrectAbstract(composite_node);
361 CorrectKernelBuildInfo(composite_node, new_input);
362
363 auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
364 auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
365 sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
366 MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
367 }
368
AddDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & clean_node,const AnfNodePtr & composite_node,const AnfNodePtr & user_node,int index) const369 void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
370 const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) const {
371 // Create depend node to hold execution order.
372 AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
373 auto depend_cnode = main_graph->NewCNode(d_inputs);
374 depend_cnode->set_abstract(clean_node->abstract());
375 main_graph->AddNode(depend_cnode);
376
377 auto user_cnode = user_node->cast<CNodePtr>();
378 MS_EXCEPTION_IF_NULL(user_cnode);
379 user_cnode->set_input(IntToSize(index), depend_cnode);
380 }
381
InsertUpdateState(const KernelGraphPtr & main_graph,const CNodePtr & composite_node) const382 CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph,
383 const CNodePtr &composite_node) const {
384 // Insert update_state_node, need mount a monad node.
385 auto u = NewValueNode(kUMonad);
386 u->set_abstract(kUMonad->ToAbstract());
387 AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, composite_node};
388 auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
389 update_state_cnode->set_abstract(kUMonad->ToAbstract());
390 main_graph->AddNode(update_state_cnode);
391 return update_state_cnode;
392 }
393
CreateAtomicCleanCompositeNode(const KernelGraphPtr & main_graph,TypeId dst_type)394 CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
395 std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
396
397 if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
398 MS_LOG(EXCEPTION) << "Atomic add not support data type " << dst_type;
399 }
400
401 // Create zero value which will be broadcast to target shape.
402 auto format = GetFormat(atomic_add_node_);
403 auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
404 ValueNodePtr value_node;
405 if (dtype == kNumberTypeFloat32) {
406 value_node = CreateScalarTensorValueNode<float>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
407 static_cast<float>(0), sizeof(float));
408 } else {
409 value_node = CreateScalarTensorValueNode<double>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
410 static_cast<double>(0), sizeof(double));
411 }
412
413 // Create composite op's sub-graph.
414 auto new_sub_graph = std::make_shared<FuncGraph>();
415
416 AnfNodePtr broadcast_input_node;
417 if (dst_type == kNumberTypeFloat16) {
418 AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
419 auto cast_node_inner =
420 CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
421 SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
422 broadcast_input_node = cast_node_inner;
423 } else {
424 broadcast_input_node = value_node;
425 }
426
427 // Create broadcast basic op.
428 auto dst_shape_vec = GetShape(atomic_add_node_);
429 AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
430 auto broadcast_to_node_inner = CreateCNode(
431 atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)});
432 SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_node_)), broadcast_to_node_inner);
433
434 // Makeup sub-graph.
435 new_sub_graph->set_output(broadcast_to_node_inner);
436 auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
437 broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
438 SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
439 auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
440 new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
441 new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));
442
443 return broadcast_to_composite_node;
444 }
445
FindOriginCNodeUsers(const KernelGraphPtr & main_graph,const AnfNodePtr & composite_node,const FuncGraphManagerPtr & mng,bool correct_index) const446 std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
447 const AnfNodePtr &composite_node,
448 const FuncGraphManagerPtr &mng,
449 bool correct_index) const {
450 std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes;
451 if (real_output_num_ <= 1) {
452 auto users = mng->node_users()[composite_node];
453 (void)std::transform(users.cbegin(), users.cend(), std::back_inserter(reduce_user_nodes),
454 [](const std::pair<AnfNodePtr, int> &pair) { return pair; });
455 } else {
456 std::vector<std::pair<AnfNodePtr, int> > getitem_user_nodes;
457 auto users = mng->node_users()[composite_node];
458 for (const auto &node_index : users) {
459 const auto &user_node = node_index.first;
460 if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
461 continue;
462 }
463 auto get_item_cnode = user_node->cast<CNodePtr>();
464 auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
465 MS_EXCEPTION_IF_NULL(value_input);
466 auto value_node = value_input->cast<ValueNodePtr>();
467 MS_EXCEPTION_IF_NULL(value_node);
468 auto item_idx = GetValue<int64_t>(value_node->value());
469 if (item_idx == static_cast<int64_t>(reduce_real_output_index_)) {
470 getitem_user_nodes.push_back(node_index);
471 } else if (correct_index) {
472 if (real_output_num_ > 2) {
473 // Recorrect other getitem index.
474 int64_t new_item_idx = CalNewIndex(item_idx, SizeToLong(reduce_real_output_index_));
475 AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimTupleGetItem), composite_node,
476 NewValueNode(new_item_idx)};
477 auto new_out = main_graph->NewCNode(new_inputs);
478 new_out->set_abstract(get_item_cnode->abstract());
479 for (const auto &[user, index] : mng->node_users()[get_item_cnode]) {
480 auto user_cnode = user->cast<CNodePtr>();
481 MS_EXCEPTION_IF_NULL(user_cnode);
482 user_cnode->set_input(IntToSize(index), new_out);
483 }
484 } else {
485 for (const auto &[user, index] : mng->node_users()[node_index.first]) {
486 auto user_cnode = user->cast<CNodePtr>();
487 MS_EXCEPTION_IF_NULL(user_cnode);
488 user_cnode->set_input(IntToSize(index), composite_node);
489 }
490 }
491 }
492 }
493 for (auto &pair : getitem_user_nodes) {
494 // Directory to find real user.
495 auto real_users = mng->node_users()[pair.first];
496 (void)reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end());
497 }
498 }
499
500 return reduce_user_nodes;
501 }
502
ProcessOriginCNodeUser(const KernelGraphPtr & main_graph,const AnfNodePtr & composite_node,const AnfNodePtr & broadcast_to_node,const AnfNodePtr & update_state_node,const FuncGraphManagerPtr & mng)503 void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
504 const AnfNodePtr &broadcast_to_node,
505 const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
506 // 1. find users, change getitem index if needed.
507 std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
508 FindOriginCNodeUsers(main_graph, composite_node, mng, true);
509 for (const auto &[user_node, index] : reduce_user_nodes) {
510 // 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect
511 // update_state_node, broadcat_node and load_node to keep order.
512 AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node};
513 auto load_node = main_graph->NewCNode(load_inputs);
514 load_node->set_abstract(broadcast_to_node->abstract());
515 main_graph->AddNode(load_node);
516 auto user_cnode = user_node->cast<CNodePtr>();
517 MS_EXCEPTION_IF_NULL(user_cnode);
518 user_cnode->set_input(IntToSize(index), load_node);
519 }
520 }
521
UpdateAtomicAddInfo(const AtomicAddInfo & atomic_add_info)522 void AtomicCleanInsertter::UpdateAtomicAddInfo(const AtomicAddInfo &atomic_add_info) {
523 atomic_add_node_ = atomic_add_info.atomic_add_node;
524 reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
525 real_output_num_ = atomic_add_info.real_output_num;
526 }
527
InsertAtomicClean(const KernelGraphPtr & main_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & mng)528 void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
529 const FuncGraphManagerPtr &mng) {
530 auto origin_composite_node = anf_node->cast<CNodePtr>();
531 MS_EXCEPTION_IF_NULL(origin_composite_node);
532
533 // Create broadcst node.
534 auto out_type = GetType(atomic_add_node_)->cast<TensorTypePtr>();
535 MS_EXCEPTION_IF_NULL(out_type);
536 auto broadcast_to_node = CreateAtomicCleanCompositeNode(main_graph, out_type->element()->type_id());
537
538 // Insert extra input(broadcast node output) to composite node, and make Reducesum inplaceassign to it.
539 // Note: if it's single output, this will increase total memory because of a fake out.
540 ProcessOriginCNode(origin_composite_node, broadcast_to_node);
541
542 // Insert update_state_node to keep execution order.
543 auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
544
545 // Replace origin ReduceSum's user with atomic clean output
546 ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng);
547 MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
548 << ", clean node: " << broadcast_to_node->fullname_with_scope();
549 }
550
IsExistStructuralObstacle(const KernelGraphPtr & main_graph,const AnfNodePtr & node,const FuncGraphManagerPtr & mng)551 bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
552 const FuncGraphManagerPtr &mng) {
553 auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
554 // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
555 // node and user node. If reduce is Depend node, the origin node may be wrong!
556 return std::all_of(
557 reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
558 auto &user = user_info.first;
559 if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) &&
560 !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
561 return false;
562 } else {
563 return true;
564 }
565 });
566 }
567
Run(const FuncGraphPtr & func_graph)568 bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
569 auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
570 MS_EXCEPTION_IF_NULL(kernel_graph);
571 auto mng = kernel_graph->manager();
572 if (mng == nullptr) {
573 mng = Manage(kernel_graph, true);
574 kernel_graph->set_manager(mng);
575 }
576
577 bool changed = false;
578 std::shared_ptr<AtomicAddChecker> atomic_add_checker = AtomicAddChecker::Init();
579 if (atomic_add_checker == nullptr) {
580 return changed;
581 }
582
583 auto topo_nodes = TopoSort(kernel_graph->get_return());
584 for (const auto &node : topo_nodes) {
585 if (!atomic_add_checker->Check(node) || !IsExistStructuralObstacle(kernel_graph, node, mng)) {
586 continue;
587 }
588 changed = true;
589 auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
590 UpdateAtomicAddInfo(atomic_add_info);
591 InsertAtomicClean(kernel_graph, node, mng);
592 }
593
594 if (changed) {
595 mng->RemoveRoots();
596 mng->KeepRoots({func_graph});
597 }
598
599 return changed;
600 }
601 } // namespace opt
602 } // namespace mindspore
603