1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/static_analysis/program_specialize.h"
20
21 #include <algorithm>
22 #include <exception>
23 #include "frontend/operator/ops.h"
24 #include "frontend/operator/composite/do_signature.h"
25 #include "abstract/abstract_function.h"
26 #include "abstract/utils.h"
27 #include "utils/utils.h"
28 #include "ir/graph_utils.h"
29 #include "utils/log_adapter.h"
30 #include "debug/trace.h"
31
32 namespace mindspore {
33 namespace abstract {
34 namespace {
GetEvaluatedValue(const AnfNodeConfigPtr & conf)35 inline AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
36 MS_EXCEPTION_IF_NULL(conf);
37 if (conf->node()->intermediate_abstract()) {
38 return conf->node()->intermediate_abstract();
39 }
40 MS_EXCEPTION_IF_NULL(conf->ObtainEvalResult());
41 return conf->ObtainEvalResult()->abstract();
42 }
43
BuildValueNode(const ValuePtr & v,const AbstractBasePtr & abs_base)44 AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
45 MS_EXCEPTION_IF_NULL(abs_base);
46 AnfNodePtr value_node = NewValueNode(v);
47 value_node->set_abstract(abs_base);
48 MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
49 return value_node;
50 }
51
IsVisible(FuncGraphPtr fg,const FuncGraphPtr & parent)52 bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
53 while (fg != nullptr && fg != parent) {
54 fg = fg->parent();
55 }
56 return fg == parent;
57 }
58
CheckAbstractTensor(const AbstractBasePtr & abs_base)59 bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
60 MS_EXCEPTION_IF_NULL(abs_base);
61 if (abs_base->isa<AbstractTensor>()) {
62 return true;
63 } else if (abs_base->isa<AbstractSequeue>()) {
64 const auto &abs_seq = abs_base->cast<AbstractSequeuePtr>();
65 MS_EXCEPTION_IF_NULL(abs_seq);
66 const auto &elements = abs_seq->elements();
67 return std::all_of(elements.cbegin(), elements.cend(), [](const auto &v) { return CheckAbstractTensor(v); });
68 } else {
69 return false;
70 }
71 }
72 } // namespace
73
Run(const FuncGraphPtr & fg,const AnalysisContextPtr & context)74 FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
75 MS_EXCEPTION_IF_NULL(fg);
76 MS_EXCEPTION_IF_NULL(context);
77 MS_LOG(DEBUG) << "Specialize topmost function graph: "
78 << (context->func_graph() ? context->func_graph()->ToString() : "FG(Null)");
79 if (top_context_ == nullptr) {
80 top_context_ = context;
81 MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
82 }
83 return SpecializeFuncGraph(fg, context);
84 }
85
SpecializeFuncGraph(const FuncGraphPtr & fg,const AnalysisContextPtr & context)86 FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
87 MS_EXCEPTION_IF_NULL(fg);
88 MS_EXCEPTION_IF_NULL(context);
89 auto iter = specializations_.find(context->SpecializeKey());
90 if (iter != specializations_.end()) {
91 MS_EXCEPTION_IF_NULL(iter->second);
92 return iter->second->specialized_func_graph();
93 }
94
95 std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
96 FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
97 specializations_[context->SpecializeKey()] = fg_spec;
98 fg_spec->Run();
99 return fg2;
100 }
101
GetFuncGraphSpecializer(const AnalysisContextPtr & context)102 std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
103 MS_EXCEPTION_IF_NULL(context);
104 auto iter = specializations_.find(context->SpecializeKey());
105 if (iter != specializations_.end()) {
106 return iter->second;
107 }
108 if (context->func_graph() != nullptr) {
109 MS_LOG(EXCEPTION) << "Specialize inner error";
110 }
111 return nullptr;
112 }
113
GetNextCounter()114 std::string GetNextCounter() {
115 static int64_t g_CloneCounter = 1;
116 std::string str_count = std::to_string(g_CloneCounter);
117 g_CloneCounter++;
118 return str_count;
119 }
120
FuncGraphSpecializer(ProgramSpecializer * const s,const FuncGraphPtr & fg,const AnalysisContextPtr & context)121 FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
122 const AnalysisContextPtr &context)
123 : specializer_(s), func_graph_(fg), context_(context) {
124 parent_ = s->GetFuncGraphSpecializer(context->parent());
125 engine_ = s->engine();
126 cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
127 repl_node_ = cloner_->cloned_node();
128 specialized_func_graph_ = cloner_->cloned_func_graph()[fg];
129 todo_.push_back(fg->get_return());
130 auto ps = fg->parameters();
131 (void)todo_.insert(todo_.end(), ps.begin(), ps.end());
132 }
133
ReplicateDisconnectedNode(const AnfNodePtr & node)134 AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
135 MS_EXCEPTION_IF_NULL(node);
136 if (node->isa<ValueNode>()) {
137 return node;
138 }
139 std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
140
141 // If had replicated, just return that.
142 MS_EXCEPTION_IF_NULL(specializer->repl_node_);
143 auto iter = specializer->repl_node_->find(node);
144 if (iter != specializer->repl_node_->end()) {
145 return iter->second;
146 }
147 auto new_node = specializer->cloner_->CloneDisconnected(node);
148 if (node->isa<CNode>()) {
149 if (!new_node->isa<CNode>()) {
150 MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
151 }
152 UpdateNewCNodeInputs(node, new_node);
153 }
154
155 iter = specializer->repl_node_->find(node);
156 if (iter != specializer->repl_node_->end()) {
157 if (iter->second == node) {
158 MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
159 }
160 } else {
161 MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString();
162 }
163 return new_node;
164 }
165
UpdateNewCNodeInputs(const AnfNodePtr & node,const AnfNodePtr & new_node)166 void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
167 MS_EXCEPTION_IF_NULL(node);
168 auto c_node = node->cast<CNodePtr>();
169 MS_EXCEPTION_IF_NULL(c_node);
170 auto inputs = c_node->inputs();
171 std::vector<AnfNodePtr> new_inputs;
172 (void)std::transform(
173 inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr {
174 auto new_inp = ReplicateDisconnectedNode(inp);
175 // Refer the comments in BuildReplacedNode.
176 if (inp->isa<CNode>()) {
177 auto c_inp = inp->cast<CNodePtr>();
178 MS_EXCEPTION_IF_NULL(c_inp);
179 auto c_new_inp = new_inp->cast<CNodePtr>();
180 MS_EXCEPTION_IF_NULL(c_new_inp);
181 MS_EXCEPTION_IF_NULL(c_new_inp->func_graph());
182 MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString();
183 c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp);
184 }
185 return new_inp;
186 });
187
188 auto c_new_node = new_node->cast<CNodePtr>();
189 MS_EXCEPTION_IF_NULL(c_new_node);
190 c_new_node->set_inputs(new_inputs);
191 }
192
GetReplicatedNode(const AnfNodePtr & node)193 AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
194 std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
195 MS_EXCEPTION_IF_NULL(specializer->repl_node_);
196 auto iter = specializer->repl_node_->find(node);
197 if (iter != specializer->repl_node_->end()) {
198 return iter->second;
199 }
200 return node;
201 }
202
203 // Return itself if node's ValueNode as top,
204 // return the top func graph specializer as top if node's forward Parameter,
205 // or, return the top parent specializer as top.
GetTopSpecializer(const AnfNodePtr & node)206 std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(const AnfNodePtr &node) {
207 MS_EXCEPTION_IF_NULL(node);
208 FuncGraphPtr fg = node->func_graph();
209 if (fg == nullptr) { // If ValueNode, return current specializer.
210 MS_LOG(DEBUG) << "Node's a ValueNode, node: " << node->DebugString();
211 return shared_from_this();
212 }
213 std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
214 while (fg != specializer->func_graph_) {
215 if (specializer->parent_ == nullptr && node->isa<Parameter>()) {
216 // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
217 MS_EXCEPTION_IF_NULL(specializer_->top_context());
218 if (specializer_->top_context()->func_graph() == fg) { // `fg` is top func graph.
219 specializer = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
220 MS_LOG(INFO) << "Used top func graph specializer as parent for "
221 << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << ", node: " << node->DebugString()
222 << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
223 MS_EXCEPTION_IF_NULL(specializer);
224 break;
225 }
226 } else {
227 specializer = specializer->parent_;
228 }
229 if (specializer == nullptr) {
230 MS_LOG(EXCEPTION) << "`specializer` should not be null, node: " << node->DebugString()
231 << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info()) << ".\n"
232 << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
233 << " has no parent context? At least not " << fg->ToString();
234 }
235 }
236 return specializer;
237 }
238
Run()239 void FuncGraphSpecializer::Run() {
240 MS_LOG(DEBUG) << "Before run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
241 << ", cloned func graph name: "
242 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", func graph: "
243 << (func_graph_ ? func_graph_->get_return() ? func_graph_->get_return()->DebugString() : "return null"
244 : "FG(null)");
245 FirstPass();
246 SecondPass();
247 MS_LOG(DEBUG) << "After run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
248 << ", cloned func graph name: "
249 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", new func graph: "
250 << (specialized_func_graph_ ? specialized_func_graph_->get_return()
251 ? specialized_func_graph_->get_return()->DebugString()
252 : "return null"
253 : "FG(null)");
254 }
255
FirstPass()256 void FuncGraphSpecializer::FirstPass() {
257 while (todo_.size()) {
258 AnfNodePtr node = todo_.back();
259 todo_.pop_back();
260 if (node->func_graph() == nullptr) {
261 // do nothing for ValueNode
262 continue;
263 }
264 if (node->func_graph() != func_graph_) {
265 std::shared_ptr<FuncGraphSpecializer> parent = nullptr;
266 if (parent_ != nullptr) {
267 parent = parent_;
268 } else if (specializer_->top_context()->func_graph() == node->func_graph() && node->isa<Parameter>()) {
269 // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
270 parent = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
271 MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
272 << ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
273 }
274 if (parent == nullptr) {
275 MS_LOG(EXCEPTION) << "Parent must not null, node: " << node->DebugString()
276 << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
277 }
278 parent->AddTodoItem(node);
279 parent->FirstPass();
280 AnfNodePtr new_node = parent->GetReplicatedNode(node);
281 if (node->isa<CNode>()) {
282 parent->ProcessCNode(new_node->cast<CNodePtr>());
283 }
284 continue;
285 }
286 if (marked_.count(node) > 0) {
287 continue;
288 }
289 (void)marked_.insert(node);
290 ProcessNode(node);
291 }
292 }
293
294 // Specialize CNode in func graphs
SecondPass()295 void FuncGraphSpecializer::SecondPass() {
296 for (auto &node : BroadFirstSearchGraphCNodes({specialized_func_graph_->get_return()})) {
297 if (node->isa<CNode>()) {
298 ProcessCNode(node->cast<CNodePtr>());
299 }
300 }
301 }
302
ProcessNode(const AnfNodePtr & node)303 void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
304 MS_EXCEPTION_IF_NULL(node);
305 ScopeGuard scope_guard(node->scope());
306 AnfNodeConfigPtr conf = MakeConfig(node);
307 AnfNodePtr new_node = GetReplicatedNode(node);
308 MS_EXCEPTION_IF_NULL(new_node);
309 if (new_node->func_graph() != specialized_func_graph_) {
310 MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
311 << ", new_node: " << new_node->DebugString() << ", new_node->func_graph(): "
312 << (new_node->func_graph() ? new_node->func_graph()->ToString() : "FG(Null)")
313 << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
314 return;
315 }
316 new_node->set_abstract(GetEvaluatedValue(conf));
317 if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
318 auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
319 if (partial_abstract->node() == node) {
320 partial_abstract->set_node(new_node);
321 }
322 }
323
324 MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
325
326 if (node->isa<CNode>()) {
327 auto attrs = conf->ObtainEvalResult()->attribute();
328 auto c_old = node->cast<CNodePtr>();
329 auto c_new = new_node->cast<CNodePtr>();
330 MS_EXCEPTION_IF_NULL(c_new);
331 auto new_inputs = c_new->inputs();
332 auto old_inputs = c_old->inputs();
333 for (size_t i = 0; i < old_inputs.size(); ++i) {
334 auto node_input = old_inputs[i];
335 AnfNodeConfigPtr iconf = MakeConfig(node_input);
336 AbstractBasePtr ival = GetEvaluatedValue(iconf);
337 // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
338 // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
339 AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
340 if (replace_node == nullptr) {
341 replace_node = BuildReplacedNode(iconf);
342 replace_node->set_abstract(ival);
343 MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
344 } else {
345 MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
346 << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
347 }
348 if (new_inputs[i] != replace_node) {
349 new_inputs[i] = replace_node;
350 MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
351 }
352 }
353 c_new->set_inputs(new_inputs);
354 }
355 }
356
BuildReplacedNode(const AnfNodeConfigPtr & conf)357 AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
358 MS_EXCEPTION_IF_NULL(conf);
359
360 auto conf_iter = engine_->anfnode_config_map().find(conf);
361 AnfNodeConfigPtr new_conf = conf;
362 while (conf_iter != engine_->anfnode_config_map().end()) {
363 MS_LOG(DEBUG) << "Origin conf: node(" << (new_conf->node() ? new_conf->node()->DebugString() : "Node(Null)") << ")";
364 new_conf = conf_iter->second;
365 MS_EXCEPTION_IF_NULL(new_conf);
366 const auto &forward_node = new_conf->node();
367 MS_LOG(DEBUG) << "Replaced conf: node(" << forward_node->DebugString() << ")";
368 const auto &replicated_forward_node = ReplicateDisconnectedNode(forward_node);
369 if (replicated_forward_node && replicated_forward_node->isa<CNode>()) {
370 // The AnfNode in order_list can be:
371 // case 1: also in FuncGraphManager, so it can be got from nodes API of func_graph. it will
372 // be replaced in CloneOrderList in Cloner.
373 // case 2: AnfNode is not in FuncGraphManager which generated in Analyze phase, so it will not
374 // be cloned by normal clone API.
375 // 2.1: A forward node , the original node is in FuncGraphManager. The original node will
376 // be cloned in CloneOrderList in Cloner, and the replicated forward node will replace
377 // the replicated original node here.
378 // 2.2: an input of a forward node, such as Cast CNode generated in DoCast. It is also another
379 // original node to fowrad.
380 // 2.3: an input of an input of a forward node, but it's not an original node. Like the Cast CNode
381 // in MixedPrecisionCastHelper.
382 // For 2.2 and 2.3, we will put a placeholder in order list of replicated func_graph, refer to
383 // CloneOrderlist, and it will be replaced inside ReplicateDisconnectedNode.
384 // For 2.1 the following code will do the job, replace replicated origin cnode with the replicated
385 // forward one in the replicated func_graph.
386 MS_EXCEPTION_IF_NULL(conf_iter->first);
387 const auto &origin_node = conf_iter->first->node();
388 const auto &replicated_origin_node = GetReplicatedNode(origin_node);
389 if (replicated_origin_node != origin_node) {
390 MS_LOG(DEBUG) << "Replace replicated origin node in order list: " << replicated_origin_node->DebugString()
391 << ", with replicated forwarded node: " << replicated_forward_node->DebugString();
392 MS_EXCEPTION_IF_NULL(replicated_forward_node->func_graph());
393 replicated_forward_node->func_graph()->ReplaceInOrder(replicated_origin_node, replicated_forward_node);
394 } else {
395 MS_LOG(EXCEPTION) << "Origin node is not replicated in specialized func_graph, origin node: "
396 << (origin_node ? origin_node->DebugString() : "Node(Null)");
397 }
398 }
399 conf_iter = engine_->anfnode_config_map().find(new_conf);
400 }
401 todo_.push_back(new_conf->node());
402 auto repl = GetReplicatedNode(new_conf->node());
403 if (repl->func_graph()) {
404 MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString()
405 << ") to replace origin:" << new_conf->node()->DebugString();
406 } else {
407 MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
408 << ") to replace origin: " << new_conf->node()->DebugString();
409 }
410 return repl;
411 }
412
413 namespace {
414 const StringImmPtr kDeadNode = std::make_shared<StringImm>(kDeadNodeName);
415 const StringImmPtr kPolyNode = std::make_shared<StringImm>(kPolyNodeName);
416
CanSpecializeNode(const AnfNodePtr & node)417 inline bool CanSpecializeNode(const AnfNodePtr &node) {
418 if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
419 return true;
420 }
421 return false;
422 }
423 } // namespace
424
BuildSpecializedNode(const AnfNodePtr & node,const AbstractBasePtr & abs,const AbstractBasePtrList & argvals)425 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
426 const AbstractBasePtrList &argvals) {
427 MS_EXCEPTION_IF_NULL(abs);
428 MS_EXCEPTION_IF_NULL(node);
429 AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
430 MS_EXCEPTION_IF_NULL(real_a);
431
432 AbstractFunctionPtr func = real_a->GetUnique();
433 SpecializeStatusCode errcode;
434 ScopeGuard scope_guard(node->scope());
435 AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode);
436 if (repl == nullptr) {
437 if (errcode == kSpecializeFindUniqueArgvalDead) {
438 const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
439 repl = BuildValueNode(kDeadNode, error_dead_node);
440 MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString();
441 } else if (errcode == kSpecializeFindUniqueArgvalPoly) {
442 const auto error_poly_node = std::make_shared<AbstractError>(kPolyNode, node);
443 repl = BuildValueNode(kPolyNode, error_poly_node);
444 MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString();
445 } else {
446 MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString()
447 << ", abstract: " << abs->ToString();
448 }
449 }
450
451 // Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
452 MS_EXCEPTION_IF_NULL(func);
453 if (func->isa<MetaFuncGraphAbstractClosure>()) {
454 auto specialized_fg = GetValueNode<FuncGraphPtr>(repl);
455 if (specialized_fg != nullptr && (argvals.size() > 1) && argvals.back() != nullptr &&
456 argvals.back()->isa<AbstractUMonad>()) {
457 specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
458 }
459 }
460 return repl;
461 }
462
BuildSpecializedNodeInner(const AnfNodePtr & node,const AbstractBasePtr & abs,const AbstractFunctionPtr & func,const AbstractBasePtrList & args,SpecializeStatusCode * errcode)463 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
464 const AbstractFunctionPtr &func,
465 const AbstractBasePtrList &args,
466 SpecializeStatusCode *errcode) {
467 MS_EXCEPTION_IF_NULL(abs);
468 MS_EXCEPTION_IF_NULL(func);
469 MS_EXCEPTION_IF_NULL(errcode);
470 *errcode = kSpecializeSuccess;
471
472 auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func);
473 if (real_func != nullptr) {
474 return BuildValueNode(real_func->prim(), abs);
475 }
476
477 EvaluatorPtr eval = engine_->GetEvaluatorFor(func);
478 MS_EXCEPTION_IF_NULL(eval);
479 AbstractBasePtrList argvals = eval->NormalizeArgs(args);
480
481 std::pair<AbstractBasePtrList, AbstractBasePtr> result;
482 SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result);
483 if (status != kSpecializeSuccess) {
484 *errcode = status;
485 return nullptr;
486 }
487 argvals = result.first;
488 AbstractBasePtr unique_output = result.second;
489
490 auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
491 if (prim_func != nullptr) {
492 auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
493 return BuildValueNode(prim_func->prim(), type_func);
494 }
495
496 if (!eval->isa<BaseFuncGraphEvaluator>()) {
497 MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString();
498 }
499 auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
500
501 if (func->context() == nullptr) {
502 MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
503 }
504 AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals);
505 MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
506 << ", graph: " << context->func_graph()->get_return()->DebugString();
507 MS_EXCEPTION_IF_NULL(context->func_graph());
508 if (context->func_graph()->stub()) {
509 MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
510 << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
511 << ", " << node->ToString();
512 return node;
513 }
514 FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
515 MS_EXCEPTION_IF_NULL(v);
516 v->set_flag(kFuncGraphFlagUndetermined, false);
517 return BuildValueNode(v, abs);
518 }
519
MakeContext(const AnalysisEnginePtr & engine,const BaseFuncGraphEvaluatorPtr & evaluator,const AbstractBasePtrList & args_spec_list)520 inline AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine,
521 const BaseFuncGraphEvaluatorPtr &evaluator,
522 const AbstractBasePtrList &args_spec_list) {
523 AbstractBasePtrList normalized_args_spec_list = evaluator->NormalizeArgs(args_spec_list);
524 FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_spec_list);
525 MS_EXCEPTION_IF_NULL(evaluator->parent_context());
526 AnalysisContextPtr new_context = evaluator->parent_context()->NewContext(fg, normalized_args_spec_list);
527 return new_context;
528 }
529
BuildSpecializedParameterNode(const CNodePtr & new_node)530 AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
531 MS_EXCEPTION_IF_NULL(new_node);
532 auto new_inputs = new_node->inputs();
533 if (new_inputs.empty()) {
534 MS_LOG(EXCEPTION) << "inputs can't be empty.";
535 }
536 AnfNodePtr func = new_inputs[0];
537 MS_EXCEPTION_IF_NULL(new_inputs[0]);
538 AbstractBasePtr fnval = new_inputs[0]->abstract();
539
540 AbstractBasePtrList args;
541 auto backed_fnval = fnval;
542 if (fnval->isa<PartialAbstractClosure>()) {
543 auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
544 backed_fnval = partial_closure->fn();
545 args = partial_closure->args();
546 }
547 std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
548 [](const AnfNodePtr &inp) { return inp->abstract(); });
549
550 ScopeGuard scope_guard(new_node->scope());
551
552 auto specialized_node = BuildSpecializedNode(func, backed_fnval, args);
553 auto wrapped_node = specialized_node;
554 if (fnval->isa<PartialAbstractClosure>()) {
555 auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
556 AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)),
557 specialized_node};
558 auto anf_node = partial_closure->node();
559 if (!anf_node->isa<CNode>()) {
560 MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString();
561 }
562 auto cnode = anf_node->cast<CNodePtr>();
563 if (cnode->size() != partial_closure->args().size() + 2) {
564 MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
565 << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
566 }
567 auto attrs = std::make_shared<AttrValueMap>();
568 for (size_t i = 0; i < partial_closure->args().size(); i++) {
569 auto old_node = cnode->input(i + 2);
570 auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
571 if (possibile_value_node != nullptr) {
572 partial_node_list.push_back(possibile_value_node);
573 } else {
574 if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
575 MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
576 }
577 partial_node_list.push_back(old_node);
578 }
579 }
580 MS_EXCEPTION_IF_NULL(new_node->func_graph());
581 wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
582 wrapped_node->set_abstract(partial_closure);
583 }
584 return wrapped_node;
585 }
586
GetEvalCache(const EvaluatorPtr & eval)587 const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
588 MS_EXCEPTION_IF_NULL(eval);
589 auto cache_iter = evalcaches_.find(eval);
590 if (cache_iter == evalcaches_.end()) {
591 evalcaches_[eval] = eval->evaluator_cache_mgr();
592 return eval->evaluator_cache_mgr();
593 }
594 return cache_iter->second;
595 }
596
BuildFromBroadedArgsVal(const EvaluatorPtr & eval)597 std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgsVal(
598 const EvaluatorPtr &eval) {
599 MS_EXCEPTION_IF_NULL(eval);
600 std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
601 EvalResultPtr ret = nullptr;
602 AbstractBasePtrList broaded_argvals;
603 std::vector<AbstractBasePtrList> args_vector;
604 auto eval_cache_iter = evalcaches_.find(eval);
605 if (eval_cache_iter == evalcaches_.end()) {
606 MS_LOG(EXCEPTION) << "Evaluator:" << eval->ToString() << " not exist in cache.";
607 }
608 auto &origin_eval_cache = eval_cache_iter->second->GetCache();
609 for (auto &argvals_map : origin_eval_cache) {
610 auto argvals = argvals_map.first;
611 args_vector.push_back(argvals);
612 broaded_argvals.clear();
613 BroadenArgs(argvals, &broaded_argvals);
614 (void)choices.insert(broaded_argvals);
615 MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
616 }
617 if (choices.size() == 1) {
618 constexpr auto args_size = 2;
619 if (args_vector.size() < args_size) {
620 MS_LOG(EXCEPTION) << "Should have " << args_size << " or more choices, but: " << args_vector.size();
621 }
622 AbstractBasePtrList joined_argvals = args_vector[0];
623 for (size_t i = 1; i < args_vector.size(); ++i) {
624 joined_argvals = abstract::AbstractJoin(joined_argvals, args_vector[i]);
625 }
626 MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals);
627 EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
628 const auto joined_eval_result = origin_eval_cache.get(joined_argvals);
629 if (joined_eval_result != nullptr) {
630 MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString();
631
632 real->SetValue(joined_argvals, joined_eval_result);
633 evalcaches_[eval] = real;
634 return std::make_pair(joined_argvals, joined_eval_result->abstract());
635 } else {
636 bool all_args_tensor = std::all_of(broaded_argvals.cbegin(), broaded_argvals.cend(),
637 [](const AbstractBasePtr &v) { return CheckAbstractTensor(v); });
638 if (all_args_tensor) {
639 ConfigPtrList args_conf_list;
640 (void)std::transform(broaded_argvals.cbegin(), broaded_argvals.cend(), std ::back_inserter(args_conf_list),
641 [](const AbstractBasePtr &v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
642 MS_LOG(WARNING) << "Cannot find joined argvals in cache, run with broaded argsvals: " << broaded_argvals.size()
643 << ", " << ::mindspore::ToString(broaded_argvals);
644 ret = eval->SingleRun(engine_, args_conf_list, nullptr);
645 MS_EXCEPTION_IF_NULL(ret);
646 real->SetValue(broaded_argvals, ret);
647 evalcaches_[eval] = real;
648 return std::make_pair(broaded_argvals, ret->abstract());
649 }
650 }
651 }
652 MS_LOG(DEBUG) << "Choices.size: " << choices.size();
653 return std::make_pair(AbstractBasePtrList(), nullptr);
654 }
655
ProcessCNode(const CNodePtr & new_node)656 void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
657 MS_EXCEPTION_IF_NULL(new_node);
658 if (specializer_->seen().count(new_node) > 0) {
659 return;
660 }
661 specializer_->AddSeen(new_node);
662 auto new_inputs = new_node->inputs();
663 if (new_inputs.empty()) {
664 MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
665 }
666 AnfNodePtr func = new_inputs[0];
667 MS_EXCEPTION_IF_NULL(func);
668
669 // First element is func so arg start from 1
670 std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
671 // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
672 const size_t arg_start_index = 2;
673 while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
674 std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
675 // First element is partial, second is func so arg is start from 2
676 (void)args.insert(args.begin(), inputs.begin() + SizeToInt(arg_start_index), inputs.end());
677 func = inputs[1];
678 }
679 new_inputs = args;
680 (void)new_inputs.insert(new_inputs.begin(), func);
681
682 AbstractBasePtrList argvals;
683 MS_EXCEPTION_IF_NULL(new_inputs[0]);
684 AbstractBasePtr fnval = new_inputs[0]->abstract();
685 MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", "
686 << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString();
687
688 // First element is func so function arguments start from 1
689 for (size_t i = 1; i < new_inputs.size(); ++i) {
690 argvals.push_back(new_inputs[i]->abstract());
691 MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", "
692 << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
693 }
694
695 if (!func->isa<ValueNode>()) {
696 MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString();
697 if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) {
698 auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
699 EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
700 std::pair<AbstractBasePtrList, AbstractBasePtr> result;
701 AbstractBasePtrList empty_args;
702 auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
703 MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
704 // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
705 MS_EXCEPTION_IF_NULL(func->func_graph());
706 if (status == kSpecializeFindUniqueArgvalPoly ||
707 (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
708 auto wrapped_node = BuildSpecializedParameterNode(new_node);
709 new_inputs[0] = wrapped_node;
710 }
711 }
712 }
713
714 if (CanSpecializeNode(func)) {
715 // for primitive node, we build the primitive node with inferred attributes in the first pass
716 // so we do not build replaced node again here in second pass
717 if (IsValueNode<Primitive>(func)) {
718 new_inputs[0] = func;
719 } else {
720 new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
721 }
722 }
723
724 for (size_t i = 0; i < argvals.size();) {
725 size_t next = i + 1;
726 if (CanSpecializeNode(args[i])) {
727 new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
728 }
729 i = next;
730 }
731 new_node->set_inputs(new_inputs);
732 }
733
734 namespace {
DumpEvaluatorCache(const EvaluatorCacheMgrPtr & evaluator_cache_mgr,const AbstractBasePtrList & argvals)735 void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const AbstractBasePtrList &argvals) {
736 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
737 MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
738 int64_t i = 0;
739 const EvalResultCache &map = evaluator_cache_mgr->GetCache();
740 for (const auto &item : map) {
741 MS_LOG(DEBUG) << "evaluator_cache[" << i++ << "]: " << item.first;
742 }
743 }
744
IsPolyFunc(const AbstractFunctionPtr & func,const AbstractBasePtrList & argvals)745 bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
746 MS_EXCEPTION_IF_NULL(func);
747 if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
748 MS_LOG(DEBUG) << "High order primitive return POLY.";
749 return true;
750 }
751 if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
752 auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
753 auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
754 if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
755 auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
756 if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
757 MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
758 return true;
759 }
760 }
761 }
762 return false;
763 }
764 } // end anonymous namespace
765
FindUniqueArgvals(const AbstractFunctionPtr & func,const EvaluatorPtr & eval,const AbstractBasePtrList & argvals,std::pair<AbstractBasePtrList,AbstractBasePtr> * result)766 SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
767 const AbstractBasePtrList &argvals,
768 std::pair<AbstractBasePtrList, AbstractBasePtr> *result) {
769 MS_EXCEPTION_IF_NULL(func);
770 MS_EXCEPTION_IF_NULL(eval);
771 MS_EXCEPTION_IF_NULL(result);
772
773 EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
774 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
775 auto data = evaluator_cache_mgr->GetValue(argvals);
776 if (data != nullptr) {
777 *result = std::make_pair(argvals, data->abstract());
778 return kSpecializeSuccess;
779 }
780 DumpEvaluatorCache(evaluator_cache_mgr, argvals);
781
782 auto cache = GetEvalCache(eval);
783 MS_EXCEPTION_IF_NULL(cache);
784 const EvalResultCache &choices = cache->GetCache();
785 if (choices.get(argvals) != nullptr) {
786 MS_EXCEPTION_IF_NULL(cache->GetValue(argvals));
787 *result = std::make_pair(argvals, cache->GetValue(argvals)->abstract());
788 return kSpecializeSuccess;
789 } else if (choices.size() == 1) {
790 MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
791 MS_EXCEPTION_IF_NULL(choices.begin()->second);
792 *result = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
793 return kSpecializeSuccess;
794 } else if (choices.empty()) {
795 MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
796 << func->type_name();
797 return kSpecializeFindUniqueArgvalDead;
798 } else {
799 if (IsPolyFunc(func, argvals)) {
800 return kSpecializeFindUniqueArgvalPoly;
801 }
802
803 MS_LOG(DEBUG) << "Try to find generalized argvals.";
804 *result = BuildFromBroadedArgsVal(eval);
805 if (!result->first.empty()) {
806 return kSpecializeSuccess;
807 }
808 MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism.";
809 return kSpecializeFindUniqueArgvalPoly;
810 }
811 }
BuildPrimtiveValueWithAttributes(const PrimitivePtr & prim,const AttrValueMapPtr & attrs)812 static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
813 MS_EXCEPTION_IF_NULL(prim);
814 auto &prim_attrs = prim->attrs();
815 bool is_attr_same = true;
816 for (auto &item : *attrs) {
817 auto itr = prim_attrs.find(item.first);
818 if (itr != prim_attrs.end()) {
819 MS_EXCEPTION_IF_NULL(itr->second);
820 MS_EXCEPTION_IF_NULL(item.second);
821 if (!(*(itr->second) == *(item.second))) {
822 is_attr_same = false;
823 break;
824 }
825 } else {
826 is_attr_same = false;
827 break;
828 }
829 }
830 if (!is_attr_same) {
831 auto cloned_prim = prim->Clone();
832 for (auto &item : *attrs) {
833 cloned_prim->AddAttr(item.first, item.second);
834 }
835 return cloned_prim;
836 }
837 return prim;
838 }
839
BuildPossibleValueNode(const AnfNodePtr & origin_node,const AbstractBasePtr & ival,const AttrValueMapPtr & attrs)840 AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
841 const AttrValueMapPtr &attrs) {
842 MS_EXCEPTION_IF_NULL(origin_node);
843 MS_EXCEPTION_IF_NULL(ival);
844
845 AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
846 if (abs != nullptr) {
847 // Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction.
848 if (abs->isa<AbstractFuncUnion>()) {
849 return nullptr;
850 }
851 ValuePtr value = nullptr;
852 if (abs->isa<PrimitiveAbstractClosure>()) {
853 auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
854 // for primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
855 if (attrs != nullptr) {
856 value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
857 } else {
858 value = real_fn->prim();
859 }
860 } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
861 auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
862 value = real_fn->meta_func_graph();
863 } else if (abs->isa<FuncGraphAbstractClosure>()) {
864 auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
865 value = real_fn->func_graph();
866 } else {
867 return nullptr;
868 }
869 MS_EXCEPTION_IF_NULL(value);
870 if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
871 (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
872 return BuildValueNode(value, ival);
873 } else {
874 return nullptr;
875 }
876 } else {
877 ValuePtr val = ival->BuildValue();
878 if (val->isa<AnyValue>()) {
879 return nullptr;
880 }
881 // keep primitive 'depend' not to be optimized
882 if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) {
883 return nullptr;
884 }
885 return BuildValueNode(val, ival);
886 }
887 }
888
MakeConfig(const AnfNodePtr & node)889 inline AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
890 return engine_->MakeConfig(node, context_, func_graph_); // `func_graph_` is dummy here.
891 }
892 } // namespace abstract
893 } // namespace mindspore
894