• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/split_model/area.h"
17 #include <algorithm>
18 #include <sstream>
19 #include "mindspore/core/symbolic_shape/int_symbol.h"
20 
21 namespace mindspore::graphkernel::inner {
22 namespace {
ShapeEqual(const NodePtr & a,const NodePtr & b,bool skip_leading_one=true)23 bool ShapeEqual(const NodePtr &a, const NodePtr &b, bool skip_leading_one = true) {
24   MS_EXCEPTION_IF_NULL(a);
25   MS_EXCEPTION_IF_NULL(b);
26   auto l = a->shape.size() < b->shape.size() ? b : a;
27   auto s = a->shape.size() < b->shape.size() ? a : b;
28   auto l_shape = l->shape;
29   auto s_shape = s->shape;
30   auto l_symbol_shape = l->symbolic_shape;
31   auto s_symbol_shape = s->symbolic_shape;
32   bool use_symbol = (l_symbol_shape != nullptr && s_symbol_shape != nullptr);
33   if (IsDynamicRank(l_shape)) {
34     return use_symbol ? (l_symbol_shape == s_symbol_shape) : false;
35   }
36   auto diff = l_shape.size() - s_shape.size();
37   if (diff != 0 && !skip_leading_one) {
38     // shapes with different rank
39     return false;
40   }
41   // check leading one
42   for (size_t i = 0; i < diff; ++i) {
43     if (l_shape[i] == 1 || (l_shape[i] < 0 && l_symbol_shape != nullptr && l_symbol_shape->item(i)->EqualsTo(kSym1))) {
44       continue;
45     }
46     return false;
47   }
48   // check other dimensions
49   for (size_t i = 0; i < s_shape.size(); ++i) {
50     auto il = i + diff;
51     if (l_shape[il] < 0 || s_shape[i] < 0) {
52       if (use_symbol && l_symbol_shape->item(il)->EqualsTo(s_symbol_shape->item(i))) {
53         continue;
54       }
55       return false;
56     } else if (l_shape[il] != s_shape[i]) {
57       return false;
58     }
59   }
60   return true;
61 }
62 
GetRelation(const PrimOpPtr & node,const NodePtr & input)63 EdgeRelation GetRelation(const PrimOpPtr &node, const NodePtr &input) {
64   if (node->compute_type() != NodePattern::ELEMWISE) {
65     return EdgeRelation::INJECTIVE;
66   }
67   if (node->inputs().size() == 1) {
68     // single input elemwise op has no broadcast
69     return EdgeRelation::INJECTIVE;
70   }
71   if (IsDynamic(input->shape)) {
72     if (std::all_of(node->inputs().begin(), node->inputs().end(),
73                     [input](const NodePtr &inp) { return inp == input; })) {
74       return EdgeRelation::INJECTIVE;
75     }
76   }
77   // naively set the edge relation to "broadcast" if the result shape is not equal to the input shape.
78   return ShapeEqual(node, input) ? EdgeRelation::INJECTIVE : EdgeRelation::BROADCAST;
79 }
80 
SameArea(const AreaWithRelation & a,const AreaWithRelation & b)81 bool SameArea(const AreaWithRelation &a, const AreaWithRelation &b) { return a.first == b.first; }
82 
AreaWithRelationCmp(const AreaWithRelation & a,const AreaWithRelation & b)83 bool AreaWithRelationCmp(const AreaWithRelation &a, const AreaWithRelation &b) {
84   // for same areas, put the area with greater EdgeRelation in front when sorting.
85   // compare the areas with unique id, instead of Area pointer, to avoid random result.
86   return SameArea(a, b) ? (a.second > b.second) : (a.first->id() < b.first->id());
87 }
88 }  // namespace
89 
Area(size_t id,const PrimOpPtr & prim_op,bool is_output,const HashMap<NodePtr,AreaPtr> & node_area_map)90 Area::Area(size_t id, const PrimOpPtr &prim_op, bool is_output, const HashMap<NodePtr, AreaPtr> &node_area_map)
91     : hd_(new NodeHandle(this, prim_op)), unique_id_(id), is_output_(is_output), ops_(1, prim_op) {
92   // link inputs of the handle node
93   auto init_pattern = pattern();
94   for (auto &inp : prim_op->inputs()) {
95     auto input_relation = GetRelation(prim_op, inp);
96     if (init_pattern == NodePattern::ELEMWISE && input_relation == EdgeRelation::BROADCAST) {
97       hd_->compute_type_ = NodePattern::BROADCAST;
98     }
99     if (auto inp_area_iter = node_area_map.find(inp); inp_area_iter != node_area_map.end()) {
100       (void)inputs_with_relation_.emplace_back(std::make_pair(inp_area_iter->second, input_relation));
101     }
102   }
103   // ELEMWISE if op has one variable input, other inputs are const input with shape [1]
104   // e.g. Cast(out_0, 43)
105   //      Add(param0, const)
106   if (hd_->compute_type_ == NodePattern::BROADCAST && init_pattern == NodePattern::ELEMWISE) {
107     size_t scalar_input_num = 0;
108     auto input_num = prim_op->inputs().size();
109     for (size_t i = 0; i < input_num; ++i) {
110       auto inp = prim_op->inputs()[i];
111       if (inp != nullptr && inp->tensor_size() == 1 &&
112           (inp->NodeType() == NType::Tensor || inp->NodeType() == NType::Scalar)) {
113         scalar_input_num++;
114       }
115     }
116     if (scalar_input_num + 1 == input_num) {
117       hd_->compute_type_ = NodePattern::ELEMWISE;
118       if (!inputs_with_relation_.empty()) {
119         inputs_with_relation_[0].second = EdgeRelation::INJECTIVE;
120       }
121     }
122   }
123   MakeUniqueAndSyncInputs();
124 }
125 
inputs() const126 std::vector<AreaPtr> Area::inputs() const {
127   std::vector<AreaPtr> result;
128   (void)std::transform(inputs_with_relation_.begin(), inputs_with_relation_.end(), std::back_inserter(result),
129                        [](const AreaWithRelation &inp) { return inp.first; });
130   return result;
131 }
132 
users() const133 std::vector<AreaPtr> Area::users() const {
134   std::vector<AreaPtr> result;
135   (void)std::transform(hd_->users().begin(), hd_->users().end(), std::back_inserter(result), [](const auto &u) {
136     Node *node = u.first;
137     return node->As<NodeHandle>()->area();
138   });
139   return result;
140 }
141 
users_with_relation() const142 std::vector<AreaWithRelation> Area::users_with_relation() const {
143   std::vector<AreaWithRelation> result;
144   (void)std::transform(hd_->users().begin(), hd_->users().end(), std::back_inserter(result), [](const auto &u) {
145     Node *node = u.first;
146     auto area = node->As<NodeHandle>()->area();
147     // the input edge of area is unique
148     const auto relation = area->input_relation(*(u.second.begin()));
149     return std::make_pair(area, relation);
150   });
151   return result;
152 }
153 
compute_size() const154 int64_t Area::compute_size() const {
155   auto op = dom();
156   MS_EXCEPTION_IF_NULL(op);
157   return SizeToLong(op->tensor_size());
158 }
159 
ComputeSizeEqual(const AreaPtr & other) const160 bool Area::ComputeSizeEqual(const AreaPtr &other) const {
161   if (other == nullptr) {
162     return false;
163   }
164   auto op = dom();
165   auto other_op = other->dom();
166   if (op == nullptr || other_op == nullptr) {
167     return false;
168   }
169   auto op_shape = op->shape;
170   auto other_op_shape = other_op->shape;
171   if (!IsDynamic(op_shape) && !IsDynamic(other_op_shape)) {
172     return compute_size() == other->compute_size();
173   }
174   return ShapeEqual(op, other_op);
175 }
176 
ToString() const177 std::string Area::ToString() const {
178   std::ostringstream oss;
179   bool is_first = true;
180   oss << "<";
181   for (auto op : ops_) {
182     if (is_first) {
183       is_first = false;
184       oss << id() << ":";
185     } else {
186       oss << "-";
187     }
188     oss << op->debug_name();
189   }
190   oss << ">";
191   return oss.str();
192 }
193 
MakeUniqueAndSyncInputs()194 void Area::MakeUniqueAndSyncInputs() {
195   // remove the repeated inputs, keep the area with greater EdgeRelation.
196   std::sort(inputs_with_relation_.begin(), inputs_with_relation_.end(), AreaWithRelationCmp);
197   (void)inputs_with_relation_.erase(std::unique(inputs_with_relation_.begin(), inputs_with_relation_.end(), SameArea),
198                                     inputs_with_relation_.cend());
199   // sync the inputs to NodeHandle to maintain users
200   this->hd_->ClearInputs();
201   (void)std::for_each(inputs_with_relation_.begin(), inputs_with_relation_.end(),
202                       [this](const AreaWithRelation &inp) { this->hd_->AddInput(inp.first->hd_); });
203 }
204 
UpdateUsersRelation(const AreaPtr & input_area)205 void Area::UpdateUsersRelation(const AreaPtr &input_area) {
206   auto &user_node_with_index = input_area->hd_->users();
207   std::vector<AreaPtr> user_areas;
208   for (auto &[user_hd, index] : user_node_with_index) {
209     (void)user_areas.emplace_back(user_hd->As<NodeHandle>()->area());
210     const auto idx = *(index.begin());
211     user_areas.back()->inputs_with_relation_[idx].first = this->shared_from_this();
212   }
213   // the inputs should be updated outside the above for-loop,
214   // since the users cannot be updated while traversing.
215   for (auto user : user_areas) {
216     user->MakeUniqueAndSyncInputs();
217   }
218 }
219 
FuseInput(const AreaPtr & input_area)220 void Area::FuseInput(const AreaPtr &input_area) {
221   auto iter = std::find_if(inputs_with_relation_.begin(), inputs_with_relation_.end(),
222                            [&input_area](const AreaWithRelation &a) { return a.first == input_area; });
223   if (iter == inputs_with_relation_.end()) {
224     MS_LOG(EXCEPTION) << "The area " << input_area->ToString() << " should be the input of area " << this->ToString();
225   }
226   auto input_idx = LongToSize(iter - inputs_with_relation_.begin());
227 
228   if (input_area->is_output_) {
229     is_output_ = true;
230   }
231 
232   // Update ops, and discard the input_area's ops.
233   // The dominant node is ops[0], keep the dominant with greater pattern.
234   if (pattern() < input_area->pattern()) {
235     ops_.swap(input_area->ops_);
236   }
237   (void)ops_.insert(ops_.cend(), input_area->ops_.cbegin(), input_area->ops_.cend());
238 
239   // update area pattern
240   hd_->compute_type_ = std::max(pattern(), input_area->pattern());
241   if ((pattern() == NodePattern::ELEMWISE) && (input_relation(input_idx) == EdgeRelation::BROADCAST)) {
242     hd_->compute_type_ = NodePattern::BROADCAST;
243   }
244 
245   // update inputs and relations
246   (void)inputs_with_relation_.erase(iter);
247   (void)inputs_with_relation_.insert(inputs_with_relation_.cend(), input_area->inputs_with_relation_.cbegin(),
248                                      input_area->inputs_with_relation_.cend());
249   MakeUniqueAndSyncInputs();
250   UpdateUsersRelation(input_area);
251 
252   // clear the input_area.
253   input_area->ops_.clear();
254   input_area->inputs_with_relation_.clear();
255   input_area->hd_->ClearInputs();
256 }
257 }  // namespace mindspore::graphkernel::inner
258