1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/parallel_op_combine.h"
17
18 #include <vector>
19 #include <string>
20 #include <set>
21 #include <deque>
22 #include <utility>
23 #include <algorithm>
24 #include <unordered_set>
25 #include "include/backend/anf_runtime_algorithm.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "kernel/framework_utils.h"
28 #include "backend/common/graph_kernel/graph_kernel_helper.h"
29 #include "include/backend/kernel_graph.h"
30 #include "utils/anf_utils.h"
31 #include "include/common/utils/utils.h"
32 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
33 #include "utils/ms_context.h"
34 #include "ops/array_ops.h"
35 #include "backend/common/graph_kernel/adapter/callback_impl.h"
36
37 namespace mindspore::graphkernel {
38 namespace {
39 constexpr auto kPerm = "perm";
40 constexpr auto kShape = "shape";
41 const int kMinUpdateSize = 2;
GetTransposePerm(const PrimitivePtr & primitive)42 std::vector<int64_t> GetTransposePerm(const PrimitivePtr &primitive) {
43 ValuePtr perm = primitive->GetAttr(kPerm);
44 MS_EXCEPTION_IF_NULL(perm);
45 auto perm_val = perm->cast<ValueTuplePtr>();
46 MS_EXCEPTION_IF_NULL(perm_val);
47 auto perm_val_data = perm_val->value();
48 std::vector<int64_t> perm_int;
49 (void)std::transform(perm_val_data.begin(), perm_val_data.end(), std::back_inserter(perm_int),
50 [=](const ValuePtr &e) -> int64_t {
51 if (e->isa<Int64Imm>()) {
52 return GetValue<int64_t>(e);
53 } else if (e->isa<Int32Imm>()) {
54 return GetValue<int>(e);
55 } else {
56 MS_LOG(EXCEPTION) << "Perm must be int";
57 return -1;
58 }
59 });
60 return perm_int;
61 }
62 } // namespace
BranchGroupFinder(const std::string & op_name,FIsSupportedOp fis_supported_op,FAreCompatibleOps fare_compatible_ops)63 BranchGroupFinder::BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op,
64 FAreCompatibleOps fare_compatible_ops)
65 : op_name_(op_name), fis_supported_op_(fis_supported_op), fare_compatible_ops_(fare_compatible_ops) {}
66
GetConsumers(FuncGraphManagerPtr mng,const AnfNodePtr & producer)67 AnfNodeIndexSet BranchGroupFinder::GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer) {
68 AnfNodeIndexSet consumers;
69 auto users = mng->node_users()[producer];
70 for (auto it : users) {
71 auto user = it.first;
72 if (user && user->cast<CNodePtr>() && AnfUtils::IsRealKernel(user) && fis_supported_op_(user)) {
73 consumers.add(CNodeIndexPair(it.first, it.second));
74 (void)children_map_[producer].insert(user);
75 }
76 }
77 return consumers;
78 }
79
Find(const AnfNodePtr & start_node,const FuncGraphPtr & func_graph)80 std::vector<Group> BranchGroupFinder::Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph) {
81 auto graph_kernel_fg = func_graph == nullptr ? common::AnfAlgo::GetCNodeFuncGraphPtr(start_node) : func_graph;
82 MS_EXCEPTION_IF_NULL(graph_kernel_fg);
83 auto mng = graph_kernel_fg->manager();
84 MS_EXCEPTION_IF_NULL(mng);
85 auto cnode = start_node->cast<CNodePtr>();
86 MS_EXCEPTION_IF_NULL(cnode);
87 std::deque<AnfNodePtr> init_consumer;
88 (void)std::transform(graph_kernel_fg->parameters().begin(), graph_kernel_fg->parameters().end(),
89 std::back_inserter(init_consumer), [](const AnfNodePtr &global_in) { return global_in; });
90 for (size_t i = 1; i < cnode->size(); ++i) {
91 init_consumer.push_back(cnode->input(i));
92 }
93 while (!init_consumer.empty()) {
94 auto new_node = init_consumer.front();
95 init_consumer.pop_front();
96 auto new_consumer = GetConsumers(mng, new_node);
97 (void)std::transform(new_consumer.begin(), new_consumer.end(), std::back_inserter(init_consumer),
98 [](const CNodeIndexPair &index_pair) { return index_pair.first; });
99 }
100 for (auto it : children_map_) {
101 if (it.second.size() > 1) {
102 (void)op_roots_.insert(it.first);
103 }
104 }
105 std::vector<Group> groups;
106 for (const auto &root : op_roots_) {
107 size_t ngroups = groups.size();
108 auto childrens = children_map_.at(root);
109 for (auto child : childrens) {
110 auto prim = GetCNodePrimitive(child);
111 if (!prim) {
112 continue;
113 }
114 auto prim_name = prim->name();
115 // Branch should start with target node that specified by `op_name_`
116 if (prim_name != op_name_) {
117 continue;
118 }
119 auto branch = CreateBranch(child);
120 branch.SetDataRoot(root);
121 auto it = std::find_if(groups.begin() + ngroups, groups.end(), [this, &branch](const Group &group) {
122 MS_EXCEPTION_IF_CHECK_FAIL(!group.empty() && !group[0].ops.empty(), "group empty or group[0] empty");
123 auto top_branch = group[0];
124 return (branch.target_op_pos == top_branch.target_op_pos) &&
125 fare_compatible_ops_(branch.GetTargetOp(), top_branch.GetTargetOp());
126 });
127 if (it != groups.end()) {
128 it->push_back(branch);
129 } else {
130 (void)groups.emplace_back();
131 groups.back().push_back(branch);
132 }
133 }
134 }
135 return groups;
136 }
137
CreateBranch(AnfNodePtr lead_op)138 Branch BranchGroupFinder::CreateBranch(AnfNodePtr lead_op) {
139 AnfNodePtrList ops{lead_op};
140 int root_idx = GetCNodePrimitive(lead_op)->name() == op_name_ ? 0 : -1;
141 auto it = children_map_.find(lead_op);
142 while (it != children_map_.end() && it->second.size() == 1) {
143 auto node = *(it->second).begin();
144 ops.push_back(node);
145 auto prim_name = GetCNodePrimitive(node)->name();
146 if (prim_name == op_name_) {
147 root_idx = static_cast<int>(ops.size());
148 }
149 it = children_map_.find(node);
150 }
151 return Branch(ops, root_idx);
152 }
153
ParallelOpCombiner(const std::string & op_name,uint64_t min_num_branches,const std::string & layout)154 ParallelOpCombiner::ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout)
155 : op_name_(op_name), min_num_branches_(min_num_branches), layout_(layout) {}
156
Combine(const AnfNodePtr & root,const FuncGraphPtr & func_graph)157 AnfNodePtr ParallelOpCombiner::Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph) {
158 MS_EXCEPTION_IF_NULL(root);
159 if (func_graph) {
160 main_graph_ = func_graph;
161 } else {
162 main_graph_ = common::AnfAlgo::GetCNodeFuncGraphPtr(root);
163 }
164 MS_EXCEPTION_IF_NULL(main_graph_);
165 auto finder = BranchGroupFinder(
166 op_name_, [&](const AnfNodePtr n) { return IsSupportedOp(n); },
167 [&](const AnfNodePtr a, const AnfNodePtr b) { return CanOpsBeCombined(a, b); });
168 auto groups = finder.Find(root, main_graph_);
169 children_map_ = std::move(finder.children_map_);
170 for (const Group &group : groups) {
171 if (group.size() < min_num_branches_) {
172 MS_LOG(INFO) << "group size = " << group.size() << " < " << min_num_branches_ << ", skip.";
173 continue;
174 }
175 CombineBranches(group);
176 }
177 return combined_;
178 }
179
CombineBranches(const Group & branches)180 void ParallelOpCombiner::CombineBranches(const Group &branches) {
181 auto combined = MakeCombinedOp(branches);
182 auto it = std::min_element(branches.begin(), branches.end(), [](const Branch &branch_a, const Branch &branch_b) {
183 return branch_a.ops.size() < branch_b.ops.size();
184 });
185 size_t depth = it->ops.size();
186 size_t pos;
187 for (pos = 0; pos < depth; ++pos) {
188 if (static_cast<int>(pos) == it->target_op_pos) {
189 continue;
190 }
191 if (!CheckLevel(branches, pos)) {
192 break;
193 }
194 combined = MakeCombinedAnfNodePtrFromFollowingOps(combined, branches, pos);
195 }
196 if (pos > 0) {
197 UpdateGroupOutput(combined, branches, pos - 1);
198 }
199 combined_ = combined;
200 }
201
CheckLevel(const Group & branches,size_t depth)202 bool ParallelOpCombiner::CheckLevel(const Group &branches, size_t depth) {
203 auto repr = branches[0].ops[depth];
204 auto repr_prim_name = GetCNodePrimitive(repr)->name();
205 // check if all branches in current depth can be combined
206 for (auto it = branches.begin() + 1; it != branches.end(); it++) {
207 const Branch &branch = *it;
208 auto node = branch.ops[depth];
209 auto prim_name = GetCNodePrimitive(node)->name();
210 if (prim_name != repr_prim_name) {
211 MS_LOG(INFO) << "Prim not compatible!" << prim_name << " vs " << repr_prim_name;
212 return false;
213 }
214 if (unsupported_ops_.find(prim_name) != unsupported_ops_.end()) {
215 MS_LOG(INFO) << "Op " << prim_name << " not supported for combination for now, stop.";
216 return false;
217 }
218 if (!IsArgCompatible(repr, node)) {
219 return false;
220 }
221 }
222 MS_LOG(DEBUG) << "Op " << repr_prim_name << " can be combined at depth " << depth;
223 return true;
224 }
225
AutoUpdateInfo(const CNodePtr & to_update)226 bool ParallelOpCombiner::AutoUpdateInfo(const CNodePtr &to_update) {
227 if (to_update->size() < kMinUpdateSize) {
228 MS_LOG(ERROR) << "Cannot auto update for " << to_update->fullname_with_scope() << " with input size "
229 << to_update->size();
230 return false;
231 }
232 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
233 Callback::Instance()->ResetKernelInfo(to_update);
234 #else
235 auto rep_input = to_update->input(1);
236 // NOTE: We assume the inputs' formats and types are consistent with outputs'.
237 std::string input_format = Callback::Instance()->GetTargetFromContext() == kAscendDevice ? "" : kOpFormat_NCHW;
238 auto GetPrevOutFormat = [&input_format](const CNodePtr &cnode) -> bool {
239 if (cnode == nullptr || !cnode->HasAttr(kOutputsFormat)) {
240 return false;
241 }
242 auto prev_of = GetValue<std::vector<std::string> >(cnode->GetAttr(kOutputsFormat));
243 if (prev_of.size() > 0) {
244 input_format = prev_of[0];
245 return true;
246 }
247 return false;
248 };
249 if (AnfUtils::IsRealKernel(rep_input)) {
250 (void)GetPrevOutFormat(rep_input->cast<CNodePtr>());
251 }
252 if (input_format.empty()) {
253 auto it = children_map_.find(rep_input);
254 if (it != children_map_.end()) {
255 for (auto orig_user : it->second) {
256 if (GetPrevOutFormat(orig_user->cast<CNodePtr>())) {
257 break;
258 }
259 }
260 }
261 }
262 if (input_format.empty()) {
263 MS_LOG(WARNING) << "Cannot find prev node's input format, use " << layout_
264 << " by default and that may cause error.";
265 input_format = layout_;
266 }
267 std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(to_update), input_format);
268 to_update->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
269 #endif
270 return true;
271 }
272
GetUniqueInputs(const Group & branches,size_t depth) const273 std::map<size_t, AnfNodePtrList> ParallelOpCombiner::GetUniqueInputs(const Group &branches, size_t depth) const {
274 std::map<size_t, AnfNodePtrList> unique_inputs;
275 AnfNodePtrList parent_in_branch;
276 if (depth >= 1) {
277 (void)std::transform(branches.begin(), branches.end(), std::back_inserter(parent_in_branch),
278 [&depth](const Branch &br) { return br.ops[depth - 1]; });
279 } else {
280 Branch b1 = branches[0];
281 parent_in_branch.push_back(b1.GetRootData());
282 }
283
284 for (auto br : branches) {
285 auto op = br.ops[depth];
286 auto cnode = op->cast<CNodePtr>();
287 // Here we can know for sure that op's arg length are the same (check before)
288 for (size_t i = 1; i < cnode->size(); ++i) {
289 auto in = cnode->input(i);
290 if (std::any_of(parent_in_branch.begin(), parent_in_branch.end(),
291 [&in](const AnfNodePtr &p) { return in == p; })) {
292 continue;
293 }
294 unique_inputs[i].push_back(in);
295 }
296 }
297 return unique_inputs;
298 }
299
NewConcatNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & input_node,size_t concat_dim,size_t input_num)300 CNodePtr GraphBuilder::NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node,
301 size_t concat_dim, size_t input_num) {
302 std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
303 if (Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
304 auto maketuple = NewTupleNode(func_graph, input_node);
305 concat_inputs.push_back(maketuple);
306 } else {
307 for (size_t i = 0; i < input_node.size(); ++i) {
308 auto n = input_node[i];
309 concat_inputs.push_back(n);
310 }
311 }
312 auto concat = func_graph->NewCNode(concat_inputs);
313 MS_EXCEPTION_IF_NULL(concat);
314 func_graph->AddNode(concat);
315 std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node[0], 0)};
316 auto shape = common::AnfAlgo::GetOutputInferShape(input_node[0], 0);
317 shape[concat_dim] *= SizeToLong(input_num);
318 std::vector<ShapeVector> shapes(1, shape);
319 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
320 common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(concat_dim)), concat);
321 common::AnfAlgo::SetNodeAttr(kAttrN, MakeValue(static_cast<int64_t>(input_num)), concat);
322 return concat;
323 }
324
NewTupleNode(const FuncGraphPtr & func_graph,AnfNodePtrList shared_inputs)325 CNodePtr GraphBuilder::NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs) {
326 auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
327 AbstractBasePtrList abs_list;
328 for (auto in : shared_inputs) {
329 mk_inputs.push_back(in);
330 abs_list.push_back(in->abstract());
331 }
332 auto make_tuple_node = func_graph->NewCNode(mk_inputs);
333 func_graph->AddNode(make_tuple_node);
334 make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
335 return make_tuple_node;
336 }
337
NewSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,size_t split_dim,size_t split_num)338 CNodePtr GraphBuilder::NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim,
339 size_t split_num) {
340 if (split_num == 0) {
341 MS_LOG(EXCEPTION) << "split_num should not be zero.";
342 }
343 MS_EXCEPTION_IF_NULL(input_node);
344 std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
345 input_node};
346 auto split = func_graph->NewCNode(split_inputs);
347 func_graph->AddNode(split);
348 MS_EXCEPTION_IF_NULL(split);
349 auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
350 std::vector<TypeId> dtypes(split_num, dtype);
351 auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
352 shape[split_dim] /= SizeToLong(split_num);
353 std::vector<ShapeVector> shapes(split_num, shape);
354 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
355 common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(split_dim), split);
356 common::AnfAlgo::SetNodeAttr(kAttrOutputNum, MakeValue<int64_t>(split_num), split);
357 return split;
358 }
359
NewElemwiseNoAttrNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs)360 CNodePtr GraphBuilder::NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
361 auto node = func_graph->NewCNode(inputs);
362 func_graph->AddNode(node);
363 MS_EXCEPTION_IF_NULL(node);
364 MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
365 MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
366 std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
367 std::vector<ShapeVector> shapes = {common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0)};
368 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
369 return node;
370 }
371
NewReshapeNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs,const AnfNodePtr & orig_node)372 CNodePtr GraphBuilder::NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
373 const AnfNodePtr &orig_node) {
374 auto node = func_graph->NewCNode(inputs);
375 func_graph->AddNode(node);
376 MS_EXCEPTION_IF_NULL(node);
377 MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
378 MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
379 std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
380 auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
381 auto orig_shape_in = common::AnfAlgo::GetPrevNodeOutputInferShape(orig_node, 0);
382 auto orig_shape_out = common::AnfAlgo::GetOutputInferShape(orig_node, 0);
383 auto new_out_shape = InferReshapeOut(orig_shape_in, orig_shape_out, new_shape_in);
384 GetCNodePrimitive(node)->set_attr(kShape, MakeValue(new_out_shape));
385 std::vector<ShapeVector> shapes = {new_out_shape};
386 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
387 return node;
388 }
389
NewTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs)390 CNodePtr GraphBuilder::NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
391 auto node = func_graph->NewCNode(inputs);
392 func_graph->AddNode(node);
393 MS_EXCEPTION_IF_NULL(node);
394 MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
395 MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
396 std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
397 auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
398 auto perm_int = GetTransposePerm(GetCNodePrimitive(node));
399 auto new_out_shape = InferTransposeOut(new_shape_in, perm_int);
400 std::vector<ShapeVector> shapes = {new_out_shape};
401 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
402 return node;
403 }
404
InferReshapeOut(const ShapeVector & orig_reshape_in,const ShapeVector & orig_reshape_out,const ShapeVector & new_reshape_in)405 ShapeVector GraphBuilder::InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
406 const ShapeVector &new_reshape_in) {
407 ShapeVector new_shape_out;
408 if (orig_reshape_in.size() == new_reshape_in.size()) {
409 return InferConcatReshapeOut(orig_reshape_in, orig_reshape_out, new_reshape_in);
410 } else {
411 MS_LOG(EXCEPTION) << "Stack combiner infer for reshape not impl yet";
412 }
413 return new_shape_out;
414 }
415
InferTransposeOut(const ShapeVector & in_shape,const std::vector<int64_t> & perm)416 ShapeVector GraphBuilder::InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm) {
417 ShapeVector out_shape;
418 for (int64_t i : perm) {
419 auto idx = LongToSize(i);
420 out_shape.push_back(in_shape[idx]);
421 }
422 return out_shape;
423 }
424 } // namespace mindspore::graphkernel
425