1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/inplace_assign_builder.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
26 #include "backend/common/graph_kernel/graph_kernel_helper.h"
27 #include "include/common/debug/anf_ir_dump.h"
28 #include "include/common/utils/utils.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/nn_optimizer_ops.h"
32 #include "mindspore/core/ops/sequence_ops.h"
33 #include "utils/log_adapter.h"
34
35 namespace mindspore::graphkernel {
36 namespace {
CreateAssign(const FuncGraphPtr & sub_graph,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & parameters_infos,size_t idx)37 CNodePtr CreateAssign(const FuncGraphPtr &sub_graph,
38 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> ¶meters_infos, size_t idx) {
39 if (idx >= parameters_infos.size()) {
40 MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")";
41 }
42 MS_EXCEPTION_IF_NULL(sub_graph);
43
44 const auto &target_node = parameters_infos[idx].first.op_node;
45 const auto &new_parameter = parameters_infos[idx].second;
46
47 auto node = CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, target_node}, sub_graph,
48 {GetFormat(target_node), GetShape(target_node), GetType(target_node)});
49 return node;
50 }
51
GetItemIdx(const AnfNodePtr & node)52 size_t GetItemIdx(const AnfNodePtr &node) {
53 if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
54 MS_LOG(EXCEPTION) << "Expect TupleGetItem node, but got " << common::AnfAlgo::GetCNodeName(node);
55 }
56 auto get_item_cnode = node->cast<CNodePtr>();
57 MS_EXCEPTION_IF_NULL(get_item_cnode);
58 auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
59 MS_EXCEPTION_IF_NULL(value_input);
60 auto value_node = value_input->cast<ValueNodePtr>();
61 MS_EXCEPTION_IF_NULL(value_node);
62 auto item_idx = LongToSize(GetValue<int64_t>(value_node->value()));
63 return item_idx;
64 }
65 } // namespace
66
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & inplace_infos)67 void InplaceAssignBuilder::CorrectKernelBuildInfo(
68 const AnfNodePtr &composite_node, const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &inplace_infos) {
69 // Change kernel build info.
70 auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
71 MS_EXCEPTION_IF_NULL(kernel_info);
72 const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
73 MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
74 auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
75 auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
76
77 std::vector<std::string> &new_inputs_format = origin_inputs_format;
78 std::vector<TypeId> &new_inputs_type = origin_inputs_type;
79 for (const auto &inplace_info : inplace_infos) {
80 if (inplace_info.first.inplace_to_origin_input < 0) {
81 auto &new_input = inplace_info.second;
82 auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0);
83 new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
84 new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
85 }
86 }
87
88 auto new_selected_info = BuildSelectKernelBuildInfo(
89 new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
90 origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
91 AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
92 }
93
CreateAssignNodeAndCorrectReturn(const FuncGraphPtr & sub_graph,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & parameters_infos) const94 void InplaceAssignBuilder::CreateAssignNodeAndCorrectReturn(
95 const FuncGraphPtr &sub_graph,
96 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> ¶meters_infos) const {
97 std::map<size_t, size_t> target_indices;
98 for (size_t i = 0; i < parameters_infos.size(); ++i) {
99 target_indices[parameters_infos[i].first.real_output_index + 1] = i;
100 }
101
102 // Change output to Assign node.
103 auto output = sub_graph->output();
104 if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
105 auto output_cnode = output->cast<CNodePtr>();
106 MS_EXCEPTION_IF_NULL(output_cnode);
107 for (size_t i = 1; i < output_cnode->size(); ++i) {
108 std::map<size_t, size_t>::const_iterator cur_input = target_indices.find(i);
109 if (cur_input == target_indices.end()) {
110 continue;
111 }
112 auto inplace = CreateAssign(sub_graph, parameters_infos, cur_input->second);
113 output_cnode->set_input(i, inplace);
114 }
115 } else if (parameters_infos.size() == 1) {
116 auto inplace = CreateAssign(sub_graph, parameters_infos, 0);
117 sub_graph->set_output(inplace);
118 }
119 }
120
CreateCleanCompositeNode(const InplaceAssignerInfo & op_info,const FuncGraphPtr & main_graph,TypeId dst_type)121 CNodePtr InplaceAssignBuilder::CreateCleanCompositeNode(const InplaceAssignerInfo &op_info,
122 const FuncGraphPtr &main_graph, TypeId dst_type) {
123 std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
124
125 if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
126 MS_LOG(EXCEPTION) << "For CreateCleanCompositeNode, the data type: " << TypeIdToString(dst_type, true)
127 << " is not in supported list: [float16, float32, float64].";
128 }
129
130 // Create zero value which will be broadcast to target shape.
131 auto format = GetFormat(op_info.op_node);
132 auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
133 ValueNodePtr value_node;
134 if (dtype == kNumberTypeFloat32) {
135 float val = 0;
136 value_node = CreateTensorValueNode({format, {1}, TypeIdToType(dtype)}, &val, sizeof(float));
137 } else {
138 double val = 0;
139 value_node = CreateTensorValueNode({format, {1}, TypeIdToType(dtype)}, &val, sizeof(double));
140 }
141
142 // Create composite op's sub-graph.
143 auto new_sub_graph = std::make_shared<FuncGraph>();
144
145 AnfNodePtr broadcast_input_node;
146 if (dst_type == kNumberTypeFloat16) {
147 AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
148 auto cast_node_inner = CreateCNode(cast_inputs, new_sub_graph, {format, {1}, TypeIdToType(dst_type)});
149 SetNodeAttrSafely("dst_type", kFloat32, cast_node_inner);
150 broadcast_input_node = cast_node_inner;
151 } else {
152 broadcast_input_node = value_node;
153 }
154
155 // Create broadcast basic op.
156 auto dst_shape_vec = GetShape(op_info.op_node);
157 auto device_shape = GetDeviceShape(op_info.op_node);
158 auto shape_node = CreateTensorValueNode({kOpFormat_DEFAULT, {SizeToLong(device_shape.size())}, kInt64},
159 device_shape.data(), device_shape.size() * sizeof(int64_t));
160
161 AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node, shape_node};
162 auto broadcast_to_node_inner =
163 CreateCNode(clean_inputs, new_sub_graph, {format, dst_shape_vec, GetType(op_info.op_node)});
164
165 // Makeup sub-graph.
166 new_sub_graph->set_output(broadcast_to_node_inner);
167 auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
168 broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
169 Callback::Instance()->SetGraphKernelNodeKernelInfo(broadcast_to_composite_node);
170 auto graph_attr =
171 GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "inplace_assign_builder");
172 new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
173 new_sub_graph->set_attr("composite_type", MakeValue("inplace_assign_builder"));
174
175 return broadcast_to_composite_node;
176 }
177
ProcessOriginCNode(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr)178 void InplaceAssignBuilder::ProcessOriginCNode(
179 const AnfNodePtr &composite_node,
180 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr) {
181 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
182 auto mng_sub = sub_graph->manager();
183 if (mng_sub == nullptr) {
184 mng_sub = Manage(sub_graph, false);
185 sub_graph->set_manager(mng_sub);
186 }
187
188 // Add input
189 std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> parameters_infos;
190 std::vector<AnfNodePtr> additonal_inputs;
191 for (const auto &[target_node_info, input] : info_and_inplace_assignee_addr) {
192 // Add attribute to target node.
193 SetTargetAttrs(target_node_info.op_node);
194
195 // add parameter
196 if (target_node_info.inplace_to_origin_input < 0) {
197 auto parameter = sub_graph->add_parameter();
198 parameter->set_abstract(input->abstract());
199 parameter->set_kernel_info(input->kernel_info_ptr());
200 (void)parameters_infos.emplace_back(target_node_info, parameter);
201 (void)additonal_inputs.emplace_back(input);
202 } else {
203 auto params = sub_graph->parameters();
204 (void)parameters_infos.emplace_back(target_node_info,
205 params[IntToSize(target_node_info.inplace_to_origin_input)]);
206 }
207 }
208
209 auto inputs = composite_node->cast<CNodePtr>()->inputs();
210 (void)inputs.insert(inputs.end(), additonal_inputs.begin(), additonal_inputs.end());
211 composite_node->cast<CNodePtr>()->set_inputs(inputs);
212
213 CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
214 CorrectKernelBuildInfo(composite_node, info_and_inplace_assignee_addr);
215 }
216
FindOriginCNodeUsers(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr,const FuncGraphManagerPtr & mng) const217 std::vector<InplaceAssignUserInfo> InplaceAssignBuilder::FindOriginCNodeUsers(
218 const AnfNodePtr &composite_node,
219 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr,
220 const FuncGraphManagerPtr &mng) const {
221 std::vector<InplaceAssignUserInfo> user_node_infos;
222
223 std::map<size_t, AnfNodePtr> real_indices_and_input_node;
224 for (auto &[info, clean] : info_and_inplace_assignee_addr) {
225 (void)real_indices_and_input_node.emplace(info.real_output_index, clean);
226 }
227
228 if (info_and_inplace_assignee_addr[0].first.real_output_num <= 1) {
229 // Find users directly.
230 auto users = mng->node_users()[composite_node];
231 for (const auto &[user, index] : users) {
232 user_node_infos.push_back({info_and_inplace_assignee_addr[0].second, composite_node, user, IntToSize(index)});
233 }
234 } else {
235 std::vector<std::pair<AnfNodePtr, AnfNodePtr>> getitem_user_nodes;
236 auto users = mng->node_users()[composite_node];
237 for (const auto &node_index : users) {
238 // 1. First, find TupleGetItem nodes.
239 const auto &user_node = node_index.first;
240 if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
241 continue;
242 }
243 auto item_idx = GetItemIdx(user_node);
244 const auto iter = real_indices_and_input_node.find(item_idx);
245 if (iter != real_indices_and_input_node.end()) {
246 (void)getitem_user_nodes.emplace_back(user_node, iter->second);
247 }
248 }
249 // 2. Find users of TupleGetItem nodes.
250 for (size_t i = 0; i < getitem_user_nodes.size(); ++i) {
251 const auto &getitem_node = getitem_user_nodes[i].first;
252 const auto &broadcast_to_node = getitem_user_nodes[i].second;
253 auto real_users = mng->node_users()[getitem_node];
254 for (const auto &[user, index] : real_users) {
255 user_node_infos.push_back({broadcast_to_node, getitem_node, user, IntToSize(index)});
256 }
257 }
258 }
259
260 return user_node_infos;
261 }
262
ProcessOriginCNodeUser(const FuncGraphPtr & main_graph,const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr,const FuncGraphManagerPtr & mng) const263 void InplaceAssignBuilder::ProcessOriginCNodeUser(
264 const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
265 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr,
266 const FuncGraphManagerPtr &mng) const {
267 // 1. Find users.
268 auto user_nodes = FindOriginCNodeUsers(composite_node, info_and_inplace_assignee_addr, mng);
269 for (const auto &iter : user_nodes) {
270 // 2. Make sure modified composite node running first, So firstly, create depend_node, then add edge to connect
271 // work_node, broadcast_node and depend_node to keep order.
272 AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), iter.inplace_assignee_addr, iter.work_node};
273 auto depend_node = main_graph->NewCNode(depend_inputs);
274 depend_node->set_abstract(iter.inplace_assignee_addr->abstract());
275 main_graph->AddNode(depend_node);
276 auto user_cnode = iter.user_node->cast<CNodePtr>();
277 MS_EXCEPTION_IF_NULL(user_cnode);
278 user_cnode->set_input(iter.user_input_idx, depend_node);
279 }
280 }
281 } // namespace mindspore::graphkernel
282