• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/tsa_atomic_add_to_first_tensor.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include "ir/tensor.h"
22 #include "include/common/utils/utils.h"
23 #include "utils/log_adapter.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "include/common/utils/anfalgo.h"
26 #include "include/backend/kernel_graph.h"
27 #include "kernel/kernel.h"
28 #include "kernel/framework_utils.h"
29 #include "backend/common/graph_kernel/graph_kernel_helper.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
31 
32 namespace mindspore::graphkernel {
33 constexpr auto kTsaInputIndex = 2;
34 class TsaChecker : public AtomicAddChecker {
35  public:
TsaChecker(const PrimitivePtr & target)36   explicit TsaChecker(const PrimitivePtr &target) { target_type_ = target; }
37   virtual ~TsaChecker() = default;
38 
39  protected:
CanActivateAtomicAdd(const AnfNodePtr & anf_node)40   bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override {
41     if (!FindCandidate(anf_node)) {
42       return false;
43     }
44 
45     for (auto atomic_add_info : atomic_add_infos_) {
46       auto tsa_cnode = atomic_add_info.op_node;
47       if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
48         return false;
49       }
50     }
51 
52     return true;
53   }
54 };
55 
FindTsaFirstRealInputInGraph(const KernelGraphPtr &,const CNodePtr & tsa_node,const AnfNodePtr & node) const56 std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &,
57                                                                                       const CNodePtr &tsa_node,
58                                                                                       const AnfNodePtr &node) const {
59   auto cnode = node->cast<CNodePtr>();
60   MS_EXCEPTION_IF_NULL(cnode);
61   auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
62   auto mng_sub = sub_graph->manager();
63   if (mng_sub == nullptr) {
64     mng_sub = Manage(sub_graph, false);
65     sub_graph->set_manager(mng_sub);
66   }
67 
68   auto first_input = tsa_node->input(1)->cast<ParameterPtr>();
69   MS_EXCEPTION_IF_NULL(first_input);
70   auto parameters = sub_graph->parameters();
71   bool hit = false;
72   size_t tsa_first_input_index = 0;
73   for (size_t i = 0; i < parameters.size(); ++i) {
74     if (parameters[i] == first_input) {
75       tsa_first_input_index = i;
76       hit = true;
77       break;
78     }
79   }
80   if (!hit) {
81     MS_LOG(EXCEPTION) << "Cannot find tensor scatter add first input in sub-graph parameters!";
82   }
83 
84   return {cnode->input(tsa_first_input_index + 1), tsa_first_input_index};  // CNode input have a primitive, so add 1.
85 }
86 
GetOrCreateNewTsaFirstNode(const KernelGraphPtr & main_graph,const InplaceAssignerInfo & atomic_add_info,const AnfNodePtr & node) const87 std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstNode(
88   const KernelGraphPtr &main_graph, const InplaceAssignerInfo &atomic_add_info, const AnfNodePtr &node) const {
89   auto mng = main_graph->manager();
90   if (mng == nullptr) {
91     mng = Manage(main_graph, true);
92     main_graph->set_manager(mng);
93   }
94 
95   // Find first input of tsa
96   auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.op_node, node);
97   auto users = mng->node_users()[tsa_first_input.first];
98   if (users.size() == 1 &&
99       !(utils::isa<ValueNodePtr>(tsa_first_input.first) || utils::isa<ParameterPtr>(tsa_first_input.first))) {
100     // If current composite node is only user, and first input is not Parameter or Tensor Value, then use itself.
101     return tsa_first_input;
102   }
103 
104   // Create a copy of first input to atomic add to.
105   // Create composite op's sub-graph.
106   auto new_sub_graph = std::make_shared<FuncGraph>();
107   auto parameter = new_sub_graph->add_parameter();
108   auto kernel_with_index = common::AnfAlgo::VisitKernel(tsa_first_input.first, 0);
109   parameter->set_abstract(GetOutputAbstract(kernel_with_index.first, kernel_with_index.second));
110   parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
111   std::string parameter_format;
112   TypeId parameter_type;
113   if (utils::isa<ValueNodePtr>(kernel_with_index.first)) {
114     auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
115     MS_EXCEPTION_IF_NULL(tensor);
116     parameter_format = kOpFormat_DEFAULT;
117     parameter_type = tensor->data_type();
118   } else {
119     parameter_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
120     parameter_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
121   }
122 
123   kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
124   para_info_builder.SetOutputsFormat({parameter_format});
125   para_info_builder.SetOutputsDeviceType({parameter_type});
126   para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
127   para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
128   AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), parameter.get());
129 
130   // Create inner op.
131   auto identity_node = CreateCNode({NewValueNode(std::make_shared<Primitive>("Reshape")), parameter}, new_sub_graph,
132                                    {GetFormat(parameter), GetShape(parameter), GetType(parameter)});
133   SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(parameter)), identity_node);
134 
135   // Makeup sub-graph.
136   new_sub_graph->set_output(identity_node);
137   auto new_copy_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input.first});
138   new_copy_composite_node->set_abstract(identity_node->abstract());
139   Callback::Instance()->SetGraphKernelNodeKernelInfo(new_copy_composite_node);
140   auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
141   new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
142   new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
143 
144   return {new_copy_composite_node, tsa_first_input.second};
145 }
146 
ChangeKernelBuildInfo(const AnfNodePtr & composite_node,const std::vector<std::tuple<InplaceAssignerInfo,AnfNodePtr,size_t>> & outer_infos) const147 void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo(
148   const AnfNodePtr &composite_node,
149   const std::vector<std::tuple<InplaceAssignerInfo, AnfNodePtr, size_t>> &outer_infos) const {
150   // Change kernel build info with modify input
151   auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
152   MS_EXCEPTION_IF_NULL(kernel_info);
153   const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
154   MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
155   auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
156   auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
157 
158   std::vector<std::string> &modified_inputs_format = origin_inputs_format;
159   std::vector<TypeId> &modified_inputs_type = origin_inputs_type;
160 
161   for (const auto &outer_info : outer_infos) {
162     auto &modified_input = std::get<1>(outer_info);
163     auto tsa_first_input_index = std::get<kTsaInputIndex>(outer_info);
164     auto kernel_with_index = common::AnfAlgo::VisitKernel(modified_input, 0);
165     modified_inputs_format[tsa_first_input_index] =
166       AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
167     modified_inputs_type[tsa_first_input_index] =
168       AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
169   }
170 
171   auto new_selected_info = BuildSelectKernelBuildInfo(
172     modified_inputs_format, modified_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
173     origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
174   AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
175 }
176 
ProcessOriginalCNode(const AnfNodePtr & composite_node,const std::vector<std::tuple<InplaceAssignerInfo,AnfNodePtr,size_t>> & outer_nodes) const177 void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
178   const AnfNodePtr &composite_node,
179   const std::vector<std::tuple<InplaceAssignerInfo, AnfNodePtr, size_t>> &outer_nodes) const {
180   auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
181   auto mng_sub = sub_graph->manager();
182   if (mng_sub == nullptr) {
183     mng_sub = Manage(sub_graph, false);
184     sub_graph->set_manager(mng_sub);
185   }
186 
187   // Modify input
188   std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> parameters_infos;
189   std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> info_and_tsa_outers;
190   for (const auto &[atomic_add_info, outer_node, tsa_first_input_index] : outer_nodes) {
191     composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index + 1, outer_node);
192     auto parameter = sub_graph->parameters()[tsa_first_input_index];
193     (void)parameters_infos.emplace_back(atomic_add_info, parameter);
194     (void)info_and_tsa_outers.emplace_back(atomic_add_info, outer_node);
195   }
196 
197   CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
198   ChangeKernelBuildInfo(composite_node, outer_nodes);
199 
200   auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
201   auto new_graph_name =
202     GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
203   sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
204   MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name;
205 }
206 
ProcessTsa(const KernelGraphPtr & main_graph,const AnfNodePtr & anf_node,const std::vector<InplaceAssignerInfo> & atomic_add_infos,const FuncGraphManagerPtr & mng) const207 void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
208                                            const std::vector<InplaceAssignerInfo> &atomic_add_infos,
209                                            const FuncGraphManagerPtr &mng) const {
210   auto origin_composite_node = anf_node->cast<CNodePtr>();
211   MS_EXCEPTION_IF_NULL(origin_composite_node);
212 
213   // Create identity node.
214   std::vector<std::tuple<InplaceAssignerInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
215   std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> info_and_outer_nodes;
216   for (auto atomic_add_info : atomic_add_infos) {
217     auto outer = GetOrCreateNewTsaFirstNode(main_graph, atomic_add_info, anf_node);
218     (void)info_and_outer_nodes_with_index.emplace_back(atomic_add_info, outer.first, outer.second);
219     (void)info_and_outer_nodes.emplace_back(atomic_add_info, outer.first);
220   }
221 
222   // Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplace-assign to it.
223   ProcessOriginalCNode(origin_composite_node, info_and_outer_nodes_with_index);
224 
225   // Insert Depend before origin TensorScatterAdd's user to keep execution order.
226   ProcessOriginCNodeUser(main_graph, origin_composite_node, info_and_outer_nodes, mng);
227   std::stringstream ss;
228   ss << "Target node: " << origin_composite_node->fullname_with_scope() << ", outer nodes: ";
229   for (auto iter : info_and_outer_nodes) {
230     ss << iter.second->fullname_with_scope() << ", ";
231   }
232 }
233 
Run(const FuncGraphPtr & func_graph)234 bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
235   auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
236   MS_EXCEPTION_IF_NULL(kernel_graph);
237   auto mng = kernel_graph->manager();
238   if (mng == nullptr) {
239     mng = Manage(kernel_graph, true);
240     kernel_graph->set_manager(mng);
241   }
242 
243   bool changed = false;
244   std::shared_ptr<AtomicAddChecker> atomic_add_checker =
245     std::make_shared<TsaChecker>(std::make_shared<Primitive>("TensorScatterAdd"));
246   if (atomic_add_checker == nullptr) {
247     return changed;
248   }
249 
250   auto topo_nodes = TopoSort(kernel_graph->get_return());
251   for (const auto &node : topo_nodes) {
252     if (!atomic_add_checker->Check(node)) {
253       continue;
254     }
255     auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
256     ProcessTsa(kernel_graph, node, atomic_add_infos, mng);
257     changed = true;
258   }
259 
260   if (changed) {
261     GkUtils::UpdateFuncGraphManager(mng, func_graph);
262   }
263 
264   return changed;
265 }
266 }  // namespace mindspore::graphkernel
267