1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/static_analysis/program_specialize.h"
20
21 #include <algorithm>
22 #include <exception>
23 #include <unordered_set>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/operator/ops.h"
27 #include "frontend/operator/composite/do_signature.h"
28 #include "abstract/abstract_function.h"
29 #include "abstract/utils.h"
30 #include "ir/graph_utils.h"
31 #include "utils/log_adapter.h"
32 #include "utils/compile_config.h"
33 #include "pipeline/jit/ps/debug/trace.h"
34 #include "pipeline/jit/ps/fallback.h"
35 #include "include/common/fallback.h"
36 #include "include/common/utils/convert_utils_py.h"
37
38 namespace mindspore {
39 namespace abstract {
40 namespace {
GetEvalResult(const AnfNodeConfigPtr & conf)41 EvalResultPtr GetEvalResult(const AnfNodeConfigPtr &conf) {
42 try {
43 MS_EXCEPTION_IF_NULL(conf);
44 const auto &eval_result = conf->ObtainEvalResult();
45 MS_EXCEPTION_IF_NULL(eval_result);
46 return eval_result;
47 } catch (const std::exception &e) {
48 constexpr int recursive_level = 2;
49 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
50 if (enable_pre_lift && IsPrimitiveCNode(conf->node(), prim::kPrimPartial)) {
51 MS_LOG(ERROR) << "node: " << conf->node()->DebugString(recursive_level);
52 auto abs_res = std::make_shared<AbstractNone>();
53 auto eval_result = std::make_shared<EvalResult>(abs_res, std::make_shared<AttrValueMap>());
54 return eval_result;
55 }
56 MS_LOG(INTERNAL_EXCEPTION) << "Fail to get eval result with conf " << conf->ToString();
57 }
58 }
59
BuildValueNode(const ValuePtr & v,const AnfNodePtr & origin_node,const AbstractBasePtr & abs_base)60 AnfNodePtr BuildValueNode(const ValuePtr &v, const AnfNodePtr &origin_node, const AbstractBasePtr &abs_base) {
61 MS_EXCEPTION_IF_NULL(abs_base);
62 AnfNodePtr value_node = NewValueNode(v);
63 value_node->set_abstract(abs_base);
64 value_node->set_debug_info(origin_node->debug_info());
65 MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
66 return value_node;
67 }
68
IsVisible(FuncGraphPtr fg,const FuncGraphPtr & parent)69 bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
70 while (fg != nullptr && fg != parent) {
71 fg = fg->parent();
72 }
73 return fg == parent;
74 }
75
CanSpecializeValueNode(const AnfNodePtr & node)76 bool CanSpecializeValueNode(const AnfNodePtr &node) {
77 if (IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
78 return true;
79 }
80 if (IsValueNode<FuncGraph>(node)) {
81 if (node->abstract() != nullptr) {
82 auto abs_func = node->abstract()->cast_ptr<FuncGraphAbstractClosure>();
83 // If this funcgraph had specialized in ProcessCNode of FirstPass,
84 // then ignore it.
85 if (abs_func != nullptr && abs_func->specialized()) {
86 MS_LOG(DEBUG) << "Ignore specializing func graph: " << abs_func->ToString();
87 return false;
88 }
89 }
90 return true;
91 }
92 return false;
93 }
94
PurifyAbstractOfSequence(ProgramSpecializer * const specializer)95 void PurifyAbstractOfSequence(ProgramSpecializer *const specializer) {
96 MS_EXCEPTION_IF_NULL(specializer);
97 constexpr int recursive_level = 2;
98 for (auto &abstract_and_node : specializer->sequence_abstract_list()) {
99 auto &sequence_abs = abstract_and_node.first;
100 MS_EXCEPTION_IF_NULL(sequence_abs);
101 MS_EXCEPTION_IF_NULL(abstract_and_node.second);
102 if (!sequence_abs->PurifyElements()) {
103 MS_LOG(INFO) << "Purify elements failed, abstract: " << sequence_abs->ToString()
104 << ", node: " << abstract_and_node.second->DebugString(recursive_level);
105 } else {
106 MS_LOG(DEBUG) << "Purify elements, abstract: " << sequence_abs->ToString()
107 << ", node: " << abstract_and_node.second->DebugString(recursive_level);
108 }
109 }
110 }
111
112 // Second elimination.
113 // Eliminate the dead node in sequence node, and purify the abstract of sequence node.
EliminateCollectedSequenceNodes(ProgramSpecializer * const specializer)114 void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
115 MS_EXCEPTION_IF_NULL(specializer);
116 // Call PurifyElements() to purify tuple/list elements.
117 static const auto enable_only_mark_unused_element = (common::GetCompileConfig("DDE_ONLY_MARK") == "1");
118 if (enable_only_mark_unused_element) {
119 return;
120 }
121
122 // Purify the abstract of tuple/list.
123 PurifyAbstractOfSequence(specializer);
124 // Eliminate DeadNode in tuple/list.
125 for (auto &dead_node_info : specializer->dead_node_list()) {
126 auto pos = dead_node_info.second;
127 auto node = dead_node_info.first;
128 auto flags = GetSequenceNodeElementsUseFlags(node);
129 if (flags == nullptr) {
130 continue;
131 }
132
133 // Handle MakeTuple/MakeList CNode.
134 auto cnode = dyn_cast_ptr<CNode>(node);
135 if (cnode != nullptr) {
136 if (pos + 1 >= cnode->size()) {
137 continue;
138 }
139 if (!IsDeadNode(cnode->input(pos + 1))) {
140 continue;
141 }
142
143 constexpr int recursive_level = 2;
144 MS_LOG(DEBUG) << "Erase elements[" << pos << "] DeadNode as zero for " << cnode->DebugString(recursive_level);
145 // Change the node.
146 auto zero_value = NewValueNode(MakeValue<int64_t>(0));
147 zero_value->set_abstract(
148 std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(0), std::make_shared<Problem>()));
149 cnode->set_input(pos + 1, zero_value);
150
151 // Change the abstract.
152 (*flags)[pos] = false; // Change the use flag as 0.
153 auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
154 if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
155 MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
156 << ", node: " << node->DebugString(recursive_level);
157 }
158 continue;
159 }
160 // Handle ValueTuple/ValueList.
161 if (IsValueNode<ValueTuple>(node) || IsValueNode<ValueList>(node)) {
162 auto sequence_value = GetValuePtr<ValueSequence>(node);
163 MS_EXCEPTION_IF_NULL(sequence_value);
164 if (pos >= sequence_value->value().size()) {
165 continue;
166 }
167 ValuePtr element_value = sequence_value->value()[pos];
168 auto element_err_value = element_value->cast_ptr<ValueProblem>();
169 if (element_err_value == nullptr || !element_err_value->IsDead()) {
170 continue;
171 }
172
173 MS_LOG(DEBUG) << "Erase elements[" << pos << "] DeadNode as zero for " << node->DebugString();
174 // Change the node.
175 auto zero = MakeValue<int64_t>(0);
176 auto value_list = const_cast<ValuePtrList &>(sequence_value->value());
177 value_list[pos] = zero;
178
179 // Change the abstract.
180 (*flags)[pos] = false; // Change the use flag as 0.
181 auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
182 if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
183 constexpr int recursive_level = 2;
184 MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
185 << ", node: " << node->DebugString(recursive_level);
186 }
187 }
188 }
189 }
190
BroadenArgs(const AbstractBasePtrList & args_abs_list,AbstractBasePtrList * broaded_args)191 void BroadenArgs(const AbstractBasePtrList &args_abs_list, AbstractBasePtrList *broaded_args) {
192 MS_EXCEPTION_IF_NULL(broaded_args);
193 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(*broaded_args),
194 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
195 MS_EXCEPTION_IF_NULL(arg);
196 if (arg->GetValueTrack() != kValueAny) {
197 return arg->Broaden();
198 }
199 return arg;
200 });
201 }
202
203 // These abstract sequence can't handled by DDE.
IsInvalidAbstractSequence(const AbstractSequencePtr & abs)204 bool IsInvalidAbstractSequence(const AbstractSequencePtr &abs) {
205 if (abs == nullptr || abs->isa<AbstractSparseTensor>() || abs->sequence_nodes() == nullptr ||
206 abs->sequence_nodes()->empty()) {
207 return true;
208 }
209 if (abs->dyn_len_arg() || abs->dynamic_len()) {
210 return true;
211 }
212 return false;
213 }
214 } // namespace
215
Run(const FuncGraphPtr & fg,const AnalysisContextPtr & context)216 FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
217 MS_EXCEPTION_IF_NULL(fg);
218 MS_EXCEPTION_IF_NULL(context);
219 MS_LOG(DEBUG) << "Specialize topmost function graph: "
220 << (context->func_graph() ? context->func_graph()->ToString() : "FG(Null)");
221 if (top_context_ == nullptr) {
222 top_context_ = context;
223 MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
224 }
225 auto top_func_graph_spec = NewFuncGraphSpecializer(context, fg);
226 PushFuncGraphTodoItem(top_func_graph_spec);
227 while (!func_graph_todo_items_.empty()) {
228 auto current_fg_spec = func_graph_todo_items_.top();
229 MS_EXCEPTION_IF_NULL(current_fg_spec);
230 if (current_fg_spec->done()) {
231 func_graph_todo_items_.pop();
232 continue;
233 }
234 // run current func graph specializer
235 current_fg_spec->Run();
236 }
237 auto res = top_func_graph_spec->specialized_func_graph();
238 MS_LOG(DEBUG) << "Specialized top graph: " << res->ToString();
239 EliminateCollectedSequenceNodes(this);
240 return res;
241 }
242
GetFuncGraphSpecializer(const AnalysisContextPtr & context)243 std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
244 MS_EXCEPTION_IF_NULL(context);
245 auto iter = specializations_.find(context);
246 if (iter != specializations_.end()) {
247 return iter->second;
248 }
249 return nullptr;
250 }
251
NewFuncGraphSpecializer(const AnalysisContextPtr & context,const FuncGraphPtr & fg)252 FuncGraphSpecializerPtr ProgramSpecializer::NewFuncGraphSpecializer(const AnalysisContextPtr &context,
253 const FuncGraphPtr &fg) {
254 MS_EXCEPTION_IF_NULL(context);
255 auto result = specializations_.emplace(context, nullptr);
256 if (result.second) {
257 MS_LOG(DEBUG) << "Make new specializer of context: " << context->ToString() << ", fg: " << fg->ToString();
258 auto fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
259 result.first->second = fg_spec;
260 return fg_spec;
261 }
262 MS_LOG(INTERNAL_EXCEPTION) << "Specializer exist in cache, can't not create again, context: " << context->ToString();
263 }
264
SetSpecializedAbstract(const AbstractFunctionPtr & old_abs_func,const AbstractFunctionPtr & new_abs_func,const CNodePtr & cnode,const AnfNodePtr & func)265 void ProgramSpecializer::SetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func,
266 const AbstractFunctionPtr &new_abs_func, const CNodePtr &cnode,
267 const AnfNodePtr &func) {
268 MS_EXCEPTION_IF_NULL(cnode);
269 MS_EXCEPTION_IF_NULL(func);
270 MS_EXCEPTION_IF_NULL(old_abs_func);
271 MS_EXCEPTION_IF_NULL(new_abs_func);
272 auto iter = specialized_abs_map_.find(old_abs_func);
273 if (iter == specialized_abs_map_.end()) {
274 MS_LOG(DEBUG) << "Emplace cnode: " << cnode->DebugString() << ", func: " << func->ToString()
275 << ", old_abstract: " << old_abs_func->ToString() << ", new_abs_func: " << new_abs_func->ToString();
276 (void)specialized_abs_map_.emplace(old_abs_func, std::make_pair(true, new_abs_func));
277 } else {
278 MS_LOG(DEBUG) << "Duplicate abstract from cnode: " << cnode->DebugString() << ", func: " << func->ToString()
279 << ", old_abstract: " << old_abs_func->ToString() << ", new_abs_func: " << new_abs_func->ToString();
280 if (!(*iter->second.second == *new_abs_func)) {
281 MS_LOG(DEBUG) << "Duplicate abstract from cnode: " << cnode->DebugString() << ", func: " << func->ToString()
282 << ", old_abstract: " << old_abs_func->ToString() << ", first: " << iter->second.second->ToString()
283 << ", new_abs_func: " << new_abs_func->ToString();
284 // Cannot determined which one to use.
285 iter->second.first = false;
286 }
287 }
288 }
289
GetSpecializedAbstract(const AbstractFunctionPtr & old_abs_func)290 AbstractFunctionPtr ProgramSpecializer::GetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func) {
291 MS_EXCEPTION_IF_NULL(old_abs_func);
292 auto iter = specialized_abs_map_.find(old_abs_func);
293 if (iter != specialized_abs_map_.end()) {
294 if (iter->second.first) {
295 MS_EXCEPTION_IF_NULL(iter->second.second);
296 MS_LOG(DEBUG) << "Find abstract for old_abstract: " << old_abs_func->ToString()
297 << ", new_abs_func: " << iter->second.second->ToString();
298 return iter->second.second;
299 }
300 return nullptr;
301 }
302 if (old_abs_func->isa<FuncGraphAbstractClosure>()) {
303 const auto &old_func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(old_abs_func);
304 auto unique_specialized_abs = GetUniqueFuncGraphAbstract(old_func_graph_abs->func_graph());
305 if (unique_specialized_abs != nullptr) {
306 MS_EXCEPTION_IF_NULL(old_func_graph_abs->func_graph());
307 MS_LOG(DEBUG) << "Find unique abstract for funcgraph: " << old_func_graph_abs->func_graph()->ToString() << " in "
308 << old_abs_func->ToString() << ", unique_abs: " << unique_specialized_abs->ToString();
309 return unique_specialized_abs;
310 }
311 }
312 MS_LOG(DEBUG) << "Cannot find abstract for old_abstract: " << old_abs_func->ToString();
313 return nullptr;
314 }
315
SpecializeAbstractFuncRecursively(const AbstractFunctionPtr & old_abs_func)316 AbstractFunctionPtr ProgramSpecializer::SpecializeAbstractFuncRecursively(const AbstractFunctionPtr &old_abs_func) {
317 MS_EXCEPTION_IF_NULL(old_abs_func);
318 AbstractFunctionPtr new_abs = nullptr;
319 if (old_abs_func->isa<AbstractFuncUnion>()) {
320 AbstractFuncAtomPtrList func_atoms;
321 auto build_new_abs = [this, &func_atoms](const AbstractFuncAtomPtr &poss) {
322 MS_EXCEPTION_IF_NULL(poss);
323 auto resolved_atom = poss;
324 if (poss->isa<AsyncAbstractFuncAtom>()) {
325 auto async_abs_func = poss->cast_ptr<AsyncAbstractFuncAtom>();
326 const auto &resolved_func = async_abs_func->GetUnique();
327 MS_EXCEPTION_IF_NULL(resolved_func);
328 resolved_atom = resolved_func->cast<AbstractFuncAtomPtr>();
329 MS_EXCEPTION_IF_NULL(resolved_atom);
330 MS_LOG(DEBUG) << "Resolved AsyncAbstractFuncAtom is: " << resolved_atom->ToString();
331 }
332 auto specialized_abs = this->SpecializeAbstractFuncRecursively(resolved_atom);
333 AbstractFuncAtomPtr new_abs_atom = nullptr;
334 if (specialized_abs == nullptr) {
335 MS_LOG(DEBUG) << "Cannot resolve old_abs: " << resolved_atom->ToString()
336 << " to specialized abstract, use old one";
337 new_abs_atom = resolved_atom;
338 } else if (specialized_abs->isa<AbstractFuncAtom>()) {
339 MS_LOG(DEBUG) << "Resolve old_abs: " << resolved_atom->ToString()
340 << " to specialized abstract, specialized abstract: " << specialized_abs->ToString();
341 new_abs_atom = specialized_abs->cast<AbstractFuncAtomPtr>();
342 } else {
343 MS_LOG(DEBUG) << "Cannot resolve old_abs: " << resolved_atom->ToString()
344 << " to AbstractFuncAtom, use old one. Specialized abstract: " << specialized_abs->ToString();
345 new_abs_atom = resolved_atom;
346 }
347 func_atoms.push_back(new_abs_atom);
348 };
349 old_abs_func->Visit(build_new_abs);
350 new_abs = std::make_shared<AbstractFuncUnion>(func_atoms);
351 } else if (old_abs_func->isa<FuncGraphAbstractClosure>() || old_abs_func->isa<MetaFuncGraphAbstractClosure>()) {
352 new_abs = GetSpecializedAbstract(old_abs_func);
353 if (new_abs != nullptr) {
354 MS_LOG(DEBUG) << "Find specialized abstract, old_abstract: " << old_abs_func->ToString()
355 << ", specialized_abstract: " << new_abs->ToString();
356 } else {
357 MS_LOG(DEBUG) << "cannot find specialized abstract, old_abstract: " << old_abs_func->ToString();
358 }
359 } else if (old_abs_func->isa<PartialAbstractClosure>()) {
360 const auto &old_partial_abs = old_abs_func->cast<PartialAbstractClosurePtr>();
361 const auto &old_abs_fn = old_partial_abs->fn();
362 auto new_abs_fn = GetSpecializedAbstract(old_abs_fn);
363 if (new_abs_fn != nullptr && new_abs_fn->isa<AbstractFuncAtom>()) {
364 auto new_abs_fn_atom = new_abs_fn->cast<AbstractFuncAtomPtr>();
365 auto new_partial_abs =
366 std::make_shared<PartialAbstractClosure>(new_abs_fn_atom, old_partial_abs->args(), old_partial_abs->node());
367 new_partial_abs->set_need_append_to_end(old_partial_abs->need_append_to_end());
368 new_abs = new_partial_abs;
369 MS_LOG(DEBUG) << "Find specialized abstract, old_abstract: " << old_abs_func->ToString()
370 << ", specialized_abstract: " << new_abs->ToString();
371 } else {
372 MS_LOG(DEBUG) << "Cannot find specialized abstract, old_abstract: " << old_abs_func->ToString();
373 }
374 }
375 return new_abs;
376 }
377
SpecializeCNodeInput0FuncGraph()378 void ProgramSpecializer::SpecializeCNodeInput0FuncGraph() {
379 MS_EXCEPTION_IF_NULL(manager_);
380 const auto &all_nodes = manager_->all_nodes();
381 for (auto node : all_nodes) {
382 MS_EXCEPTION_IF_NULL(node);
383 if (!node->isa<CNode>()) {
384 continue;
385 }
386 auto &input0 = node->cast_ptr<CNode>()->input(0);
387 MS_EXCEPTION_IF_NULL(input0);
388 if (IsValueNode<FuncGraph>(input0) || IsValueNode<Primitive>(input0)) {
389 continue;
390 }
391 MS_EXCEPTION_IF_NULL(node);
392 const auto &old_abs = input0->abstract();
393 if (old_abs == nullptr) {
394 constexpr auto recursive_level = 2;
395 MS_LOG(INTERNAL_EXCEPTION) << "Node's first input abstract should not be null, "
396 << node->DebugString(recursive_level);
397 }
398 if (!(old_abs->isa<FuncGraphAbstractClosure>() || old_abs->isa<MetaFuncGraphAbstractClosure>() ||
399 old_abs->isa<AbstractFuncUnion>() || old_abs->isa<PartialAbstractClosure>())) {
400 continue;
401 }
402 auto old_abs_func = old_abs->cast<AbstractFunctionPtr>();
403 auto new_abs_func = SpecializeAbstractFuncRecursively(old_abs_func);
404 if (new_abs_func != nullptr) {
405 input0->set_abstract(new_abs_func);
406 MS_LOG(DEBUG) << "Find specialized abstract for node: " << input0->DebugString()
407 << ", old_abstract: " << old_abs->ToString()
408 << ", specialized_abstract: " << new_abs_func->ToString();
409 } else {
410 MS_LOG(DEBUG) << "cannot find specialized abstract for node: " << input0->DebugString()
411 << ", old_abstract: " << old_abs_func->ToString();
412 }
413 }
414 }
415
GetNextCounter()416 static int64_t GetNextCounter() {
417 static int64_t g_CloneCounter = 1;
418 return g_CloneCounter++;
419 }
420
FuncGraphSpecializer(ProgramSpecializer * const s,const FuncGraphPtr & fg,const AnalysisContextPtr & context)421 FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
422 const AnalysisContextPtr &context)
423 : specializer_(s), func_graph_(fg), context_(context) {
424 parent_ = s->GetFuncGraphSpecializer(context->parent());
425 MS_EXCEPTION_IF_NULL(context->parent());
426 if (ParentNotSpecialized(context)) {
427 MS_LOG(INTERNAL_EXCEPTION) << "Parent func graph should be handled in advance, fg: " << fg->ToString()
428 << ", context: " << context->ToString()
429 << ", parent context: " << context->parent()->ToString();
430 }
431 engine_ = s->engine();
432 cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
433 specialized_func_graph_ = cloner_->cloned_func_graphs().find(fg)->second;
434 AddTodoItem(fg->get_return());
435 AddTodoItem(fg->parameters());
436 }
437
ReplicateDisconnectedNode(const AnfNodePtr & node)438 AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
439 MS_EXCEPTION_IF_NULL(node);
440 if (node->isa<ValueNode>()) {
441 return node;
442 }
443 std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
444 if (specializer == nullptr) {
445 constexpr auto recursive_level = 2;
446 MS_LOG(INTERNAL_EXCEPTION) << "Specializer should not be null, node: " << node->DebugString(recursive_level)
447 << ", NodeInfo: \n"
448 << trace::GetDebugInfoStr(node->debug_info()) << "\n"
449 << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " has no parent context?";
450 }
451
452 // If had replicated, just return that.
453 auto iter = specializer->cloned_nodes().find(node);
454 if (iter != specializer->cloned_nodes().end()) {
455 return iter->second;
456 }
457 auto new_node = specializer->cloner_->CloneDisconnected(node);
458 if (node->isa<CNode>()) {
459 if (!new_node->isa<CNode>()) {
460 MS_LOG(INTERNAL_EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
461 }
462 UpdateNewCNodeInputs(node, new_node);
463 }
464
465 iter = specializer->cloned_nodes().find(node);
466 if (iter != specializer->cloned_nodes().end()) {
467 if (iter->second == node) {
468 MS_LOG(INTERNAL_EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
469 }
470 } else {
471 MS_LOG(INTERNAL_EXCEPTION) << "Replicate node failed, node: " << node->ToString();
472 }
473 return new_node;
474 }
475
UpdateNewCNodeInputs(const AnfNodePtr & node,const AnfNodePtr & new_node)476 void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
477 MS_EXCEPTION_IF_NULL(node);
478 auto c_node = node->cast_ptr<CNode>();
479 MS_EXCEPTION_IF_NULL(c_node);
480 auto inputs = c_node->weak_inputs();
481 AnfNodeWeakPtrList new_inputs;
482 (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(new_inputs),
483 [this](const AnfNodeWeakPtr &weak_inp) -> AnfNodePtr {
484 auto inp = weak_inp.lock();
485 MS_EXCEPTION_IF_NULL(inp);
486 auto new_inp = ReplicateDisconnectedNode(inp);
487 // Refer the comments in BuildReplacedNode.
488 if (inp->isa<CNode>()) {
489 auto c_inp = inp->cast<CNodePtr>();
490 MS_EXCEPTION_IF_NULL(c_inp);
491 auto c_new_inp = new_inp->cast<CNodePtr>();
492 MS_EXCEPTION_IF_NULL(c_new_inp);
493 MS_EXCEPTION_IF_NULL(c_new_inp->func_graph());
494 MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> "
495 << new_inp->DebugString();
496 c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp);
497 }
498 return new_inp;
499 });
500 MS_EXCEPTION_IF_NULL(new_node);
501 auto c_new_node = new_node->cast_ptr<CNode>();
502 MS_EXCEPTION_IF_NULL(c_new_node);
503 c_new_node->set_weak_inputs(new_inputs);
504 }
505
GetReplicatedNode(const AnfNodePtr & node)506 AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
507 std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
508 if (specializer == nullptr) {
509 constexpr auto recursive_level = 2;
510 MS_LOG(INTERNAL_EXCEPTION) << "Specializer should not be null, node: " << node->DebugString(recursive_level)
511 << ", NodeInfo: \n"
512 << trace::GetDebugInfoStr(node->debug_info()) << "\n"
513 << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " has no parent context?";
514 }
515 auto iter = specializer->cloned_nodes().find(node);
516 if (iter != specializer->cloned_nodes().end()) {
517 return iter->second;
518 }
519 return node;
520 }
521
522 // Return itself if node's ValueNode as top,
523 // return the top func graph specializer as top if node's forward Parameter,
524 // or, return the top parent specializer as top.
GetTopSpecializer(const AnfNodePtr & node)525 std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(const AnfNodePtr &node) {
526 MS_EXCEPTION_IF_NULL(node);
527 FuncGraphPtr fg = node->func_graph();
528 if (fg == nullptr) { // If ValueNode, return current specializer.
529 MS_LOG(DEBUG) << "Node's a ValueNode, node: " << node->DebugString();
530 return shared_from_this();
531 }
532 std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
533 while (fg != specializer->func_graph_) {
534 if (specializer->parent_ == nullptr && node->isa<Parameter>()) {
535 // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
536 auto &top_context = specializer_->top_context();
537 MS_EXCEPTION_IF_NULL(top_context);
538 if (top_context->func_graph() == fg) { // `fg` is top func graph.
539 MS_LOG(INFO) << "Used top func graph specializer as parent for "
540 << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << ", node: " << node->DebugString()
541 << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
542 specializer = specializer_->GetFuncGraphSpecializer(top_context);
543 if (specializer == nullptr) {
544 constexpr auto recursive_level = 2;
545 MS_LOG(INTERNAL_EXCEPTION) << "Specializer must not be null, node: " << node->DebugString(recursive_level)
546 << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
547 }
548 } else {
549 MS_EXCEPTION_IF_NULL(top_context->func_graph());
550 MS_LOG(INFO) << "Used current specializer, fg: " << fg->ToString()
551 << ", current fg: " << specializer->func_graph_->ToString()
552 << ", top fg: " << top_context->func_graph()->ToString();
553 }
554 break;
555 } else {
556 specializer = specializer->parent_;
557 }
558 if (specializer == nullptr) {
559 return nullptr;
560 }
561 }
562 return specializer;
563 }
564
Run()565 void FuncGraphSpecializer::Run() {
566 MS_LOG(DEBUG) << "Before run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
567 << ", cloned func graph name: "
568 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", func graph: "
569 << (func_graph_ ? func_graph_->get_return() ? func_graph_->get_return()->DebugString() : "return null"
570 : "FG(null)");
571 FirstPass();
572 SecondPass();
573 MS_LOG(DEBUG) << "After run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
574 << ", cloned func graph name: "
575 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", new func graph: "
576 << (specialized_func_graph_ ? specialized_func_graph_->get_return()
577 ? specialized_func_graph_->get_return()->DebugString()
578 : "return null"
579 : "FG(null)");
580 }
581
FirstPass()582 void FuncGraphSpecializer::FirstPass() {
583 while (!todo_.empty()) {
584 AnfNodePtr node = todo_.back();
585 todo_.pop_back();
586 if (node->func_graph() == nullptr) {
587 // Do nothing for ValueNode
588 continue;
589 }
590 if (node->func_graph() != func_graph_) {
591 std::shared_ptr<FuncGraphSpecializer> parent = nullptr;
592 if (parent_ != nullptr) {
593 parent = parent_;
594 } else if (specializer_->top_context() && specializer_->top_context()->func_graph() == node->func_graph() &&
595 node->isa<Parameter>()) {
596 // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
597 parent = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
598 MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
599 << ", node: " << node->DebugString()
600 << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
601 }
602 if (parent == nullptr) {
603 MS_LOG(INTERNAL_EXCEPTION) << "Parent must not be null, node: " << node->DebugString()
604 << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
605 }
606 parent->AddTodoItem(node);
607 parent->FirstPass();
608 AnfNodePtr new_node = parent->GetReplicatedNode(node);
609 if (new_node->isa<CNode>()) {
610 MS_LOG(DEBUG) << "ProcessCNode in FirstPass for " << func_graph_->ToString()
611 << ", node: " << node->DebugString() << ", new_node: " << new_node->DebugString();
612 (void)parent->ProcessCNode(new_node->cast<CNodePtr>());
613 }
614 continue;
615 }
616 if (marked_.count(node) > 0) {
617 continue;
618 }
619 (void)marked_.insert(node);
620 ProcessNode(node);
621 }
622 }
623
624 // Specialize CNode in func graphs
SecondPass()625 void FuncGraphSpecializer::SecondPass() {
626 if (second_pass_todo_.empty()) {
627 second_pass_todo_ = BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node());
628 }
629 MS_LOG(DEBUG) << "Start in index: " << second_pass_todo_index_ << ", fg: " << func_graph_->ToString()
630 << ", todo list size: " << second_pass_todo_.size();
631 while (second_pass_todo_index_ < second_pass_todo_.size()) {
632 auto success = ProcessCNode(second_pass_todo_[second_pass_todo_index_]);
633 if (!success) {
634 MS_LOG(DEBUG) << "Suspend in index: " << second_pass_todo_index_
635 << ", node: " << second_pass_todo_[second_pass_todo_index_]->DebugString();
636 return;
637 }
638 ++second_pass_todo_index_;
639 }
640 MS_EXCEPTION_IF_NULL(func_graph_);
641 MS_LOG(DEBUG) << "Set done of fg: " << func_graph_->ToString();
642 done_ = true;
643 }
644
645 namespace {
UpdateForEmptySequenceNode(const AnfNodePtr & new_node,const AnfNodePtr & old_node,const AbstractSequencePtr & old_sequence_abs)646 void UpdateForEmptySequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
647 const AbstractSequencePtr &old_sequence_abs) {
648 if (!IsValueNode<ValueTuple>(new_node) && !IsValueNode<ValueList>(new_node)) {
649 return;
650 }
651 MS_EXCEPTION_IF_NULL(old_sequence_abs);
652 auto sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
653 (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_node));
654 old_sequence_abs->set_sequence_nodes(sequence_nodes);
655 auto flags = GetSequenceNodeElementsUseFlags(old_node);
656 if (flags != nullptr) {
657 SetSequenceNodeElementsUseFlags(new_node, flags);
658 } else {
659 SetSequenceNodeElementsUseFlags(new_node,
660 std::make_shared<std::vector<bool>>(old_sequence_abs->elements().size(), true));
661 }
662 }
663
664 // Update elements use flags for MakeTuple/tuple node,
665 // and update the node's AbstractSequence 'sequence_nodes' info.
UpdateSequenceNode(const AnfNodePtr & new_node,const AnfNodePtr & old_node,const AbstractBasePtr & old_abs)666 void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node, const AbstractBasePtr &old_abs) {
667 if (new_node == old_node) {
668 return;
669 }
670 MS_EXCEPTION_IF_NULL(old_node);
671 auto old_sequence_abs = dyn_cast<AbstractSequence>(old_abs);
672 if (old_sequence_abs == nullptr || old_sequence_abs->isa<AbstractSparseTensor>()) {
673 MS_LOG(DEBUG) << "The abstract is not AbstractTuple/AbstractList, " << old_node->DebugString() << " --> "
674 << new_node->DebugString();
675 return;
676 }
677 if (old_sequence_abs->sequence_nodes() == nullptr || old_sequence_abs->sequence_nodes()->empty()) {
678 MS_LOG(DEBUG) << "No sequence node in old abs, " << old_node->DebugString() << " --> " << new_node->DebugString();
679 // The abstract of old_node may have not sequence_nodes when it is a parameter or tuple output cnode.
680 UpdateForEmptySequenceNode(new_node, old_node, old_sequence_abs);
681 return;
682 }
683
684 // Since the 'old_node' may not equal to 'old_abs' sequence node,
685 // if the new_node is built by the abstract of 'forward old node',
686 // we just set 'new_node' as 'old_abs' sequence node here.
687 if (IsValueNode<ValueTuple>(new_node) || IsValueNode<ValueList>(new_node)) {
688 // Just find a valid sequence node.
689 for (auto &weak_node : *old_sequence_abs->sequence_nodes()) {
690 auto sequence_node = weak_node.lock();
691 if (sequence_node == nullptr) {
692 continue;
693 }
694 auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
695 if (flags == nullptr) {
696 continue;
697 }
698 // Copy the flags to new node, and set new node to sequence abstract.
699 // Actually, here we needn't require unique sequence nodes pointer between abstract any more.
700 SetSequenceNodeElementsUseFlags(new_node, flags);
701 old_sequence_abs->InsertSequenceNode(new_node);
702 return;
703 }
704 MS_LOG(INFO) << "Not found any valid sequence node, " << old_node->DebugString() << " --> "
705 << new_node->DebugString();
706 return;
707 }
708
709 for (auto &weak_node : *old_sequence_abs->sequence_nodes()) {
710 auto sequence_node = weak_node.lock();
711 if (sequence_node == nullptr) {
712 MS_LOG(DEBUG) << "The sequence_nodes is free. " << old_node->DebugString() << " --> " << new_node->DebugString();
713 continue;
714 }
715 if (sequence_node != old_node) {
716 continue;
717 }
718
719 // Update new node's flags with old one, and update old sequence abstract's source node.
720 auto flags = GetSequenceNodeElementsUseFlags(old_node);
721 MS_LOG(DEBUG) << "Update sequence node, " << old_node->DebugString() << " --> " << new_node->DebugString()
722 << ", elements_use_flags: " << (*flags);
723 SetSequenceNodeElementsUseFlags(new_node, flags);
724 old_sequence_abs->UpdateSequenceNode(sequence_node, new_node);
725
726 // Update new sequence abstract if it's not equal to old one.
727 const AbstractBasePtr &new_abs = new_node->abstract();
728 if (old_abs == new_abs) {
729 continue;
730 }
731 MS_LOG(ERROR) << "New abstract, " << old_node->DebugString() << " --> " << new_node->DebugString()
732 << ", elements_use_flags: " << (*flags);
733 auto new_sequence_abs = dyn_cast_ptr<AbstractSequence>(new_abs);
734 if (new_sequence_abs == nullptr) {
735 MS_LOG(INTERNAL_EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString();
736 }
737 if (new_sequence_abs->sequence_nodes() == nullptr || new_sequence_abs->sequence_nodes()->empty()) {
738 std::shared_ptr<AnfNodeWeakPtrList> new_sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
739 (void)new_sequence_nodes->emplace_back(AnfNodeWeakPtr(new_node));
740 new_sequence_abs->set_sequence_nodes(new_sequence_nodes);
741 } else {
742 new_sequence_abs->InsertSequenceNode(new_node);
743 }
744 }
745 }
746
747 // Purify specific input of a CNode.
748 template <typename T, typename S>
PurifySequenceValueNode(const CNodePtr & cnode,size_t index,ProgramSpecializer * const specializer)749 void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecializer *const specializer) {
750 MS_EXCEPTION_IF_NULL(cnode);
751 const auto &old_input = cnode->input(index);
752 MS_EXCEPTION_IF_NULL(old_input);
753 auto sequence_value = GetValuePtr<T>(old_input);
754 if (sequence_value == nullptr) {
755 return;
756 }
757 auto flags = GetSequenceNodeElementsUseFlags(old_input);
758 if (flags == nullptr) {
759 return;
760 }
761 auto old_input_abs = old_input->abstract();
762 MS_EXCEPTION_IF_NULL(old_input_abs);
763 auto old_sequence_abs = dyn_cast<AbstractSequence>(old_input_abs);
764 MS_EXCEPTION_IF_NULL(old_sequence_abs);
765 // Dynamic len abstract sequence no need purify.
766 if (IsInvalidAbstractSequence(old_sequence_abs)) {
767 return;
768 }
769
770 std::vector<size_t> dead_node_positions;
771 ValuePtrList elements;
772 AbstractBasePtrList elements_abs{};
773 auto sequence_value_size = sequence_value->value().size();
774 if (flags->size() < sequence_value_size) {
775 MS_LOG(INTERNAL_EXCEPTION) << "Inner exception. CNode: " << cnode->ToString() << " input: " << old_input->ToString()
776 << " flags size: " << flags->size()
777 << " values size: " << sequence_value->value().size();
778 }
779 for (size_t i = 0; i < sequence_value_size; ++i) {
780 ValuePtr old_sequence_value = sequence_value->value()[i];
781 MS_EXCEPTION_IF_NULL(old_sequence_value);
782 auto old_sequence_err_value = old_sequence_value->cast_ptr<ValueProblem>();
783 if (old_sequence_err_value != nullptr && old_sequence_err_value->IsDead()) {
784 MS_LOG(DEBUG) << "Collect for erasing elements[" << i << "] DeadNode as zero for " << old_input->DebugString()
785 << ", which is inputs[" << index << "] of " << cnode->DebugString();
786 (void)dead_node_positions.emplace_back(i);
787 }
788 if (!(*flags)[i]) {
789 auto zero = MakeValue<int64_t>(0);
790 (void)elements.emplace_back(zero);
791 (void)elements_abs.emplace_back(zero->ToAbstract());
792 MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs["
793 << index << "] of " << cnode->DebugString();
794 } else {
795 (void)elements.emplace_back(old_sequence_value);
796 (void)elements_abs.emplace_back(old_sequence_abs->elements()[i]);
797 }
798 }
799 auto new_sequence_value = std::make_shared<T>(elements);
800 auto new_input = NewValueNode(new_sequence_value);
801 auto new_sequence_abs = std::make_shared<S>(elements_abs);
802 std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
803 (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_input));
804 new_sequence_abs->set_sequence_nodes(sequence_nodes);
805 if constexpr (std::is_same<S, AbstractList>()) {
806 auto old_sequence_abs_list = old_sequence_abs->cast<AbstractListPtr>();
807 MS_EXCEPTION_IF_NULL(old_sequence_abs_list);
808 if (fallback::HasObjInExtraInfoHolder(old_sequence_abs_list)) {
809 MS_LOG(DEBUG) << "old AbstractList has python object, attach it to new AbstractList.";
810 auto list_obj = fallback::GetObjFromExtraInfoHolder(old_sequence_abs_list);
811 auto create_in_graph = fallback::GetCreateInGraphFromExtraInfoHolder(old_sequence_abs_list);
812 fallback::AttachPyObjToExtraInfoHolder(new_sequence_abs, list_obj, create_in_graph);
813 }
814 }
815
816 new_input->set_abstract(new_sequence_abs);
817
818 // Always reset tuple value node's use flags as non-use.
819 SetSequenceNodeElementsUseFlags(new_input, flags);
820 MS_LOG(DEBUG) << "Update ValueTuple/ValueList, " << old_input->DebugString() << " --> " << new_input->DebugString()
821 << ", which is inputs[" << index << "] of " << cnode->DebugString() << ", flags: " << (*flags);
822 // Keep the node not to release before we purify its abstract.
823 (void)specializer->sequence_abstract_list().emplace_back(std::pair(new_sequence_abs, old_input));
824 for (size_t pos : dead_node_positions) {
825 (void)specializer->dead_node_list().emplace_back(std::pair(new_input, pos));
826 }
827 cnode->set_input(index, new_input);
828 }
829
PurifyNamedTupleValueNode(const CNodePtr & cnode,size_t index,ProgramSpecializer * const specializer)830 void PurifyNamedTupleValueNode(const CNodePtr &cnode, size_t index, ProgramSpecializer *const specializer) {
831 MS_EXCEPTION_IF_NULL(cnode);
832 const auto &old_input = cnode->input(index);
833 MS_EXCEPTION_IF_NULL(old_input);
834 auto sequence_value = GetValuePtr<ValueNamedTuple>(old_input);
835 if (sequence_value == nullptr) {
836 return;
837 }
838 auto flags = GetSequenceNodeElementsUseFlags(old_input);
839 if (flags == nullptr) {
840 return;
841 }
842 auto old_input_abs = old_input->abstract();
843 MS_EXCEPTION_IF_NULL(old_input_abs);
844 auto old_sequence_abs = dyn_cast<AbstractSequence>(old_input_abs);
845 MS_EXCEPTION_IF_NULL(old_sequence_abs);
846 // Dynamic len abstract sequence no need purify.
847 if (IsInvalidAbstractSequence(old_sequence_abs)) {
848 return;
849 }
850
851 std::vector<size_t> dead_node_positions;
852 ValuePtrList elements;
853 AbstractBasePtrList elements_abs{};
854 auto sequence_value_size = sequence_value->value().size();
855 if (flags->size() < sequence_value_size) {
856 MS_LOG(INTERNAL_EXCEPTION) << "Inner exception. CNode: " << cnode->ToString() << " input: " << old_input->ToString()
857 << " flags size: " << flags->size()
858 << " values size: " << sequence_value->value().size();
859 }
860 for (size_t i = 0; i < sequence_value_size; ++i) {
861 ValuePtr old_sequence_value = sequence_value->value()[i];
862 MS_EXCEPTION_IF_NULL(old_sequence_value);
863 auto old_sequence_err_value = old_sequence_value->cast_ptr<ValueProblem>();
864 if (old_sequence_err_value != nullptr && old_sequence_err_value->IsDead()) {
865 MS_LOG(DEBUG) << "Collect for erasing elements[" << i << "] DeadNode as zero for " << old_input->DebugString()
866 << ", which is inputs[" << index << "] of " << cnode->DebugString();
867 (void)dead_node_positions.emplace_back(i);
868 }
869 if (!(*flags)[i]) {
870 auto zero = MakeValue<int64_t>(0);
871 (void)elements.emplace_back(zero);
872 (void)elements_abs.emplace_back(zero->ToAbstract());
873 MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs["
874 << index << "] of " << cnode->DebugString();
875 } else {
876 (void)elements.emplace_back(old_sequence_value);
877 (void)elements_abs.emplace_back(old_sequence_abs->elements()[i]);
878 }
879 }
880
881 const auto &sub_class_name = sequence_value->sub_class_name();
882 const auto &keys = sequence_value->key();
883 abstract::AbstractBasePtrList key_abs;
884 (void)std::transform(keys.begin(), keys.end(), std::back_inserter(key_abs), [](const ValuePtr &key) {
885 MS_EXCEPTION_IF_NULL(key);
886 return key->ToAbstract();
887 });
888 auto new_sequence_value = std::make_shared<ValueNamedTuple>(sub_class_name, keys, elements);
889 auto new_input = NewValueNode(new_sequence_value);
890 auto new_sequence_abs = std::make_shared<AbstractNamedTuple>(sub_class_name, key_abs, elements_abs);
891 std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
892 (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_input));
893 new_sequence_abs->set_sequence_nodes(sequence_nodes);
894
895 new_input->set_abstract(new_sequence_abs);
896
897 // Always reset tuple value node's use flags as non-use.
898 SetSequenceNodeElementsUseFlags(new_input, flags);
899 MS_LOG(DEBUG) << "Update ValueNamedTuple, " << old_input->DebugString() << " --> " << new_input->DebugString()
900 << ", which is inputs[" << index << "] of " << cnode->DebugString() << ", flags: " << (*flags);
901 // Keep the node not to release before we purify its abstract.
902 (void)specializer->sequence_abstract_list().emplace_back(std::pair(new_sequence_abs, old_input));
903 for (size_t pos : dead_node_positions) {
904 (void)specializer->dead_node_list().emplace_back(std::pair(new_input, pos));
905 }
906 cnode->set_input(index, new_input);
907 }
908 } // namespace
909
910 // First elimination.
911 // Eliminate the unused items of Tuple/List.
912 // Just adjust the nodes, not change the abstracts and dead nodes.
EliminateUnusedSequenceItem(const CNodePtr & cnode) const913 void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) const {
914 if (cnode == nullptr || cnode->abstract() == nullptr) {
915 MS_LOG(INTERNAL_EXCEPTION) << "The parameter \'node\' and its abstract should not be null.";
916 }
917 auto &sequence_abstract_list = specializer_->sequence_abstract_list();
918
919 // Add CNode's inputs if they're sequence abstract, and sequence nodes exist.
920 (void)std::for_each(cnode->weak_inputs().cbegin(), cnode->weak_inputs().cend(),
921 [&sequence_abstract_list](const AnfNodeWeakPtr &weak_input) {
922 auto input = weak_input.lock();
923 MS_EXCEPTION_IF_NULL(input);
924 const AbstractBasePtr input_abs = input->abstract();
925 AbstractSequencePtr input_sequence_abs = dyn_cast<AbstractSequence>(input_abs);
926 if (IsInvalidAbstractSequence(input_sequence_abs)) {
927 return;
928 }
929 // Not call PurifyElements() here, just add to list.
930 (void)sequence_abstract_list.emplace_back(std::pair(input_sequence_abs, input));
931 });
932
933 // Add CNode if it's sequence abstract, and sequence nodes exist.
934 const AbstractBasePtr abs = cnode->abstract();
935 AbstractSequencePtr sequence_abs = dyn_cast<AbstractSequence>(abs);
936 if (IsInvalidAbstractSequence(sequence_abs)) {
937 return;
938 }
939 // Not call PurifyElements() here, just add to list.
940 (void)sequence_abstract_list.emplace_back(std::pair(sequence_abs, cnode));
941
942 // Purify MakeTuple/MakeList CNode.
943 if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
944 auto flags = GetSequenceNodeElementsUseFlags(cnode);
945 if (flags != nullptr) {
946 std::vector<AnfNodePtr> inputs;
947 (void)inputs.emplace_back(cnode->input(0));
948 for (size_t i = 0; i < (*flags).size(); ++i) {
949 auto old_input = cnode->input(i + 1);
950 if (!(*flags)[i]) {
951 auto zero_value = NewValueNode(MakeValue<int64_t>(0));
952 zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(0)));
953 (void)inputs.emplace_back(zero_value);
954 constexpr int recursive_level = 2;
955 MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << cnode->DebugString(recursive_level);
956 } else if (IsDeadNode(old_input)) {
957 constexpr int recursive_level = 2;
958 MS_LOG(DEBUG) << "Collect for erasing elements[" << i << "] DeadNode as zero for " << cnode << "/"
959 << cnode->DebugString(recursive_level);
960 (void)specializer_->dead_node_list().emplace_back(std::pair(cnode, i));
961 (void)inputs.emplace_back(old_input);
962 } else {
963 (void)inputs.emplace_back(old_input);
964 }
965 }
966 cnode->set_inputs(std::move(inputs));
967 cnode->set_abstract(sequence_abs);
968 }
969 }
970 // Purify each Tuple/List ValueNode in CNode.
971 for (size_t i = 1; i < cnode->size(); ++i) {
972 if (IsValueNode<ValueTuple>(cnode->input(i))) {
973 if (IsValueNode<ValueNamedTuple>(cnode->input(i))) {
974 PurifyNamedTupleValueNode(cnode, i, specializer_);
975 } else {
976 PurifySequenceValueNode<ValueTuple, AbstractTuple>(cnode, i, specializer_);
977 }
978 } else if (IsValueNode<ValueList>(cnode->input(i))) {
979 PurifySequenceValueNode<ValueList, AbstractList>(cnode, i, specializer_);
980 }
981 }
982 }
983
ProcessNode(const AnfNodePtr & node)984 void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
985 MS_EXCEPTION_IF_NULL(node);
986 ScopeGuard scope_guard(node->scope());
987 AnfNodeConfigPtr conf = MakeConfig(node);
988 MS_EXCEPTION_IF_NULL(conf);
989 TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
990 AnfNodePtr new_node = GetReplicatedNode(node);
991 MS_EXCEPTION_IF_NULL(new_node);
992 if (new_node->func_graph() != specialized_func_graph_) {
993 MS_LOG(INTERNAL_EXCEPTION) << "Found not specialized node, node: " << node->DebugString()
994 << ", new_node: " << new_node->DebugString() << ", new_node->func_graph(): "
995 << (new_node->func_graph() ? new_node->func_graph()->ToString() : "FG(Null)")
996 << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
997 }
998 const EvalResultPtr &conf_eval_result = GetEvalResult(conf);
999 MS_EXCEPTION_IF_NULL(conf_eval_result);
1000 new_node->set_abstract(conf_eval_result->abstract());
1001 MS_EXCEPTION_IF_NULL(new_node->abstract());
1002
1003 // Update PartialAbstractClosure's bound node.
1004 if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
1005 auto partial_closure = dyn_cast_ptr<PartialAbstractClosure>(new_node->abstract());
1006 MS_EXCEPTION_IF_NULL(partial_closure);
1007 auto partial_node = partial_closure->node();
1008 if (partial_node != nullptr && GetTopSpecializer(partial_node) != nullptr) {
1009 auto new_partial_node = GetReplicatedNode(partial_node);
1010 if (new_partial_node != partial_node) { // Old Partial CNode was replaced. Need update.
1011 partial_closure->set_node(new_partial_node);
1012 }
1013 }
1014 }
1015 MS_LOG(DEBUG) << "Set new_node: " << new_node->DebugString() << ", abstract as: " << new_node->abstract()->ToString()
1016 << ", func_graph_: " << func_graph_->ToString()
1017 << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
1018
1019 if (!node->isa<CNode>()) {
1020 return;
1021 }
1022 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1023 auto attrs = conf_eval_result->attribute();
1024 auto c_old = node->cast_ptr<CNode>();
1025 auto c_new = new_node->cast_ptr<CNode>();
1026 MS_EXCEPTION_IF_NULL(c_new);
1027 auto new_inputs = c_new->weak_inputs();
1028 auto old_inputs = c_old->weak_inputs();
1029 for (size_t i = 0; i < old_inputs.size(); ++i) {
1030 auto node_input = old_inputs[i].lock();
1031 MS_EXCEPTION_IF_NULL(node_input);
1032 AnfNodeConfigPtr input_conf = MakeConfig(node_input);
1033 MS_EXCEPTION_IF_NULL(input_conf);
1034 const auto &eval_result = GetEvalResult(input_conf);
1035 const AbstractBasePtr &abs = eval_result->abstract();
1036 // Check if there's an inplace abstract and use it.
1037 AbstractBasePtr real_abs;
1038 if (abs->inplace_abstract() == nullptr) {
1039 real_abs = abs;
1040 } else {
1041 real_abs = abs->inplace_abstract();
1042 MS_LOG(INFO) << "Use inplace abstract, " << abs->ToString() << " -> " << real_abs->ToString();
1043 }
1044 bool ignore_build_value = false;
1045 AnfNodePtr replace_node = nullptr;
1046 MS_EXCEPTION_IF_NULL(specializer_->engine());
1047 if (specializer_->engine()->check_side_effect()) {
1048 auto cnode_input = dyn_cast_ptr<CNode>(node_input);
1049 ignore_build_value = (cnode_input != nullptr && cnode_input->has_side_effect_node());
1050 if (ignore_build_value) {
1051 MS_LOG(INFO) << "Don't build value node for CNode which contains isolated side-effect inputs, node: "
1052 << cnode_input->DebugString() << ", flag: " << cnode_input->has_side_effect_node();
1053 }
1054 }
1055 if (!ignore_build_value) {
1056 // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
1057 // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
1058 replace_node = BuildPossibleValueNode(node_input, real_abs, attrs, node);
1059 }
1060 if (replace_node == nullptr) {
1061 replace_node = BuildReplacedNode(input_conf);
1062 MS_EXCEPTION_IF_NULL(replace_node);
1063 replace_node->set_abstract(real_abs);
1064 MS_LOG(DEBUG) << "Set replaced input[" << i << "]: " << replace_node->DebugString()
1065 << ", NodeConfig: " << input_conf->ToString() << ", result: " << real_abs.get() << "/"
1066 << real_abs->ToString();
1067 } else {
1068 MS_EXCEPTION_IF_NULL(real_abs);
1069 MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
1070 << ", real_abs: " << real_abs->ToString() << ", replace_node: " << replace_node->DebugString();
1071 }
1072 MS_EXCEPTION_IF_NULL(replace_node);
1073 if (enable_eliminate_unused_element) {
1074 UpdateSequenceNode(replace_node, node_input, real_abs);
1075 }
1076 if (new_inputs[i].lock() != replace_node) {
1077 new_node->func_graph()->AddOwnNode(replace_node);
1078 new_inputs[i] = replace_node;
1079 MS_LOG(DEBUG) << "Set new_input[" << i << "]: " << replace_node->DebugString();
1080 }
1081 }
1082 c_new->set_weak_inputs(new_inputs);
1083 MS_LOG(DEBUG) << "Update cnode: " << c_new << "/" << c_new->DebugString();
1084 }
1085
BuildReplacedNode(const AnfNodeConfigPtr & conf)1086 AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
1087 MS_EXCEPTION_IF_NULL(conf);
1088 auto conf_iter = engine_->anfnode_config_map().find(conf);
1089 AnfNodeConfigPtr new_conf = conf;
1090 while (conf_iter != engine_->anfnode_config_map().end()) {
1091 MS_LOG(DEBUG) << "Origin conf: node(" << (new_conf->node() ? new_conf->node()->DebugString() : "Node(Null)") << ")";
1092 new_conf = conf_iter->second;
1093 MS_EXCEPTION_IF_NULL(new_conf);
1094 const auto &forward_node = new_conf->node();
1095 MS_EXCEPTION_IF_NULL(forward_node);
1096 MS_LOG(DEBUG) << "Replaced conf: node(" << forward_node->DebugString() << ")";
1097 const auto &replicated_forward_node = ReplicateDisconnectedNode(forward_node);
1098 if (replicated_forward_node && replicated_forward_node->isa<CNode>()) {
1099 // The AnfNode in order_list can be:
1100 // case 1: also in FuncGraphManager, so it can be got from nodes API of func_graph. it will
1101 // be replaced in CloneOrderList in Cloner.
1102 // case 2: AnfNode is not in FuncGraphManager which generated in Analyze phase, so it will not
1103 // be cloned by normal clone API.
1104 // 2.1: A forward node , the original node is in FuncGraphManager. The original node will
1105 // be cloned in CloneOrderList in Cloner, and the replicated forward node will replace
1106 // the replicated original node here.
1107 // 2.2: an input of a forward node, such as Cast CNode generated in DoCast. It is also another
1108 // original node to fowrad.
1109 // 2.3: an input of an input of a forward node, but it's not an original node. Like the Cast CNode
1110 // in MixedPrecisionCastHelper.
1111 // For 2.2 and 2.3, we will put a placeholder in order list of replicated func_graph, refer to
1112 // CloneOrderlist, and it will be replaced inside ReplicateDisconnectedNode.
1113 // For 2.1 the following code will do the job, replace replicated origin cnode with the replicated
1114 // forward one in the replicated func_graph.
1115 MS_EXCEPTION_IF_NULL(conf_iter->first);
1116 const auto &origin_node = conf_iter->first->node();
1117 const auto &replicated_origin_node = GetReplicatedNode(origin_node);
1118 if (replicated_origin_node != origin_node) {
1119 MS_LOG(DEBUG) << "Replace replicated origin node in order list: " << replicated_origin_node->DebugString()
1120 << ", with replicated forwarded node: " << replicated_forward_node->DebugString();
1121 MS_EXCEPTION_IF_NULL(replicated_forward_node->func_graph());
1122 replicated_forward_node->func_graph()->ReplaceInOrder(replicated_origin_node, replicated_forward_node);
1123 } else {
1124 MS_LOG(INTERNAL_EXCEPTION) << "Origin node is not replicated in specialized func_graph, origin node: "
1125 << (origin_node ? origin_node->DebugString() : "Node(Null)");
1126 }
1127 }
1128 conf_iter = engine_->anfnode_config_map().find(new_conf);
1129 }
1130 AddTodoItem(new_conf->node());
1131 auto repl = GetReplicatedNode(new_conf->node());
1132 if (repl->func_graph()) {
1133 MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node: " << repl->DebugString()
1134 << ") to replace origin: " << new_conf->node()->DebugString();
1135 } else {
1136 MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
1137 << ") to replace origin: " << new_conf->node()->DebugString();
1138 }
1139 return repl;
1140 }
1141
BuildSpecializedNode(const CNodePtr & cnode,const AnfNodePtr & func,const AbstractBasePtr & abs,const AbstractBasePtrList & args_abs_list)1142 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, const AnfNodePtr &func,
1143 const AbstractBasePtr &abs,
1144 const AbstractBasePtrList &args_abs_list) {
1145 MS_EXCEPTION_IF_NULL(abs);
1146 MS_EXCEPTION_IF_NULL(func);
1147 auto real_a = dyn_cast_ptr<AbstractFunction>(abs);
1148 MS_EXCEPTION_IF_NULL(real_a);
1149
1150 AbstractFunctionPtr func_abs = real_a->GetUnique();
1151 SpecializeStatusCode errcode;
1152 ScopeGuard scope_guard(func->scope());
1153 AnfNodePtr specialized_node = BuildSpecializedNodeInner(cnode, func, abs, func_abs, args_abs_list, &errcode);
1154 if (specialized_node == nullptr) {
1155 // If errcode is success, it means child graph specialize.
1156 if (errcode == kSpecializeSuccess) {
1157 return nullptr;
1158 }
1159 if (errcode == kSpecializeDead) {
1160 const auto err_dead_value = std::make_shared<ValueProblem>(ValueProblemType::kDead);
1161 const auto err_dead_abstract = std::make_shared<AbstractProblem>(err_dead_value, func);
1162 specialized_node = BuildValueNode(err_dead_value, cnode, err_dead_abstract);
1163 constexpr auto recursive_level = 2;
1164 MS_LOG(DEBUG) << "DEAD for func: " << func->DebugString(recursive_level) << ", abstract: " << abs->ToString();
1165 } else if (errcode == kSpecializePoly) {
1166 const auto error_poly_value = std::make_shared<ValueProblem>(ValueProblemType::kPoly);
1167 const auto error_poly_abstract = std::make_shared<AbstractProblem>(error_poly_value, func);
1168 specialized_node = BuildValueNode(error_poly_value, cnode, error_poly_abstract);
1169 constexpr auto recursive_level = 2;
1170 MS_LOG(DEBUG) << "POLY for func: " << func->DebugString(recursive_level) << ", abstract: " << abs->ToString();
1171 } else {
1172 MS_LOG(INTERNAL_EXCEPTION) << "Failed to build specialized func, func: " << func->DebugString()
1173 << ", abstract: " << abs->ToString();
1174 }
1175 }
1176
1177 // Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
1178 MS_EXCEPTION_IF_NULL(func_abs);
1179 if (func_abs->isa<MetaFuncGraphAbstractClosure>()) {
1180 auto specialized_fg = GetValuePtr<FuncGraph>(specialized_node);
1181 if (specialized_fg != nullptr && (args_abs_list.size() > 1) && args_abs_list.back() != nullptr &&
1182 args_abs_list.back()->isa<AbstractUMonad>()) {
1183 specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
1184 }
1185 }
1186 return specialized_node;
1187 }
1188
BuildSpecializedNodeInner(const CNodePtr & cnode,const AnfNodePtr & func,const AbstractBasePtr & abs,const AbstractFunctionPtr & func_abs,const AbstractBasePtrList & args,SpecializeStatusCode * errcode)1189 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode, const AnfNodePtr &func,
1190 const AbstractBasePtr &abs,
1191 const AbstractFunctionPtr &func_abs,
1192 const AbstractBasePtrList &args,
1193 SpecializeStatusCode *errcode) {
1194 MS_EXCEPTION_IF_NULL(abs);
1195 MS_EXCEPTION_IF_NULL(func_abs);
1196 MS_EXCEPTION_IF_NULL(errcode);
1197 *errcode = kSpecializeSuccess;
1198 auto real_func = dyn_cast_ptr<TypedPrimitiveAbstractClosure>(func_abs);
1199 if (real_func != nullptr) {
1200 return BuildValueNode(real_func->prim(), cnode, abs);
1201 }
1202
1203 EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
1204 MS_EXCEPTION_IF_NULL(eval);
1205 eval->set_bound_node(cnode);
1206 AbstractBasePtrList args_abs_list = eval->NormalizeArgs(args);
1207 std::pair<AbstractBasePtrList, AbstractBasePtr> result;
1208 SpecializeStatusCode status = AcquireUniqueEvalResult(func_abs, eval, args_abs_list, &result);
1209 if (status != kSpecializeSuccess) {
1210 *errcode = status;
1211 return nullptr;
1212 }
1213 args_abs_list = result.first;
1214 AbstractBasePtr unique_output = result.second;
1215
1216 auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func_abs);
1217 if (prim_func != nullptr) {
1218 auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), args_abs_list, unique_output);
1219 return BuildValueNode(prim_func->prim(), cnode, type_func);
1220 }
1221
1222 if (!eval->isa<BaseFuncGraphEvaluator>()) {
1223 MS_LOG(INTERNAL_EXCEPTION) << "Expect the eval is a BaseGraphEvaluator, but got " << eval->ToString()
1224 << ", func: " << func->DebugString() << ", abs: " << func_abs->ToString()
1225 << ", args: " << args;
1226 }
1227 auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
1228
1229 if (func_abs->context() == nullptr) {
1230 MS_LOG(INTERNAL_EXCEPTION) << "Func context is nullptr NodeInfo: "
1231 << trace::GetDebugInfoStr(func_graph_->debug_info());
1232 }
1233 auto context = GetAnalysisContext(engine_, real_eval, args_abs_list);
1234 if (context == nullptr) {
1235 MS_LOG(INTERNAL_EXCEPTION) << "Failed to get context from static analysis cache, call node: "
1236 << cnode->DebugString() << ", args: " << mindspore::ToString(args);
1237 }
1238
1239 constexpr auto recursive_level = 2;
1240 MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << args_abs_list
1241 << ", func: " << func->DebugString(recursive_level) << ", context: " << context.get() << ", "
1242 << context->ToString();
1243 MS_EXCEPTION_IF_NULL(context->func_graph());
1244 if (context->func_graph()->stub()) {
1245 MS_EXCEPTION_IF_NULL(context->func_graph()->get_return());
1246 MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
1247 << ", args: " << args_abs_list.size()
1248 << ", graph: " << context->func_graph()->get_return()->DebugString() << ", " << func->ToString();
1249 return func;
1250 }
1251 // Get the upper most func graph of which parent has been specialized.
1252 while (ParentNotSpecialized(context)) {
1253 context = context->parent();
1254 }
1255 auto fg_spec = specializer_->GetFuncGraphSpecializer(context);
1256 // If func graph specializer dose not exist before, make a new specializer and push to stack, and return nullptr.
1257 if (fg_spec == nullptr) {
1258 fg_spec = specializer_->NewFuncGraphSpecializer(context, context->func_graph());
1259 specializer_->PushFuncGraphTodoItem(fg_spec);
1260 return nullptr;
1261 }
1262
1263 FuncGraphPtr func_graph = fg_spec->specialized_func_graph();
1264 MS_LOG(DEBUG) << "Get spec fg of func graph: " << context->func_graph()->ToString()
1265 << ", specialized fg: " << func_graph->ToString();
1266 MS_EXCEPTION_IF_NULL(func_graph);
1267 func_graph->set_flag(kFuncGraphFlagUndetermined, false);
1268 static auto dummy_context = AnalysisContext::DummyContext();
1269 MS_EXCEPTION_IF_NULL(dummy_context);
1270 // Build a map that map unspecialized abstract function to specialized function, later it can be used
1271 // for specialize input0 of CNode in specialized func graph if input0 is not FuncGraph.
1272 auto new_abs_func = std::make_shared<FuncGraphAbstractClosure>(func_graph, dummy_context, nullptr, true);
1273 specializer_->SetSpecializedAbstract(func_abs, new_abs_func, cnode, func);
1274 if (func_abs->isa<FuncGraphAbstractClosure>()) {
1275 const auto &func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(func_abs);
1276 specializer_->SetSpecializedFuncGraphToAbstract(func_graph_abs->func_graph(), new_abs_func);
1277 }
1278 return BuildValueNode(func_graph, cnode, new_abs_func);
1279 }
1280
1281 // The CNode function is Parameter.
1282 // If the Parameter is PartialApp, unpack it and rebuild a new one.
BuildSpecializedParameterCNode(const CNodePtr & cnode)1283 AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterCNode(const CNodePtr &cnode) {
1284 MS_EXCEPTION_IF_NULL(cnode);
1285 auto new_inputs = cnode->weak_inputs();
1286 if (new_inputs.empty()) {
1287 MS_LOG(INTERNAL_EXCEPTION) << "inputs can't be empty.";
1288 }
1289 AnfNodePtr func = new_inputs[0].lock();
1290 MS_EXCEPTION_IF_NULL(func);
1291 AbstractBasePtr func_abs = func->abstract();
1292
1293 AbstractBasePtrList args;
1294 auto real_func_abs = func_abs;
1295 MS_EXCEPTION_IF_NULL(func_abs);
1296 if (func_abs->isa<PartialAbstractClosure>()) {
1297 auto partial_closure = dyn_cast_ptr<PartialAbstractClosure>(func_abs);
1298 real_func_abs = partial_closure->fn();
1299 args = partial_closure->args();
1300 }
1301 (void)std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
1302 [](const AnfNodeWeakPtr &weak_inp) -> AbstractBasePtr {
1303 auto inp = weak_inp.lock();
1304 MS_EXCEPTION_IF_NULL(inp);
1305 return inp->abstract();
1306 });
1307
1308 ScopeGuard scope_guard(cnode->scope());
1309 auto specialized_node = BuildSpecializedNode(cnode, func, real_func_abs, args);
1310 if (specialized_node == nullptr) {
1311 return nullptr;
1312 }
1313
1314 // Built for Non-Partial parameter function.
1315 if (!func_abs->isa<PartialAbstractClosure>()) {
1316 MS_LOG(DEBUG) << "cnode: " << cnode->DebugString() << ", func_abs: " << func_abs->ToString()
1317 << ", specialized_node: " << specialized_node->DebugString();
1318 return specialized_node;
1319 }
1320
1321 // To build for Partial parameter function.
1322 auto partial_closure = dyn_cast<PartialAbstractClosure>(func_abs);
1323 AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, cnode, FromValueInside(prim::kPrimPartial)),
1324 specialized_node};
1325 auto partial_node = partial_closure->node();
1326 if (partial_node == nullptr) {
1327 MS_LOG(INTERNAL_EXCEPTION) << "Partial node is null, cnode: " << cnode->DebugString()
1328 << ", func_abs: " << func_abs->ToString();
1329 }
1330 if (!partial_node->isa<CNode>()) {
1331 MS_LOG(INTERNAL_EXCEPTION) << "Must be cnode, but " << partial_node->DebugString();
1332 }
1333 auto partial_cnode = partial_node->cast<CNodePtr>();
1334 constexpr auto extra_args_size = 2;
1335 if (partial_cnode->size() != partial_closure->args().size() + extra_args_size) {
1336 MS_LOG(INTERNAL_EXCEPTION) << "Size of cnode: " << partial_cnode->DebugString()
1337 << " is not equal to 2 added to size of args: "
1338 << mindspore::ToString(partial_closure->args());
1339 }
1340 auto attrs = std::make_shared<AttrValueMap>();
1341 for (size_t i = 0; i < partial_closure->args().size(); i++) {
1342 auto old_node = partial_cnode->input(i + extra_args_size);
1343 MS_EXCEPTION_IF_NULL(old_node);
1344 auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
1345 if (possibile_value_node != nullptr) {
1346 partial_node_list.push_back(possibile_value_node);
1347 } else {
1348 if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
1349 MS_LOG(INTERNAL_EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
1350 }
1351 partial_node_list.push_back(old_node);
1352 }
1353 }
1354 MS_EXCEPTION_IF_NULL(cnode->func_graph());
1355 auto wrapped_node = cnode->func_graph()->NewCNode(std::move(partial_node_list));
1356 wrapped_node->set_abstract(partial_closure);
1357 MS_LOG(DEBUG) << "cnode: " << cnode->DebugString() << ", func_abs: " << func_abs->ToString()
1358 << ", wrapped_node: " << wrapped_node->DebugString();
1359 return wrapped_node;
1360 }
1361
GetEvalCache(const EvaluatorPtr & eval)1362 const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
1363 MS_EXCEPTION_IF_NULL(eval);
1364 auto cache_iter = eval_cache_.find(eval);
1365 if (cache_iter == eval_cache_.end()) {
1366 eval_cache_[eval] = eval->evaluator_cache_mgr();
1367 return eval->evaluator_cache_mgr();
1368 }
1369 return cache_iter->second;
1370 }
1371
BuildFromBroadedArgs(const EvaluatorPtr & eval)1372 std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgs(const EvaluatorPtr &eval) {
1373 MS_EXCEPTION_IF_NULL(eval);
1374 std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
1375 EvalResultPtr res = nullptr;
1376 AbstractBasePtrList broaded_args_list;
1377 std::vector<AbstractBasePtrList> args_vector;
1378 auto eval_cache_iter = eval_cache_.find(eval);
1379 if (eval_cache_iter == eval_cache_.end()) {
1380 MS_LOG(INTERNAL_EXCEPTION) << "Evaluator: " << eval->ToString() << " not exist in cache.";
1381 }
1382 MS_EXCEPTION_IF_NULL(eval_cache_iter->second);
1383 auto &origin_eval_cache = eval_cache_iter->second->GetCache();
1384 for (auto &args_map : origin_eval_cache) {
1385 auto args = args_map.first;
1386 args_vector.push_back(args);
1387 }
1388 // If joinable, maybe choices size is 1 or dynamic shape.
1389 constexpr auto args_size = 2;
1390 if (args_vector.size() < args_size) {
1391 MS_LOG(INTERNAL_EXCEPTION) << "Should have " << args_size << " or more choices, but: " << args_vector.size();
1392 }
1393 AbstractBasePtrList joined_args = args_vector[0];
1394 for (size_t i = 1; i < args_vector.size(); ++i) {
1395 // The args may be not joinable (AbstractScalar join with AbstractTensor), just ignore that case.
1396 try {
1397 MS_LOG_TRY_CATCH_SCOPE;
1398 joined_args = abstract::AbstractJoin(joined_args, args_vector[i]);
1399 } catch (const std::exception &e) {
1400 MS_LOG(DEBUG) << "Cannot join, args1: " << ::mindspore::ToString(joined_args)
1401 << ", args2: " << ::mindspore::ToString(args_vector[i]);
1402 return std::make_pair(AbstractBasePtrList(), nullptr);
1403 }
1404 }
1405 MS_LOG(DEBUG) << "Joined args list: " << joined_args.size() << ", " << ::mindspore::ToString(joined_args);
1406
1407 EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
1408 const auto joined_eval_result = origin_eval_cache.get(joined_args);
1409 if (joined_eval_result != nullptr) {
1410 MS_LOG(DEBUG) << "Find unique choice in original eval cache for joined args list: "
1411 << joined_eval_result->abstract()->ToString();
1412 real->SetValue(joined_args, joined_eval_result);
1413 eval_cache_[eval] = real;
1414 return std::make_pair(joined_args, joined_eval_result->abstract());
1415 }
1416 for (const auto &args : args_vector) {
1417 broaded_args_list.clear();
1418 BroadenArgs(args, &broaded_args_list);
1419 (void)choices.insert(broaded_args_list);
1420 MS_LOG(DEBUG) << "Broaded args list: " << broaded_args_list.size() << ", "
1421 << ::mindspore::ToString(broaded_args_list);
1422 }
1423 if (choices.size() == 1) {
1424 ConfigPtrList args_conf_list;
1425 (void)std::transform(broaded_args_list.cbegin(), broaded_args_list.cend(), std ::back_inserter(args_conf_list),
1426 [](const AbstractBasePtr &v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
1427 MS_LOG(DEBUG) << "Cannot find joined args in cache, run with broaded args list: " << broaded_args_list.size()
1428 << ", " << ::mindspore::ToString(broaded_args_list);
1429 res = eval->SingleRun(engine_, args_conf_list, nullptr);
1430 MS_EXCEPTION_IF_NULL(res);
1431 real->SetValue(broaded_args_list, res);
1432 eval_cache_[eval] = real;
1433 return std::make_pair(broaded_args_list, res->abstract());
1434 }
1435 MS_LOG(DEBUG) << "Choices.size: " << choices.size();
1436 return std::make_pair(AbstractBasePtrList(), nullptr);
1437 }
1438
1439 namespace {
IsHighOrderCall(const AnfNodePtr & func)1440 bool IsHighOrderCall(const AnfNodePtr &func) {
1441 return !func->isa<ValueNode>() && func->abstract()->isa<AbstractFunction>() &&
1442 !func->abstract()->isa<AbstractFuncUnion>();
1443 }
1444
1445 // Update inputs' user data from their abstracts to nodes.
UpdateInputsUserData(const CNodePtr & old_cnode,const AnfNodeWeakPtrList & new_weak_inputs)1446 void UpdateInputsUserData(const CNodePtr &old_cnode, const AnfNodeWeakPtrList &new_weak_inputs) {
1447 const auto &old_weak_inputs = old_cnode->weak_inputs();
1448 if (old_weak_inputs.size() != new_weak_inputs.size()) {
1449 MS_LOG(DEBUG) << "Old inputs size is not equal to new inputs size, node: " << old_cnode->DebugString();
1450 return;
1451 }
1452 // Update real type and shape info.
1453 for (size_t i = 0; i < old_cnode->size(); ++i) {
1454 const auto &old_input = old_weak_inputs[i].lock();
1455 MS_EXCEPTION_IF_NULL(old_input);
1456 const auto &old_input_abs = old_input->abstract();
1457 if (old_input_abs == nullptr) {
1458 MS_LOG(INTERNAL_EXCEPTION) << "The pointer 'old_input_abs' is null, old input node: " << old_input->DebugString();
1459 }
1460 auto new_weak_input = new_weak_inputs[i].lock();
1461 if (new_weak_input == nullptr) {
1462 MS_LOG(INTERNAL_EXCEPTION) << "The " << i << "th input is null, " << old_cnode->DebugString();
1463 }
1464 if (fallback::HasRealType(old_input_abs)) {
1465 const auto &real_type = fallback::GetRealType<AbstractBase, Type>(old_input_abs);
1466 fallback::SetRealType<AnfNode, Type>(new_weak_input, real_type);
1467 }
1468 if (fallback::HasRealShape(old_input_abs)) {
1469 const auto &real_type = fallback::GetRealShape<AbstractBase, BaseShape>(old_input_abs);
1470 fallback::SetRealShape<AnfNode, BaseShape>(new_weak_input, real_type);
1471 }
1472 if (fallback::HasObjInExtraInfoHolder(old_input_abs)) {
1473 MS_LOG(DEBUG) << "Inherit python list object from old input abstract.";
1474 auto list_py_obj = fallback::GetObjFromExtraInfoHolder(old_input_abs);
1475 fallback::AttachPyObjToExtraInfoHolder(new_weak_input->abstract(), list_py_obj, false);
1476 }
1477 }
1478 }
1479
BuildRealInputsFromPartialCNode(const AnfNodePtr & func,AnfNodeWeakPtrList * new_inputs_ptr)1480 AnfNodePtr BuildRealInputsFromPartialCNode(const AnfNodePtr &func, AnfNodeWeakPtrList *new_inputs_ptr) {
1481 auto &new_inputs = *new_inputs_ptr;
1482 AnfNodePtr real_func = func;
1483 constexpr int arg_start_index = 2;
1484 while (IsPrimitiveCNode(real_func, prim::kPrimPartial)) {
1485 auto func_cnode = real_func->cast_ptr<CNode>();
1486 MS_EXCEPTION_IF_NULL(func_cnode);
1487 auto &inputs = func_cnode->weak_inputs();
1488 // First element is partial, second is func so arg is start from 2
1489 (void)new_inputs.insert(new_inputs.cbegin(), inputs.cbegin() + arg_start_index, inputs.cend());
1490 real_func = inputs[1].lock();
1491 MS_LOG(DEBUG) << "Real func: " << real_func->ToString() << ", func_cnode: " << func_cnode->DebugString()
1492 << ", new_inputs size: " << new_inputs.size();
1493 }
1494 return real_func;
1495 }
1496
1497 // If it's Partial CNode, repack the inputs.
1498 // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
GetCNodeRealInputs(const CNodePtr & cnode)1499 AnfNodeWeakPtrList GetCNodeRealInputs(const CNodePtr &cnode) {
1500 auto &inputs = cnode->weak_inputs();
1501 if (inputs.empty()) {
1502 MS_LOG(INTERNAL_EXCEPTION) << "Inputs of CNode is empty";
1503 }
1504 AnfNodePtr func = inputs[0].lock();
1505 MS_EXCEPTION_IF_NULL(func);
1506 if (!IsPrimitiveCNode(func, prim::kPrimPartial)) {
1507 return inputs;
1508 }
1509
1510 // First element is func, so start from 1.
1511 AnfNodeWeakPtrList new_inputs(inputs.begin() + 1, inputs.end());
1512 func = BuildRealInputsFromPartialCNode(func, &new_inputs);
1513 (void)new_inputs.insert(new_inputs.cbegin(), func);
1514 cnode->func_graph()->AddOwnNode(func);
1515 return new_inputs;
1516 }
1517 } // namespace
1518
ProcessCNodeEnd(const CNodePtr & cnode,const AnfNodeWeakPtrList & new_weak_inputs)1519 void FuncGraphSpecializer::ProcessCNodeEnd(const CNodePtr &cnode, const AnfNodeWeakPtrList &new_weak_inputs) {
1520 // Update inputs' user data from their abstracts to nodes.
1521 UpdateInputsUserData(cnode, new_weak_inputs);
1522 // Set the updated inputs.
1523 cnode->set_weak_inputs(new_weak_inputs);
1524
1525 // Eliminate the unused elements in the tuple/list.
1526 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1527 static const auto enable_only_mark_unused_element = (common::GetCompileConfig("DDE_ONLY_MARK") == "1");
1528 if (enable_eliminate_unused_element && !enable_only_mark_unused_element) {
1529 EliminateUnusedSequenceItem(cnode);
1530 }
1531 constexpr auto recursive_level = 2;
1532 // Only success processed node can be added to seen.
1533 MS_LOG(DEBUG) << "New CNode: " << cnode->DebugString(recursive_level);
1534 specializer_->AddSeen(cnode);
1535 }
1536
1537 // Process Switch App CNode in advance.
1538 // Including: Switch App CNode, Switch CNode, and Switch inputs CNodes(Partial CNode).
ProcessSwitchAppCNode(const CNodePtr & cnode)1539 bool FuncGraphSpecializer::ProcessSwitchAppCNode(const CNodePtr &cnode) {
1540 auto new_switch_app_inputs = cnode->weak_inputs();
1541 if (new_switch_app_inputs.empty()) {
1542 MS_LOG(INTERNAL_EXCEPTION) << "Inputs of CNode is empty";
1543 }
1544 const AnfNodePtr &func = new_switch_app_inputs[0].lock();
1545 MS_EXCEPTION_IF_NULL(func);
1546 if (!IsPrimitiveCNode(func, prim::kPrimSwitch)) {
1547 return false;
1548 }
1549 const auto &switch_cnode = dyn_cast<CNode>(func);
1550 auto new_switch_inputs = switch_cnode->weak_inputs();
1551 if (new_switch_inputs.empty()) {
1552 MS_LOG(INTERNAL_EXCEPTION) << "Switch CNode input is empty";
1553 }
1554
1555 // Specialize the switch app fg arguments, from index 1(cond).
1556 bool finished = true;
1557 constexpr size_t switch_fg_arg_start_index = 1;
1558 constexpr size_t switch_fg_arg_end_index = 4;
1559 for (size_t i = switch_fg_arg_start_index; i < switch_fg_arg_end_index; ++i) {
1560 auto switch_input_node = new_switch_inputs[i].lock();
1561 MS_EXCEPTION_IF_NULL(switch_input_node);
1562 CNodePtr switch_input_cnode = nullptr;
1563 AnfNodePtr real_switch_input_cnode_func = nullptr;
1564 AnfNodeWeakPtrList real_switch_input_cnode_inputs;
1565 if (IsPrimitiveCNode(switch_input_node, prim::kPrimPartial)) {
1566 switch_input_cnode = dyn_cast<CNode>(switch_input_node);
1567 MS_EXCEPTION_IF_NULL(switch_input_cnode);
1568 real_switch_input_cnode_func =
1569 BuildRealInputsFromPartialCNode(switch_input_cnode, &real_switch_input_cnode_inputs);
1570 } else {
1571 if (!IsValueNode<FuncGraph>(switch_input_node)) {
1572 // The Switch input[i] is not Partial CNode, or FuncGraph node
1573 continue;
1574 }
1575 real_switch_input_cnode_func = switch_input_node;
1576 // Since BuildSpecializedNode() 1st argument CNode is used for debug info, we use switch node for FuncGraph input.
1577 switch_input_cnode = switch_cnode;
1578 }
1579
1580 if (!CanSpecializeValueNode(real_switch_input_cnode_func)) {
1581 continue;
1582 }
1583 constexpr size_t switch_app_arg_start_index = 1;
1584 for (size_t j = switch_app_arg_start_index; j < new_switch_app_inputs.size(); ++j) {
1585 (void)real_switch_input_cnode_inputs.emplace_back(new_switch_app_inputs[j]);
1586 }
1587 AbstractBasePtrList args;
1588 AbstractBasePtr func_abs = real_switch_input_cnode_func->abstract();
1589 // First element is function, so the arguments start from 1.
1590 for (size_t j = 0; j < real_switch_input_cnode_inputs.size(); ++j) {
1591 args.push_back(real_switch_input_cnode_inputs[j].lock()->abstract());
1592 }
1593 auto specialized_func_node = BuildSpecializedNode(switch_input_cnode, real_switch_input_cnode_func, func_abs, args);
1594 if (specialized_func_node == nullptr) {
1595 finished = false;
1596 continue;
1597 }
1598 if (!finished) {
1599 continue;
1600 }
1601 // Rebuild a Partial CNode.
1602 if (!IsDeadNode(specialized_func_node) && IsPrimitiveCNode(switch_input_node, prim::kPrimPartial)) {
1603 // Fill new Partial CNode's inputs list.
1604 AnfNodePtr partial_value_node = NewValueNode(prim::kPrimPartial);
1605 partial_value_node->set_abstract(FromValueInside(prim::kPrimPartial));
1606 partial_value_node->set_debug_info(switch_input_node->debug_info());
1607 MS_EXCEPTION_IF_NULL(switch_input_cnode->func_graph());
1608 switch_input_cnode->func_graph()->AddOwnNode(partial_value_node);
1609 switch_input_cnode->func_graph()->AddOwnNode(specialized_func_node);
1610 AnfNodeWeakPtrList partial_node_list = {partial_value_node, specialized_func_node};
1611 // Specialize Partial CNode func graph inputs.
1612 constexpr auto partial_arg_start_index = 2;
1613 (void)std::copy(switch_input_cnode->weak_inputs().cbegin() + partial_arg_start_index,
1614 switch_input_cnode->weak_inputs().cend(), std::back_inserter(partial_node_list));
1615 for (size_t j = partial_arg_start_index; j < partial_node_list.size(); ++j) {
1616 auto old_node = partial_node_list[j].lock();
1617 MS_EXCEPTION_IF_NULL(old_node);
1618 if (CanSpecializeValueNode(old_node)) {
1619 auto new_partial_input_node =
1620 BuildSpecializedNode(switch_input_cnode, old_node, old_node->abstract(), std::vector<AbstractBasePtr>{});
1621 if (new_partial_input_node == nullptr) {
1622 return false;
1623 }
1624 partial_node_list[j] = new_partial_input_node;
1625 switch_input_cnode->func_graph()->AddOwnNode(new_partial_input_node);
1626 }
1627 }
1628
1629 // Finish the Partial CNode specialize.
1630 MS_EXCEPTION_IF_NULL(switch_input_cnode);
1631 ProcessCNodeEnd(switch_input_cnode, partial_node_list);
1632 new_switch_inputs[i] = switch_input_cnode;
1633 } else {
1634 new_switch_inputs[i] = specialized_func_node;
1635 }
1636 }
1637
1638 // Wait for sub func graph specialize finish.
1639 if (!finished) {
1640 return false;
1641 }
1642
1643 ProcessCNodeEnd(switch_cnode, new_switch_inputs);
1644
1645 new_switch_app_inputs[0] = switch_cnode;
1646 ProcessCNodeEnd(cnode, new_switch_app_inputs);
1647
1648 return true;
1649 }
1650
ProcessCNode(const CNodePtr & cnode)1651 bool FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
1652 MS_EXCEPTION_IF_NULL(cnode);
1653 if (specializer_->seen().count(cnode) > 0) {
1654 return true;
1655 }
1656 constexpr auto recursive_level = 2;
1657 MS_LOG(DEBUG) << "Handle CNode: " << cnode->DebugString(recursive_level);
1658 auto new_inputs = GetCNodeRealInputs(cnode);
1659 const AnfNodePtr &func = new_inputs[0].lock();
1660
1661 // Deal with Switch App CNode.
1662 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
1663 if (enable_pre_lift && IsPrimitiveCNode(func, prim::kPrimSwitch)) {
1664 return ProcessSwitchAppCNode(cnode);
1665 }
1666
1667 // Deal with the CNode|Parameter function call including Partial closure ahead.
1668 if (IsHighOrderCall(func)) {
1669 MS_EXCEPTION_IF_NULL(func->abstract());
1670 auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
1671 EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
1672 std::pair<AbstractBasePtrList, AbstractBasePtr> result;
1673 AbstractBasePtrList empty_args;
1674 auto status = AcquireUniqueEvalResult(func_abs, eval, empty_args, &result);
1675 MS_EXCEPTION_IF_NULL(func->func_graph());
1676 MS_LOG(DEBUG) << "POLY: " << (status == kSpecializePoly) << ", func: " << func->ToString()
1677 << ", abstract: " << func_abs->ToString() << ", "
1678 << func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER);
1679 // If a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early.
1680 if (status == kSpecializePoly ||
1681 (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
1682 auto wrapped_node = BuildSpecializedParameterCNode(cnode);
1683 if (wrapped_node == nullptr) {
1684 return false;
1685 }
1686 MS_LOG(DEBUG) << "Partial closure or parameter call is handled, wrapped_node: "
1687 << wrapped_node->DebugString(recursive_level);
1688 new_inputs[0] = wrapped_node;
1689 cnode->func_graph()->AddOwnNode(wrapped_node);
1690 }
1691 }
1692
1693 // Specialize the function, aka inputs[0], if input0 is a ValueNode<FuncGraph> or ValueNode<Primitive>,
1694 // CanSpecializeValueNode return true, otherwise false.
1695 if (CanSpecializeValueNode(func)) {
1696 // For primitive node, we build the primitive node with inferred attributes in the first pass,
1697 // so we do not build replaced node again here in second pass.
1698 if (IsValueNode<Primitive>(func)) {
1699 new_inputs[0] = func;
1700 cnode->func_graph()->AddOwnNode(func);
1701 } else {
1702 AbstractBasePtrList args;
1703 AbstractBasePtr func_abs = new_inputs[0].lock()->abstract();
1704 // First element is function, so the arguments start from 1.
1705 for (size_t i = 1; i < new_inputs.size(); ++i) {
1706 args.push_back(new_inputs[i].lock()->abstract());
1707 }
1708 auto specialized_func_node = BuildSpecializedNode(cnode, func, func_abs, args);
1709 if (specialized_func_node == nullptr) {
1710 return false;
1711 }
1712
1713 new_inputs[0] = specialized_func_node;
1714 cnode->func_graph()->AddOwnNode(specialized_func_node);
1715 MS_LOG(DEBUG) << "Specalize func: " << func->type_name() << "/" << func->DebugString(recursive_level)
1716 << ", new_func: " << new_inputs[0].lock()->DebugString(recursive_level) << ", args: " << args;
1717 }
1718 }
1719
1720 // Specialize the arguments, except inputs[0].
1721 for (size_t i = 1; i < new_inputs.size(); ++i) {
1722 auto old_node = new_inputs[i].lock();
1723 if (CanSpecializeValueNode(old_node)) {
1724 auto new_node = BuildSpecializedNode(cnode, old_node, old_node->abstract(), std::vector<AbstractBasePtr>{});
1725 if (new_node == nullptr) {
1726 return false;
1727 }
1728
1729 MS_LOG(DEBUG) << "Specalize arg[" << i << "]: " << old_node->DebugString(recursive_level)
1730 << ", new_node: " << new_node->DebugString(recursive_level);
1731 new_inputs[i] = new_node;
1732 cnode->func_graph()->AddOwnNode(new_node);
1733 }
1734 }
1735 ProcessCNodeEnd(cnode, new_inputs);
1736 return true;
1737 }
1738
ParentNotSpecialized(const AnalysisContextPtr & context) const1739 bool FuncGraphSpecializer::ParentNotSpecialized(const AnalysisContextPtr &context) const {
1740 auto parent_context = context->parent();
1741 auto parent_specializer = specializer_->GetFuncGraphSpecializer(parent_context);
1742 // If can't get specializer of parent and parent is not DummyContext, it means parent not specialized.
1743 auto parent_not_specialized = parent_specializer == nullptr && parent_context->func_graph() != nullptr;
1744 return parent_not_specialized;
1745 }
1746
1747 namespace {
DumpEvaluatorCache(const EvaluatorPtr & eval,const AbstractBasePtrList & args_abs_list)1748 void DumpEvaluatorCache(const EvaluatorPtr &eval, const AbstractBasePtrList &args_abs_list) {
1749 MS_EXCEPTION_IF_NULL(eval);
1750 const EvaluatorCacheMgrPtr &evaluator_cache_mgr = eval->evaluator_cache_mgr();
1751 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
1752 MS_LOG(DEBUG) << "Find unique args_abs_list failed, total " << args_abs_list.size() << ". Check cache all items.";
1753 MS_LOG(DEBUG) << "[" << eval << "/" << eval->ToString()
1754 << "] Dump current key, args_abs_list hash: " << AbstractBasePtrListHash(args_abs_list)
1755 << ", args_abs_list: " << args_abs_list;
1756
1757 int64_t i = 0;
1758 const EvalResultCache &map = evaluator_cache_mgr->GetCache();
1759 for (const auto &item : map) {
1760 MS_LOG(DEBUG) << "\tevaluator_cache[" << i++ << "]: {args_abs_list hash: " << AbstractBasePtrListHash(item.first)
1761 << ", args_abs_list: " << item.first << "}";
1762 }
1763 }
1764
IsPolyFunc(const AbstractFunctionPtr & func,const AbstractBasePtrList & args_abs_list)1765 bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_abs_list) {
1766 MS_EXCEPTION_IF_NULL(func);
1767 if (func->isa<PrimitiveAbstractClosure>() && args_abs_list.empty()) {
1768 MS_LOG(DEBUG) << "High order primitive return POLY.";
1769 return true;
1770 }
1771 if (func->isa<MetaFuncGraphAbstractClosure>() && args_abs_list.empty()) {
1772 auto meta_func_graph_wrapper = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
1773 auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
1774 if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
1775 auto do_signature = dyn_cast_ptr<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
1776 if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
1777 MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
1778 return true;
1779 }
1780 }
1781 }
1782 return false;
1783 }
1784 } // namespace
1785
AcquireUniqueEvalResult(const AbstractFunctionPtr & func,const EvaluatorPtr & eval,const AbstractBasePtrList & args_abs_list,std::pair<AbstractBasePtrList,AbstractBasePtr> * res)1786 SpecializeStatusCode FuncGraphSpecializer::AcquireUniqueEvalResult(
1787 const AbstractFunctionPtr &func, const EvaluatorPtr &eval, const AbstractBasePtrList &args_abs_list,
1788 std::pair<AbstractBasePtrList, AbstractBasePtr> *res) {
1789 MS_EXCEPTION_IF_NULL(func);
1790 MS_EXCEPTION_IF_NULL(eval);
1791 MS_EXCEPTION_IF_NULL(res);
1792
1793 EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
1794 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
1795 auto data = evaluator_cache_mgr->GetValue(args_abs_list);
1796 if (data != nullptr) {
1797 *res = std::make_pair(args_abs_list, data->abstract());
1798 return kSpecializeSuccess;
1799 }
1800 DumpEvaluatorCache(eval, args_abs_list);
1801
1802 auto cache = GetEvalCache(eval);
1803 MS_EXCEPTION_IF_NULL(cache);
1804 const EvalResultCache &choices = cache->GetCache();
1805 auto eval_result = choices.get(args_abs_list);
1806 if (eval_result != nullptr) {
1807 *res = std::make_pair(args_abs_list, eval_result->abstract());
1808 return kSpecializeSuccess;
1809 } else if (choices.size() == 1) {
1810 MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
1811 MS_EXCEPTION_IF_NULL(choices.begin()->second);
1812 *res = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
1813 return kSpecializeSuccess;
1814 } else if (choices.empty()) {
1815 MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
1816 << func->type_name() << ", evaluator: " << eval->ToString() << ", ptr: " << eval.get();
1817 return kSpecializeDead;
1818 } else {
1819 if (IsPolyFunc(func, args_abs_list)) {
1820 return kSpecializePoly;
1821 }
1822 *res = BuildFromBroadedArgs(eval);
1823 if (!res->first.empty()) {
1824 MS_LOG(DEBUG) << "Build for generalized args_abs_list successfully.";
1825 // Synchronize the new evaluated abstract with the abstract from common evaluating routine.
1826 MS_EXCEPTION_IF_NULL(res->second);
1827 auto new_sequence_abs = dyn_cast<abstract::AbstractSequence>(res->second);
1828 for (auto &choice : choices) {
1829 MS_EXCEPTION_IF_NULL(choice.second);
1830 MS_EXCEPTION_IF_NULL(choice.second->abstract());
1831 auto abs = choice.second->abstract()->cast<AbstractSequencePtr>();
1832 if (abs != nullptr) {
1833 SynchronizeSequenceElementsUseFlagsRecursively(abs, new_sequence_abs);
1834 }
1835 }
1836 return kSpecializeSuccess;
1837 }
1838 MS_LOG(DEBUG) << "Found POLY node, it may be unused code or unresolved polymorphism, "
1839 << "func: " << func->ToString() << ", choices.size: " << choices.size()
1840 << ", args_abs_list.size: " << args_abs_list.size();
1841 return kSpecializePoly;
1842 }
1843 }
1844
BuildPrimtiveValueWithAttributes(const PrimitivePtr & prim,const AttrValueMapPtr & attrs)1845 static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
1846 MS_EXCEPTION_IF_NULL(prim);
1847 auto &prim_attrs = prim->attrs();
1848 bool is_attr_same = true;
1849 for (auto &item : *attrs) {
1850 auto itr = prim_attrs.find(item.first);
1851 if (itr != prim_attrs.end()) {
1852 MS_EXCEPTION_IF_NULL(itr->second);
1853 MS_EXCEPTION_IF_NULL(item.second);
1854 if (!(*(itr->second) == *(item.second))) {
1855 is_attr_same = false;
1856 break;
1857 }
1858 } else {
1859 is_attr_same = false;
1860 break;
1861 }
1862 }
1863 if (!is_attr_same) {
1864 auto cloned_prim = prim->Clone();
1865 for (auto &item : *attrs) {
1866 cloned_prim->AddAttr(item.first, item.second);
1867 }
1868 return cloned_prim;
1869 }
1870 return prim;
1871 }
1872
GetValueForAbstractFunction(const AbstractFunctionPtr & abs,const AttrValueMapPtr & attrs)1873 ValuePtr GetValueForAbstractFunction(const AbstractFunctionPtr &abs, const AttrValueMapPtr &attrs) {
1874 ValuePtr value = nullptr;
1875 if (abs->isa<PrimitiveAbstractClosure>()) {
1876 auto real_fn = dyn_cast_ptr<PrimitiveAbstractClosure>(abs);
1877 MS_EXCEPTION_IF_NULL(real_fn);
1878 // For primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
1879 if (attrs != nullptr) {
1880 value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
1881 } else {
1882 value = real_fn->prim();
1883 }
1884 } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
1885 auto real_fn = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(abs);
1886 value = real_fn->meta_func_graph();
1887 } else if (abs->isa<FuncGraphAbstractClosure>()) {
1888 auto real_fn = dyn_cast_ptr<FuncGraphAbstractClosure>(abs);
1889 value = real_fn->func_graph();
1890 } else {
1891 return nullptr;
1892 }
1893 return value;
1894 }
1895
BuildValueNodeForAbstractFunction(const AnfNodePtr & origin_node,const AbstractBasePtr & ival,const AttrValueMapPtr & attrs,const AnfNodePtr & cnode,const AbstractFunctionPtr & abs)1896 AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNodePtr &origin_node,
1897 const AbstractBasePtr &ival,
1898 const AttrValueMapPtr &attrs,
1899 const AnfNodePtr &cnode,
1900 const AbstractFunctionPtr &abs) {
1901 ValuePtr value = GetValueForAbstractFunction(abs, attrs);
1902 if (value == nullptr) {
1903 return nullptr;
1904 }
1905 if (value->isa<FuncGraph>() && value->cast_ptr<FuncGraph>()->has_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH)) {
1906 return nullptr;
1907 }
1908 if (!value->isa<FuncGraph>() || value->cast_ptr<FuncGraph>()->parent() == nullptr ||
1909 (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast_ptr<FuncGraph>()->parent()))) {
1910 if (IS_OUTPUT_ON(MsLogLevel::kDebug)) {
1911 if (cnode != nullptr) {
1912 MS_LOG(DEBUG) << "Specialize non-value to func graph, value: " << value->ToString()
1913 << ", cnode: " << cnode->DebugString() << ", origin_node: " << origin_node->DebugString()
1914 << ", func_graph_: " << func_graph_->ToString();
1915 }
1916 if (value->isa<FuncGraph>() && value->cast_ptr<FuncGraph>()->parent() != nullptr) {
1917 MS_LOG(DEBUG) << "Specialize func graph, " << value->ToString()
1918 << " has_parent, is_visible: " << IsVisible(func_graph_, value->cast_ptr<FuncGraph>()->parent());
1919 }
1920 }
1921 return BuildValueNode(value, origin_node, ival);
1922 } else if (cnode != nullptr && IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa<Parameter>() &&
1923 !value->cast_ptr<FuncGraph>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
1924 // Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph.
1925 MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString();
1926 return BuildValueNode(value, origin_node, ival);
1927 }
1928 return nullptr;
1929 }
1930
BuildPossibleValueNode(const AnfNodePtr & origin_node,const AbstractBasePtr & ival,const AttrValueMapPtr & attrs,const AnfNodePtr & cnode)1931 AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
1932 const AttrValueMapPtr &attrs, const AnfNodePtr &cnode) {
1933 MS_EXCEPTION_IF_NULL(origin_node);
1934 MS_EXCEPTION_IF_NULL(ival);
1935
1936 AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
1937 if (abs != nullptr) {
1938 // Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction.
1939 if (abs->isa<AbstractFuncUnion>()) {
1940 return nullptr;
1941 }
1942 return BuildValueNodeForAbstractFunction(origin_node, ival, attrs, cnode, abs);
1943 } else {
1944 ValuePtr val = ival->BuildValue();
1945 if (val->ContainsValueAny()) {
1946 return nullptr;
1947 }
1948 // If node is an AutoMonad node, don't convert the node to value node `U` or `IO` to avoid side-effect op miss.
1949 if (val->isa<Monad>()) {
1950 return nullptr;
1951 }
1952 // Keep primitive 'depend' not to be optimized
1953 if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) {
1954 return nullptr;
1955 }
1956 // Keep primitive 'ListInplaceClear' not to be optimized
1957 if (IsPrimitiveCNode(origin_node, prim::kPrimListInplaceClear)) {
1958 return nullptr;
1959 }
1960 // Keep primitive 'PyExecute' not to be optimized
1961 if (IsPrimitiveCNode(origin_node, prim::kPrimPyExecute)) {
1962 return nullptr;
1963 }
1964 return BuildValueNode(val, origin_node, ival);
1965 }
1966 }
1967
GetAnalysisContext(const AnalysisEnginePtr & engine,const BaseFuncGraphEvaluatorPtr & evaluator,const AbstractBasePtrList & args_abs_list) const1968 inline AnalysisContextPtr FuncGraphSpecializer::GetAnalysisContext(const AnalysisEnginePtr &engine,
1969 const BaseFuncGraphEvaluatorPtr &evaluator,
1970 const AbstractBasePtrList &args_abs_list) const {
1971 MS_EXCEPTION_IF_NULL(evaluator);
1972 // If it is common calling header, try to use the context generated by the infer process of body calling header, so
1973 // need broaden the args to keep context of common calling header same with context of body calling header.
1974 AbstractBasePtrList normalized_args_abs_list = evaluator->NormalizeArgs(args_abs_list);
1975 FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_abs_list);
1976 auto parent_context = evaluator->parent_context();
1977 MS_EXCEPTION_IF_NULL(parent_context);
1978 auto cached_context = parent_context->GetCachedContext(fg, normalized_args_abs_list);
1979 if (cached_context != nullptr) {
1980 return cached_context;
1981 }
1982 // If can't get context by broadened args, try to get context by not broadened args.
1983 cached_context = parent_context->GetCachedContext(fg, args_abs_list);
1984 if (cached_context != nullptr) {
1985 return cached_context;
1986 }
1987 // if it is a bprop meta func graph, need to make a new context and do static analysis in ProcessNode.
1988 return NewContext(parent_context, fg, normalized_args_abs_list);
1989 }
1990 } // namespace abstract
1991 } // namespace mindspore
1992