• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/abstract_function.h"
18 
19 #include <vector>
20 
21 namespace mindspore {
22 namespace abstract {
23 class Evaluator;
24 class AnalysisEngine;
MakeAbstractFunction(const AbstractFuncAtomPtrList & func_list)25 AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) {
26   if (func_list.size() == 1) {
27     return func_list[0];
28   }
29   return std::make_shared<AbstractFuncUnion>(func_list);
30 }
31 
Join(const AbstractFunctionPtr & other)32 AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
33   MS_EXCEPTION_IF_NULL(other);
34   auto this_func = shared_from_base<AbstractFuncAtom>();
35   if (other->isa<AbstractFuncAtom>()) {
36     if (*this_func == *other) {
37       return this_func;
38     }
39     return std::make_shared<AbstractFuncUnion>(this_func, other);
40   }
41   auto other_union = dyn_cast<AbstractFuncUnion>(other);
42   MS_EXCEPTION_IF_NULL(other_union);
43   if (other_union->IsSuperSet(this_func)) {
44     return other;
45   }
46   return std::make_shared<AbstractFuncUnion>(this_func, other);
47 }
48 
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const49 void AbstractFuncAtom::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
50   visit_func(const_cast<AbstractFuncAtom *>(this)->shared_from_base<AbstractFuncAtom>());
51 }
52 
operator ==(const AbstractFunction & other) const53 bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; }
54 
AbstractFuncUnion(const AbstractFuncAtomPtrList & func_list)55 AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; }
56 
AbstractFuncUnion(const AbstractFunctionPtr & first,const AbstractFunctionPtr & second)57 AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) {
58   AbstractFuncAtomPtrList new_func_list;
59   auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); };
60   MS_EXCEPTION_IF_NULL(first);
61   MS_EXCEPTION_IF_NULL(second);
62   first->Visit(build_func_list);
63   second->Visit(build_func_list);
64   func_list_ = new_func_list;
65 }
66 
ToString() const67 std::string AbstractFuncUnion::ToString() const {
68   std::ostringstream buffer;
69   buffer << "AbstractFuncUnion({";
70   int64_t i = 0;
71   for (const auto &func : func_list_) {
72     MS_EXCEPTION_IF_NULL(func);
73     buffer << "[" << i << "]: " << func->ToString() << ", ";
74     i++;
75   }
76   buffer << "})";
77   return buffer.str();
78 }
79 
IsSuperSet(const AbstractFunctionPtr & other)80 bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) {
81   MS_EXCEPTION_IF_NULL(other);
82   std::vector<bool> is_in_list;
83   auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) {
84     auto iter = find(func_list_.begin(), func_list_.end(), func);
85     if (iter == func_list_.end()) {
86       is_in_list.push_back(false);
87     }
88     return true;
89   };
90   other->Visit(build_in_list);
91   return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; });
92 }
93 
Join(const AbstractFunctionPtr & other)94 AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
95   auto this_func = shared_from_base<AbstractFunction>();
96   MS_EXCEPTION_IF_NULL(other);
97   if (other->isa<AbstractFuncAtom>()) {
98     if (IsSuperSet(other)) {
99       return this_func;
100     }
101     return std::make_shared<AbstractFuncUnion>(this_func, other);
102   }
103   auto other_union = dyn_cast<AbstractFuncUnion>(other);
104   MS_EXCEPTION_IF_NULL(other_union);
105   if (other_union->IsSuperSet(this_func)) {
106     return other;
107   }
108   return std::make_shared<AbstractFuncUnion>(this_func, other);
109 }
110 
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const111 void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
112   for (const AbstractFuncAtomPtr &poss : func_list_) {
113     visit_func(poss);
114   }
115 }
116 
operator ==(const AbstractFunction & other) const117 bool AbstractFuncUnion::operator==(const AbstractFunction &other) const {
118   if (!other.isa<AbstractFuncUnion>()) {
119     return false;
120   }
121   auto other_union = static_cast<const AbstractFuncUnion *>(&other);
122   if (func_list_.size() != other_union->func_list_.size()) {
123     return false;
124   }
125   return func_list_ == other_union->func_list_;
126 }
127 
hash() const128 std::size_t AbstractFuncUnion::hash() const {
129   std::size_t hash_sum = 0;
130   for (const auto &f : func_list_) {
131     MS_EXCEPTION_IF_NULL(f);
132     hash_sum = hash_combine(hash_sum, f->hash());
133   }
134   return hash_sum;
135 }
136 
operator ==(const AbstractFunction & other) const137 bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
138   if (!other.isa<PrimitiveAbstractClosure>()) {
139     return false;
140   }
141   auto other_prim = static_cast<const PrimitiveAbstractClosure *>(&other);
142   MS_EXCEPTION_IF_NULL(prim_);
143   return (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id());
144 }
145 
hash() const146 std::size_t PrimitiveAbstractClosure::hash() const {
147   auto hash_value = hash_combine(tid(), prim_->hash());
148   // Keep in sync with operator==() which compares the prim_ pointer;
149   hash_value = hash_combine(hash_value, std::hash<Primitive *>{}(prim_.get()));
150   if (tracking_id() != nullptr) {
151     hash_value = hash_combine(hash_value, tracking_id()->hash());
152   }
153   return hash_value;
154 }
155 
operator ==(const AbstractFunction & other) const156 bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
157   if (!other.isa<FuncGraphAbstractClosure>()) {
158     return false;
159   }
160   auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
161   return func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ &&
162          tracking_id() == other_fg->tracking_id();
163 }
164 
hash() const165 std::size_t FuncGraphAbstractClosure::hash() const {
166   auto hash_value = hash_combine(tid(), func_graph_->hash());
167   hash_value = hash_combine(hash_value, context_->hash());
168   if (tracking_id() != nullptr) {
169     hash_value = hash_combine(hash_value, tracking_id()->hash());
170   }
171   return hash_value;
172 }
173 
ToString() const174 std::string FuncGraphAbstractClosure::ToString() const {
175   std::stringstream ss;
176   MS_EXCEPTION_IF_NULL(func_graph_);
177   MS_EXCEPTION_IF_NULL(context_);
178   ss << "FuncGraphAbstractClosure: "
179      << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString();
180   return ss.str();
181 }
182 
operator ==(const AbstractFunction & other) const183 bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
184   if (!other.isa<MetaFuncGraphAbstractClosure>()) {
185     return false;
186   }
187   auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
188   return meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id();
189 }
190 
hash() const191 std::size_t MetaFuncGraphAbstractClosure::hash() const {
192   MS_EXCEPTION_IF_NULL(meta_func_graph_);
193   auto hash_value = hash_combine(tid(), meta_func_graph_->hash());
194   if (tracking_id() != nullptr) {
195     hash_value = hash_combine(hash_value, tracking_id()->hash());
196   }
197   return hash_value;
198 }
199 
ToString() const200 std::string MetaFuncGraphAbstractClosure::ToString() const {
201   MS_EXCEPTION_IF_NULL(meta_func_graph_);
202   return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name();
203 }
204 
operator ==(const AbstractFunction & other) const205 bool PartialAbstractClosure::operator==(const AbstractFunction &other) const {
206   if (!other.isa<PartialAbstractClosure>()) {
207     return false;
208   }
209   auto other_partial = static_cast<const PartialAbstractClosure *>(&other);
210   if (fn_ != other_partial->fn_) {
211     return false;
212   }
213   if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
214     return false;
215   }
216   return args_spec_list_ == other_partial->args_spec_list_;
217 }
218 
hash() const219 std::size_t PartialAbstractClosure::hash() const {
220   MS_EXCEPTION_IF_NULL(fn_);
221   auto hash_value = hash_combine(tid(), fn_->hash());
222   hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
223   return hash_value;
224 }
225 
ToString() const226 std::string PartialAbstractClosure::ToString() const {
227   std::ostringstream buffer;
228   buffer << "PartialAbstractClosure(" << fn_->ToString() << "(";
229   for (const auto &arg : args_spec_list_) {
230     MS_EXCEPTION_IF_NULL(arg);
231     buffer << arg->ToString() << ", ";
232   }
233   buffer << "))";
234   return buffer.str();
235 }
236 
operator ==(const AbstractFunction & other) const237 bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
238   if (!other.isa<JTransformedAbstractClosure>()) {
239     return false;
240   }
241   auto other_transformed = static_cast<const JTransformedAbstractClosure *>(&other);
242   return fn_ == other_transformed->fn_;
243 }
244 
hash() const245 std::size_t JTransformedAbstractClosure::hash() const {
246   MS_EXCEPTION_IF_NULL(fn_);
247   auto hash_value = hash_combine(tid(), fn_->hash());
248   return hash_value;
249 }
250 
operator ==(const AbstractFunction & other) const251 bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
252   if (!other.isa<VirtualAbstractClosure>()) {
253     return false;
254   }
255   auto other_virtual = static_cast<const VirtualAbstractClosure *>(&other);
256   if (output_ != other_virtual->output_) {
257     return false;
258   }
259   if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) {
260     return false;
261   }
262   return args_spec_list_ == other_virtual->args_spec_list_;
263 }
264 
hash() const265 std::size_t VirtualAbstractClosure::hash() const {
266   MS_EXCEPTION_IF_NULL(output_);
267   auto hash_value = hash_combine(tid(), output_->hash());
268   hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
269   return hash_value;
270 }
271 
ToString() const272 std::string VirtualAbstractClosure::ToString() const {
273   std::ostringstream buffer;
274   buffer << "VirtualAbstractClosure(args: {";
275   int64_t i = 0;
276   for (const auto &arg : args_spec_list_) {
277     MS_EXCEPTION_IF_NULL(arg);
278     buffer << "[" << i << "]: " << arg->ToString() << ", ";
279     i++;
280   }
281   MS_EXCEPTION_IF_NULL(output_);
282   buffer << "}, output: " << output_->ToString() << ")";
283   return buffer.str();
284 }
285 
operator ==(const AbstractFunction & other) const286 bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
287   if (!other.isa<TypedPrimitiveAbstractClosure>()) {
288     return false;
289   }
290   auto other_typed = static_cast<const TypedPrimitiveAbstractClosure *>(&other);
291   if (output_ != other_typed->output_) {
292     return false;
293   }
294   if (prim_ != other_typed->prim_) {
295     return false;
296   }
297   if (args_spec_list_.size() != other_typed->args_spec_list_.size()) {
298     return false;
299   }
300   return args_spec_list_ == other_typed->args_spec_list_;
301 }
302 
hash() const303 std::size_t TypedPrimitiveAbstractClosure::hash() const {
304   auto hash_value = hash_combine(tid(), prim_->hash());
305   hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
306   return hash_value;
307 }
308 
ToString() const309 std::string TypedPrimitiveAbstractClosure::ToString() const {
310   std::ostringstream buffer;
311   buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {";
312   int64_t i = 0;
313   for (const auto &arg : args_spec_list_) {
314     MS_EXCEPTION_IF_NULL(arg);
315     buffer << "[" << i << "]: " << arg->ToString() << ", ";
316     i++;
317   }
318   MS_EXCEPTION_IF_NULL(output_);
319   buffer << "}, output: " << output_->ToString() << ")";
320   return buffer.str();
321 }
322 
operator ==(const AbstractFunction & other) const323 bool PyInterpretAbstractClosure::operator==(const AbstractFunction &other) const {
324   if (!other.isa<PyInterpretAbstractClosure>()) {
325     return false;
326   }
327   auto other_partial = static_cast<const PyInterpretAbstractClosure *>(&other);
328   if (fn_ != other_partial->fn_) {
329     return false;
330   }
331   if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
332     return false;
333   }
334   return args_spec_list_ == other_partial->args_spec_list_;
335 }
336 
hash() const337 std::size_t PyInterpretAbstractClosure::hash() const {
338   MS_EXCEPTION_IF_NULL(fn_);
339   auto hash_value = hash_combine(tid(), fn_->hash());
340   hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
341   return hash_value;
342 }
343 
ToString() const344 std::string PyInterpretAbstractClosure::ToString() const {
345   std::ostringstream buffer;
346   buffer << "PyInterpretAbstractClosure(" << fn_->ToString() << "(";
347   for (const auto &arg : args_spec_list_) {
348     MS_EXCEPTION_IF_NULL(arg);
349     buffer << arg->ToString() << ", ";
350   }
351   buffer << "))";
352   return buffer.str();
353 }
354 
operator ==(const AbstractFunction & other) const355 bool DummyAbstractClosure::operator==(const AbstractFunction &other) const {
356   return !other.isa<DummyAbstractClosure>();
357 }
358 }  // namespace abstract
359 }  // namespace mindspore
360