1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/analysis_context.h"
18
19 #include <algorithm>
20
21 #include "utils/symbolic.h"
22 #include "utils/trace_base.h"
23
24 namespace mindspore {
25 namespace abstract {
26 std::list<AnalysisContextPtr> AnalysisContext::all_context_;
NewContext(const FuncGraphPtr & func_graph,const AbstractBasePtrList & args_spec_list)27 AnalysisContextPtr AnalysisContext::NewContext(const FuncGraphPtr &func_graph,
28 const AbstractBasePtrList &args_spec_list) {
29 // Find func graph's parent and its parent context firstly.
30 MS_EXCEPTION_IF_NULL(func_graph);
31 FuncGraphPtr parent_graph = func_graph->parent();
32 AnalysisContextPtr parent_context = nullptr;
33 auto iter = extant_context_cache_.find(parent_graph);
34 if (iter != extant_context_cache_.end()) {
35 parent_context = iter->second.lock();
36 }
37 if (parent_context == nullptr) { // If parent context is not found, we'll raise exception.
38 std::ostringstream oss;
39 oss << "BUG: Failed to find parent context in current context: " << this->ToString()
40 << ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
41 if (parent_graph != nullptr) {
42 oss << parent_graph->ToString();
43 } else {
44 oss << "nullptr";
45 }
46 MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
47 }
48
49 // Check if we created a context for func graph with the same arguments before.
50 auto children_context_map_iter = parent_context->children_cache_.find(func_graph);
51 if (children_context_map_iter != parent_context->children_cache_.end()) {
52 auto children_context_map = children_context_map_iter->second;
53 auto children_context_iter = children_context_map.find(args_spec_list);
54 if (children_context_iter != children_context_map.end()) {
55 return children_context_iter->second.lock();
56 }
57 }
58
59 // Create a new context for the func graph and its specific arguments.
60 AnalysisContextPtr new_context = CreateContext(parent_context, func_graph, args_spec_list);
61 // To avoid cycle-reference, use weak_ptr here.
62 auto weak_new_context = std::weak_ptr<AnalysisContext>(new_context);
63 new_context->extant_context_cache_[func_graph] = weak_new_context;
64 parent_context->children_cache_[func_graph][args_spec_list] = weak_new_context;
65 return new_context;
66 }
67
FindOwnOrParentContext(const FuncGraphPtr & func_graph)68 AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &func_graph) {
69 auto p_iter = extant_context_cache_.find(func_graph);
70 AnalysisContextPtr extant_context = nullptr;
71 if (p_iter != extant_context_cache_.end()) {
72 extant_context = p_iter->second.lock();
73 } else {
74 auto iter_parent = extant_context_cache_.find(func_graph->parent());
75 if (iter_parent != extant_context_cache_.end()) {
76 extant_context = iter_parent->second.lock();
77 }
78 }
79 // If this happen, it would be a bug in code. But we raise exception to keep the scene.
80 if (extant_context == nullptr) {
81 std::ostringstream oss;
82 oss << "BUG: Failed to find context for: " << func_graph->ToString() << ", parent_graph: ";
83 if (func_graph->parent() != nullptr) {
84 oss << func_graph->parent()->ToString();
85 } else {
86 oss << "nullptr";
87 }
88 oss << " extant context list: {";
89 for (const auto &iter : extant_context_cache_) {
90 if (iter.first == nullptr) {
91 oss << " [graph: nullptr";
92 } else {
93 oss << " [graph: " << iter.first->ToString();
94 }
95 // iter.second cannot be nullptr even iter.first is nullptr as it will
96 // always be a Context() object.
97 oss << ", context: " << iter.second.lock()->ToString() << "]";
98 }
99 oss << "}";
100 MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
101 }
102 return extant_context;
103 }
104
DummyContext()105 AnalysisContextPtr AnalysisContext::DummyContext() {
106 AnalysisContextPtr dummy_context = CreateContext(nullptr, nullptr, AbstractBasePtrList());
107 dummy_context->extant_context_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
108 return dummy_context;
109 }
110
IsDummyContext()111 bool AnalysisContext::IsDummyContext() {
112 return parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty();
113 }
114
115 const AnalysisContextPtr kDummyAnalysisContext =
116 AnalysisContext::CreateContext(nullptr, nullptr, AbstractBasePtrList());
117
operator ==(const AnalysisContext & other) const118 bool AnalysisContext::operator==(const AnalysisContext &other) const {
119 if (func_graph_ != other.func_graph_) {
120 return false;
121 }
122
123 if (args_spec_list_.size() != other.args_spec_list_.size()) {
124 return false;
125 }
126
127 if (((parent_ == nullptr) && (other.parent_ != nullptr)) || ((parent_ != nullptr) && (other.parent_ == nullptr))) {
128 return false;
129 }
130 // Compare parent with content.
131 bool is_parent_equal = false;
132 if (parent_ == other.parent_) {
133 is_parent_equal = true;
134 } else if (*parent_ == *other.parent_) {
135 is_parent_equal = true;
136 } else {
137 return false;
138 }
139 for (std::size_t i = 0; i < args_spec_list_.size(); i++) {
140 if (!(*args_spec_list_[i] == *other.args_spec_list_[i])) {
141 return false;
142 }
143 }
144 return is_parent_equal;
145 }
146
147 // brief The key which controls the graph cloning in Specialize.
148 // Originally, specialize use context directly as the key for cloning graph. The graph will be cloned multiple times
149 // for different context, which means the graph is called from different node with different arguments and different
150 // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what
151 // graph can be reused.
152 // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined
153 // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused
154 // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize.
155 // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies
156 // on correct shape to specialize a tensor constant.
SpecializeKey() const157 AnalysisContextPtr AnalysisContext::SpecializeKey() const {
158 AbstractBasePtrList args_broad_shp;
159 (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(args_broad_shp),
160 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
161 MS_EXCEPTION_IF_NULL(arg);
162 if (arg->isa<AbstractScalar>()) {
163 auto val = arg->GetValueTrack();
164 MS_EXCEPTION_IF_NULL(val);
165 if (val->isa<SymbolicKeyInstance>()) {
166 auto scalar_spec = dyn_cast<AbstractScalar>(arg);
167 MS_EXCEPTION_IF_NULL(scalar_spec);
168 auto ret_spec = scalar_spec->Broaden();
169 return ret_spec;
170 }
171 }
172 if (arg->isa<AbstractRef>()) {
173 MS_LOG(DEBUG) << "refkey broaden";
174 return arg->Broaden();
175 }
176 return arg;
177 });
178 AnalysisContextPtr context_new = CreateContext(nullptr, func_graph_, args_broad_shp);
179 context_new->parent_ = parent_;
180 return context_new;
181 }
182
hash()183 std::size_t AnalysisContext::hash() {
184 std::size_t hash_value = 0;
185 // hash() recursion exit condition.
186 if (parent_ != nullptr) {
187 hash_value = hash_combine(hash_value, parent_->hash());
188 }
189 if (func_graph_ != nullptr) {
190 hash_value = hash_combine(hash_value, func_graph_->hash());
191 }
192 return hash_value;
193 }
194
ToString() const195 std::string AnalysisContext::ToString() const {
196 std::ostringstream buffer;
197 buffer << "{";
198 if (func_graph_ != nullptr) {
199 buffer << "Func Graph: " << func_graph_->ToString();
200 }
201 buffer << " Args: ";
202 int64_t i = 0;
203 for (const auto &arg : args_spec_list_) {
204 buffer << "[" << i << "]: " << arg->ToString() << ", ";
205 i++;
206 }
207 if (parent_ != nullptr) {
208 buffer << "Parent: " << parent_->ToString();
209 }
210 buffer << "}";
211 return buffer.str();
212 }
213
ClearContext()214 void AnalysisContext::ClearContext() {
215 for (auto &item : all_context_) {
216 item->parent_ = nullptr;
217 item->func_graph_ = nullptr;
218 item->args_spec_list_.clear();
219 item->extant_context_cache_.clear();
220 item->children_cache_.clear();
221 }
222 all_context_.clear();
223 }
224
CreateContext(const AnalysisContextPtr & parent,const FuncGraphPtr & fg,const AbstractBasePtrList & args_spec_list)225 AnalysisContextPtr AnalysisContext::CreateContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg,
226 const AbstractBasePtrList &args_spec_list) {
227 auto context = std::make_shared<AnalysisContext>(parent, fg, args_spec_list);
228 (void)all_context_.emplace_back(context);
229 return context;
230 }
231 } // namespace abstract
232 } // namespace mindspore
233