1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/split_model/area.h"
17 #include <algorithm>
18 #include <sstream>
19 #include "mindspore/core/symbolic_shape/int_symbol.h"
20
21 namespace mindspore::graphkernel::inner {
22 namespace {
ShapeEqual(const NodePtr & a,const NodePtr & b,bool skip_leading_one=true)23 bool ShapeEqual(const NodePtr &a, const NodePtr &b, bool skip_leading_one = true) {
24 MS_EXCEPTION_IF_NULL(a);
25 MS_EXCEPTION_IF_NULL(b);
26 auto l = a->shape.size() < b->shape.size() ? b : a;
27 auto s = a->shape.size() < b->shape.size() ? a : b;
28 auto l_shape = l->shape;
29 auto s_shape = s->shape;
30 auto l_symbol_shape = l->symbolic_shape;
31 auto s_symbol_shape = s->symbolic_shape;
32 bool use_symbol = (l_symbol_shape != nullptr && s_symbol_shape != nullptr);
33 if (IsDynamicRank(l_shape)) {
34 return use_symbol ? (l_symbol_shape == s_symbol_shape) : false;
35 }
36 auto diff = l_shape.size() - s_shape.size();
37 if (diff != 0 && !skip_leading_one) {
38 // shapes with different rank
39 return false;
40 }
41 // check leading one
42 for (size_t i = 0; i < diff; ++i) {
43 if (l_shape[i] == 1 || (l_shape[i] < 0 && l_symbol_shape != nullptr && l_symbol_shape->item(i)->EqualsTo(kSym1))) {
44 continue;
45 }
46 return false;
47 }
48 // check other dimensions
49 for (size_t i = 0; i < s_shape.size(); ++i) {
50 auto il = i + diff;
51 if (l_shape[il] < 0 || s_shape[i] < 0) {
52 if (use_symbol && l_symbol_shape->item(il)->EqualsTo(s_symbol_shape->item(i))) {
53 continue;
54 }
55 return false;
56 } else if (l_shape[il] != s_shape[i]) {
57 return false;
58 }
59 }
60 return true;
61 }
62
GetRelation(const PrimOpPtr & node,const NodePtr & input)63 EdgeRelation GetRelation(const PrimOpPtr &node, const NodePtr &input) {
64 if (node->compute_type() != NodePattern::ELEMWISE) {
65 return EdgeRelation::INJECTIVE;
66 }
67 if (node->inputs().size() == 1) {
68 // single input elemwise op has no broadcast
69 return EdgeRelation::INJECTIVE;
70 }
71 if (IsDynamic(input->shape)) {
72 if (std::all_of(node->inputs().begin(), node->inputs().end(),
73 [input](const NodePtr &inp) { return inp == input; })) {
74 return EdgeRelation::INJECTIVE;
75 }
76 }
77 // naively set the edge relation to "broadcast" if the result shape is not equal to the input shape.
78 return ShapeEqual(node, input) ? EdgeRelation::INJECTIVE : EdgeRelation::BROADCAST;
79 }
80
SameArea(const AreaWithRelation & a,const AreaWithRelation & b)81 bool SameArea(const AreaWithRelation &a, const AreaWithRelation &b) { return a.first == b.first; }
82
AreaWithRelationCmp(const AreaWithRelation & a,const AreaWithRelation & b)83 bool AreaWithRelationCmp(const AreaWithRelation &a, const AreaWithRelation &b) {
84 // for same areas, put the area with greater EdgeRelation in front when sorting.
85 // compare the areas with unique id, instead of Area pointer, to avoid random result.
86 return SameArea(a, b) ? (a.second > b.second) : (a.first->id() < b.first->id());
87 }
88 } // namespace
89
Area(size_t id,const PrimOpPtr & prim_op,bool is_output,const HashMap<NodePtr,AreaPtr> & node_area_map)90 Area::Area(size_t id, const PrimOpPtr &prim_op, bool is_output, const HashMap<NodePtr, AreaPtr> &node_area_map)
91 : hd_(new NodeHandle(this, prim_op)), unique_id_(id), is_output_(is_output), ops_(1, prim_op) {
92 // link inputs of the handle node
93 auto init_pattern = pattern();
94 for (auto &inp : prim_op->inputs()) {
95 auto input_relation = GetRelation(prim_op, inp);
96 if (init_pattern == NodePattern::ELEMWISE && input_relation == EdgeRelation::BROADCAST) {
97 hd_->compute_type_ = NodePattern::BROADCAST;
98 }
99 if (auto inp_area_iter = node_area_map.find(inp); inp_area_iter != node_area_map.end()) {
100 (void)inputs_with_relation_.emplace_back(std::make_pair(inp_area_iter->second, input_relation));
101 }
102 }
103 // ELEMWISE if op has one variable input, other inputs are const input with shape [1]
104 // e.g. Cast(out_0, 43)
105 // Add(param0, const)
106 if (hd_->compute_type_ == NodePattern::BROADCAST && init_pattern == NodePattern::ELEMWISE) {
107 size_t scalar_input_num = 0;
108 auto input_num = prim_op->inputs().size();
109 for (size_t i = 0; i < input_num; ++i) {
110 auto inp = prim_op->inputs()[i];
111 if (inp != nullptr && inp->tensor_size() == 1 &&
112 (inp->NodeType() == NType::Tensor || inp->NodeType() == NType::Scalar)) {
113 scalar_input_num++;
114 }
115 }
116 if (scalar_input_num + 1 == input_num) {
117 hd_->compute_type_ = NodePattern::ELEMWISE;
118 if (!inputs_with_relation_.empty()) {
119 inputs_with_relation_[0].second = EdgeRelation::INJECTIVE;
120 }
121 }
122 }
123 MakeUniqueAndSyncInputs();
124 }
125
inputs() const126 std::vector<AreaPtr> Area::inputs() const {
127 std::vector<AreaPtr> result;
128 (void)std::transform(inputs_with_relation_.begin(), inputs_with_relation_.end(), std::back_inserter(result),
129 [](const AreaWithRelation &inp) { return inp.first; });
130 return result;
131 }
132
users() const133 std::vector<AreaPtr> Area::users() const {
134 std::vector<AreaPtr> result;
135 (void)std::transform(hd_->users().begin(), hd_->users().end(), std::back_inserter(result), [](const auto &u) {
136 Node *node = u.first;
137 return node->As<NodeHandle>()->area();
138 });
139 return result;
140 }
141
users_with_relation() const142 std::vector<AreaWithRelation> Area::users_with_relation() const {
143 std::vector<AreaWithRelation> result;
144 (void)std::transform(hd_->users().begin(), hd_->users().end(), std::back_inserter(result), [](const auto &u) {
145 Node *node = u.first;
146 auto area = node->As<NodeHandle>()->area();
147 // the input edge of area is unique
148 const auto relation = area->input_relation(*(u.second.begin()));
149 return std::make_pair(area, relation);
150 });
151 return result;
152 }
153
compute_size() const154 int64_t Area::compute_size() const {
155 auto op = dom();
156 MS_EXCEPTION_IF_NULL(op);
157 return SizeToLong(op->tensor_size());
158 }
159
ComputeSizeEqual(const AreaPtr & other) const160 bool Area::ComputeSizeEqual(const AreaPtr &other) const {
161 if (other == nullptr) {
162 return false;
163 }
164 auto op = dom();
165 auto other_op = other->dom();
166 if (op == nullptr || other_op == nullptr) {
167 return false;
168 }
169 auto op_shape = op->shape;
170 auto other_op_shape = other_op->shape;
171 if (!IsDynamic(op_shape) && !IsDynamic(other_op_shape)) {
172 return compute_size() == other->compute_size();
173 }
174 return ShapeEqual(op, other_op);
175 }
176
ToString() const177 std::string Area::ToString() const {
178 std::ostringstream oss;
179 bool is_first = true;
180 oss << "<";
181 for (auto op : ops_) {
182 if (is_first) {
183 is_first = false;
184 oss << id() << ":";
185 } else {
186 oss << "-";
187 }
188 oss << op->debug_name();
189 }
190 oss << ">";
191 return oss.str();
192 }
193
MakeUniqueAndSyncInputs()194 void Area::MakeUniqueAndSyncInputs() {
195 // remove the repeated inputs, keep the area with greater EdgeRelation.
196 std::sort(inputs_with_relation_.begin(), inputs_with_relation_.end(), AreaWithRelationCmp);
197 (void)inputs_with_relation_.erase(std::unique(inputs_with_relation_.begin(), inputs_with_relation_.end(), SameArea),
198 inputs_with_relation_.cend());
199 // sync the inputs to NodeHandle to maintain users
200 this->hd_->ClearInputs();
201 (void)std::for_each(inputs_with_relation_.begin(), inputs_with_relation_.end(),
202 [this](const AreaWithRelation &inp) { this->hd_->AddInput(inp.first->hd_); });
203 }
204
UpdateUsersRelation(const AreaPtr & input_area)205 void Area::UpdateUsersRelation(const AreaPtr &input_area) {
206 auto &user_node_with_index = input_area->hd_->users();
207 std::vector<AreaPtr> user_areas;
208 for (auto &[user_hd, index] : user_node_with_index) {
209 (void)user_areas.emplace_back(user_hd->As<NodeHandle>()->area());
210 const auto idx = *(index.begin());
211 user_areas.back()->inputs_with_relation_[idx].first = this->shared_from_this();
212 }
213 // the inputs should be updated outside the above for-loop,
214 // since the users cannot be updated while traversing.
215 for (auto user : user_areas) {
216 user->MakeUniqueAndSyncInputs();
217 }
218 }
219
FuseInput(const AreaPtr & input_area)220 void Area::FuseInput(const AreaPtr &input_area) {
221 auto iter = std::find_if(inputs_with_relation_.begin(), inputs_with_relation_.end(),
222 [&input_area](const AreaWithRelation &a) { return a.first == input_area; });
223 if (iter == inputs_with_relation_.end()) {
224 MS_LOG(EXCEPTION) << "The area " << input_area->ToString() << " should be the input of area " << this->ToString();
225 }
226 auto input_idx = LongToSize(iter - inputs_with_relation_.begin());
227
228 if (input_area->is_output_) {
229 is_output_ = true;
230 }
231
232 // Update ops, and discard the input_area's ops.
233 // The dominant node is ops[0], keep the dominant with greater pattern.
234 if (pattern() < input_area->pattern()) {
235 ops_.swap(input_area->ops_);
236 }
237 (void)ops_.insert(ops_.cend(), input_area->ops_.cbegin(), input_area->ops_.cend());
238
239 // update area pattern
240 hd_->compute_type_ = std::max(pattern(), input_area->pattern());
241 if ((pattern() == NodePattern::ELEMWISE) && (input_relation(input_idx) == EdgeRelation::BROADCAST)) {
242 hd_->compute_type_ = NodePattern::BROADCAST;
243 }
244
245 // update inputs and relations
246 (void)inputs_with_relation_.erase(iter);
247 (void)inputs_with_relation_.insert(inputs_with_relation_.cend(), input_area->inputs_with_relation_.cbegin(),
248 input_area->inputs_with_relation_.cend());
249 MakeUniqueAndSyncInputs();
250 UpdateUsersRelation(input_area);
251
252 // clear the input_area.
253 input_area->ops_.clear();
254 input_area->inputs_with_relation_.clear();
255 input_area->hd_->ClearInputs();
256 }
257 } // namespace mindspore::graphkernel::inner
258