1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/tsa_atomic_add_to_first_tensor.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include "ir/tensor.h"
22 #include "include/common/utils/utils.h"
23 #include "utils/log_adapter.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "include/common/utils/anfalgo.h"
26 #include "include/backend/kernel_graph.h"
27 #include "kernel/kernel.h"
28 #include "kernel/framework_utils.h"
29 #include "backend/common/graph_kernel/graph_kernel_helper.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
31
32 namespace mindspore::graphkernel {
33 constexpr auto kTsaInputIndex = 2;
34 class TsaChecker : public AtomicAddChecker {
35 public:
TsaChecker(const PrimitivePtr & target)36 explicit TsaChecker(const PrimitivePtr &target) { target_type_ = target; }
37 virtual ~TsaChecker() = default;
38
39 protected:
CanActivateAtomicAdd(const AnfNodePtr & anf_node)40 bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override {
41 if (!FindCandidate(anf_node)) {
42 return false;
43 }
44
45 for (auto atomic_add_info : atomic_add_infos_) {
46 auto tsa_cnode = atomic_add_info.op_node;
47 if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
48 return false;
49 }
50 }
51
52 return true;
53 }
54 };
55
FindTsaFirstRealInputInGraph(const KernelGraphPtr &,const CNodePtr & tsa_node,const AnfNodePtr & node) const56 std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &,
57 const CNodePtr &tsa_node,
58 const AnfNodePtr &node) const {
59 auto cnode = node->cast<CNodePtr>();
60 MS_EXCEPTION_IF_NULL(cnode);
61 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
62 auto mng_sub = sub_graph->manager();
63 if (mng_sub == nullptr) {
64 mng_sub = Manage(sub_graph, false);
65 sub_graph->set_manager(mng_sub);
66 }
67
68 auto first_input = tsa_node->input(1)->cast<ParameterPtr>();
69 MS_EXCEPTION_IF_NULL(first_input);
70 auto parameters = sub_graph->parameters();
71 bool hit = false;
72 size_t tsa_first_input_index = 0;
73 for (size_t i = 0; i < parameters.size(); ++i) {
74 if (parameters[i] == first_input) {
75 tsa_first_input_index = i;
76 hit = true;
77 break;
78 }
79 }
80 if (!hit) {
81 MS_LOG(EXCEPTION) << "Cannot find tensor scatter add first input in sub-graph parameters!";
82 }
83
84 return {cnode->input(tsa_first_input_index + 1), tsa_first_input_index}; // CNode input have a primitive, so add 1.
85 }
86
GetOrCreateNewTsaFirstNode(const KernelGraphPtr & main_graph,const InplaceAssignerInfo & atomic_add_info,const AnfNodePtr & node) const87 std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstNode(
88 const KernelGraphPtr &main_graph, const InplaceAssignerInfo &atomic_add_info, const AnfNodePtr &node) const {
89 auto mng = main_graph->manager();
90 if (mng == nullptr) {
91 mng = Manage(main_graph, true);
92 main_graph->set_manager(mng);
93 }
94
95 // Find first input of tsa
96 auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.op_node, node);
97 auto users = mng->node_users()[tsa_first_input.first];
98 if (users.size() == 1 &&
99 !(utils::isa<ValueNodePtr>(tsa_first_input.first) || utils::isa<ParameterPtr>(tsa_first_input.first))) {
100 // If current composite node is only user, and first input is not Parameter or Tensor Value, then use itself.
101 return tsa_first_input;
102 }
103
104 // Create a copy of first input to atomic add to.
105 // Create composite op's sub-graph.
106 auto new_sub_graph = std::make_shared<FuncGraph>();
107 auto parameter = new_sub_graph->add_parameter();
108 auto kernel_with_index = common::AnfAlgo::VisitKernel(tsa_first_input.first, 0);
109 parameter->set_abstract(GetOutputAbstract(kernel_with_index.first, kernel_with_index.second));
110 parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
111 std::string parameter_format;
112 TypeId parameter_type;
113 if (utils::isa<ValueNodePtr>(kernel_with_index.first)) {
114 auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
115 MS_EXCEPTION_IF_NULL(tensor);
116 parameter_format = kOpFormat_DEFAULT;
117 parameter_type = tensor->data_type();
118 } else {
119 parameter_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
120 parameter_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
121 }
122
123 kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
124 para_info_builder.SetOutputsFormat({parameter_format});
125 para_info_builder.SetOutputsDeviceType({parameter_type});
126 para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
127 para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
128 AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), parameter.get());
129
130 // Create inner op.
131 auto identity_node = CreateCNode({NewValueNode(std::make_shared<Primitive>("Reshape")), parameter}, new_sub_graph,
132 {GetFormat(parameter), GetShape(parameter), GetType(parameter)});
133 SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(parameter)), identity_node);
134
135 // Makeup sub-graph.
136 new_sub_graph->set_output(identity_node);
137 auto new_copy_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input.first});
138 new_copy_composite_node->set_abstract(identity_node->abstract());
139 Callback::Instance()->SetGraphKernelNodeKernelInfo(new_copy_composite_node);
140 auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
141 new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
142 new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
143
144 return {new_copy_composite_node, tsa_first_input.second};
145 }
146
ChangeKernelBuildInfo(const AnfNodePtr & composite_node,const std::vector<std::tuple<InplaceAssignerInfo,AnfNodePtr,size_t>> & outer_infos) const147 void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo(
148 const AnfNodePtr &composite_node,
149 const std::vector<std::tuple<InplaceAssignerInfo, AnfNodePtr, size_t>> &outer_infos) const {
150 // Change kernel build info with modify input
151 auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
152 MS_EXCEPTION_IF_NULL(kernel_info);
153 const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
154 MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
155 auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
156 auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
157
158 std::vector<std::string> &modified_inputs_format = origin_inputs_format;
159 std::vector<TypeId> &modified_inputs_type = origin_inputs_type;
160
161 for (const auto &outer_info : outer_infos) {
162 auto &modified_input = std::get<1>(outer_info);
163 auto tsa_first_input_index = std::get<kTsaInputIndex>(outer_info);
164 auto kernel_with_index = common::AnfAlgo::VisitKernel(modified_input, 0);
165 modified_inputs_format[tsa_first_input_index] =
166 AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
167 modified_inputs_type[tsa_first_input_index] =
168 AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
169 }
170
171 auto new_selected_info = BuildSelectKernelBuildInfo(
172 modified_inputs_format, modified_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
173 origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
174 AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
175 }
176
ProcessOriginalCNode(const AnfNodePtr & composite_node,const std::vector<std::tuple<InplaceAssignerInfo,AnfNodePtr,size_t>> & outer_nodes) const177 void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
178 const AnfNodePtr &composite_node,
179 const std::vector<std::tuple<InplaceAssignerInfo, AnfNodePtr, size_t>> &outer_nodes) const {
180 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
181 auto mng_sub = sub_graph->manager();
182 if (mng_sub == nullptr) {
183 mng_sub = Manage(sub_graph, false);
184 sub_graph->set_manager(mng_sub);
185 }
186
187 // Modify input
188 std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> parameters_infos;
189 std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> info_and_tsa_outers;
190 for (const auto &[atomic_add_info, outer_node, tsa_first_input_index] : outer_nodes) {
191 composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index + 1, outer_node);
192 auto parameter = sub_graph->parameters()[tsa_first_input_index];
193 (void)parameters_infos.emplace_back(atomic_add_info, parameter);
194 (void)info_and_tsa_outers.emplace_back(atomic_add_info, outer_node);
195 }
196
197 CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
198 ChangeKernelBuildInfo(composite_node, outer_nodes);
199
200 auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
201 auto new_graph_name =
202 GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
203 sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
204 MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name;
205 }
206
ProcessTsa(const KernelGraphPtr & main_graph,const AnfNodePtr & anf_node,const std::vector<InplaceAssignerInfo> & atomic_add_infos,const FuncGraphManagerPtr & mng) const207 void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
208 const std::vector<InplaceAssignerInfo> &atomic_add_infos,
209 const FuncGraphManagerPtr &mng) const {
210 auto origin_composite_node = anf_node->cast<CNodePtr>();
211 MS_EXCEPTION_IF_NULL(origin_composite_node);
212
213 // Create identity node.
214 std::vector<std::tuple<InplaceAssignerInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
215 std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> info_and_outer_nodes;
216 for (auto atomic_add_info : atomic_add_infos) {
217 auto outer = GetOrCreateNewTsaFirstNode(main_graph, atomic_add_info, anf_node);
218 (void)info_and_outer_nodes_with_index.emplace_back(atomic_add_info, outer.first, outer.second);
219 (void)info_and_outer_nodes.emplace_back(atomic_add_info, outer.first);
220 }
221
222 // Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplace-assign to it.
223 ProcessOriginalCNode(origin_composite_node, info_and_outer_nodes_with_index);
224
225 // Insert Depend before origin TensorScatterAdd's user to keep execution order.
226 ProcessOriginCNodeUser(main_graph, origin_composite_node, info_and_outer_nodes, mng);
227 std::stringstream ss;
228 ss << "Target node: " << origin_composite_node->fullname_with_scope() << ", outer nodes: ";
229 for (auto iter : info_and_outer_nodes) {
230 ss << iter.second->fullname_with_scope() << ", ";
231 }
232 }
233
Run(const FuncGraphPtr & func_graph)234 bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
235 auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
236 MS_EXCEPTION_IF_NULL(kernel_graph);
237 auto mng = kernel_graph->manager();
238 if (mng == nullptr) {
239 mng = Manage(kernel_graph, true);
240 kernel_graph->set_manager(mng);
241 }
242
243 bool changed = false;
244 std::shared_ptr<AtomicAddChecker> atomic_add_checker =
245 std::make_shared<TsaChecker>(std::make_shared<Primitive>("TensorScatterAdd"));
246 if (atomic_add_checker == nullptr) {
247 return changed;
248 }
249
250 auto topo_nodes = TopoSort(kernel_graph->get_return());
251 for (const auto &node : topo_nodes) {
252 if (!atomic_add_checker->Check(node)) {
253 continue;
254 }
255 auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
256 ProcessTsa(kernel_graph, node, atomic_add_infos, mng);
257 changed = true;
258 }
259
260 if (changed) {
261 GkUtils::UpdateFuncGraphManager(mng, func_graph);
262 }
263
264 return changed;
265 }
266 } // namespace mindspore::graphkernel
267