1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/parallel_op_concatenate.h"
17
18 #include "include/backend/anf_runtime_algorithm.h"
19 #include "include/common/utils/anfalgo.h"
20 #include "kernel/common_utils.h"
21 #include "backend/common/graph_kernel/graph_kernel_helper.h"
22 #include "backend/common/graph_kernel/adapter/callback_impl.h"
23
24 namespace mindspore::graphkernel {
ParallelOpConcatenater(const std::string & op_name,uint64_t min_num_branches,const std::string & layout)25 ParallelOpConcatenater::ParallelOpConcatenater(const std::string &op_name, uint64_t min_num_branches,
26 const std::string &layout)
27 : ParallelOpCombiner(op_name, min_num_branches, layout) {}
28
IsArgCompatible(const AnfNodePtr a,const AnfNodePtr b)29 bool ParallelOpConcatenater::IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) {
30 auto cnode_a = a->cast<CNodePtr>();
31 auto cnode_b = b->cast<CNodePtr>();
32 MS_EXCEPTION_IF_NULL(cnode_a);
33 MS_EXCEPTION_IF_NULL(cnode_b);
34 auto arg_size = cnode_a->size();
35 if (arg_size != cnode_b->size()) {
36 MS_LOG(DEBUG) << "Args size not compatible: " << arg_size << " vs " << cnode_b->size();
37 return false;
38 }
39 auto cb = Callback::Instance();
40 for (size_t i = 1; i < arg_size; ++i) {
41 auto shape_a = cb->GetInputInferShape(a, i);
42 auto shape_b = cb->GetInputInferShape(b, i);
43 if (shape_a != shape_b) {
44 MS_LOG(ERROR) << "Args shape not compatible:" << shape_a << " vs " << shape_b;
45 return false;
46 }
47 }
48 return true;
49 }
50
MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr & data,const Group & branches,size_t depth)51 AnfNodePtr ParallelOpConcatenater::MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches,
52 size_t depth) {
53 auto ew_plan = GetElemWiseFollowingPlan(branches, depth);
54 plans_.push_back(ew_plan);
55 auto overall_inputs = ReloadInputs(branches, depth, data);
56 if (branches.empty()) {
57 MS_LOG(EXCEPTION) << "Fail to sample ops in a empty group.";
58 }
59 // Since all the ops of same depth in group should be the same, we just sample op in first branch.
60 Branch b0 = branches[0];
61 auto orig_node = b0.GetOp(static_cast<int>(depth));
62 MS_EXCEPTION_IF_NULL(orig_node);
63 CNodePtr new_node;
64 if (GetCNodePrimitive(orig_node)->name() == kReshapeOpName) {
65 new_node = GraphBuilder::NewReshapeNode(main_graph_, overall_inputs, orig_node);
66 } else if (GetCNodePrimitive(orig_node)->name() == kTransposeOpName) {
67 new_node = GraphBuilder::NewTransposeNode(main_graph_, overall_inputs);
68 } else {
69 new_node = GraphBuilder::NewElemwiseNoAttrNode(main_graph_, overall_inputs);
70 }
71 MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(new_node), "AutoUpdateInfo fail");
72 return new_node;
73 }
74
ConcatUniqueInputs(std::map<size_t,AnfNodePtrList> unique_inputs,size_t concat_idx)75 std::map<size_t, AnfNodePtr> ParallelOpConcatenater::ConcatUniqueInputs(std::map<size_t, AnfNodePtrList> unique_inputs,
76 size_t concat_idx) {
77 std::map<size_t, AnfNodePtr> concated_inputs;
78 for (auto it : unique_inputs) {
79 size_t input_idx = it.first;
80 auto local_inputs = it.second;
81 if (local_inputs.size() < kDim2) {
82 MS_LOG(WARNING) << "Concat Op needs at least 2 inputs, while got " << local_inputs.size();
83 continue;
84 }
85 auto concat_node = GraphBuilder::NewConcatNode(main_graph_, local_inputs, concat_idx, local_inputs.size());
86 MS_EXCEPTION_IF_NULL(concat_node);
87 MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(concat_node), "AutoUpdateInfo fail");
88 concated_inputs[input_idx] = concat_node;
89 }
90 return concated_inputs;
91 }
92
UpdateGroupOutput(const AnfNodePtr & data,const Group & branches,size_t depth)93 void ParallelOpConcatenater::UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) {
94 if (depth >= plans_.size()) {
95 MS_LOG(EXCEPTION) << "Cannot get plan at depth " << depth << " vs " << plans_.size();
96 }
97 auto ew_plan = plans_[depth];
98 auto split_node = GraphBuilder::NewSplitNode(main_graph_, data, ew_plan.split_out_idx, branches.size());
99 MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(split_node), "AutoUpdateInfo fail");
100 main_graph_->AddNode(split_node);
101 auto mng = main_graph_->manager();
102 for (size_t i = 0; i < branches.size(); ++i) {
103 auto br = branches[i];
104 auto target = br.ops[depth];
105 auto idx_val = MakeValue(SizeToLong(i));
106 auto gt_idx = NewValueNode(idx_val);
107 gt_idx->set_abstract(idx_val->ToAbstract());
108 AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), split_node, gt_idx};
109 auto new_out = main_graph_->NewCNode(gt_inputs);
110 new_out->set_abstract(target->abstract()->Clone());
111 (void)mng->Replace(target, new_out);
112 }
113 return;
114 }
115
GetElemWiseFollowingPlan(const Group & branches,size_t depth)116 ConcatenatePlan ParallelOpConcatenater::GetElemWiseFollowingPlan(const Group &branches, size_t depth) {
117 if (depth - 1 >= plans_.size()) {
118 MS_LOG(EXCEPTION) << "Should get " << (depth - 1) << " plan first, current plan size = " << plans_.size();
119 }
120 auto last_plan = plans_[depth - 1];
121 ConcatenatePlan ew_plan;
122 auto unique_inputs = GetUniqueInputs(branches, depth);
123 auto cb = Callback::Instance();
124 for (auto it : unique_inputs) {
125 for (auto in : it.second) {
126 if (!ew_plan.in_shape.empty()) {
127 break;
128 }
129 ew_plan.in_shape = cb->GetOutputInferShape(in, 0);
130 }
131 }
132 auto UpdateIdx = [](ShapeVector &base_shape, ShapeVector &new_shape, int base_idx) -> int {
133 if (new_shape.empty()) {
134 return base_idx;
135 }
136 auto rank_diff = static_cast<int>(base_shape.size()) - static_cast<int>(new_shape.size());
137 if (rank_diff > base_idx) {
138 return base_idx;
139 }
140 return base_idx - rank_diff;
141 };
142 ew_plan.concat_in_idx = UpdateIdx(last_plan.in_shape, ew_plan.in_shape, last_plan.concat_in_idx);
143 Branch b0 = branches[0];
144 auto op = b0.ops[depth];
145 ew_plan.out_shape = cb->GetOutputInferShape(op, 0);
146 ew_plan.split_out_idx = UpdateIdx(last_plan.out_shape, ew_plan.out_shape, last_plan.split_out_idx);
147 MS_LOG(DEBUG) << "EW plan: " << ew_plan.concat_in_idx << ", " << ew_plan.split_out_idx << ", " << ew_plan.out_shape;
148 return ew_plan;
149 }
150
ReloadInputs(const Group & branches,size_t depth,AnfNodePtr shared_input)151 AnfNodePtrList ParallelOpConcatenater::ReloadInputs(const Group &branches, size_t depth, AnfNodePtr shared_input) {
152 Branch b1 = branches[0];
153 auto cnode = b1.ops[depth]->cast<CNodePtr>();
154 MS_EXCEPTION_IF_NULL(cnode);
155 auto input_size = cnode->size();
156 auto plan = plans_[depth];
157 auto unique_inputs = GetUniqueInputs(branches, depth);
158 AnfNodePtrList overall_inputs{cnode->input(0)}; // prim
159 auto concated_inputs = ConcatUniqueInputs(unique_inputs, plan.concat_in_idx);
160 for (size_t i = 1; i < input_size; ++i) {
161 if (concated_inputs.find(i) != concated_inputs.end()) {
162 overall_inputs.push_back(concated_inputs[i]);
163 } else {
164 overall_inputs.push_back(shared_input);
165 }
166 }
167 return overall_inputs;
168 }
169
InferConcatReshapeOut(const ShapeVector & orig_reshape_in,const ShapeVector & orig_reshape_out,const ShapeVector & new_reshape_in)170 ShapeVector GraphBuilder::InferConcatReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
171 const ShapeVector &new_reshape_in) {
172 std::map<int, int> idx_map_rev;
173 std::map<int, int> mul_map;
174 int oidx = static_cast<int>(orig_reshape_out.size()) - 1;
175 for (int ridx = static_cast<int>(orig_reshape_in.size()) - 1; ridx >= 0; --ridx) {
176 auto cur_size = orig_reshape_in[ridx];
177 mul_map[ridx] = new_reshape_in[ridx] / orig_reshape_in[ridx];
178 while (oidx >= 0 && cur_size >= orig_reshape_out[oidx] && cur_size % orig_reshape_out[oidx] == 0) {
179 idx_map_rev[oidx] = ridx;
180 cur_size = cur_size / orig_reshape_out[oidx];
181 oidx--;
182 }
183 }
184 ShapeVector new_shape_out;
185 for (int i = 0; i < static_cast<int>(orig_reshape_out.size()); ++i) {
186 auto in_idx = idx_map_rev[i];
187 auto mul = mul_map[in_idx];
188 new_shape_out.push_back(orig_reshape_out[i] * mul);
189 mul_map[in_idx] = 1;
190 }
191 return new_shape_out;
192 }
193 } // namespace mindspore::graphkernel
194