1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/abstract_function.h"
18 #include <vector>
19 #include <utility>
20 #include <algorithm>
21 #include "base/base.h"
22 #include "utils/hashing.h"
23 #include "utils/hash_set.h"
24 #include "utils/ms_utils.h"
25 #include "abstract/abstract_value.h"
26
27 namespace mindspore {
28 namespace abstract {
29 class Evaluator;
30 class AnalysisEngine;
31
MakeAbstractFunction(const AbstractFuncAtomPtrList & func_list)32 AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) {
33 if (func_list.size() == 1) {
34 return func_list[0];
35 }
36 return std::make_shared<AbstractFuncUnion>(func_list);
37 }
38
Join(const AbstractFunctionPtr & other)39 AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
40 MS_EXCEPTION_IF_NULL(other);
41 auto this_func = shared_from_base<AbstractFuncAtom>();
42 if (other->isa<AbstractFuncAtom>()) {
43 if (*this_func == *other) {
44 return this_func;
45 }
46 return std::make_shared<AbstractFuncUnion>(this_func, other);
47 }
48 auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
49 MS_EXCEPTION_IF_NULL(other_union);
50 if (other_union->IsSuperSet(this_func)) {
51 return other;
52 }
53 return std::make_shared<AbstractFuncUnion>(this_func, other);
54 }
55
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const56 void AbstractFuncAtom::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
57 visit_func(const_cast<AbstractFuncAtom *>(this)->shared_from_base<AbstractFuncAtom>());
58 }
59
operator ==(const AbstractFunction & other) const60 bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; }
61
AbstractFuncUnion(const AbstractFuncAtomPtrList & func_list)62 AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) : func_list_(func_list) {}
63
AbstractFuncUnion(const AbstractFunctionPtr & first,const AbstractFunctionPtr & second)64 AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) {
65 MS_EXCEPTION_IF_NULL(first);
66 MS_EXCEPTION_IF_NULL(second);
67 AbstractFuncAtomPtrList new_func_list;
68 auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); };
69 first->Visit(build_func_list);
70 second->Visit(build_func_list);
71 func_list_ = std::move(new_func_list);
72 }
73
ToString() const74 std::string AbstractFuncUnion::ToString() const {
75 std::ostringstream buffer;
76 buffer << "AbstractFuncUnion({";
77 int64_t i = 0;
78 for (const auto &func : func_list_) {
79 MS_EXCEPTION_IF_NULL(func);
80 buffer << "[" << i << "]: " << func->ToString() << ", ";
81 i++;
82 }
83 buffer << "})";
84 return buffer.str();
85 }
86
ToString(bool verbose) const87 std::string AbstractFuncUnion::ToString(bool verbose) const {
88 if (verbose) {
89 return ToString();
90 }
91 std::ostringstream buffer;
92 buffer << type_name() << "({";
93 size_t i = 0;
94 for (const auto &func : func_list_) {
95 MS_EXCEPTION_IF_NULL(func);
96 buffer << func->ToString(false);
97 i++;
98 if (i < func_list_.size()) {
99 buffer << ", ";
100 }
101 }
102 buffer << "})";
103 return buffer.str();
104 }
105
IsSuperSet(const AbstractFunctionPtr & other)106 bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) {
107 MS_EXCEPTION_IF_NULL(other);
108 bool all_in_list = true;
109 other->Visit([this, &all_in_list](const AbstractFuncAtomPtr &func) {
110 if (all_in_list) {
111 auto iter = std::find(func_list_.begin(), func_list_.end(), func);
112 if (iter == func_list_.end()) {
113 all_in_list = false;
114 }
115 }
116 });
117 return all_in_list;
118 }
119
Join(const AbstractFunctionPtr & other)120 AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
121 auto this_func = shared_from_base<AbstractFunction>();
122 MS_EXCEPTION_IF_NULL(other);
123 if (other->isa<AbstractFuncAtom>()) {
124 if (IsSuperSet(other)) {
125 return this_func;
126 }
127 return std::make_shared<AbstractFuncUnion>(this_func, other);
128 }
129 auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
130 MS_EXCEPTION_IF_NULL(other_union);
131 if (other_union->IsSuperSet(this_func)) {
132 return other;
133 }
134 return std::make_shared<AbstractFuncUnion>(this_func, other);
135 }
136
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const137 void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
138 for (const auto &poss : func_list_) {
139 visit_func(poss);
140 }
141 }
142
operator ==(const AbstractFunction & other) const143 bool AbstractFuncUnion::operator==(const AbstractFunction &other) const {
144 if (!other.isa<AbstractFuncUnion>()) {
145 return false;
146 }
147 const auto &other_union = static_cast<const AbstractFuncUnion &>(other);
148 if (func_list_.size() != other_union.func_list_.size()) {
149 return false;
150 }
151 for (size_t i = 0; i < func_list_.size(); ++i) {
152 if (!common::IsEqual(func_list_[i], other_union.func_list_[i])) {
153 return false;
154 }
155 }
156 return true;
157 }
158
hash() const159 std::size_t AbstractFuncUnion::hash() const {
160 std::size_t hash_sum = 0;
161 for (const auto &f : func_list_) {
162 MS_EXCEPTION_IF_NULL(f);
163 hash_sum = hash_combine(hash_sum, f->hash());
164 }
165 return hash_sum;
166 }
167
operator ==(const AbstractFunction & other) const168 bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
169 if (!other.isa<PrimitiveAbstractClosure>()) {
170 return false;
171 }
172 const auto &other_abs = static_cast<const PrimitiveAbstractClosure &>(other);
173 return (prim_ == other_abs.prim_) && (tracking_id_ == other_abs.tracking_id_);
174 }
175
hash() const176 std::size_t PrimitiveAbstractClosure::hash() const {
177 auto hash_value = static_cast<std::size_t>(tid());
178 hash_value = hash_combine(hash_value, PointerHash<PrimitivePtr>{}(prim_));
179 if (tracking_id_ != 0) {
180 hash_value = hash_combine(hash_value, static_cast<size_t>(tracking_id_));
181 }
182 return hash_value;
183 }
184
ToString(bool verbose) const185 std::string PrimitiveAbstractClosure::ToString(bool verbose) const {
186 if (verbose) {
187 return ToString();
188 }
189 return type_name() + " (" + prim_->name() + ")";
190 }
191
operator ==(const AbstractFunction & other) const192 bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
193 if (!other.isa<FuncGraphAbstractClosure>()) {
194 return false;
195 }
196 const auto &other_fg = static_cast<const FuncGraphAbstractClosure &>(other);
197 MS_EXCEPTION_IF_NULL(func_graph());
198 MS_EXCEPTION_IF_NULL(other_fg.func_graph());
199 return func_graph() == other_fg.func_graph() && context_ == other_fg.context_ &&
200 tracking_id_ == other_fg.tracking_id_;
201 }
202
IsEqualExceptTrackingId(const FuncGraphAbstractClosure & other) const203 bool FuncGraphAbstractClosure::IsEqualExceptTrackingId(const FuncGraphAbstractClosure &other) const {
204 MS_EXCEPTION_IF_NULL(func_graph());
205 MS_EXCEPTION_IF_NULL(other.func_graph());
206 return (this == &other) || (func_graph() == other.func_graph() && context_ == other.context_);
207 }
208
HashWithoutTrackingId() const209 std::size_t FuncGraphAbstractClosure::HashWithoutTrackingId() const {
210 MS_EXCEPTION_IF_NULL(func_graph());
211 auto hash_value = hash_combine(tid(), PointerHash<FuncGraphPtr>{}(func_graph()));
212 return hash_combine(hash_value, PointerHash<AnalysisContextPtr>{}(context_));
213 }
214
hash() const215 std::size_t FuncGraphAbstractClosure::hash() const {
216 MS_EXCEPTION_IF_NULL(func_graph());
217 auto hash_value = hash_combine(tid(), PointerHash<FuncGraphPtr>{}(func_graph()));
218 hash_value = hash_combine(hash_value, PointerHash<AnalysisContextPtr>{}(context_));
219 if (tracking_id_ != 0) {
220 hash_value = hash_combine(hash_value, static_cast<size_t>(tracking_id_));
221 }
222 return hash_value;
223 }
224
ToString() const225 std::string FuncGraphAbstractClosure::ToString() const {
226 std::stringstream ss;
227 MS_EXCEPTION_IF_NULL(func_graph());
228 MS_EXCEPTION_IF_NULL(context_);
229 ss << "FuncGraphAbstractClosure: "
230 << "FuncGraph: " << func_graph()->ToString() << "; Context: " << context_->ToString();
231 return ss.str();
232 }
233
ToString(bool verbose) const234 std::string FuncGraphAbstractClosure::ToString(bool verbose) const {
235 if (verbose) {
236 return ToString();
237 }
238 std::stringstream ss;
239 MS_EXCEPTION_IF_NULL(func_graph());
240 ss << type_name() << "(" << func_graph()->ToString() << ")";
241 return ss.str();
242 }
243
operator ==(const AbstractFunction & other) const244 bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
245 if (!other.isa<MetaFuncGraphAbstractClosure>()) {
246 return false;
247 }
248 const auto &other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure &>(other);
249 return (meta_func_graph_ == other_meta_fg.meta_func_graph_) && (tracking_id_ == other_meta_fg.tracking_id_);
250 }
251
hash() const252 std::size_t MetaFuncGraphAbstractClosure::hash() const {
253 MS_EXCEPTION_IF_NULL(meta_func_graph_);
254 auto hash_value = hash_combine(tid(), PointerHash<MetaFuncGraphPtr>{}(meta_func_graph_));
255 if (tracking_id_ != 0) {
256 hash_value = hash_combine(hash_value, static_cast<size_t>(tracking_id_));
257 }
258 return hash_value;
259 }
260
ToString() const261 std::string MetaFuncGraphAbstractClosure::ToString() const {
262 MS_EXCEPTION_IF_NULL(meta_func_graph_);
263 return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name();
264 }
265
266 namespace {
267 // Helper class to prevent recursive calls.
268 class VisitedHistory {
269 public:
VisitedHistory(const void * address)270 explicit VisitedHistory(const void *address) : visited_(!history_.emplace(address).second) { ++deep_; }
~VisitedHistory()271 ~VisitedHistory() {
272 --deep_;
273 // cppcheck-suppress *
274 if (deep_ == 0) { // The result of cppcheck is "Condition (deep_==0) is always true". But it's wrong.
275 history_.clear();
276 }
277 }
IsVisited() const278 bool IsVisited() const { return visited_; }
279
280 private:
281 static inline thread_local mindspore::HashSet<const void *> history_;
282 static inline thread_local size_t deep_ = 0;
283 bool visited_{false};
284 };
285 } // namespace
286
operator ==(const AbstractFunction & other) const287 bool PartialAbstractClosure::operator==(const AbstractFunction &other) const {
288 if (!other.isa<PartialAbstractClosure>()) {
289 return false;
290 }
291 // Avoid to recursively compare.
292 VisitedHistory history(this);
293 if (history.IsVisited()) {
294 return true;
295 }
296 const auto &other_partial = static_cast<const PartialAbstractClosure &>(other);
297 if (!common::IsEqual(fn_, other_partial.fn_)) {
298 return false;
299 }
300 if (args_abs_list_.size() != other_partial.args_abs_list_.size()) {
301 return false;
302 }
303 for (size_t i = 0; i < args_abs_list_.size(); ++i) {
304 const auto &a = args_abs_list_[i];
305 const auto &b = other_partial.args_abs_list_[i];
306 if (a != nullptr && a->isa<AbstractFunction>()) {
307 if (!common::IsEqual(a, b)) {
308 return false;
309 }
310 } else if (a != b) {
311 return false;
312 }
313 }
314 return true;
315 }
316
hash() const317 std::size_t PartialAbstractClosure::hash() const {
318 // Avoid to recursively hashing.
319 VisitedHistory history(this);
320 if (history.IsVisited()) {
321 return 0;
322 }
323 MS_EXCEPTION_IF_NULL(fn_);
324 auto hash_value = hash_combine(tid(), fn_->hash());
325 for (const auto &arg : args_abs_list_) {
326 if (arg != nullptr && arg->isa<AbstractFunction>()) {
327 hash_value = hash_combine(hash_value, arg->hash());
328 } else {
329 hash_value = hash_combine(hash_value, PointerHash<AbstractBasePtr>{}(arg));
330 }
331 }
332 return hash_value;
333 }
334
ToString() const335 std::string PartialAbstractClosure::ToString() const {
336 // Avoid to recursively ToString.
337 VisitedHistory history(this);
338 if (history.IsVisited()) {
339 return "<recurred>";
340 }
341 std::ostringstream buffer;
342 buffer << "PartialAbstractClosure{" << fn_->ToString() << "(";
343 for (const auto &arg : args_abs_list_) {
344 buffer << (arg == nullptr ? "<null>" : arg->ToString()) << ", ";
345 }
346 buffer << ")}";
347 return buffer.str();
348 }
349
ToString(bool verbose) const350 std::string PartialAbstractClosure::ToString(bool verbose) const {
351 if (verbose) {
352 return ToString();
353 }
354 std::ostringstream buffer;
355 buffer << type_name() << "(" << fn_->ToString(false) << " (argc=" << args_abs_list_.size() << "))";
356 return buffer.str();
357 }
358
operator ==(const AbstractFunction & other) const359 bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
360 if (!other.isa<JTransformedAbstractClosure>()) {
361 return false;
362 }
363 const auto &other_transformed = static_cast<const JTransformedAbstractClosure &>(other);
364 return fn_ == other_transformed.fn_;
365 }
366
hash() const367 std::size_t JTransformedAbstractClosure::hash() const {
368 return hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
369 }
370
operator ==(const AbstractFunction & other) const371 bool TaylorTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
372 if (!other.isa<TaylorTransformedAbstractClosure>()) {
373 return false;
374 }
375 const auto &other_transformed = static_cast<const TaylorTransformedAbstractClosure &>(other);
376 return fn_ == other_transformed.fn_;
377 }
378
hash() const379 std::size_t TaylorTransformedAbstractClosure::hash() const {
380 return hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
381 }
382
operator ==(const AbstractFunction & other) const383 bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
384 if (!other.isa<ShardTransformedAbstractClosure>()) {
385 return false;
386 }
387 const auto &other_transformed = static_cast<const ShardTransformedAbstractClosure &>(other);
388 return fn_ == other_transformed.fn_;
389 }
390
hash() const391 std::size_t ShardTransformedAbstractClosure::hash() const {
392 return hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
393 }
394
operator ==(const AbstractFunction & other) const395 bool VmapTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
396 if (!other.isa<VmapTransformedAbstractClosure>()) {
397 return false;
398 }
399 const auto &other_transformed = static_cast<const VmapTransformedAbstractClosure &>(other);
400 return fn_ == other_transformed.fn_ && in_axes_ == other_transformed.in_axes_ &&
401 out_axes_ == other_transformed.out_axes_;
402 }
403
hash() const404 std::size_t VmapTransformedAbstractClosure::hash() const {
405 auto hash_value = hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
406 hash_value = hash_combine(hash_value, PointerHash<ValuePtr>{}(in_axes_));
407 hash_value = hash_combine(hash_value, PointerHash<ValuePtr>{}(out_axes_));
408 return hash_value;
409 }
410
operator ==(const AbstractFunction & other) const411 bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
412 if (!other.isa<VirtualAbstractClosure>()) {
413 return false;
414 }
415 const auto &other_virtual = static_cast<const VirtualAbstractClosure &>(other);
416 if (!common::IsEqual(output_, other_virtual.output_)) {
417 return false;
418 }
419 return AbstractBasePtrListDeepEqual(args_abs_list_, other_virtual.args_abs_list_);
420 }
421
hash() const422 std::size_t VirtualAbstractClosure::hash() const {
423 MS_EXCEPTION_IF_NULL(output_);
424 auto hash_value = hash_combine(tid(), output_->hash());
425 return hash_combine(hash_value, AbstractBasePtrListHash(args_abs_list_));
426 }
427
ToString() const428 std::string VirtualAbstractClosure::ToString() const {
429 std::ostringstream buffer;
430 buffer << "VirtualAbstractClosure(args: {";
431 int64_t i = 0;
432 for (const auto &arg : args_abs_list_) {
433 MS_EXCEPTION_IF_NULL(arg);
434 if (arg->isa<AbstractFuncAtom>()) {
435 // If the arg is a subclass of AbstractFuncAtom, a recursive dead loop may occur.
436 // So in this case, we use type_name() instead of ToString().
437 buffer << "[" << i << "]: " << arg->type_name() << ", ";
438 } else {
439 buffer << "[" << i << "]: " << arg->ToString() << ", ";
440 }
441 i++;
442 }
443 MS_EXCEPTION_IF_NULL(output_);
444 buffer << "}, output: " << output_->ToString() << ")";
445 return buffer.str();
446 }
447
operator ==(const AbstractFunction & other) const448 bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
449 if (!other.isa<TypedPrimitiveAbstractClosure>()) {
450 return false;
451 }
452 // Avoid to recursively compare.
453 VisitedHistory history(this);
454 if (history.IsVisited()) {
455 return true;
456 }
457 const auto &other_typed = static_cast<const TypedPrimitiveAbstractClosure &>(other);
458 if (prim_ != other_typed.prim_) {
459 return false;
460 }
461 if (!common::IsEqual(output_, other_typed.output_)) {
462 return false;
463 }
464 return AbstractBasePtrListDeepEqual(args_abs_list_, other_typed.args_abs_list_);
465 }
466
hash() const467 std::size_t TypedPrimitiveAbstractClosure::hash() const {
468 // Avoid to recursively hashing.
469 VisitedHistory history(this);
470 if (history.IsVisited()) {
471 return 0;
472 }
473 auto hash_value = hash_combine(tid(), PointerHash<PrimitivePtr>{}(prim_));
474 if (output_ != nullptr) {
475 hash_value = hash_combine(hash_value, output_->hash());
476 }
477 hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_abs_list_));
478 return hash_value;
479 }
480
ToString() const481 std::string TypedPrimitiveAbstractClosure::ToString() const {
482 // Avoid to recursively ToString.
483 VisitedHistory history(this);
484 if (history.IsVisited()) {
485 return "<recurred>";
486 }
487 std::ostringstream buffer;
488 buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {";
489 for (const auto &arg : args_abs_list_) {
490 buffer << (arg == nullptr ? "<null>" : arg->ToString()) << ", ";
491 }
492 MS_EXCEPTION_IF_NULL(output_);
493 buffer << "}, output: " << output_->ToString() << ")";
494 return buffer.str();
495 }
496 } // namespace abstract
497 } // namespace mindspore
498