1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/abstract_function.h"
18
19 #include <vector>
20
21 namespace mindspore {
22 namespace abstract {
23 class Evaluator;
24 class AnalysisEngine;
MakeAbstractFunction(const AbstractFuncAtomPtrList & func_list)25 AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) {
26 if (func_list.size() == 1) {
27 return func_list[0];
28 }
29 return std::make_shared<AbstractFuncUnion>(func_list);
30 }
31
Join(const AbstractFunctionPtr & other)32 AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
33 MS_EXCEPTION_IF_NULL(other);
34 auto this_func = shared_from_base<AbstractFuncAtom>();
35 if (other->isa<AbstractFuncAtom>()) {
36 if (*this_func == *other) {
37 return this_func;
38 }
39 return std::make_shared<AbstractFuncUnion>(this_func, other);
40 }
41 auto other_union = dyn_cast<AbstractFuncUnion>(other);
42 MS_EXCEPTION_IF_NULL(other_union);
43 if (other_union->IsSuperSet(this_func)) {
44 return other;
45 }
46 return std::make_shared<AbstractFuncUnion>(this_func, other);
47 }
48
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const49 void AbstractFuncAtom::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
50 visit_func(const_cast<AbstractFuncAtom *>(this)->shared_from_base<AbstractFuncAtom>());
51 }
52
operator ==(const AbstractFunction & other) const53 bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; }
54
AbstractFuncUnion(const AbstractFuncAtomPtrList & func_list)55 AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; }
56
AbstractFuncUnion(const AbstractFunctionPtr & first,const AbstractFunctionPtr & second)57 AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) {
58 AbstractFuncAtomPtrList new_func_list;
59 auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); };
60 MS_EXCEPTION_IF_NULL(first);
61 MS_EXCEPTION_IF_NULL(second);
62 first->Visit(build_func_list);
63 second->Visit(build_func_list);
64 func_list_ = new_func_list;
65 }
66
ToString() const67 std::string AbstractFuncUnion::ToString() const {
68 std::ostringstream buffer;
69 buffer << "AbstractFuncUnion({";
70 int64_t i = 0;
71 for (const auto &func : func_list_) {
72 MS_EXCEPTION_IF_NULL(func);
73 buffer << "[" << i << "]: " << func->ToString() << ", ";
74 i++;
75 }
76 buffer << "})";
77 return buffer.str();
78 }
79
IsSuperSet(const AbstractFunctionPtr & other)80 bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) {
81 MS_EXCEPTION_IF_NULL(other);
82 std::vector<bool> is_in_list;
83 auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) {
84 auto iter = find(func_list_.begin(), func_list_.end(), func);
85 if (iter == func_list_.end()) {
86 is_in_list.push_back(false);
87 }
88 return true;
89 };
90 other->Visit(build_in_list);
91 return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; });
92 }
93
Join(const AbstractFunctionPtr & other)94 AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
95 auto this_func = shared_from_base<AbstractFunction>();
96 MS_EXCEPTION_IF_NULL(other);
97 if (other->isa<AbstractFuncAtom>()) {
98 if (IsSuperSet(other)) {
99 return this_func;
100 }
101 return std::make_shared<AbstractFuncUnion>(this_func, other);
102 }
103 auto other_union = dyn_cast<AbstractFuncUnion>(other);
104 MS_EXCEPTION_IF_NULL(other_union);
105 if (other_union->IsSuperSet(this_func)) {
106 return other;
107 }
108 return std::make_shared<AbstractFuncUnion>(this_func, other);
109 }
110
Visit(std::function<void (const AbstractFuncAtomPtr &)> visit_func) const111 void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
112 for (const AbstractFuncAtomPtr &poss : func_list_) {
113 visit_func(poss);
114 }
115 }
116
operator ==(const AbstractFunction & other) const117 bool AbstractFuncUnion::operator==(const AbstractFunction &other) const {
118 if (!other.isa<AbstractFuncUnion>()) {
119 return false;
120 }
121 auto other_union = static_cast<const AbstractFuncUnion *>(&other);
122 if (func_list_.size() != other_union->func_list_.size()) {
123 return false;
124 }
125 return func_list_ == other_union->func_list_;
126 }
127
hash() const128 std::size_t AbstractFuncUnion::hash() const {
129 std::size_t hash_sum = 0;
130 for (const auto &f : func_list_) {
131 MS_EXCEPTION_IF_NULL(f);
132 hash_sum = hash_combine(hash_sum, f->hash());
133 }
134 return hash_sum;
135 }
136
operator ==(const AbstractFunction & other) const137 bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
138 if (!other.isa<PrimitiveAbstractClosure>()) {
139 return false;
140 }
141 auto other_prim = static_cast<const PrimitiveAbstractClosure *>(&other);
142 MS_EXCEPTION_IF_NULL(prim_);
143 return (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id());
144 }
145
hash() const146 std::size_t PrimitiveAbstractClosure::hash() const {
147 auto hash_value = hash_combine(tid(), prim_->hash());
148 // Keep in sync with operator==() which compares the prim_ pointer;
149 hash_value = hash_combine(hash_value, std::hash<Primitive *>{}(prim_.get()));
150 if (tracking_id() != nullptr) {
151 hash_value = hash_combine(hash_value, tracking_id()->hash());
152 }
153 return hash_value;
154 }
155
operator ==(const AbstractFunction & other) const156 bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
157 if (!other.isa<FuncGraphAbstractClosure>()) {
158 return false;
159 }
160 auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
161 return func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ &&
162 tracking_id() == other_fg->tracking_id();
163 }
164
hash() const165 std::size_t FuncGraphAbstractClosure::hash() const {
166 auto hash_value = hash_combine(tid(), func_graph_->hash());
167 hash_value = hash_combine(hash_value, context_->hash());
168 if (tracking_id() != nullptr) {
169 hash_value = hash_combine(hash_value, tracking_id()->hash());
170 }
171 return hash_value;
172 }
173
ToString() const174 std::string FuncGraphAbstractClosure::ToString() const {
175 std::stringstream ss;
176 MS_EXCEPTION_IF_NULL(func_graph_);
177 MS_EXCEPTION_IF_NULL(context_);
178 ss << "FuncGraphAbstractClosure: "
179 << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString();
180 return ss.str();
181 }
182
operator ==(const AbstractFunction & other) const183 bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
184 if (!other.isa<MetaFuncGraphAbstractClosure>()) {
185 return false;
186 }
187 auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
188 return meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id();
189 }
190
hash() const191 std::size_t MetaFuncGraphAbstractClosure::hash() const {
192 MS_EXCEPTION_IF_NULL(meta_func_graph_);
193 auto hash_value = hash_combine(tid(), meta_func_graph_->hash());
194 if (tracking_id() != nullptr) {
195 hash_value = hash_combine(hash_value, tracking_id()->hash());
196 }
197 return hash_value;
198 }
199
ToString() const200 std::string MetaFuncGraphAbstractClosure::ToString() const {
201 MS_EXCEPTION_IF_NULL(meta_func_graph_);
202 return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name();
203 }
204
operator ==(const AbstractFunction & other) const205 bool PartialAbstractClosure::operator==(const AbstractFunction &other) const {
206 if (!other.isa<PartialAbstractClosure>()) {
207 return false;
208 }
209 auto other_partial = static_cast<const PartialAbstractClosure *>(&other);
210 if (fn_ != other_partial->fn_) {
211 return false;
212 }
213 if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
214 return false;
215 }
216 return args_spec_list_ == other_partial->args_spec_list_;
217 }
218
hash() const219 std::size_t PartialAbstractClosure::hash() const {
220 MS_EXCEPTION_IF_NULL(fn_);
221 auto hash_value = hash_combine(tid(), fn_->hash());
222 hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
223 return hash_value;
224 }
225
ToString() const226 std::string PartialAbstractClosure::ToString() const {
227 std::ostringstream buffer;
228 buffer << "PartialAbstractClosure(" << fn_->ToString() << "(";
229 for (const auto &arg : args_spec_list_) {
230 MS_EXCEPTION_IF_NULL(arg);
231 buffer << arg->ToString() << ", ";
232 }
233 buffer << "))";
234 return buffer.str();
235 }
236
operator ==(const AbstractFunction & other) const237 bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
238 if (!other.isa<JTransformedAbstractClosure>()) {
239 return false;
240 }
241 auto other_transformed = static_cast<const JTransformedAbstractClosure *>(&other);
242 return fn_ == other_transformed->fn_;
243 }
244
hash() const245 std::size_t JTransformedAbstractClosure::hash() const {
246 MS_EXCEPTION_IF_NULL(fn_);
247 auto hash_value = hash_combine(tid(), fn_->hash());
248 return hash_value;
249 }
250
operator ==(const AbstractFunction & other) const251 bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
252 if (!other.isa<VirtualAbstractClosure>()) {
253 return false;
254 }
255 auto other_virtual = static_cast<const VirtualAbstractClosure *>(&other);
256 if (output_ != other_virtual->output_) {
257 return false;
258 }
259 if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) {
260 return false;
261 }
262 return args_spec_list_ == other_virtual->args_spec_list_;
263 }
264
hash() const265 std::size_t VirtualAbstractClosure::hash() const {
266 MS_EXCEPTION_IF_NULL(output_);
267 auto hash_value = hash_combine(tid(), output_->hash());
268 hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
269 return hash_value;
270 }
271
ToString() const272 std::string VirtualAbstractClosure::ToString() const {
273 std::ostringstream buffer;
274 buffer << "VirtualAbstractClosure(args: {";
275 int64_t i = 0;
276 for (const auto &arg : args_spec_list_) {
277 MS_EXCEPTION_IF_NULL(arg);
278 buffer << "[" << i << "]: " << arg->ToString() << ", ";
279 i++;
280 }
281 MS_EXCEPTION_IF_NULL(output_);
282 buffer << "}, output: " << output_->ToString() << ")";
283 return buffer.str();
284 }
285
operator ==(const AbstractFunction & other) const286 bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
287 if (!other.isa<TypedPrimitiveAbstractClosure>()) {
288 return false;
289 }
290 auto other_typed = static_cast<const TypedPrimitiveAbstractClosure *>(&other);
291 if (output_ != other_typed->output_) {
292 return false;
293 }
294 if (prim_ != other_typed->prim_) {
295 return false;
296 }
297 if (args_spec_list_.size() != other_typed->args_spec_list_.size()) {
298 return false;
299 }
300 return args_spec_list_ == other_typed->args_spec_list_;
301 }
302
hash() const303 std::size_t TypedPrimitiveAbstractClosure::hash() const {
304 auto hash_value = hash_combine(tid(), prim_->hash());
305 hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
306 return hash_value;
307 }
308
ToString() const309 std::string TypedPrimitiveAbstractClosure::ToString() const {
310 std::ostringstream buffer;
311 buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {";
312 int64_t i = 0;
313 for (const auto &arg : args_spec_list_) {
314 MS_EXCEPTION_IF_NULL(arg);
315 buffer << "[" << i << "]: " << arg->ToString() << ", ";
316 i++;
317 }
318 MS_EXCEPTION_IF_NULL(output_);
319 buffer << "}, output: " << output_->ToString() << ")";
320 return buffer.str();
321 }
322
operator ==(const AbstractFunction & other) const323 bool PyInterpretAbstractClosure::operator==(const AbstractFunction &other) const {
324 if (!other.isa<PyInterpretAbstractClosure>()) {
325 return false;
326 }
327 auto other_partial = static_cast<const PyInterpretAbstractClosure *>(&other);
328 if (fn_ != other_partial->fn_) {
329 return false;
330 }
331 if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
332 return false;
333 }
334 return args_spec_list_ == other_partial->args_spec_list_;
335 }
336
hash() const337 std::size_t PyInterpretAbstractClosure::hash() const {
338 MS_EXCEPTION_IF_NULL(fn_);
339 auto hash_value = hash_combine(tid(), fn_->hash());
340 hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
341 return hash_value;
342 }
343
ToString() const344 std::string PyInterpretAbstractClosure::ToString() const {
345 std::ostringstream buffer;
346 buffer << "PyInterpretAbstractClosure(" << fn_->ToString() << "(";
347 for (const auto &arg : args_spec_list_) {
348 MS_EXCEPTION_IF_NULL(arg);
349 buffer << arg->ToString() << ", ";
350 }
351 buffer << "))";
352 return buffer.str();
353 }
354
operator ==(const AbstractFunction & other) const355 bool DummyAbstractClosure::operator==(const AbstractFunction &other) const {
356 return !other.isa<DummyAbstractClosure>();
357 }
358 } // namespace abstract
359 } // namespace mindspore
360