• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/analysis_context.h"
18 
19 #include <algorithm>
20 
21 #include "utils/symbolic.h"
22 #include "utils/trace_base.h"
23 
24 namespace mindspore {
25 namespace abstract {
26 std::list<AnalysisContextPtr> AnalysisContext::all_context_;
NewContext(const FuncGraphPtr & func_graph,const AbstractBasePtrList & args_spec_list)27 AnalysisContextPtr AnalysisContext::NewContext(const FuncGraphPtr &func_graph,
28                                                const AbstractBasePtrList &args_spec_list) {
29   // Find func graph's parent and its parent context firstly.
30   MS_EXCEPTION_IF_NULL(func_graph);
31   FuncGraphPtr parent_graph = func_graph->parent();
32   AnalysisContextPtr parent_context = nullptr;
33   auto iter = extant_context_cache_.find(parent_graph);
34   if (iter != extant_context_cache_.end()) {
35     parent_context = iter->second.lock();
36   }
37   if (parent_context == nullptr) {  // If parent context is not found, we'll raise exception.
38     std::ostringstream oss;
39     oss << "BUG: Failed to find parent context in current context: " << this->ToString()
40         << ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
41     if (parent_graph != nullptr) {
42       oss << parent_graph->ToString();
43     } else {
44       oss << "nullptr";
45     }
46     MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
47   }
48 
49   // Check if we created a context for func graph with the same arguments before.
50   auto children_context_map_iter = parent_context->children_cache_.find(func_graph);
51   if (children_context_map_iter != parent_context->children_cache_.end()) {
52     auto children_context_map = children_context_map_iter->second;
53     auto children_context_iter = children_context_map.find(args_spec_list);
54     if (children_context_iter != children_context_map.end()) {
55       return children_context_iter->second.lock();
56     }
57   }
58 
59   // Create a new context for the func graph and its specific arguments.
60   AnalysisContextPtr new_context = CreateContext(parent_context, func_graph, args_spec_list);
61   // To avoid cycle-reference, use weak_ptr here.
62   auto weak_new_context = std::weak_ptr<AnalysisContext>(new_context);
63   new_context->extant_context_cache_[func_graph] = weak_new_context;
64   parent_context->children_cache_[func_graph][args_spec_list] = weak_new_context;
65   return new_context;
66 }
67 
FindOwnOrParentContext(const FuncGraphPtr & func_graph)68 AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &func_graph) {
69   auto p_iter = extant_context_cache_.find(func_graph);
70   AnalysisContextPtr extant_context = nullptr;
71   if (p_iter != extant_context_cache_.end()) {
72     extant_context = p_iter->second.lock();
73   } else {
74     auto iter_parent = extant_context_cache_.find(func_graph->parent());
75     if (iter_parent != extant_context_cache_.end()) {
76       extant_context = iter_parent->second.lock();
77     }
78   }
79   // If this happen, it would be a bug in code. But we raise exception to keep the scene.
80   if (extant_context == nullptr) {
81     std::ostringstream oss;
82     oss << "BUG: Failed to find context for: " << func_graph->ToString() << ", parent_graph: ";
83     if (func_graph->parent() != nullptr) {
84       oss << func_graph->parent()->ToString();
85     } else {
86       oss << "nullptr";
87     }
88     oss << " extant context list: {";
89     for (const auto &iter : extant_context_cache_) {
90       if (iter.first == nullptr) {
91         oss << " [graph: nullptr";
92       } else {
93         oss << " [graph: " << iter.first->ToString();
94       }
95       // iter.second cannot be nullptr even iter.first is nullptr as it will
96       // always be a Context() object.
97       oss << ", context: " << iter.second.lock()->ToString() << "]";
98     }
99     oss << "}";
100     MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
101   }
102   return extant_context;
103 }
104 
DummyContext()105 AnalysisContextPtr AnalysisContext::DummyContext() {
106   AnalysisContextPtr dummy_context = CreateContext(nullptr, nullptr, AbstractBasePtrList());
107   dummy_context->extant_context_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
108   return dummy_context;
109 }
110 
IsDummyContext()111 bool AnalysisContext::IsDummyContext() {
112   return parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty();
113 }
114 
115 const AnalysisContextPtr kDummyAnalysisContext =
116   AnalysisContext::CreateContext(nullptr, nullptr, AbstractBasePtrList());
117 
operator ==(const AnalysisContext & other) const118 bool AnalysisContext::operator==(const AnalysisContext &other) const {
119   if (func_graph_ != other.func_graph_) {
120     return false;
121   }
122 
123   if (args_spec_list_.size() != other.args_spec_list_.size()) {
124     return false;
125   }
126 
127   if (((parent_ == nullptr) && (other.parent_ != nullptr)) || ((parent_ != nullptr) && (other.parent_ == nullptr))) {
128     return false;
129   }
130   // Compare parent with content.
131   bool is_parent_equal = false;
132   if (parent_ == other.parent_) {
133     is_parent_equal = true;
134   } else if (*parent_ == *other.parent_) {
135     is_parent_equal = true;
136   } else {
137     return false;
138   }
139   for (std::size_t i = 0; i < args_spec_list_.size(); i++) {
140     if (!(*args_spec_list_[i] == *other.args_spec_list_[i])) {
141       return false;
142     }
143   }
144   return is_parent_equal;
145 }
146 
147 // brief The key which controls the graph cloning in Specialize.
148 // Originally, specialize use context directly as the key for cloning graph. The graph will be cloned multiple times
149 // for different context, which means the graph is called from different node with different arguments and different
150 // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what
151 // graph can be reused.
152 // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined
153 // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused
154 // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize.
155 // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies
156 // on correct shape to specialize a tensor constant.
SpecializeKey() const157 AnalysisContextPtr AnalysisContext::SpecializeKey() const {
158   AbstractBasePtrList args_broad_shp;
159   (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(args_broad_shp),
160                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
161                          MS_EXCEPTION_IF_NULL(arg);
162                          if (arg->isa<AbstractScalar>()) {
163                            auto val = arg->GetValueTrack();
164                            MS_EXCEPTION_IF_NULL(val);
165                            if (val->isa<SymbolicKeyInstance>()) {
166                              auto scalar_spec = dyn_cast<AbstractScalar>(arg);
167                              MS_EXCEPTION_IF_NULL(scalar_spec);
168                              auto ret_spec = scalar_spec->Broaden();
169                              return ret_spec;
170                            }
171                          }
172                          if (arg->isa<AbstractRef>()) {
173                            MS_LOG(DEBUG) << "refkey broaden";
174                            return arg->Broaden();
175                          }
176                          return arg;
177                        });
178   AnalysisContextPtr context_new = CreateContext(nullptr, func_graph_, args_broad_shp);
179   context_new->parent_ = parent_;
180   return context_new;
181 }
182 
hash()183 std::size_t AnalysisContext::hash() {
184   std::size_t hash_value = 0;
185   // hash() recursion exit condition.
186   if (parent_ != nullptr) {
187     hash_value = hash_combine(hash_value, parent_->hash());
188   }
189   if (func_graph_ != nullptr) {
190     hash_value = hash_combine(hash_value, func_graph_->hash());
191   }
192   return hash_value;
193 }
194 
ToString() const195 std::string AnalysisContext::ToString() const {
196   std::ostringstream buffer;
197   buffer << "{";
198   if (func_graph_ != nullptr) {
199     buffer << "Func Graph: " << func_graph_->ToString();
200   }
201   buffer << " Args: ";
202   int64_t i = 0;
203   for (const auto &arg : args_spec_list_) {
204     buffer << "[" << i << "]: " << arg->ToString() << ", ";
205     i++;
206   }
207   if (parent_ != nullptr) {
208     buffer << "Parent: " << parent_->ToString();
209   }
210   buffer << "}";
211   return buffer.str();
212 }
213 
ClearContext()214 void AnalysisContext::ClearContext() {
215   for (auto &item : all_context_) {
216     item->parent_ = nullptr;
217     item->func_graph_ = nullptr;
218     item->args_spec_list_.clear();
219     item->extant_context_cache_.clear();
220     item->children_cache_.clear();
221   }
222   all_context_.clear();
223 }
224 
CreateContext(const AnalysisContextPtr & parent,const FuncGraphPtr & fg,const AbstractBasePtrList & args_spec_list)225 AnalysisContextPtr AnalysisContext::CreateContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg,
226                                                   const AbstractBasePtrList &args_spec_list) {
227   auto context = std::make_shared<AnalysisContext>(parent, fg, args_spec_list);
228   (void)all_context_.emplace_back(context);
229   return context;
230 }
231 }  // namespace abstract
232 }  // namespace mindspore
233