• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.h"
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <map>
22 #include <memory>
23 #include <utility>
24 #include <set>
25 #include <stack>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29 #include "base/core_ops.h"
30 #include "ir/tensor.h"
31 #include "utils/utils.h"
32 #include "utils/log_adapter.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "backend/session/kernel_graph.h"
35 #include "backend/kernel_compiler/kernel.h"
36 #include "backend/kernel_compiler/common_utils.h"
37 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
38 
39 namespace mindspore {
40 namespace opt {
41 class TsaChecker : public AtomicAddChecker {
42  public:
TsaChecker(const PrimitivePtr & target)43   explicit TsaChecker(const PrimitivePtr &target) { target_type_ = target; }
44   virtual ~TsaChecker() = default;
45 
46  protected:
CanActivateAtomicAdd(const AnfNodePtr & anf_node)47   bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override {
48     if (!FindCandidate(anf_node)) {
49       return false;
50     }
51 
52     auto tsa_cnode = atomic_add_info_.atomic_add_node;
53     if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
54       return false;
55     }
56 
57     return true;
58   }
59 };
60 
FindTsaFirstRealInputInGraph(const KernelGraphPtr &,const AnfNodePtr & node)61 AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &, const AnfNodePtr &node) {
62   auto cnode = node->cast<CNodePtr>();
63   MS_EXCEPTION_IF_NULL(cnode);
64   auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
65   auto mng_sub = sub_graph->manager();
66   if (mng_sub == nullptr) {
67     mng_sub = Manage(sub_graph, false);
68     sub_graph->set_manager(mng_sub);
69   }
70 
71   auto first_input = atomic_add_node_->input(1)->cast<ParameterPtr>();
72   MS_EXCEPTION_IF_NULL(first_input);
73   auto parameters = sub_graph->parameters();
74   bool hit = false;
75   for (size_t i = 0; i < parameters.size(); ++i) {
76     if (parameters[i] == first_input) {
77       tsa_first_input_index_ = i;
78       hit = true;
79       break;
80     }
81   }
82   if (!hit) {
83     MS_LOG(EXCEPTION) << "Cannot find tensor scatter add first input in sub-graph parameters!";
84   }
85 
86   return cnode->input(tsa_first_input_index_ + 1);  // CNode input have a primitive, so add 1.
87 }
88 
ProcessTsaFirstNode(const KernelGraphPtr & main_graph,const AnfNodePtr & node)89 AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &main_graph, const AnfNodePtr &node) {
90   auto mng = main_graph->manager();
91   if (mng == nullptr) {
92     mng = Manage(main_graph, true);
93     main_graph->set_manager(mng);
94   }
95   // find first input of tsa
96   auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, node);
97   auto users = mng->node_users()[tsa_first_input];
98   if (users.size() == 1 && !(utils::isa<ValueNodePtr>(tsa_first_input) || utils::isa<ParameterPtr>(tsa_first_input))) {
99     return tsa_first_input;
100   }
101   // Create composite op's sub-graph.
102   auto new_sub_graph = std::make_shared<FuncGraph>();
103   auto parameter = new_sub_graph->add_parameter();
104   auto kernel_with_index = AnfAlgo::VisitKernel(tsa_first_input, 0);
105   parameter->set_abstract(GetOutputAbstract(kernel_with_index.first, kernel_with_index.second));
106   parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
107   std::string parameter_format;
108   TypeId parameter_type;
109   if (utils::isa<ValueNodePtr>(kernel_with_index.first)) {
110     auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
111     MS_EXCEPTION_IF_NULL(tensor);
112     parameter_format = kOpFormat_DEFAULT;
113     parameter_type = tensor->data_type();
114   } else {
115     parameter_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
116     parameter_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
117   }
118 
119   kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
120   para_info_builder.SetOutputsFormat({parameter_format});
121   para_info_builder.SetOutputsDeviceType({parameter_type});
122   para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
123   para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
124   AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), parameter.get());
125 
126   // Create inner op.
127   auto identity_node =
128     CreateCNode({NewValueNode(std::make_shared<Primitive>("Reshape")), parameter}, new_sub_graph,
129                 {.format = GetFormat(parameter), .shape = GetShape(parameter), .type = GetType(parameter)});
130   SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(parameter)), identity_node);
131 
132   // Makeup sub-graph.
133   new_sub_graph->set_output(identity_node);
134   auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input});
135   new_composite_node->set_abstract(identity_node->abstract());
136   SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node});
137   auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
138   new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
139   new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
140 
141   return new_composite_node;
142 }
143 
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const AnfNodePtr & modified_input,bool)144 void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
145                                                        const AnfNodePtr &modified_input, bool) {
146   // Change kernel build info with modify input
147   auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
148   MS_EXCEPTION_IF_NULL(kernel_info);
149   const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
150   auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
151   auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
152   auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
153   auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
154   auto origin_processor = origin_kernel_build_info->processor();
155 
156   std::vector<std::string> &modified_inputs_format = origin_inputs_format;
157   std::vector<TypeId> &modified_inputs_type = origin_inputs_type;
158   std::vector<std::string> new_outputs_format;
159   std::vector<TypeId> new_outputs_type;
160   for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
161     if (real_output_num_ > 1 && i == reduce_real_output_index_) {
162       continue;
163     }
164     new_outputs_format.push_back(origin_outputs_format[i]);
165     new_outputs_type.push_back(origin_outputs_type[i]);
166   }
167 
168   auto kernel_with_index = AnfAlgo::VisitKernel(modified_input, 0);
169   modified_inputs_format[tsa_first_input_index_] =
170     AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
171   modified_inputs_type[tsa_first_input_index_] =
172     AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
173 
174   kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
175   new_info_builder.SetInputsFormat(modified_inputs_format);
176   new_info_builder.SetInputsDeviceType(modified_inputs_type);
177   new_info_builder.SetOutputsFormat(new_outputs_format);
178   new_info_builder.SetOutputsDeviceType(new_outputs_type);
179   new_info_builder.SetProcessor(origin_processor);
180   new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
181   new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
182   auto new_selected_info = new_info_builder.Build();
183   AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
184 }
185 
ProcessOriginCNode(const AnfNodePtr & composite_node,const AnfNodePtr & outter_node)186 void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &outter_node) {
187   auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
188   auto mng_sub = sub_graph->manager();
189   if (mng_sub == nullptr) {
190     mng_sub = Manage(sub_graph, false);
191     sub_graph->set_manager(mng_sub);
192   }
193 
194   // modify input
195   composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index_ + 1, outter_node);
196   CreateInplaceAssignNodeAndCorrectReturn(sub_graph, sub_graph->parameters()[tsa_first_input_index_]);
197 
198   CorrectAbstract(composite_node);
199   CorrectKernelBuildInfo(composite_node, outter_node);
200 
201   auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
202   auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
203   sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
204   MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name;
205 }
206 
ProcessTsa(const KernelGraphPtr & main_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & mng)207 void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
208                                            const FuncGraphManagerPtr &mng) {
209   auto origin_composite_node = anf_node->cast<CNodePtr>();
210   MS_EXCEPTION_IF_NULL(origin_composite_node);
211 
212   // Create identity node.
213   auto outter_node = ProcessTsaFirstNode(main_graph, anf_node);
214 
215   // Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplaceassign to it.
216   // Note: if it's single output, this will increase total memory because of a fake out.
217   ProcessOriginCNode(origin_composite_node, outter_node);
218 
219   // Insert update_state_node to keep execution order.
220   auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
221 
222   // Replace origin ReduceSum's user with atomic clean output
223   ProcessOriginCNodeUser(main_graph, origin_composite_node, outter_node, update_state_node, mng);
224   MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
225                << ", outer node: " << outter_node->fullname_with_scope();
226 }
227 
Run(const FuncGraphPtr & func_graph)228 bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
229   auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
230   MS_EXCEPTION_IF_NULL(kernel_graph);
231   auto mng = kernel_graph->manager();
232   if (mng == nullptr) {
233     mng = Manage(kernel_graph, true);
234     kernel_graph->set_manager(mng);
235   }
236 
237   bool changed = false;
238   std::shared_ptr<AtomicAddChecker> atomic_add_checker =
239     std::make_shared<TsaChecker>(std::make_shared<Primitive>("TensorScatterAdd"));
240   if (atomic_add_checker == nullptr) {
241     return changed;
242   }
243 
244   auto topo_nodes = TopoSort(kernel_graph->get_return());
245   for (const auto &node : topo_nodes) {
246     if (!atomic_add_checker->Check(node)) {
247       continue;
248     }
249     auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
250     atomic_add_node_ = atomic_add_info.atomic_add_node;
251     reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
252     real_output_num_ = atomic_add_info.real_output_num;
253     ProcessTsa(kernel_graph, node, mng);
254     changed = true;
255   }
256 
257   if (changed) {
258     UpdateMng(mng, func_graph);
259   }
260 
261   return changed;
262 }
263 }  // namespace opt
264 }  // namespace mindspore
265