• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <algorithm>
19 #include "backend/common/graph_kernel/split_model/split_model.h"
20 #include "backend/common/graph_kernel/graph_kernel_flags.h"
21 #include "utils/hash_set.h"
22 #include "mindspore/core/symbolic_shape/int_symbol.h"
23 
24 namespace mindspore::graphkernel::inner {
25 namespace {
FindIoNum(const std::vector<AreaPtr> * areas)26 uint64_t FindIoNum(const std::vector<AreaPtr> *areas) {
27   std::unordered_map<PrimOpPtr, uint64_t> degree;
28   std::unordered_set<PrimOpPtr> visited;
29   uint64_t io_num = 0;
30   for (auto a = areas->begin(); a != areas->end(); ++a) {
31     for (auto op : (*a)->ops()) {
32       visited.insert(op);
33       for (auto o : op->inputs()) {
34         if (auto prim = o->As<PrimOp>(); prim) {
35           degree[prim]++;
36         }
37       }
38     }
39   }
40   for (auto op : visited) {
41     auto iter = degree.find(op);
42     if (iter == degree.end()) {  // is output node
43       io_num++;
44     }
45     io_num++;
46     for (auto o : op->inputs()) {
47       if (auto prim = o->As<PrimOp>(); prim && visited.find(prim) != visited.end()) {  // is not input node
48         io_num--;
49         break;
50       }
51     }
52   }
53   return io_num;
54 }
55 }  // namespace
56 
ReachTable(size_t size)57 ReachTable::ReachTable(size_t size) : size_(size), reach_(size, std::vector<bool>(size, false)) {
58   for (size_t i = 0; i < size_; ++i) {
59     reach_[i][i] = true;
60     (void)alive_.insert(i);
61   }
62 }
63 
Link(size_t from,size_t to)64 void ReachTable::Link(size_t from, size_t to) {
65   // if there's an edge <from, to>, the `from` can reach to `to`'s succeeding areas.
66   // so we connect `from` to all succeeding areas of `to`.
67   for (const size_t suc : alive_) {
68     if (Reachable(to, suc)) {
69       reach_[from][suc] = true;
70     }
71   }
72 }
73 
FuseArea(size_t target,size_t other)74 void ReachTable::FuseArea(size_t target, size_t other) {
75   // if `suc` is the succeeding nodes of other_node,
76   // link the target_node's previous nodes to `suc`.
77   for (const size_t suc : alive_) {
78     if (Reachable(other, suc) && !Reachable(target, suc)) {
79       for (const size_t pre : alive_) {
80         if (Reachable(pre, target)) {
81           reach_[pre][suc] = true;
82         }
83       }
84     }
85   }
86   // if `pre` is the previous nodes of other_node,
87   // link `pre` to target_node's succeeding nodes.
88   for (const size_t pre : alive_) {
89     if (Reachable(pre, other) && !Reachable(pre, target)) {
90       for (const size_t suc : alive_) {
91         if (Reachable(target, suc)) {
92           reach_[pre][suc] = true;
93         }
94       }
95     }
96   }
97   // discard other_node.
98   (void)alive_.erase(other);
99 }
100 
HasCircle(const AreaPtr & a,const AreaPtr & b) const101 bool ReachTable::HasCircle(const AreaPtr &a, const AreaPtr &b) const {
102   // a is the input of b
103   if (Reachable(a->id(), b->id())) {
104     // use `inputs_with_relation` instead of `inputs` to avoid generating a new vector.
105     for (auto &inp : b->inputs_with_relation()) {
106       if (inp.first != a && Reachable(a->id(), inp.first->id())) {
107         return true;
108       }
109     }
110   } else {
111     // b is the input of a
112     for (auto &inp : a->inputs_with_relation()) {
113       if (inp.first != b && Reachable(b->id(), inp.first->id())) {
114         return true;
115       }
116     }
117   }
118   return false;
119 }
120 
NewArea(const PrimOpPtr & op,bool is_output)121 AreaPtr SplitModel::NewArea(const PrimOpPtr &op, bool is_output) {
122   auto new_area = std::make_shared<Area>(cur_area_id_++, op->As<PrimOp>(), is_output, node_area_map_);
123   (void)areas_.emplace_back(new_area);
124   node_area_map_[op] = new_area;
125   SetDefaultAreaMode(new_area);
126   return new_area;
127 }
128 
AlignShape(const LiteGraphPtr & litegraph) const129 void SplitModel::AlignShape(const LiteGraphPtr &litegraph) const {
130   for (auto &inp : litegraph->inputs()) {
131     if (inp->shape.empty()) {
132       inp->shape.push_back(1LL);
133       if (inp->symbolic_shape != nullptr && inp->symbolic_shape->size() == 0) {
134         inp->symbolic_shape = ListSymbol::Make({kSym1});
135       }
136     }
137   }
138   auto check_pattern = [](const NodePtr &op) {
139     auto pn = op->As<PrimOp>()->compute_type();
140     return pn == NodePattern::ELEMWISE || pn == NodePattern::BROADCAST || pn == NodePattern::REDUCE;
141   };
142   for (auto &op : litegraph->ops()) {
143     if (!check_pattern(op)) {
144       if (op->shape.empty()) {
145         op->shape.push_back(1LL);
146         if (op->symbolic_shape != nullptr && op->symbolic_shape->size() == 0) {
147           op->symbolic_shape = ListSymbol::Make({kSym1});
148         }
149       }
150       continue;
151     }
152     auto cur_shape_size = op->shape.size();
153     for (auto &inp : op->inputs()) {
154       if (inp->shape.size() > cur_shape_size) {
155         cur_shape_size = inp->shape.size();
156       }
157     }
158     if (cur_shape_size > op->shape.size()) {
159       auto num = cur_shape_size - op->shape.size();
160       (void)op->shape.insert(op->shape.cbegin(), num, 1LL);
161       if (op->symbolic_shape != nullptr) {
162         auto symbols = op->symbolic_shape->symbols();
163         (void)symbols.insert(symbols.begin(), num, kSym1);
164         op->symbolic_shape = ListSymbol::Make(std::move(symbols));
165       }
166     }
167   }
168 }
169 
InitGraph(const LiteGraphPtr & litegraph)170 void SplitModel::InitGraph(const LiteGraphPtr &litegraph) {
171   AlignShape(litegraph);
172   auto &outputs = litegraph->GetOutputs();
173   HashSet<NodePtr> outputs_set(outputs.begin(), outputs.end());
174   for (const auto &op : litegraph->ops()) {
175     if (op->NodeType() != NType::Primitive) {
176       MS_LOG(EXCEPTION) << "Op " << op->debug_name() << " should be a Primitive node, but got " << op->NodeType();
177     }
178     bool is_output = (outputs_set.count(op) > 0);
179     (void)NewArea(op->As<PrimOp>(), is_output);
180   }
181 
182   // Initialize reach table in reversed topological order
183   reach_table_ = std::make_shared<ReachTable>(litegraph->ops().size());
184   MS_EXCEPTION_IF_NULL(reach_table_);
185   for (auto iter = areas_.rbegin(); iter != areas_.rend(); ++iter) {
186     auto users = (*iter)->users();
187     for (auto &user : users) {
188       reach_table_->Link((*iter)->id(), user->id());
189     }
190   }
191 }
192 
AddPattern(const std::shared_ptr<FusePattern> & pn,bool enable)193 void SplitModel::AddPattern(const std::shared_ptr<FusePattern> &pn, bool enable) {
194   (void)patterns_.emplace_back(std::make_pair(pn, enable));
195   patterns_.back().first->SetCircleChecker(reach_table_);
196 }
197 
LimitAreaSize(const AreaPtr & dom,std::vector<AreaPtr> * areas) const198 void SplitModel::LimitAreaSize(const AreaPtr &dom, std::vector<AreaPtr> *areas) const {
199   uint64_t max_size = GraphKernelFlags::GetInstance().composite_op_limit_size;
200   auto dom_size = dom->size();
201   for (auto a = areas->begin(); a != areas->end(); ++a) {
202     dom_size += (*a)->size();
203   }
204   if (GraphKernelFlags::GetInstance().kernel_generator == "DVM") {
205     const uint64_t MAX_DVM_SIZE = 96;
206     max_size = std::min(MAX_DVM_SIZE, max_size);
207     if (dom_size <= max_size) {
208       uint64_t io_num = FindIoNum(areas);
209       max_size = max_size >= io_num ? max_size - io_num : 0;
210       if (dom_size <= max_size) {
211         return;
212       }
213     }
214   } else {
215     if (dom_size <= max_size) {
216       return;
217     }
218   }
219   // fuse the smaller area in priority
220   std::sort(areas->begin(), areas->end(),
221             [max_size](const AreaPtr &a, const AreaPtr &b) { return a->size() < b->size(); });
222   auto iter = std::find_if(areas->begin(), areas->end(), [cur_size = dom->size(), max_size](const AreaPtr &a) mutable {
223     cur_size += a->size();
224     return cur_size > max_size;
225   });
226   (void)areas->erase(iter, areas->cend());
227 }
228 
FuseAreas(const AreaPtr & dom,const std::vector<AreaPtr> & areas,FuseDirection direction)229 void SplitModel::FuseAreas(const AreaPtr &dom, const std::vector<AreaPtr> &areas, FuseDirection direction) {
230   if (areas.empty()) {
231     return;
232   }
233   auto target = dom;
234   for (auto a : areas) {
235     if (direction == FuseDirection::BACKWARD) {
236       // always use back node to fuse the front node.
237       std::swap(target, a);
238     }
239     for (auto &op : a->ops()) {
240       node_area_map_[op] = target;
241     }
242     target->FuseInput(a);
243     reach_table_->FuseArea(target->id(), a->id());
244   }
245   if (target->pattern() > NodePattern::RESHAPE) {
246     target->SetMode(AreaMode::COMPOSITE);
247   }
248 }
249 
RunOnePattern(const FusePatternPtr & pattern)250 bool SplitModel::RunOnePattern(const FusePatternPtr &pattern) {
251   // in one step, we only match the adjacent areas of the "area",
252   // so if matched, we should handle the same area again in the next step
253   bool changed = false;
254   for (auto iter = areas_.begin(); iter != areas_.end();) {
255     auto area = *iter;
256     if (!area->IsAlive()) {
257       iter = areas_.erase(iter);
258       continue;
259     }
260     if (pattern->Run(area)) {
261       MS_LOG(DEBUG) << "Area " << area->ToString() << " matches " << pattern->ToString();
262       LimitAreaSize(area, &pattern->fused_areas_);
263       if (!pattern->fused_areas_.empty()) {
264         FuseAreas(area, pattern->fused_areas_, pattern->direction());
265         changed = true;
266         continue;
267       }
268     }
269     ++iter;
270   }
271   return changed;
272 }
273 
RunFusePatterns()274 void SplitModel::RunFusePatterns() {
275   // process one pattern for all areas before process next pattern.
276   for (auto &[pattern, enable] : patterns_) {
277     if (!enable) {
278       continue;
279     }
280     MS_LOG(DEBUG) << "Run pattern " << pattern->name();
281     (void)RunOnePattern(pattern);
282   }
283   // remove the areas that is fused
284   for (auto iter = areas_.begin(); iter != areas_.end();) {
285     if (!(*iter)->IsAlive()) {
286       iter = areas_.erase(iter);
287     } else {
288       ++iter;
289     }
290   }
291 }
292 
Run(const LiteGraphPtr & litegraph)293 void SplitModel::Run(const LiteGraphPtr &litegraph) {
294   InitGraph(litegraph);
295   InitFusePatterns();
296   RunFusePatterns();
297 }
298 }  // namespace mindspore::graphkernel::inner
299