1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.h"
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <map>
22 #include <memory>
23 #include <utility>
24 #include <set>
25 #include <stack>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29 #include "base/core_ops.h"
30 #include "ir/tensor.h"
31 #include "utils/utils.h"
32 #include "utils/log_adapter.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "backend/session/kernel_graph.h"
35 #include "backend/kernel_compiler/kernel.h"
36 #include "backend/kernel_compiler/common_utils.h"
37 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
38
39 namespace mindspore {
40 namespace opt {
41 class TsaChecker : public AtomicAddChecker {
42 public:
TsaChecker(const PrimitivePtr & target)43 explicit TsaChecker(const PrimitivePtr &target) { target_type_ = target; }
44 virtual ~TsaChecker() = default;
45
46 protected:
CanActivateAtomicAdd(const AnfNodePtr & anf_node)47 bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override {
48 if (!FindCandidate(anf_node)) {
49 return false;
50 }
51
52 auto tsa_cnode = atomic_add_info_.atomic_add_node;
53 if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
54 return false;
55 }
56
57 return true;
58 }
59 };
60
FindTsaFirstRealInputInGraph(const KernelGraphPtr &,const AnfNodePtr & node)61 AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &, const AnfNodePtr &node) {
62 auto cnode = node->cast<CNodePtr>();
63 MS_EXCEPTION_IF_NULL(cnode);
64 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
65 auto mng_sub = sub_graph->manager();
66 if (mng_sub == nullptr) {
67 mng_sub = Manage(sub_graph, false);
68 sub_graph->set_manager(mng_sub);
69 }
70
71 auto first_input = atomic_add_node_->input(1)->cast<ParameterPtr>();
72 MS_EXCEPTION_IF_NULL(first_input);
73 auto parameters = sub_graph->parameters();
74 bool hit = false;
75 for (size_t i = 0; i < parameters.size(); ++i) {
76 if (parameters[i] == first_input) {
77 tsa_first_input_index_ = i;
78 hit = true;
79 break;
80 }
81 }
82 if (!hit) {
83 MS_LOG(EXCEPTION) << "Cannot find tensor scatter add first input in sub-graph parameters!";
84 }
85
86 return cnode->input(tsa_first_input_index_ + 1); // CNode input have a primitive, so add 1.
87 }
88
ProcessTsaFirstNode(const KernelGraphPtr & main_graph,const AnfNodePtr & node)89 AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &main_graph, const AnfNodePtr &node) {
90 auto mng = main_graph->manager();
91 if (mng == nullptr) {
92 mng = Manage(main_graph, true);
93 main_graph->set_manager(mng);
94 }
95 // find first input of tsa
96 auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, node);
97 auto users = mng->node_users()[tsa_first_input];
98 if (users.size() == 1 && !(utils::isa<ValueNodePtr>(tsa_first_input) || utils::isa<ParameterPtr>(tsa_first_input))) {
99 return tsa_first_input;
100 }
101 // Create composite op's sub-graph.
102 auto new_sub_graph = std::make_shared<FuncGraph>();
103 auto parameter = new_sub_graph->add_parameter();
104 auto kernel_with_index = AnfAlgo::VisitKernel(tsa_first_input, 0);
105 parameter->set_abstract(GetOutputAbstract(kernel_with_index.first, kernel_with_index.second));
106 parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
107 std::string parameter_format;
108 TypeId parameter_type;
109 if (utils::isa<ValueNodePtr>(kernel_with_index.first)) {
110 auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
111 MS_EXCEPTION_IF_NULL(tensor);
112 parameter_format = kOpFormat_DEFAULT;
113 parameter_type = tensor->data_type();
114 } else {
115 parameter_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
116 parameter_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
117 }
118
119 kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
120 para_info_builder.SetOutputsFormat({parameter_format});
121 para_info_builder.SetOutputsDeviceType({parameter_type});
122 para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
123 para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
124 AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), parameter.get());
125
126 // Create inner op.
127 auto identity_node =
128 CreateCNode({NewValueNode(std::make_shared<Primitive>("Reshape")), parameter}, new_sub_graph,
129 {.format = GetFormat(parameter), .shape = GetShape(parameter), .type = GetType(parameter)});
130 SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(parameter)), identity_node);
131
132 // Makeup sub-graph.
133 new_sub_graph->set_output(identity_node);
134 auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input});
135 new_composite_node->set_abstract(identity_node->abstract());
136 SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node});
137 auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
138 new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
139 new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
140
141 return new_composite_node;
142 }
143
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const AnfNodePtr & modified_input,bool)144 void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
145 const AnfNodePtr &modified_input, bool) {
146 // Change kernel build info with modify input
147 auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
148 MS_EXCEPTION_IF_NULL(kernel_info);
149 const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
150 auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
151 auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
152 auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
153 auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
154 auto origin_processor = origin_kernel_build_info->processor();
155
156 std::vector<std::string> &modified_inputs_format = origin_inputs_format;
157 std::vector<TypeId> &modified_inputs_type = origin_inputs_type;
158 std::vector<std::string> new_outputs_format;
159 std::vector<TypeId> new_outputs_type;
160 for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
161 if (real_output_num_ > 1 && i == reduce_real_output_index_) {
162 continue;
163 }
164 new_outputs_format.push_back(origin_outputs_format[i]);
165 new_outputs_type.push_back(origin_outputs_type[i]);
166 }
167
168 auto kernel_with_index = AnfAlgo::VisitKernel(modified_input, 0);
169 modified_inputs_format[tsa_first_input_index_] =
170 AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
171 modified_inputs_type[tsa_first_input_index_] =
172 AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
173
174 kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
175 new_info_builder.SetInputsFormat(modified_inputs_format);
176 new_info_builder.SetInputsDeviceType(modified_inputs_type);
177 new_info_builder.SetOutputsFormat(new_outputs_format);
178 new_info_builder.SetOutputsDeviceType(new_outputs_type);
179 new_info_builder.SetProcessor(origin_processor);
180 new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
181 new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
182 auto new_selected_info = new_info_builder.Build();
183 AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
184 }
185
ProcessOriginCNode(const AnfNodePtr & composite_node,const AnfNodePtr & outter_node)186 void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &outter_node) {
187 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
188 auto mng_sub = sub_graph->manager();
189 if (mng_sub == nullptr) {
190 mng_sub = Manage(sub_graph, false);
191 sub_graph->set_manager(mng_sub);
192 }
193
194 // modify input
195 composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index_ + 1, outter_node);
196 CreateInplaceAssignNodeAndCorrectReturn(sub_graph, sub_graph->parameters()[tsa_first_input_index_]);
197
198 CorrectAbstract(composite_node);
199 CorrectKernelBuildInfo(composite_node, outter_node);
200
201 auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
202 auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
203 sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
204 MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name;
205 }
206
ProcessTsa(const KernelGraphPtr & main_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & mng)207 void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
208 const FuncGraphManagerPtr &mng) {
209 auto origin_composite_node = anf_node->cast<CNodePtr>();
210 MS_EXCEPTION_IF_NULL(origin_composite_node);
211
212 // Create identity node.
213 auto outter_node = ProcessTsaFirstNode(main_graph, anf_node);
214
215 // Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplaceassign to it.
216 // Note: if it's single output, this will increase total memory because of a fake out.
217 ProcessOriginCNode(origin_composite_node, outter_node);
218
219 // Insert update_state_node to keep execution order.
220 auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
221
222 // Replace origin ReduceSum's user with atomic clean output
223 ProcessOriginCNodeUser(main_graph, origin_composite_node, outter_node, update_state_node, mng);
224 MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
225 << ", outer node: " << outter_node->fullname_with_scope();
226 }
227
Run(const FuncGraphPtr & func_graph)228 bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
229 auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
230 MS_EXCEPTION_IF_NULL(kernel_graph);
231 auto mng = kernel_graph->manager();
232 if (mng == nullptr) {
233 mng = Manage(kernel_graph, true);
234 kernel_graph->set_manager(mng);
235 }
236
237 bool changed = false;
238 std::shared_ptr<AtomicAddChecker> atomic_add_checker =
239 std::make_shared<TsaChecker>(std::make_shared<Primitive>("TensorScatterAdd"));
240 if (atomic_add_checker == nullptr) {
241 return changed;
242 }
243
244 auto topo_nodes = TopoSort(kernel_graph->get_return());
245 for (const auto &node : topo_nodes) {
246 if (!atomic_add_checker->Check(node)) {
247 continue;
248 }
249 auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
250 atomic_add_node_ = atomic_add_info.atomic_add_node;
251 reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
252 real_output_num_ = atomic_add_info.real_output_num;
253 ProcessTsa(kernel_graph, node, mng);
254 changed = true;
255 }
256
257 if (changed) {
258 UpdateMng(mng, func_graph);
259 }
260
261 return changed;
262 }
263 } // namespace opt
264 } // namespace mindspore
265