• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/add_atomic_clean.h"
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <set>
23 #include <stack>
24 #include <string>
25 #include <vector>
26 #include "base/core_ops.h"
27 #include "ir/tensor.h"
28 #include "utils/utils.h"
29 #include "utils/log_adapter.h"
30 #include "backend/kernel_compiler/kernel.h"
31 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
32 #include "backend/session/anf_runtime_algorithm.h"
33 #include "backend/session/kernel_graph.h"
34 #include "debug/anf_ir_dump.h"
35 #include "backend/kernel_compiler/common_utils.h"
36 
37 namespace mindspore {
38 namespace opt {
39 namespace {
GetUniqReduceAxes(const AnfNodePtr & node,bool is_ascend=false)40 std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) {
41   if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
42     MS_LOG(EXCEPTION) << "Only process for reduce sum!";
43   }
44 
45   auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex);
46   ShapeVector src_shape_vec;
47   if (is_ascend) {
48     src_shape_vec = GetDeviceShape(input);
49   } else {
50     src_shape_vec = GetShape(input);
51   }
52   auto axis_vec = GetReduceAxis(node);
53   if (axis_vec.empty()) {
54     for (size_t i = 0; i < src_shape_vec.size(); ++i) {
55       (void)axis_vec.emplace_back(i);
56     }
57   } else {
58     (void)std::transform(axis_vec.begin(), axis_vec.end(), axis_vec.begin(), [&src_shape_vec](int64_t axis) -> int64_t {
59       return axis < 0 ? axis + SizeToLong(src_shape_vec.size()) : axis;
60     });
61   }
62 
63   std::set<int64_t> axis_set(axis_vec.begin(), axis_vec.end());
64   return axis_set;
65 }
66 
HaveReduceInPredecessors(const AnfNodePtr & node)67 bool HaveReduceInPredecessors(const AnfNodePtr &node) {
68   std::stack<AnfNodePtr> st;
69   st.push(node);
70   while (!st.empty()) {
71     auto n = st.top();
72     st.pop();
73 
74     if (n != node) {
75       if (!n->isa<CNode>()) {
76         continue;
77       }
78       if (IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
79         return true;
80       }
81     }
82 
83     auto n_inputs = n->cast<CNodePtr>()->inputs();
84     (void)std::for_each(n_inputs.cbegin() + 1, n_inputs.cend(), [&st](const AnfNodePtr &n) -> void { st.push(n); });
85   }
86 
87   return false;
88 }
89 
CalNewIndex(int64_t old_index,int64_t reduce_index)90 inline int64_t CalNewIndex(int64_t old_index, int64_t reduce_index) {
91   return old_index - (old_index > reduce_index ? 1 : 0);
92 }
93 }  // namespace
Init()94 std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
95   auto processor = kernel::GetProcessorFromContext();
96   if (processor == kernel::Processor::AICORE) {
97     return std::make_shared<AtomicAddCheckerAscend>();
98   } else if (processor == kernel::Processor::CUDA) {
99     return std::make_shared<AtomicAddCheckerGPU>();
100   }
101   return nullptr;
102 }
103 
FindCandidate(const AnfNodePtr & anf_node)104 bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
105   auto node = anf_node->cast<CNodePtr>();
106   MS_EXCEPTION_IF_NULL(node);
107   auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
108   auto mng_sub = sub_graph->manager();
109   if (mng_sub == nullptr) {
110     mng_sub = Manage(sub_graph, false);
111     sub_graph->set_manager(mng_sub);
112   }
113 
114   // Rule: Only one ReduceSum inside sub-graph.
115   auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
116   if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
117     size_t target_cnt = 0;
118     const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
119     for (size_t i = 1; i < inputs.size(); ++i) {
120       if (IsPrimitiveCNode(inputs[i], target_type_)) {
121         atomic_add_info_.atomic_add_node = inputs[i]->cast<CNodePtr>();
122         atomic_add_info_.reduce_real_output_index = i - 1;
123         target_cnt++;
124       }
125     }
126 
127     if (target_cnt != 1) {
128       return false;
129     }
130     atomic_add_info_.real_output_num = inputs.size() - 1;
131   } else if (IsPrimitiveCNode(real_return_node, target_type_)) {
132     atomic_add_info_.atomic_add_node = real_return_node->cast<CNodePtr>();
133     atomic_add_info_.real_output_num = 1;
134   } else {
135     return false;
136   }
137 
138   // Rule: ReduceSum should not fuse any other ops in out direction, which means it should be in output list.
139   return (mng_sub->node_users()[atomic_add_info_.atomic_add_node].size() <= 1);
140 }
141 
CanActivateAtomicAdd(const AnfNodePtr & anf_node)142 bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
143   // Rules to activate atomic add:
144   // 1. Find only one ReduceSum inside sub-graph, and it should not fuse any other ops in out direction,
145   //    which mean it should be in output list.
146   // 2. The reduce axis and reduce number should meet condition:
147   //    (GPU) all-reduce or reduce-x when fuse number is greater than or equal to 1024, or reduce-y.
148   //    (Ascend) The first valid axis of the input data is the reduce axis or the non-reduce axis
149   //    cannot make full use of multi-core.
150   // 3. No other ReduceSum as output ReduceSum's predecessors (reduce compile limitation).
151 
152   // Rule 1.
153   if (!FindCandidate(anf_node)) {
154     return false;
155   }
156 
157   // Rule 2.
158   if (!SuitableForAtomicAdd(atomic_add_info_.atomic_add_node)) {
159     return false;
160   }
161 
162   // Rule 3.
163   return !HaveReduceInPredecessors(atomic_add_info_.atomic_add_node);
164 }
165 
Check(const AnfNodePtr & node)166 bool AtomicAddChecker::Check(const AnfNodePtr &node) {
167   return (AnfAlgo::IsGraphKernel(node) && CanActivateAtomicAdd(node));
168 }
169 
SuitableForAtomicAdd(const AnfNodePtr & node)170 bool AtomicAddCheckerGPU::SuitableForAtomicAdd(const AnfNodePtr &node) {
171   auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex);
172   auto src_shape_vec = GetShape(input);
173   std::set<int64_t> axis_set = GetUniqReduceAxes(node);
174 
175   // For reduce whose last dim is reduced (including all-reduce),
176   // it is suitable for atomic add only the reduce num is greater than or equal to 1024.
177   if (axis_set.count(src_shape_vec.size() - 1) != 0) {
178     size_t reduce_size = std::accumulate(
179       axis_set.begin(), axis_set.end(), LongToSize(1),
180       [&src_shape_vec](size_t size, int64_t axis) { return size * LongToSize(src_shape_vec[LongToSize(axis)]); });
181     return reduce_size >= 1024;
182   }
183 
184   // For reduce whose last dim is not reduced, always true.
185   return true;
186 }
187 
SuitableForAtomicAdd(const AnfNodePtr & node)188 bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) {
189   auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex);
190 
191   // Atomic addition is enabled only when the data type is fp32
192   auto type = AnfAlgo::GetOutputDeviceDataType(input, 0);
193   if (type != kNumberTypeFloat32) {
194     return false;
195   }
196 
197   // If the first valid axis of the input data is the reduce axis, enable atomic addition
198   auto src_shape_vec = GetDeviceShape(input);
199   std::set<int64_t> reduce_axis_set = GetUniqReduceAxes(node, true);
200   auto start_with_reduce = false;
201   for (size_t i = 0; i < src_shape_vec.size(); ++i) {
202     auto dim = src_shape_vec[i];
203     if (dim != 1) {
204       if (reduce_axis_set.count(i)) {
205         start_with_reduce = true;
206       }
207       break;
208     }
209   }
210   if (start_with_reduce) {
211     return true;
212   }
213 
214   // If the non-reduce axis cannot make full use of multi-core, enable atomic addition
215   constexpr auto processor_core_num = 32LL;
216   auto start_non_reduce_dim = 1LL;
217   for (size_t i = 0; i < src_shape_vec.size(); ++i) {
218     auto dim = src_shape_vec[i];
219     if (reduce_axis_set.count(i)) {
220       break;
221     }
222     start_non_reduce_dim = start_non_reduce_dim * dim;
223   }
224   if (start_non_reduce_dim < processor_core_num) {
225     return true;
226   }
227 
228   return false;
229 }
230 
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const AnfNodePtr & new_input,bool bypass)231 void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
232                                                   bool bypass) {
233   // Change kernel build info.
234   auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
235   MS_EXCEPTION_IF_NULL(kernel_info);
236   const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
237   auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
238   auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
239   auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
240   auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
241   auto origin_processor = origin_kernel_build_info->processor();
242 
243   std::vector<std::string> &new_inputs_format = origin_inputs_format;
244   std::vector<TypeId> &new_inputs_type = origin_inputs_type;
245   std::vector<std::string> new_outputs_format;
246   std::vector<TypeId> new_outputs_type;
247   for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
248     if (bypass && real_output_num_ > 1 && i == reduce_real_output_index_) {
249       continue;
250     }
251     new_outputs_format.push_back(origin_outputs_format[i]);
252     new_outputs_type.push_back(origin_outputs_type[i]);
253   }
254 
255   auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
256   new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
257   new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
258 
259   kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
260   new_info_builder.SetInputsFormat(new_inputs_format);
261   new_info_builder.SetInputsDeviceType(new_inputs_type);
262   new_info_builder.SetOutputsFormat(new_outputs_format);
263   new_info_builder.SetOutputsDeviceType(new_outputs_type);
264   new_info_builder.SetProcessor(origin_processor);
265   new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
266   new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
267   auto new_selected_info = new_info_builder.Build();
268   AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
269 }
270 
CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr & sub_graph,const AnfNodePtr & new_parameter)271 void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
272                                                                    const AnfNodePtr &new_parameter) {
273   // add inplaceassign
274   AnfNodePtr out_node;
275   bool fake_out = false;
276   size_t replace_index = 0;
277   auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex);
278   if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) {
279     const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
280     for (size_t i = 1; i < outs.size(); ++i) {
281       if (i != reduce_real_output_index_ + 1) {
282         out_node = outs[i];
283         replace_index = i;
284         break;
285       }
286     }
287   } else {
288     out_node = atomic_add_node_;  // Use result data itself, and set attr "fake_out" true.
289     fake_out = true;
290   }
291 
292   auto inplace_assign_node =
293     CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
294                 {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
295   SetNodeAttrSafely("fake_output", MakeValue(fake_out), inplace_assign_node);
296 
297   CNodePtr new_out_node;
298   if (real_output_num_ > 2) {
299     std::vector<AnfNodePtr> output_args = {NewValueNode(prim::kPrimMakeTuple)};
300     const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
301     for (size_t i = 1; i < outs.size(); ++i) {
302       if (i == reduce_real_output_index_ + 1) {
303         continue;
304       } else if (i == replace_index) {
305         output_args.push_back(inplace_assign_node);
306       } else {
307         output_args.push_back(outs[i]);
308       }
309     }
310     // Set output for AnfGraph
311     new_out_node = sub_graph->NewCNode(output_args);
312   } else {
313     new_out_node = inplace_assign_node;
314   }
315   sub_graph->set_output(new_out_node);
316 }
317 
CorrectAbstract(const AnfNodePtr & composite_node) const318 void AtomicCleanInsertter::CorrectAbstract(const AnfNodePtr &composite_node) const {
319   // If there is only one output(ReduceSum), it should be a fake output with the same abstract with origin output.
320   if (real_output_num_ <= 1) {
321     return;
322   }
323 
324   // Change abstract.
325   auto origin_out_spec = composite_node->abstract()->cast<abstract::AbstractTuplePtr>();
326   MS_EXCEPTION_IF_NULL(origin_out_spec);
327   const auto &origin_out_specs = origin_out_spec->elements();
328   AbstractBasePtrList new_out_specs;
329   for (size_t i = 0; i < origin_out_specs.size(); ++i) {
330     if (i != reduce_real_output_index_) {
331       new_out_specs.push_back(origin_out_specs[i]);
332     }
333   }
334   composite_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_out_specs));
335 }
336 
ProcessOriginCNode(const AnfNodePtr & composite_node,const AnfNodePtr & new_input)337 void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
338   auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
339   auto mng_sub = sub_graph->manager();
340   if (mng_sub == nullptr) {
341     mng_sub = Manage(sub_graph, false);
342     sub_graph->set_manager(mng_sub);
343   }
344 
345   // Add atomic attribute to reducesum node.
346   SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_node_);
347 
348   // add input
349   auto inputs = composite_node->cast<CNodePtr>()->inputs();
350   inputs.push_back(new_input);
351   composite_node->cast<CNodePtr>()->set_inputs(inputs);
352 
353   // add parameter
354   auto parameter = sub_graph->add_parameter();
355   parameter->set_abstract(new_input->abstract());
356   parameter->set_kernel_info(new_input->kernel_info_ptr());
357 
358   CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
359 
360   CorrectAbstract(composite_node);
361   CorrectKernelBuildInfo(composite_node, new_input);
362 
363   auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
364   auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
365   sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
366   MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
367 }
368 
AddDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & clean_node,const AnfNodePtr & composite_node,const AnfNodePtr & user_node,int index) const369 void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
370                                      const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) const {
371   // Create depend node to hold execution order.
372   AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
373   auto depend_cnode = main_graph->NewCNode(d_inputs);
374   depend_cnode->set_abstract(clean_node->abstract());
375   main_graph->AddNode(depend_cnode);
376 
377   auto user_cnode = user_node->cast<CNodePtr>();
378   MS_EXCEPTION_IF_NULL(user_cnode);
379   user_cnode->set_input(IntToSize(index), depend_cnode);
380 }
381 
InsertUpdateState(const KernelGraphPtr & main_graph,const CNodePtr & composite_node) const382 CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph,
383                                                  const CNodePtr &composite_node) const {
384   // Insert update_state_node, need mount a monad node.
385   auto u = NewValueNode(kUMonad);
386   u->set_abstract(kUMonad->ToAbstract());
387   AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, composite_node};
388   auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
389   update_state_cnode->set_abstract(kUMonad->ToAbstract());
390   main_graph->AddNode(update_state_cnode);
391   return update_state_cnode;
392 }
393 
CreateAtomicCleanCompositeNode(const KernelGraphPtr & main_graph,TypeId dst_type)394 CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
395   std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
396 
397   if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
398     MS_LOG(EXCEPTION) << "Atomic add not support data type " << dst_type;
399   }
400 
401   // Create zero value which will be broadcast to target shape.
402   auto format = GetFormat(atomic_add_node_);
403   auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
404   ValueNodePtr value_node;
405   if (dtype == kNumberTypeFloat32) {
406     value_node = CreateScalarTensorValueNode<float>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
407                                                     static_cast<float>(0), sizeof(float));
408   } else {
409     value_node = CreateScalarTensorValueNode<double>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
410                                                      static_cast<double>(0), sizeof(double));
411   }
412 
413   // Create composite op's sub-graph.
414   auto new_sub_graph = std::make_shared<FuncGraph>();
415 
416   AnfNodePtr broadcast_input_node;
417   if (dst_type == kNumberTypeFloat16) {
418     AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
419     auto cast_node_inner =
420       CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
421     SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
422     broadcast_input_node = cast_node_inner;
423   } else {
424     broadcast_input_node = value_node;
425   }
426 
427   // Create broadcast basic op.
428   auto dst_shape_vec = GetShape(atomic_add_node_);
429   AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
430   auto broadcast_to_node_inner = CreateCNode(
431     atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)});
432   SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_node_)), broadcast_to_node_inner);
433 
434   // Makeup sub-graph.
435   new_sub_graph->set_output(broadcast_to_node_inner);
436   auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
437   broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
438   SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
439   auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
440   new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
441   new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));
442 
443   return broadcast_to_composite_node;
444 }
445 
FindOriginCNodeUsers(const KernelGraphPtr & main_graph,const AnfNodePtr & composite_node,const FuncGraphManagerPtr & mng,bool correct_index) const446 std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
447                                                                                     const AnfNodePtr &composite_node,
448                                                                                     const FuncGraphManagerPtr &mng,
449                                                                                     bool correct_index) const {
450   std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes;
451   if (real_output_num_ <= 1) {
452     auto users = mng->node_users()[composite_node];
453     (void)std::transform(users.cbegin(), users.cend(), std::back_inserter(reduce_user_nodes),
454                          [](const std::pair<AnfNodePtr, int> &pair) { return pair; });
455   } else {
456     std::vector<std::pair<AnfNodePtr, int> > getitem_user_nodes;
457     auto users = mng->node_users()[composite_node];
458     for (const auto &node_index : users) {
459       const auto &user_node = node_index.first;
460       if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
461         continue;
462       }
463       auto get_item_cnode = user_node->cast<CNodePtr>();
464       auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
465       MS_EXCEPTION_IF_NULL(value_input);
466       auto value_node = value_input->cast<ValueNodePtr>();
467       MS_EXCEPTION_IF_NULL(value_node);
468       auto item_idx = GetValue<int64_t>(value_node->value());
469       if (item_idx == static_cast<int64_t>(reduce_real_output_index_)) {
470         getitem_user_nodes.push_back(node_index);
471       } else if (correct_index) {
472         if (real_output_num_ > 2) {
473           // Recorrect other getitem index.
474           int64_t new_item_idx = CalNewIndex(item_idx, SizeToLong(reduce_real_output_index_));
475           AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimTupleGetItem), composite_node,
476                                        NewValueNode(new_item_idx)};
477           auto new_out = main_graph->NewCNode(new_inputs);
478           new_out->set_abstract(get_item_cnode->abstract());
479           for (const auto &[user, index] : mng->node_users()[get_item_cnode]) {
480             auto user_cnode = user->cast<CNodePtr>();
481             MS_EXCEPTION_IF_NULL(user_cnode);
482             user_cnode->set_input(IntToSize(index), new_out);
483           }
484         } else {
485           for (const auto &[user, index] : mng->node_users()[node_index.first]) {
486             auto user_cnode = user->cast<CNodePtr>();
487             MS_EXCEPTION_IF_NULL(user_cnode);
488             user_cnode->set_input(IntToSize(index), composite_node);
489           }
490         }
491       }
492     }
493     for (auto &pair : getitem_user_nodes) {
494       // Directory to find real user.
495       auto real_users = mng->node_users()[pair.first];
496       (void)reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end());
497     }
498   }
499 
500   return reduce_user_nodes;
501 }
502 
ProcessOriginCNodeUser(const KernelGraphPtr & main_graph,const AnfNodePtr & composite_node,const AnfNodePtr & broadcast_to_node,const AnfNodePtr & update_state_node,const FuncGraphManagerPtr & mng)503 void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
504                                                   const AnfNodePtr &broadcast_to_node,
505                                                   const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
506   // 1. find users, change getitem index if needed.
507   std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
508     FindOriginCNodeUsers(main_graph, composite_node, mng, true);
509   for (const auto &[user_node, index] : reduce_user_nodes) {
510     // 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect
511     // update_state_node, broadcat_node and load_node to keep order.
512     AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node};
513     auto load_node = main_graph->NewCNode(load_inputs);
514     load_node->set_abstract(broadcast_to_node->abstract());
515     main_graph->AddNode(load_node);
516     auto user_cnode = user_node->cast<CNodePtr>();
517     MS_EXCEPTION_IF_NULL(user_cnode);
518     user_cnode->set_input(IntToSize(index), load_node);
519   }
520 }
521 
UpdateAtomicAddInfo(const AtomicAddInfo & atomic_add_info)522 void AtomicCleanInsertter::UpdateAtomicAddInfo(const AtomicAddInfo &atomic_add_info) {
523   atomic_add_node_ = atomic_add_info.atomic_add_node;
524   reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
525   real_output_num_ = atomic_add_info.real_output_num;
526 }
527 
InsertAtomicClean(const KernelGraphPtr & main_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & mng)528 void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
529                                              const FuncGraphManagerPtr &mng) {
530   auto origin_composite_node = anf_node->cast<CNodePtr>();
531   MS_EXCEPTION_IF_NULL(origin_composite_node);
532 
533   // Create broadcst node.
534   auto out_type = GetType(atomic_add_node_)->cast<TensorTypePtr>();
535   MS_EXCEPTION_IF_NULL(out_type);
536   auto broadcast_to_node = CreateAtomicCleanCompositeNode(main_graph, out_type->element()->type_id());
537 
538   // Insert extra input(broadcast node output) to composite node, and make Reducesum inplaceassign to it.
539   // Note: if it's single output, this will increase total memory because of a fake out.
540   ProcessOriginCNode(origin_composite_node, broadcast_to_node);
541 
542   // Insert update_state_node to keep execution order.
543   auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
544 
545   // Replace origin ReduceSum's user with atomic clean output
546   ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng);
547   MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
548                << ", clean node: " << broadcast_to_node->fullname_with_scope();
549 }
550 
IsExistStructuralObstacle(const KernelGraphPtr & main_graph,const AnfNodePtr & node,const FuncGraphManagerPtr & mng)551 bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
552                                                      const FuncGraphManagerPtr &mng) {
553   auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
554   // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
555   // node and user node. If reduce is Depend node, the origin node may be wrong!
556   return std::all_of(
557     reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
558       auto &user = user_info.first;
559       if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) &&
560           !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
561         return false;
562       } else {
563         return true;
564       }
565     });
566 }
567 
Run(const FuncGraphPtr & func_graph)568 bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
569   auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
570   MS_EXCEPTION_IF_NULL(kernel_graph);
571   auto mng = kernel_graph->manager();
572   if (mng == nullptr) {
573     mng = Manage(kernel_graph, true);
574     kernel_graph->set_manager(mng);
575   }
576 
577   bool changed = false;
578   std::shared_ptr<AtomicAddChecker> atomic_add_checker = AtomicAddChecker::Init();
579   if (atomic_add_checker == nullptr) {
580     return changed;
581   }
582 
583   auto topo_nodes = TopoSort(kernel_graph->get_return());
584   for (const auto &node : topo_nodes) {
585     if (!atomic_add_checker->Check(node) || !IsExistStructuralObstacle(kernel_graph, node, mng)) {
586       continue;
587     }
588     changed = true;
589     auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
590     UpdateAtomicAddInfo(atomic_add_info);
591     InsertAtomicClean(kernel_graph, node, mng);
592   }
593 
594   if (changed) {
595     mng->RemoveRoots();
596     mng->KeepRoots({func_graph});
597   }
598 
599   return changed;
600 }
601 }  // namespace opt
602 }  // namespace mindspore
603