• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/abstract_function.h"
18 #include <vector>
19 #include <utility>
20 #include <algorithm>
21 #include "base/base.h"
22 #include "utils/hashing.h"
23 #include "utils/hash_set.h"
24 #include "utils/ms_utils.h"
25 #include "abstract/abstract_value.h"
26 
27 namespace mindspore {
28 namespace abstract {
29 class Evaluator;
30 class AnalysisEngine;
31 
MakeAbstractFunction(const AbstractFuncAtomPtrList & func_list)32 AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) {
33   if (func_list.size() == 1) {
34     return func_list[0];
35   }
36   return std::make_shared<AbstractFuncUnion>(func_list);
37 }
38 
Join(const AbstractFunctionPtr & other)39 AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
40   MS_EXCEPTION_IF_NULL(other);
41   auto this_func = shared_from_base<AbstractFuncAtom>();
42   if (other->isa<AbstractFuncAtom>()) {
43     if (*this_func == *other) {
44       return this_func;
45     }
46     return std::make_shared<AbstractFuncUnion>(this_func, other);
47   }
48   auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
49   MS_EXCEPTION_IF_NULL(other_union);
50   if (other_union->IsSuperSet(this_func)) {
51     return other;
52   }
53   return std::make_shared<AbstractFuncUnion>(this_func, other);
54 }
55 
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const56 void AbstractFuncAtom::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
57   visit_func(const_cast<AbstractFuncAtom *>(this)->shared_from_base<AbstractFuncAtom>());
58 }
59 
operator ==(const AbstractFunction & other) const60 bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; }
61 
AbstractFuncUnion(const AbstractFuncAtomPtrList & func_list)62 AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) : func_list_(func_list) {}
63 
AbstractFuncUnion(const AbstractFunctionPtr & first,const AbstractFunctionPtr & second)64 AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) {
65   MS_EXCEPTION_IF_NULL(first);
66   MS_EXCEPTION_IF_NULL(second);
67   AbstractFuncAtomPtrList new_func_list;
68   auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); };
69   first->Visit(build_func_list);
70   second->Visit(build_func_list);
71   func_list_ = std::move(new_func_list);
72 }
73 
ToString() const74 std::string AbstractFuncUnion::ToString() const {
75   std::ostringstream buffer;
76   buffer << "AbstractFuncUnion({";
77   int64_t i = 0;
78   for (const auto &func : func_list_) {
79     MS_EXCEPTION_IF_NULL(func);
80     buffer << "[" << i << "]: " << func->ToString() << ", ";
81     i++;
82   }
83   buffer << "})";
84   return buffer.str();
85 }
86 
ToString(bool verbose) const87 std::string AbstractFuncUnion::ToString(bool verbose) const {
88   if (verbose) {
89     return ToString();
90   }
91   std::ostringstream buffer;
92   buffer << type_name() << "({";
93   size_t i = 0;
94   for (const auto &func : func_list_) {
95     MS_EXCEPTION_IF_NULL(func);
96     buffer << func->ToString(false);
97     i++;
98     if (i < func_list_.size()) {
99       buffer << ", ";
100     }
101   }
102   buffer << "})";
103   return buffer.str();
104 }
105 
IsSuperSet(const AbstractFunctionPtr & other)106 bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) {
107   MS_EXCEPTION_IF_NULL(other);
108   bool all_in_list = true;
109   other->Visit([this, &all_in_list](const AbstractFuncAtomPtr &func) {
110     if (all_in_list) {
111       auto iter = std::find(func_list_.begin(), func_list_.end(), func);
112       if (iter == func_list_.end()) {
113         all_in_list = false;
114       }
115     }
116   });
117   return all_in_list;
118 }
119 
Join(const AbstractFunctionPtr & other)120 AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
121   auto this_func = shared_from_base<AbstractFunction>();
122   MS_EXCEPTION_IF_NULL(other);
123   if (other->isa<AbstractFuncAtom>()) {
124     if (IsSuperSet(other)) {
125       return this_func;
126     }
127     return std::make_shared<AbstractFuncUnion>(this_func, other);
128   }
129   auto other_union = dyn_cast_ptr<AbstractFuncUnion>(other);
130   MS_EXCEPTION_IF_NULL(other_union);
131   if (other_union->IsSuperSet(this_func)) {
132     return other;
133   }
134   return std::make_shared<AbstractFuncUnion>(this_func, other);
135 }
136 
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const137 void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
138   for (const auto &poss : func_list_) {
139     visit_func(poss);
140   }
141 }
142 
operator ==(const AbstractFunction & other) const143 bool AbstractFuncUnion::operator==(const AbstractFunction &other) const {
144   if (!other.isa<AbstractFuncUnion>()) {
145     return false;
146   }
147   const auto &other_union = static_cast<const AbstractFuncUnion &>(other);
148   if (func_list_.size() != other_union.func_list_.size()) {
149     return false;
150   }
151   for (size_t i = 0; i < func_list_.size(); ++i) {
152     if (!common::IsEqual(func_list_[i], other_union.func_list_[i])) {
153       return false;
154     }
155   }
156   return true;
157 }
158 
hash() const159 std::size_t AbstractFuncUnion::hash() const {
160   std::size_t hash_sum = 0;
161   for (const auto &f : func_list_) {
162     MS_EXCEPTION_IF_NULL(f);
163     hash_sum = hash_combine(hash_sum, f->hash());
164   }
165   return hash_sum;
166 }
167 
operator ==(const AbstractFunction & other) const168 bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
169   if (!other.isa<PrimitiveAbstractClosure>()) {
170     return false;
171   }
172   const auto &other_abs = static_cast<const PrimitiveAbstractClosure &>(other);
173   return (prim_ == other_abs.prim_) && (tracking_id_ == other_abs.tracking_id_);
174 }
175 
hash() const176 std::size_t PrimitiveAbstractClosure::hash() const {
177   auto hash_value = static_cast<std::size_t>(tid());
178   hash_value = hash_combine(hash_value, PointerHash<PrimitivePtr>{}(prim_));
179   if (tracking_id_ != 0) {
180     hash_value = hash_combine(hash_value, static_cast<size_t>(tracking_id_));
181   }
182   return hash_value;
183 }
184 
ToString(bool verbose) const185 std::string PrimitiveAbstractClosure::ToString(bool verbose) const {
186   if (verbose) {
187     return ToString();
188   }
189   return type_name() + " (" + prim_->name() + ")";
190 }
191 
operator ==(const AbstractFunction & other) const192 bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
193   if (!other.isa<FuncGraphAbstractClosure>()) {
194     return false;
195   }
196   const auto &other_fg = static_cast<const FuncGraphAbstractClosure &>(other);
197   MS_EXCEPTION_IF_NULL(func_graph());
198   MS_EXCEPTION_IF_NULL(other_fg.func_graph());
199   return func_graph() == other_fg.func_graph() && context_ == other_fg.context_ &&
200          tracking_id_ == other_fg.tracking_id_;
201 }
202 
IsEqualExceptTrackingId(const FuncGraphAbstractClosure & other) const203 bool FuncGraphAbstractClosure::IsEqualExceptTrackingId(const FuncGraphAbstractClosure &other) const {
204   MS_EXCEPTION_IF_NULL(func_graph());
205   MS_EXCEPTION_IF_NULL(other.func_graph());
206   return (this == &other) || (func_graph() == other.func_graph() && context_ == other.context_);
207 }
208 
HashWithoutTrackingId() const209 std::size_t FuncGraphAbstractClosure::HashWithoutTrackingId() const {
210   MS_EXCEPTION_IF_NULL(func_graph());
211   auto hash_value = hash_combine(tid(), PointerHash<FuncGraphPtr>{}(func_graph()));
212   return hash_combine(hash_value, PointerHash<AnalysisContextPtr>{}(context_));
213 }
214 
hash() const215 std::size_t FuncGraphAbstractClosure::hash() const {
216   MS_EXCEPTION_IF_NULL(func_graph());
217   auto hash_value = hash_combine(tid(), PointerHash<FuncGraphPtr>{}(func_graph()));
218   hash_value = hash_combine(hash_value, PointerHash<AnalysisContextPtr>{}(context_));
219   if (tracking_id_ != 0) {
220     hash_value = hash_combine(hash_value, static_cast<size_t>(tracking_id_));
221   }
222   return hash_value;
223 }
224 
ToString() const225 std::string FuncGraphAbstractClosure::ToString() const {
226   std::stringstream ss;
227   MS_EXCEPTION_IF_NULL(func_graph());
228   MS_EXCEPTION_IF_NULL(context_);
229   ss << "FuncGraphAbstractClosure: "
230      << "FuncGraph: " << func_graph()->ToString() << "; Context: " << context_->ToString();
231   return ss.str();
232 }
233 
ToString(bool verbose) const234 std::string FuncGraphAbstractClosure::ToString(bool verbose) const {
235   if (verbose) {
236     return ToString();
237   }
238   std::stringstream ss;
239   MS_EXCEPTION_IF_NULL(func_graph());
240   ss << type_name() << "(" << func_graph()->ToString() << ")";
241   return ss.str();
242 }
243 
operator ==(const AbstractFunction & other) const244 bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
245   if (!other.isa<MetaFuncGraphAbstractClosure>()) {
246     return false;
247   }
248   const auto &other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure &>(other);
249   return (meta_func_graph_ == other_meta_fg.meta_func_graph_) && (tracking_id_ == other_meta_fg.tracking_id_);
250 }
251 
hash() const252 std::size_t MetaFuncGraphAbstractClosure::hash() const {
253   MS_EXCEPTION_IF_NULL(meta_func_graph_);
254   auto hash_value = hash_combine(tid(), PointerHash<MetaFuncGraphPtr>{}(meta_func_graph_));
255   if (tracking_id_ != 0) {
256     hash_value = hash_combine(hash_value, static_cast<size_t>(tracking_id_));
257   }
258   return hash_value;
259 }
260 
ToString() const261 std::string MetaFuncGraphAbstractClosure::ToString() const {
262   MS_EXCEPTION_IF_NULL(meta_func_graph_);
263   return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name();
264 }
265 
266 namespace {
267 // Helper class to prevent recursive calls.
268 class VisitedHistory {
269  public:
VisitedHistory(const void * address)270   explicit VisitedHistory(const void *address) : visited_(!history_.emplace(address).second) { ++deep_; }
~VisitedHistory()271   ~VisitedHistory() {
272     --deep_;
273     // cppcheck-suppress *
274     if (deep_ == 0) {  // The result of cppcheck is "Condition (deep_==0) is always true". But it's wrong.
275       history_.clear();
276     }
277   }
IsVisited() const278   bool IsVisited() const { return visited_; }
279 
280  private:
281   static inline thread_local mindspore::HashSet<const void *> history_;
282   static inline thread_local size_t deep_ = 0;
283   bool visited_{false};
284 };
285 }  // namespace
286 
operator ==(const AbstractFunction & other) const287 bool PartialAbstractClosure::operator==(const AbstractFunction &other) const {
288   if (!other.isa<PartialAbstractClosure>()) {
289     return false;
290   }
291   // Avoid to recursively compare.
292   VisitedHistory history(this);
293   if (history.IsVisited()) {
294     return true;
295   }
296   const auto &other_partial = static_cast<const PartialAbstractClosure &>(other);
297   if (!common::IsEqual(fn_, other_partial.fn_)) {
298     return false;
299   }
300   if (args_abs_list_.size() != other_partial.args_abs_list_.size()) {
301     return false;
302   }
303   for (size_t i = 0; i < args_abs_list_.size(); ++i) {
304     const auto &a = args_abs_list_[i];
305     const auto &b = other_partial.args_abs_list_[i];
306     if (a != nullptr && a->isa<AbstractFunction>()) {
307       if (!common::IsEqual(a, b)) {
308         return false;
309       }
310     } else if (a != b) {
311       return false;
312     }
313   }
314   return true;
315 }
316 
hash() const317 std::size_t PartialAbstractClosure::hash() const {
318   // Avoid to recursively hashing.
319   VisitedHistory history(this);
320   if (history.IsVisited()) {
321     return 0;
322   }
323   MS_EXCEPTION_IF_NULL(fn_);
324   auto hash_value = hash_combine(tid(), fn_->hash());
325   for (const auto &arg : args_abs_list_) {
326     if (arg != nullptr && arg->isa<AbstractFunction>()) {
327       hash_value = hash_combine(hash_value, arg->hash());
328     } else {
329       hash_value = hash_combine(hash_value, PointerHash<AbstractBasePtr>{}(arg));
330     }
331   }
332   return hash_value;
333 }
334 
ToString() const335 std::string PartialAbstractClosure::ToString() const {
336   // Avoid to recursively ToString.
337   VisitedHistory history(this);
338   if (history.IsVisited()) {
339     return "<recurred>";
340   }
341   std::ostringstream buffer;
342   buffer << "PartialAbstractClosure{" << fn_->ToString() << "(";
343   for (const auto &arg : args_abs_list_) {
344     buffer << (arg == nullptr ? "<null>" : arg->ToString()) << ", ";
345   }
346   buffer << ")}";
347   return buffer.str();
348 }
349 
ToString(bool verbose) const350 std::string PartialAbstractClosure::ToString(bool verbose) const {
351   if (verbose) {
352     return ToString();
353   }
354   std::ostringstream buffer;
355   buffer << type_name() << "(" << fn_->ToString(false) << " (argc=" << args_abs_list_.size() << "))";
356   return buffer.str();
357 }
358 
operator ==(const AbstractFunction & other) const359 bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
360   if (!other.isa<JTransformedAbstractClosure>()) {
361     return false;
362   }
363   const auto &other_transformed = static_cast<const JTransformedAbstractClosure &>(other);
364   return fn_ == other_transformed.fn_;
365 }
366 
hash() const367 std::size_t JTransformedAbstractClosure::hash() const {
368   return hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
369 }
370 
operator ==(const AbstractFunction & other) const371 bool TaylorTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
372   if (!other.isa<TaylorTransformedAbstractClosure>()) {
373     return false;
374   }
375   const auto &other_transformed = static_cast<const TaylorTransformedAbstractClosure &>(other);
376   return fn_ == other_transformed.fn_;
377 }
378 
hash() const379 std::size_t TaylorTransformedAbstractClosure::hash() const {
380   return hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
381 }
382 
operator ==(const AbstractFunction & other) const383 bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
384   if (!other.isa<ShardTransformedAbstractClosure>()) {
385     return false;
386   }
387   const auto &other_transformed = static_cast<const ShardTransformedAbstractClosure &>(other);
388   return fn_ == other_transformed.fn_;
389 }
390 
hash() const391 std::size_t ShardTransformedAbstractClosure::hash() const {
392   return hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
393 }
394 
operator ==(const AbstractFunction & other) const395 bool VmapTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
396   if (!other.isa<VmapTransformedAbstractClosure>()) {
397     return false;
398   }
399   const auto &other_transformed = static_cast<const VmapTransformedAbstractClosure &>(other);
400   return fn_ == other_transformed.fn_ && in_axes_ == other_transformed.in_axes_ &&
401          out_axes_ == other_transformed.out_axes_;
402 }
403 
hash() const404 std::size_t VmapTransformedAbstractClosure::hash() const {
405   auto hash_value = hash_combine(tid(), PointerHash<AbstractFuncAtomPtr>{}(fn_));
406   hash_value = hash_combine(hash_value, PointerHash<ValuePtr>{}(in_axes_));
407   hash_value = hash_combine(hash_value, PointerHash<ValuePtr>{}(out_axes_));
408   return hash_value;
409 }
410 
operator ==(const AbstractFunction & other) const411 bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
412   if (!other.isa<VirtualAbstractClosure>()) {
413     return false;
414   }
415   const auto &other_virtual = static_cast<const VirtualAbstractClosure &>(other);
416   if (!common::IsEqual(output_, other_virtual.output_)) {
417     return false;
418   }
419   return AbstractBasePtrListDeepEqual(args_abs_list_, other_virtual.args_abs_list_);
420 }
421 
hash() const422 std::size_t VirtualAbstractClosure::hash() const {
423   MS_EXCEPTION_IF_NULL(output_);
424   auto hash_value = hash_combine(tid(), output_->hash());
425   return hash_combine(hash_value, AbstractBasePtrListHash(args_abs_list_));
426 }
427 
ToString() const428 std::string VirtualAbstractClosure::ToString() const {
429   std::ostringstream buffer;
430   buffer << "VirtualAbstractClosure(args: {";
431   int64_t i = 0;
432   for (const auto &arg : args_abs_list_) {
433     MS_EXCEPTION_IF_NULL(arg);
434     if (arg->isa<AbstractFuncAtom>()) {
435       // If the arg is a subclass of AbstractFuncAtom, a recursive dead loop may occur.
436       // So in this case, we use type_name() instead of ToString().
437       buffer << "[" << i << "]: " << arg->type_name() << ", ";
438     } else {
439       buffer << "[" << i << "]: " << arg->ToString() << ", ";
440     }
441     i++;
442   }
443   MS_EXCEPTION_IF_NULL(output_);
444   buffer << "}, output: " << output_->ToString() << ")";
445   return buffer.str();
446 }
447 
operator ==(const AbstractFunction & other) const448 bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
449   if (!other.isa<TypedPrimitiveAbstractClosure>()) {
450     return false;
451   }
452   // Avoid to recursively compare.
453   VisitedHistory history(this);
454   if (history.IsVisited()) {
455     return true;
456   }
457   const auto &other_typed = static_cast<const TypedPrimitiveAbstractClosure &>(other);
458   if (prim_ != other_typed.prim_) {
459     return false;
460   }
461   if (!common::IsEqual(output_, other_typed.output_)) {
462     return false;
463   }
464   return AbstractBasePtrListDeepEqual(args_abs_list_, other_typed.args_abs_list_);
465 }
466 
hash() const467 std::size_t TypedPrimitiveAbstractClosure::hash() const {
468   // Avoid to recursively hashing.
469   VisitedHistory history(this);
470   if (history.IsVisited()) {
471     return 0;
472   }
473   auto hash_value = hash_combine(tid(), PointerHash<PrimitivePtr>{}(prim_));
474   if (output_ != nullptr) {
475     hash_value = hash_combine(hash_value, output_->hash());
476   }
477   hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_abs_list_));
478   return hash_value;
479 }
480 
ToString() const481 std::string TypedPrimitiveAbstractClosure::ToString() const {
482   // Avoid to recursively ToString.
483   VisitedHistory history(this);
484   if (history.IsVisited()) {
485     return "<recurred>";
486   }
487   std::ostringstream buffer;
488   buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {";
489   for (const auto &arg : args_abs_list_) {
490     buffer << (arg == nullptr ? "<null>" : arg->ToString()) << ", ";
491   }
492   MS_EXCEPTION_IF_NULL(output_);
493   buffer << "}, output: " << output_->ToString() << ")";
494   return buffer.str();
495 }
496 }  // namespace abstract
497 }  // namespace mindspore
498