1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <algorithm>
19 #include "backend/common/graph_kernel/split_model/split_model.h"
20 #include "backend/common/graph_kernel/graph_kernel_flags.h"
21 #include "utils/hash_set.h"
22 #include "mindspore/core/symbolic_shape/int_symbol.h"
23
24 namespace mindspore::graphkernel::inner {
25 namespace {
FindIoNum(const std::vector<AreaPtr> * areas)26 uint64_t FindIoNum(const std::vector<AreaPtr> *areas) {
27 std::unordered_map<PrimOpPtr, uint64_t> degree;
28 std::unordered_set<PrimOpPtr> visited;
29 uint64_t io_num = 0;
30 for (auto a = areas->begin(); a != areas->end(); ++a) {
31 for (auto op : (*a)->ops()) {
32 visited.insert(op);
33 for (auto o : op->inputs()) {
34 if (auto prim = o->As<PrimOp>(); prim) {
35 degree[prim]++;
36 }
37 }
38 }
39 }
40 for (auto op : visited) {
41 auto iter = degree.find(op);
42 if (iter == degree.end()) { // is output node
43 io_num++;
44 }
45 io_num++;
46 for (auto o : op->inputs()) {
47 if (auto prim = o->As<PrimOp>(); prim && visited.find(prim) != visited.end()) { // is not input node
48 io_num--;
49 break;
50 }
51 }
52 }
53 return io_num;
54 }
55 } // namespace
56
ReachTable(size_t size)57 ReachTable::ReachTable(size_t size) : size_(size), reach_(size, std::vector<bool>(size, false)) {
58 for (size_t i = 0; i < size_; ++i) {
59 reach_[i][i] = true;
60 (void)alive_.insert(i);
61 }
62 }
63
Link(size_t from,size_t to)64 void ReachTable::Link(size_t from, size_t to) {
65 // if there's an edge <from, to>, the `from` can reach to `to`'s succeeding areas.
66 // so we connect `from` to all succeeding areas of `to`.
67 for (const size_t suc : alive_) {
68 if (Reachable(to, suc)) {
69 reach_[from][suc] = true;
70 }
71 }
72 }
73
FuseArea(size_t target,size_t other)74 void ReachTable::FuseArea(size_t target, size_t other) {
75 // if `suc` is the succeeding nodes of other_node,
76 // link the target_node's previous nodes to `suc`.
77 for (const size_t suc : alive_) {
78 if (Reachable(other, suc) && !Reachable(target, suc)) {
79 for (const size_t pre : alive_) {
80 if (Reachable(pre, target)) {
81 reach_[pre][suc] = true;
82 }
83 }
84 }
85 }
86 // if `pre` is the previous nodes of other_node,
87 // link `pre` to target_node's succeeding nodes.
88 for (const size_t pre : alive_) {
89 if (Reachable(pre, other) && !Reachable(pre, target)) {
90 for (const size_t suc : alive_) {
91 if (Reachable(target, suc)) {
92 reach_[pre][suc] = true;
93 }
94 }
95 }
96 }
97 // discard other_node.
98 (void)alive_.erase(other);
99 }
100
HasCircle(const AreaPtr & a,const AreaPtr & b) const101 bool ReachTable::HasCircle(const AreaPtr &a, const AreaPtr &b) const {
102 // a is the input of b
103 if (Reachable(a->id(), b->id())) {
104 // use `inputs_with_relation` instead of `inputs` to avoid generating a new vector.
105 for (auto &inp : b->inputs_with_relation()) {
106 if (inp.first != a && Reachable(a->id(), inp.first->id())) {
107 return true;
108 }
109 }
110 } else {
111 // b is the input of a
112 for (auto &inp : a->inputs_with_relation()) {
113 if (inp.first != b && Reachable(b->id(), inp.first->id())) {
114 return true;
115 }
116 }
117 }
118 return false;
119 }
120
NewArea(const PrimOpPtr & op,bool is_output)121 AreaPtr SplitModel::NewArea(const PrimOpPtr &op, bool is_output) {
122 auto new_area = std::make_shared<Area>(cur_area_id_++, op->As<PrimOp>(), is_output, node_area_map_);
123 (void)areas_.emplace_back(new_area);
124 node_area_map_[op] = new_area;
125 SetDefaultAreaMode(new_area);
126 return new_area;
127 }
128
AlignShape(const LiteGraphPtr & litegraph) const129 void SplitModel::AlignShape(const LiteGraphPtr &litegraph) const {
130 for (auto &inp : litegraph->inputs()) {
131 if (inp->shape.empty()) {
132 inp->shape.push_back(1LL);
133 if (inp->symbolic_shape != nullptr && inp->symbolic_shape->size() == 0) {
134 inp->symbolic_shape = ListSymbol::Make({kSym1});
135 }
136 }
137 }
138 auto check_pattern = [](const NodePtr &op) {
139 auto pn = op->As<PrimOp>()->compute_type();
140 return pn == NodePattern::ELEMWISE || pn == NodePattern::BROADCAST || pn == NodePattern::REDUCE;
141 };
142 for (auto &op : litegraph->ops()) {
143 if (!check_pattern(op)) {
144 if (op->shape.empty()) {
145 op->shape.push_back(1LL);
146 if (op->symbolic_shape != nullptr && op->symbolic_shape->size() == 0) {
147 op->symbolic_shape = ListSymbol::Make({kSym1});
148 }
149 }
150 continue;
151 }
152 auto cur_shape_size = op->shape.size();
153 for (auto &inp : op->inputs()) {
154 if (inp->shape.size() > cur_shape_size) {
155 cur_shape_size = inp->shape.size();
156 }
157 }
158 if (cur_shape_size > op->shape.size()) {
159 auto num = cur_shape_size - op->shape.size();
160 (void)op->shape.insert(op->shape.cbegin(), num, 1LL);
161 if (op->symbolic_shape != nullptr) {
162 auto symbols = op->symbolic_shape->symbols();
163 (void)symbols.insert(symbols.begin(), num, kSym1);
164 op->symbolic_shape = ListSymbol::Make(std::move(symbols));
165 }
166 }
167 }
168 }
169
InitGraph(const LiteGraphPtr & litegraph)170 void SplitModel::InitGraph(const LiteGraphPtr &litegraph) {
171 AlignShape(litegraph);
172 auto &outputs = litegraph->GetOutputs();
173 HashSet<NodePtr> outputs_set(outputs.begin(), outputs.end());
174 for (const auto &op : litegraph->ops()) {
175 if (op->NodeType() != NType::Primitive) {
176 MS_LOG(EXCEPTION) << "Op " << op->debug_name() << " should be a Primitive node, but got " << op->NodeType();
177 }
178 bool is_output = (outputs_set.count(op) > 0);
179 (void)NewArea(op->As<PrimOp>(), is_output);
180 }
181
182 // Initialize reach table in reversed topological order
183 reach_table_ = std::make_shared<ReachTable>(litegraph->ops().size());
184 MS_EXCEPTION_IF_NULL(reach_table_);
185 for (auto iter = areas_.rbegin(); iter != areas_.rend(); ++iter) {
186 auto users = (*iter)->users();
187 for (auto &user : users) {
188 reach_table_->Link((*iter)->id(), user->id());
189 }
190 }
191 }
192
AddPattern(const std::shared_ptr<FusePattern> & pn,bool enable)193 void SplitModel::AddPattern(const std::shared_ptr<FusePattern> &pn, bool enable) {
194 (void)patterns_.emplace_back(std::make_pair(pn, enable));
195 patterns_.back().first->SetCircleChecker(reach_table_);
196 }
197
LimitAreaSize(const AreaPtr & dom,std::vector<AreaPtr> * areas) const198 void SplitModel::LimitAreaSize(const AreaPtr &dom, std::vector<AreaPtr> *areas) const {
199 uint64_t max_size = GraphKernelFlags::GetInstance().composite_op_limit_size;
200 auto dom_size = dom->size();
201 for (auto a = areas->begin(); a != areas->end(); ++a) {
202 dom_size += (*a)->size();
203 }
204 if (GraphKernelFlags::GetInstance().kernel_generator == "DVM") {
205 const uint64_t MAX_DVM_SIZE = 96;
206 max_size = std::min(MAX_DVM_SIZE, max_size);
207 if (dom_size <= max_size) {
208 uint64_t io_num = FindIoNum(areas);
209 max_size = max_size >= io_num ? max_size - io_num : 0;
210 if (dom_size <= max_size) {
211 return;
212 }
213 }
214 } else {
215 if (dom_size <= max_size) {
216 return;
217 }
218 }
219 // fuse the smaller area in priority
220 std::sort(areas->begin(), areas->end(),
221 [max_size](const AreaPtr &a, const AreaPtr &b) { return a->size() < b->size(); });
222 auto iter = std::find_if(areas->begin(), areas->end(), [cur_size = dom->size(), max_size](const AreaPtr &a) mutable {
223 cur_size += a->size();
224 return cur_size > max_size;
225 });
226 (void)areas->erase(iter, areas->cend());
227 }
228
FuseAreas(const AreaPtr & dom,const std::vector<AreaPtr> & areas,FuseDirection direction)229 void SplitModel::FuseAreas(const AreaPtr &dom, const std::vector<AreaPtr> &areas, FuseDirection direction) {
230 if (areas.empty()) {
231 return;
232 }
233 auto target = dom;
234 for (auto a : areas) {
235 if (direction == FuseDirection::BACKWARD) {
236 // always use back node to fuse the front node.
237 std::swap(target, a);
238 }
239 for (auto &op : a->ops()) {
240 node_area_map_[op] = target;
241 }
242 target->FuseInput(a);
243 reach_table_->FuseArea(target->id(), a->id());
244 }
245 if (target->pattern() > NodePattern::RESHAPE) {
246 target->SetMode(AreaMode::COMPOSITE);
247 }
248 }
249
RunOnePattern(const FusePatternPtr & pattern)250 bool SplitModel::RunOnePattern(const FusePatternPtr &pattern) {
251 // in one step, we only match the adjacent areas of the "area",
252 // so if matched, we should handle the same area again in the next step
253 bool changed = false;
254 for (auto iter = areas_.begin(); iter != areas_.end();) {
255 auto area = *iter;
256 if (!area->IsAlive()) {
257 iter = areas_.erase(iter);
258 continue;
259 }
260 if (pattern->Run(area)) {
261 MS_LOG(DEBUG) << "Area " << area->ToString() << " matches " << pattern->ToString();
262 LimitAreaSize(area, &pattern->fused_areas_);
263 if (!pattern->fused_areas_.empty()) {
264 FuseAreas(area, pattern->fused_areas_, pattern->direction());
265 changed = true;
266 continue;
267 }
268 }
269 ++iter;
270 }
271 return changed;
272 }
273
RunFusePatterns()274 void SplitModel::RunFusePatterns() {
275 // process one pattern for all areas before process next pattern.
276 for (auto &[pattern, enable] : patterns_) {
277 if (!enable) {
278 continue;
279 }
280 MS_LOG(DEBUG) << "Run pattern " << pattern->name();
281 (void)RunOnePattern(pattern);
282 }
283 // remove the areas that is fused
284 for (auto iter = areas_.begin(); iter != areas_.end();) {
285 if (!(*iter)->IsAlive()) {
286 iter = areas_.erase(iter);
287 } else {
288 ++iter;
289 }
290 }
291 }
292
Run(const LiteGraphPtr & litegraph)293 void SplitModel::Run(const LiteGraphPtr &litegraph) {
294 InitGraph(litegraph);
295 InitFusePatterns();
296 RunFusePatterns();
297 }
298 } // namespace mindspore::graphkernel::inner
299