• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/inplace_assign_builder.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
26 #include "backend/common/graph_kernel/graph_kernel_helper.h"
27 #include "include/common/debug/anf_ir_dump.h"
28 #include "include/common/utils/utils.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/nn_optimizer_ops.h"
32 #include "mindspore/core/ops/sequence_ops.h"
33 #include "utils/log_adapter.h"
34 
35 namespace mindspore::graphkernel {
36 namespace {
CreateAssign(const FuncGraphPtr & sub_graph,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & parameters_infos,size_t idx)37 CNodePtr CreateAssign(const FuncGraphPtr &sub_graph,
38                       const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &parameters_infos, size_t idx) {
39   if (idx >= parameters_infos.size()) {
40     MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")";
41   }
42   MS_EXCEPTION_IF_NULL(sub_graph);
43 
44   const auto &target_node = parameters_infos[idx].first.op_node;
45   const auto &new_parameter = parameters_infos[idx].second;
46 
47   auto node = CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, target_node}, sub_graph,
48                           {GetFormat(target_node), GetShape(target_node), GetType(target_node)});
49   return node;
50 }
51 
GetItemIdx(const AnfNodePtr & node)52 size_t GetItemIdx(const AnfNodePtr &node) {
53   if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
54     MS_LOG(EXCEPTION) << "Expect TupleGetItem node, but got " << common::AnfAlgo::GetCNodeName(node);
55   }
56   auto get_item_cnode = node->cast<CNodePtr>();
57   MS_EXCEPTION_IF_NULL(get_item_cnode);
58   auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
59   MS_EXCEPTION_IF_NULL(value_input);
60   auto value_node = value_input->cast<ValueNodePtr>();
61   MS_EXCEPTION_IF_NULL(value_node);
62   auto item_idx = LongToSize(GetValue<int64_t>(value_node->value()));
63   return item_idx;
64 }
65 }  // namespace
66 
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & inplace_infos)67 void InplaceAssignBuilder::CorrectKernelBuildInfo(
68   const AnfNodePtr &composite_node, const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &inplace_infos) {
69   // Change kernel build info.
70   auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
71   MS_EXCEPTION_IF_NULL(kernel_info);
72   const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
73   MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
74   auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
75   auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
76 
77   std::vector<std::string> &new_inputs_format = origin_inputs_format;
78   std::vector<TypeId> &new_inputs_type = origin_inputs_type;
79   for (const auto &inplace_info : inplace_infos) {
80     if (inplace_info.first.inplace_to_origin_input < 0) {
81       auto &new_input = inplace_info.second;
82       auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0);
83       new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
84       new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
85     }
86   }
87 
88   auto new_selected_info = BuildSelectKernelBuildInfo(
89     new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
90     origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
91   AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
92 }
93 
CreateAssignNodeAndCorrectReturn(const FuncGraphPtr & sub_graph,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & parameters_infos) const94 void InplaceAssignBuilder::CreateAssignNodeAndCorrectReturn(
95   const FuncGraphPtr &sub_graph,
96   const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &parameters_infos) const {
97   std::map<size_t, size_t> target_indices;
98   for (size_t i = 0; i < parameters_infos.size(); ++i) {
99     target_indices[parameters_infos[i].first.real_output_index + 1] = i;
100   }
101 
102   // Change output to Assign node.
103   auto output = sub_graph->output();
104   if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
105     auto output_cnode = output->cast<CNodePtr>();
106     MS_EXCEPTION_IF_NULL(output_cnode);
107     for (size_t i = 1; i < output_cnode->size(); ++i) {
108       std::map<size_t, size_t>::const_iterator cur_input = target_indices.find(i);
109       if (cur_input == target_indices.end()) {
110         continue;
111       }
112       auto inplace = CreateAssign(sub_graph, parameters_infos, cur_input->second);
113       output_cnode->set_input(i, inplace);
114     }
115   } else if (parameters_infos.size() == 1) {
116     auto inplace = CreateAssign(sub_graph, parameters_infos, 0);
117     sub_graph->set_output(inplace);
118   }
119 }
120 
CreateCleanCompositeNode(const InplaceAssignerInfo & op_info,const FuncGraphPtr & main_graph,TypeId dst_type)121 CNodePtr InplaceAssignBuilder::CreateCleanCompositeNode(const InplaceAssignerInfo &op_info,
122                                                         const FuncGraphPtr &main_graph, TypeId dst_type) {
123   std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
124 
125   if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
126     MS_LOG(EXCEPTION) << "For CreateCleanCompositeNode, the data type: " << TypeIdToString(dst_type, true)
127                       << " is not in supported list: [float16, float32, float64].";
128   }
129 
130   // Create zero value which will be broadcast to target shape.
131   auto format = GetFormat(op_info.op_node);
132   auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
133   ValueNodePtr value_node;
134   if (dtype == kNumberTypeFloat32) {
135     float val = 0;
136     value_node = CreateTensorValueNode({format, {1}, TypeIdToType(dtype)}, &val, sizeof(float));
137   } else {
138     double val = 0;
139     value_node = CreateTensorValueNode({format, {1}, TypeIdToType(dtype)}, &val, sizeof(double));
140   }
141 
142   // Create composite op's sub-graph.
143   auto new_sub_graph = std::make_shared<FuncGraph>();
144 
145   AnfNodePtr broadcast_input_node;
146   if (dst_type == kNumberTypeFloat16) {
147     AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
148     auto cast_node_inner = CreateCNode(cast_inputs, new_sub_graph, {format, {1}, TypeIdToType(dst_type)});
149     SetNodeAttrSafely("dst_type", kFloat32, cast_node_inner);
150     broadcast_input_node = cast_node_inner;
151   } else {
152     broadcast_input_node = value_node;
153   }
154 
155   // Create broadcast basic op.
156   auto dst_shape_vec = GetShape(op_info.op_node);
157   auto device_shape = GetDeviceShape(op_info.op_node);
158   auto shape_node = CreateTensorValueNode({kOpFormat_DEFAULT, {SizeToLong(device_shape.size())}, kInt64},
159                                           device_shape.data(), device_shape.size() * sizeof(int64_t));
160 
161   AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node, shape_node};
162   auto broadcast_to_node_inner =
163     CreateCNode(clean_inputs, new_sub_graph, {format, dst_shape_vec, GetType(op_info.op_node)});
164 
165   // Makeup sub-graph.
166   new_sub_graph->set_output(broadcast_to_node_inner);
167   auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
168   broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
169   Callback::Instance()->SetGraphKernelNodeKernelInfo(broadcast_to_composite_node);
170   auto graph_attr =
171     GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "inplace_assign_builder");
172   new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
173   new_sub_graph->set_attr("composite_type", MakeValue("inplace_assign_builder"));
174 
175   return broadcast_to_composite_node;
176 }
177 
ProcessOriginCNode(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr)178 void InplaceAssignBuilder::ProcessOriginCNode(
179   const AnfNodePtr &composite_node,
180   const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr) {
181   auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
182   auto mng_sub = sub_graph->manager();
183   if (mng_sub == nullptr) {
184     mng_sub = Manage(sub_graph, false);
185     sub_graph->set_manager(mng_sub);
186   }
187 
188   // Add input
189   std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> parameters_infos;
190   std::vector<AnfNodePtr> additonal_inputs;
191   for (const auto &[target_node_info, input] : info_and_inplace_assignee_addr) {
192     // Add attribute to target node.
193     SetTargetAttrs(target_node_info.op_node);
194 
195     // add parameter
196     if (target_node_info.inplace_to_origin_input < 0) {
197       auto parameter = sub_graph->add_parameter();
198       parameter->set_abstract(input->abstract());
199       parameter->set_kernel_info(input->kernel_info_ptr());
200       (void)parameters_infos.emplace_back(target_node_info, parameter);
201       (void)additonal_inputs.emplace_back(input);
202     } else {
203       auto params = sub_graph->parameters();
204       (void)parameters_infos.emplace_back(target_node_info,
205                                           params[IntToSize(target_node_info.inplace_to_origin_input)]);
206     }
207   }
208 
209   auto inputs = composite_node->cast<CNodePtr>()->inputs();
210   (void)inputs.insert(inputs.end(), additonal_inputs.begin(), additonal_inputs.end());
211   composite_node->cast<CNodePtr>()->set_inputs(inputs);
212 
213   CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
214   CorrectKernelBuildInfo(composite_node, info_and_inplace_assignee_addr);
215 }
216 
FindOriginCNodeUsers(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr,const FuncGraphManagerPtr & mng) const217 std::vector<InplaceAssignUserInfo> InplaceAssignBuilder::FindOriginCNodeUsers(
218   const AnfNodePtr &composite_node,
219   const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr,
220   const FuncGraphManagerPtr &mng) const {
221   std::vector<InplaceAssignUserInfo> user_node_infos;
222 
223   std::map<size_t, AnfNodePtr> real_indices_and_input_node;
224   for (auto &[info, clean] : info_and_inplace_assignee_addr) {
225     (void)real_indices_and_input_node.emplace(info.real_output_index, clean);
226   }
227 
228   if (info_and_inplace_assignee_addr[0].first.real_output_num <= 1) {
229     // Find users directly.
230     auto users = mng->node_users()[composite_node];
231     for (const auto &[user, index] : users) {
232       user_node_infos.push_back({info_and_inplace_assignee_addr[0].second, composite_node, user, IntToSize(index)});
233     }
234   } else {
235     std::vector<std::pair<AnfNodePtr, AnfNodePtr>> getitem_user_nodes;
236     auto users = mng->node_users()[composite_node];
237     for (const auto &node_index : users) {
238       // 1. First, find TupleGetItem nodes.
239       const auto &user_node = node_index.first;
240       if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
241         continue;
242       }
243       auto item_idx = GetItemIdx(user_node);
244       const auto iter = real_indices_and_input_node.find(item_idx);
245       if (iter != real_indices_and_input_node.end()) {
246         (void)getitem_user_nodes.emplace_back(user_node, iter->second);
247       }
248     }
249     // 2. Find users of TupleGetItem nodes.
250     for (size_t i = 0; i < getitem_user_nodes.size(); ++i) {
251       const auto &getitem_node = getitem_user_nodes[i].first;
252       const auto &broadcast_to_node = getitem_user_nodes[i].second;
253       auto real_users = mng->node_users()[getitem_node];
254       for (const auto &[user, index] : real_users) {
255         user_node_infos.push_back({broadcast_to_node, getitem_node, user, IntToSize(index)});
256       }
257     }
258   }
259 
260   return user_node_infos;
261 }
262 
ProcessOriginCNodeUser(const FuncGraphPtr & main_graph,const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr,const FuncGraphManagerPtr & mng) const263 void InplaceAssignBuilder::ProcessOriginCNodeUser(
264   const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
265   const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr,
266   const FuncGraphManagerPtr &mng) const {
267   // 1. Find users.
268   auto user_nodes = FindOriginCNodeUsers(composite_node, info_and_inplace_assignee_addr, mng);
269   for (const auto &iter : user_nodes) {
270     // 2. Make sure modified composite node running first, So firstly, create depend_node, then add edge to connect
271     // work_node, broadcast_node and depend_node to keep order.
272     AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), iter.inplace_assignee_addr, iter.work_node};
273     auto depend_node = main_graph->NewCNode(depend_inputs);
274     depend_node->set_abstract(iter.inplace_assignee_addr->abstract());
275     main_graph->AddNode(depend_node);
276     auto user_cnode = iter.user_node->cast<CNodePtr>();
277     MS_EXCEPTION_IF_NULL(user_cnode);
278     user_cnode->set_input(iter.user_input_idx, depend_node);
279   }
280 }
281 }  // namespace mindspore::graphkernel
282