1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/graph_analyzer.h"
17 #include <algorithm>
18 #include <unordered_set>
19 #include <utility>
20 #include <string>
21 #include <vector>
22 #include "pipeline/jit/pi/pi_jit_config.h"
23 #include "pipeline/jit/pi/graph_guard/infer.h"
24 #include "pipeline/jit/pi/graph_capture/graph.h"
25 #include "pipeline/jit/pi/graph_capture/special_func_infer.h"
26 #include "pipeline/jit/pi/graph_capture/graph_build.h"
27 #include "pipeline/jit/pi/graph_capture/side_effect.h"
28
29 namespace mindspore {
30 namespace pijit {
31
32 extern TracePtr GetTrace(ValueNode *node, bool strict, bool print, int depth, int max_depth);
33
34 const int kMsFlagSet = AObject::kMsFlagGradFunc | AObject::kMsFlagStandardFunc | AObject::kMsFlagShardFunc |
35 AObject::kMsFlagVmapFunc | AObject::kMsFlagJitFunc;
36 static bool IsRepeatWithoutSideEffect(ValueNode *v, bool repeat_attr_item_access);
37
CheckBuildTupleRepeatable(ValueNode * value,bool repeat_attr_item_access)38 static bool CheckBuildTupleRepeatable(ValueNode *value, bool repeat_attr_item_access) {
39 for (auto i : value->getInputs()) {
40 if (i->GetOpcode() == BUILD_TUPLE || !IsRepeatWithoutSideEffect(i, repeat_attr_item_access)) {
41 return false;
42 }
43 }
44 return true;
45 }
46
CheckBuildSliceRepeatable(const std::vector<ValueNode * > & inputs,bool repeat_attr_item_access)47 static bool CheckBuildSliceRepeatable(const std::vector<ValueNode *> &inputs, bool repeat_attr_item_access) {
48 for (auto i : inputs) {
49 if (i->GetOpcode() != LOAD_CONST) {
50 return false;
51 }
52 }
53 return true;
54 }
55
56 // These are operations that are repeated and have no side effects.
IsRepeatWithoutSideEffect(ValueNode * v,bool repeat_attr_item_access)57 static bool IsRepeatWithoutSideEffect(ValueNode *v, bool repeat_attr_item_access) {
58 if (IsNonLocalValue(v)) {
59 return true;
60 }
61
62 AObject::Type type = v->GetVobj() ? v->GetVobj()->GetType() : AObject::kTypeAnyValue;
63 auto opcode = v->GetOpcode();
64 if (opcode == BUILD_TUPLE) {
65 return CheckBuildTupleRepeatable(v, repeat_attr_item_access);
66 } else if (opcode == BUILD_SLICE) {
67 // NOTE: mindspore can't resolve call 'slice' class
68 return CheckBuildSliceRepeatable(v->getInputs(), repeat_attr_item_access);
69 } else if (opcode == BINARY_SUBSCR || opcode == LOAD_ATTR) {
70 return type == AObject::kTypeAnyValue ? false : repeat_attr_item_access;
71 } else if (opcode == BUILD_MAP) {
72 if (type == AObject::kTypeDict) {
73 AbstractDict *d = static_cast<AbstractDict *>(v->GetVobj());
74 return d->size() == 0 || d->KeyType() != AObject::kTypeAnyValue;
75 }
76 return false;
77 }
78 return false;
79 }
80
81 namespace {
82 /**
83 * mindspore func_graph assume these unsupported value is constant, so it same as global.
84 * avoid parameter unsupported error by global
85 */
ValidateGraphParameters(ValueNode * node)86 bool ValidateGraphParameters(ValueNode *node) {
87 static const std::set<AObject::Type> unsupported_parameter = {
88 AObject::kTypeAnyValue, AObject::kTypeFunction, AObject::kTypeBoundMethod,
89 AObject::kTypePrimitive, AObject::kTypeMetaFuncGraph, AObject::kTypeCell,
90 };
91 AObject *info = node->GetVobj();
92 if (info == nullptr) {
93 return false;
94 }
95 return unsupported_parameter.find(info->GetType()) == unsupported_parameter.end();
96 }
97 } // namespace
98
ProduceInterpretValue(ValueNode * v)99 bool GraphAnalyzer::ProduceInterpretValue(ValueNode *v) {
100 bool repeat_op = graph_->Config().GetBoolConfig(GraphJitConfig::kEnableOptimizeForAttrItem);
101 auto &locals = GetCaptureInfo().interpret_.values;
102 auto &values = GetCaptureInfo().captured_.values;
103 for (auto i : v->getInputs()) {
104 if (IsNonLocalValue(i) || locals.find(i) != locals.end()) {
105 continue;
106 }
107 if (values.find(i) == values.end()) {
108 MS_LOG(INTERNAL_EXCEPTION) << "capture info can't find the value [" << i->ToString() << "]";
109 }
110 if (!IsRepeatWithoutSideEffect(i, repeat_op)) {
111 return false;
112 }
113 // duplicate some operations if possible
114 if (ProduceInterpretValue(i)) {
115 continue;
116 }
117 return false;
118 }
119 AddToEscaped(v);
120 return true;
121 }
122
123 // if operation can't be repeated, or block has attr access side effect
124 // can't reorder attr access op, must be interpret all attr, item access operation
CheckAttrItemSupport(ValueNode * v,bool repeat_op)125 static bool CheckAttrItemSupport(ValueNode *v, bool repeat_op) {
126 int op = v->GetOpcode();
127 AObject::Type type = v->input(0)->GetVobj() ? v->input(0)->GetVobj()->GetType() : AObject::kTypeAnyValue;
128 // item access
129 if (op == BINARY_SUBSCR) {
130 return type != AObject::kTypeAnyValue;
131 }
132 // attr access
133 if (type == AObject::kTypeAnyValue || type == AObject::kTypeBoundMethod) {
134 return false;
135 }
136 if (type == AObject::kTypeTensor && !FindTensorName(v->GetName())) {
137 return false;
138 }
139 return true;
140 }
141
IsSideEffect(ValueNode * v)142 static bool IsSideEffect(ValueNode *v) {
143 static const std::set<std::string> funcs = {"assign", "Assign"};
144 static const std::set<int> unsupported_op = {
145 STORE_DEREF, DELETE_DEREF, STORE_GLOBAL, DELETE_GLOBAL, STORE_ATTR, DELETE_ATTR,
146 STORE_SUBSCR, DELETE_SUBSCR, IMPORT_STAR, RAISE_VARARGS, RERAISE, FORMAT_VALUE,
147 };
148 Opcode opcode(v->GetOpcode());
149 if (opcode.MayDelete()) {
150 return false;
151 }
152 if (opcode.IsCall()) {
153 AObject *f = v->input(0)->GetVobj();
154 if (f == nullptr) {
155 return true;
156 }
157 if (f->TestMsFlag(AObject::kMsFlagGradFunc)) {
158 return false;
159 }
160 return funcs.find(GetFuncName(f->GetPyObject())) != funcs.end();
161 }
162 return unsupported_op.find(v->GetOpcode()) != unsupported_op.end();
163 }
164
HandleCallableToGraph(AObject * f)165 bool GraphAnalyzer::HandleCallableToGraph(AObject *f) {
166 static bool known_type[AObject::kTypeCount] = {false};
167 if (known_type[AObject::kTypePrimitive] == false) {
168 known_type[AObject::kTypePrimitive] = true;
169 known_type[AObject::kTypeCell] = true;
170 known_type[AObject::kTypeMetaFuncGraph] = true;
171 known_type[AObject::kTypePrimitiveFunction] = true;
172 }
173 if (f == nullptr) {
174 return false;
175 }
176 // don't pass unknown callable to graph
177 bool is_known_func = known_type[f->GetType()] || CheckJitConstexpr(f->GetPyObject());
178 bool is_ms_support_func = f->TestMsFlag(kMsFlagSet);
179 if (!is_known_func && !is_ms_support_func) {
180 return false;
181 }
182 if (f->GetType() == AObject::kTypePrimitive) {
183 PyTypeObject *tp = f->GetTypeObject();
184 std::string name = (tp && tp->tp_name ? tp->tp_name : "");
185 if (name == "Assign") {
186 return false;
187 }
188 }
189 return true;
190 }
191
AddToCaptured(ValueNode * v)192 bool GraphAnalyzer::AddToCaptured(ValueNode *v) {
193 if (IsNonLocalValue(v)) {
194 return true;
195 }
196 if (v->GetVobj() && v->GetVobj()->TestMsFlag(AObject::kMsFlagGradFunc)) {
197 GetCaptureInfo().has_grad_ = true;
198 GetCaptureInfo().captured_.values.insert(v);
199 GetCaptureInfo().captured_.operations.push_back(v);
200 return true;
201 }
202
203 int op = v->GetOpcode();
204 bool repeat_op = graph_->Config().GetBoolConfig(GraphJitConfig::kEnableOptimizeForAttrItem);
205 if ((op == LOAD_ATTR || op == BINARY_SUBSCR) && !CheckAttrItemSupport(v, repeat_op)) {
206 return false;
207 }
208
209 bool is_call_op = Opcode(v->GetOpcode()).IsCall();
210 if (is_call_op) {
211 AObject *f = v->input(0)->GetVobj();
212 bool can_pass = HandleCallableToGraph(f);
213 if (!can_pass) {
214 return false;
215 }
216 GetCaptureInfo().has_grad_ |= f->TestMsFlag(AObject::kMsFlagGradFunc);
217 }
218
219 auto &locals = GetCaptureInfo().interpret_.values; // interpret values
220 auto &values = GetCaptureInfo().captured_.values; // graph produced values
221 for (auto i : v->getInputs()) {
222 bool produced_in_graph = values.find(i) != values.end() || IsNonLocalValue(i);
223 MS_EXCEPTION_IF_CHECK_FAIL(produced_in_graph || locals.find(i) != locals.end(),
224 "check values order, all input must be generate before this value " + i->ToString());
225 if (i->GetVobj() == nullptr) {
226 return false;
227 }
228 AObject::Type type = i->GetVobj()->GetType();
229 PyTypeObject *tp = i->GetVobj()->GetTypeObject();
230 if (type == AObject::kTypeAnyValue && !IsMsClass(reinterpret_cast<PyObject *>(tp))) {
231 // don't pass unknown object to graph
232 return false;
233 }
234 if (type == AObject::kTypeCell && !is_call_op) {
235 // don't pass a cell object that not call to graph.
236 return false;
237 }
238 }
239
240 GetCaptureInfo().captured_.values.insert(v);
241 GetCaptureInfo().captured_.operations.push_back(v);
242 return true;
243 }
244
AddToEscaped(ValueNode * v)245 void GraphAnalyzer::AddToEscaped(ValueNode *v) {
246 MS_EXCEPTION_IF_CHECK_FAIL(GetCaptureInfo().interpret_.values.find(v) == GetCaptureInfo().interpret_.values.end(),
247 "duplicate escaped values");
248 GetCaptureInfo().interpret_.values.insert(v);
249 GetCaptureInfo().interpret_.operations.push_back(v);
250 }
251
252 extern TracePtr GetTrace(ValueNode *node, bool strict, bool print, int depth, int max_depth);
253
TryToCapture(AbstractNode * n)254 bool GraphAnalyzer::TryToCapture(AbstractNode *n) {
255 ValueNode *v = static_cast<ValueNode *>(n);
256 if (IsNonLocalValue(v)) {
257 return true;
258 }
259 if (graph_->GetSideEffect()->IsRecord(v)) {
260 return true;
261 }
262 bool is_side_effect = IsSideEffect(v);
263 if (!is_side_effect && AddToCaptured(v)) {
264 return true;
265 }
266 if (!GetCaptureInfo().captured_.values.empty() && is_side_effect) {
267 return false;
268 }
269 if (ProduceInterpretValue(v)) {
270 return true;
271 }
272 if (!HasTensorOperation()) {
273 CleanCapturedValue();
274 AddToEscaped(v);
275 return true;
276 }
277
278 if (v->GetGraph() != nullptr && this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
279 GRAPH_JIT_LOG_F("capture failed, operations is unsupported [%s] at [%U: %d]", v->ToString().c_str(),
280 v->GetGraph()->GetCodeObj()->co_filename, v->GetLineNo());
281 GRAPH_JIT_LOG_F("parameters");
282 for (auto &i : v->getInputs()) {
283 PyObject *op = i->GetVobj() ? i->GetVobj()->GetPyObject().ptr() : nullptr;
284 GRAPH_JIT_LOG_F("%s", op ? AObject::ToString(op).c_str() : "NULL");
285 }
286 }
287
288 MS_LOG(DEBUG) << "---operation that depend on the graph outputs, break graph---";
289 return false;
290 }
291
AnalyzeCall(CallNode * call_node)292 bool GraphAnalyzer::AnalyzeCall(CallNode *call_node) {
293 if (call_node->GetSubGraph() == nullptr) {
294 return false;
295 }
296 if (call_node->GetInlineReason() != InlineReason::kInline) {
297 return false;
298 }
299
300 Graph *g = call_node->GetGraph();
301
302 CapturedInfo back_up = info_;
303 const auto &p = call_node->GetParams();
304 // capture parameter handle operations
305 auto iter = std::find_if(p.begin(), p.end(), [this](ValueNode *i) { return !this->TryToCapture(i); });
306 // capture sub-graph
307 if (iter == p.end() && AnalyzeRecursive(call_node->GetSubGraph())) {
308 return true;
309 }
310 info_ = back_up;
311 g->StopTraceAt(call_node->bci(), StopTraceReason::kStopTraceDataDependsOnGraphOut);
312 return false;
313 }
314
AnalyzeRecursive(Graph * g)315 bool GraphAnalyzer::AnalyzeRecursive(Graph *g) {
316 for (auto n : g->GetTracedNodes()) {
317 int bci = static_cast<ValueNode *>(n)->bci();
318 if (n->GetType() == AbstractNode::Call && AnalyzeCall(static_cast<CallNode *>(n))) {
319 continue;
320 }
321 if (bci != -1 && g->GetStopTraceBci() == bci) {
322 return false;
323 }
324 if (!TryToCapture(n)) {
325 g->StopTraceAt(bci, StopTraceReason::kStopTraceDataDependsOnGraphOut);
326 return false;
327 }
328 }
329 return true;
330 }
331
CollectCapturedInputs()332 void GraphAnalyzer::CollectCapturedInputs() {
333 auto &locals = GetCaptureInfo().interpret_.values;
334 auto &values = GetCaptureInfo().captured_.values;
335 mindspore::CompactSet<ValueNode *> inputs;
336 for (ValueNode *i : GetCaptureInfo().captured_.operations) {
337 for (auto input : i->getInputs()) {
338 if (values.find(input) != values.end() || IsNonLocalValue(input)) {
339 continue;
340 }
341 MS_EXCEPTION_IF_CHECK_FAIL(locals.find(input) != locals.end(), "check graph input");
342 inputs.insert(input);
343 }
344 }
345 GetCaptureInfo().captured_.inputs = {inputs.begin(), inputs.end()};
346 }
347
UseDefAnalyze()348 void GraphAnalyzer::UseDefAnalyze() {
349 // UD analyze: alive nodes analysis
350 std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
351 if (!aliveLocals.empty()) {
352 bool isStopAnalyze = false;
353 while (!isStopAnalyze) {
354 isStopAnalyze = AnalyzeAliveLocals(aliveLocals);
355 if (isStopAnalyze) {
356 break;
357 }
358 aliveLocals = GetAliveLocals(graph_);
359 }
360 }
361 }
362
OptimizeSideEffectRecord() const363 void GraphAnalyzer::OptimizeSideEffectRecord() const {
364 if (graph_->GetSideEffect()->IsEmpty()) {
365 return;
366 }
367 auto alive = graph_->CollectAliveNode(graph_->GetStopTraceBci());
368 auto side_effect_required_size = graph_->GetSideEffect()->GetRequiredNodes().size();
369 auto size = alive.size() - side_effect_required_size;
370 graph_->GetSideEffect()->Optimize({alive.begin(), alive.begin() + size});
371 }
372
ResetSideEffectRecord() const373 void GraphAnalyzer::ResetSideEffectRecord() const {
374 // if break point is changed, rollback graph nodes(only reset break bci) and side-effect record
375 int break_bci = graph_->GetStopTraceBci();
376 if (break_bci == -1 || graph_->GetSideEffect()->IsEmpty()) {
377 return;
378 }
379 const auto &nodes = graph_->GetTracedNodes();
380 auto iter = std::find_if(nodes.begin(), nodes.end(), [&break_bci](ValueNode *i) { return i->bci() > break_bci; });
381 graph_->GetSideEffect()->ResetRecord({nodes.begin(), iter});
382 OptimizeSideEffectRecord(); // after reset record, rollback side-effect record status
383 }
384
Analyze()385 void GraphAnalyzer::Analyze() {
386 OptimizeSideEffectRecord(); // first optimize, remove dead local side-effects and it's required nodes
387
388 const FrameStates &enter_frame = graph_->GetFrame(0);
389 GetCaptureInfo().interpret_.values.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
390 AnalyzeRecursive(graph_);
391 if (!HasTensorOperation()) {
392 CleanCapturedValue();
393 }
394 UseDefAnalyze();
395 ResetSideEffectRecord(); // if rollback nodes, rollback side-effects
396
397 CollectCapturedAndInterpret();
398 CollectGraphInputs();
399
400 need_interpret_ = true;
401
402 if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().interpret_.operations.empty()) {
403 return;
404 }
405 bool support_ret = graph_->GetRetVal()->GetVobj() && graph_->GetRetVal()->GetVobj()->IsMindSporeSupportedType();
406 if (!support_ret) {
407 return;
408 }
409 PyCodeObject *co = graph_->GetCodeObj();
410 const auto &args = enter_frame.GetLocals();
411 int argc = co->co_argcount + co->co_kwonlyargcount;
412 // check all parameters is graph supported, but here not check variable arguments
413 auto end = args.begin() + argc;
414 auto iter = std::find_if(args.begin(), end, [](ValueNode *i) { return !ValidateGraphParameters(i); });
415 if (iter == end) {
416 need_interpret_ = false;
417 }
418 need_interpret_ |= !graph_->GetSideEffect()->IsEmpty();
419 }
420
buildLastFrame(Graph * g)421 FrameStates buildLastFrame(Graph *g) { return g->GetFrame(g->GetStopTraceBci()); }
422
GetAliveLocals(Graph * g)423 std::vector<ValueNode *> GraphAnalyzer::GetAliveLocals(Graph *g) {
424 int bci = g->GetStopTraceBci();
425 if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
426 GRAPH_JIT_LOG_F("UD analyze: enter GetAliveLocals bci %d", bci);
427 }
428 std::vector<ValueNode *> outputs = g->CollectAliveNode(bci);
429 mindspore::CompactSet<ValueNode *> uniques;
430 for (auto output : outputs) {
431 uniques.insert(output);
432 }
433 outputs.assign(uniques.begin(), uniques.end());
434 return outputs;
435 }
436
PrintAliveNodes(std::vector<ValueNode * > aliveNodes)437 void PrintAliveNodes(std::vector<ValueNode *> aliveNodes) {
438 GRAPH_JIT_LOG_F("UD analyze: alive node size : %ld", aliveNodes.size());
439 for (auto node : aliveNodes) {
440 if (node) {
441 GRAPH_JIT_LOG_F("UD analyze: alive node: %s", node->ToString().c_str());
442 }
443 }
444 }
445
AnalyzeAliveLocals(std::vector<ValueNode * > aliveNodes)446 bool GraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) {
447 if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
448 PrintAliveNodes(aliveNodes);
449 }
450 bool isAllNodesSupportOutput = true;
451 for (auto node : aliveNodes) {
452 AObject *o = node->GetVobj();
453 bool supported_type = o && o->IsMindSporeSupportedType();
454 if (supported_type) {
455 continue;
456 }
457 auto capturedLocals = info_.captured_.operations;
458 if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
459 continue;
460 }
461
462 if (!HasTensorOperation()) {
463 CleanCapturedValue();
464 break;
465 }
466
467 // reset break graph point
468 isAllNodesSupportOutput = false;
469 int new_break_point = node->bci();
470 auto curNode = node;
471 MS_EXCEPTION_IF_CHECK_FAIL(new_break_point != -1, "break point cannot be -1");
472 MS_EXCEPTION_IF_NULL(curNode->GetGraph());
473 if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
474 GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
475 }
476 this->graph_->StopTraceAt(new_break_point, StopTraceReason::kStopTraceDataDependsOnGraphOut);
477
478 // re-collect captured info
479 info_.clear();
480 const FrameStates &enter_frame = graph_->GetFrame(0);
481 GetCaptureInfo().interpret_.values.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
482 (void)AnalyzeRecursive(graph_);
483 break;
484 }
485 return isAllNodesSupportOutput;
486 }
487
SkipSpecialFuncOrPrimitive(const py::object & callable)488 static bool SkipSpecialFuncOrPrimitive(const py::object &callable) {
489 if (callable.ptr() == nullptr) {
490 return false;
491 }
492 if (CheckJitConstexpr(callable) || CheckMSConstexpr(callable)) {
493 return true;
494 }
495 if (IsPrimitiveType<true>(Py_TYPE(callable.ptr()))) {
496 std::string name = callable.attr("name").cast<std::string>();
497 return GetSpecialPrimitiveInferFunc().find(name) != GetSpecialPrimitiveInferFunc().end();
498 }
499 return false;
500 }
501
HasTensorOperation() const502 bool GraphAnalyzer::HasTensorOperation() const {
503 bool has_tensor_cal = false;
504 for (auto i : info_.captured_.values) {
505 AObject *value = i->GetVobj();
506 Opcode op(i->GetOpcode());
507 if (op.IsCall()) {
508 if (SkipSpecialFuncOrPrimitive(i->input(0)->GetVobj()->GetPyObject())) {
509 continue;
510 }
511 if (value->GetType() == AObject::kTypeCFunction) {
512 continue;
513 }
514 return true;
515 }
516 if (op.IsBinaryMath() && value->GetType() == AObject::kTypeTensor) {
517 return true;
518 }
519 }
520 return has_tensor_cal;
521 }
522
clear()523 void GraphAnalyzer::CapturedInfo::Info::clear() {
524 values.clear();
525 inputs.clear();
526 operations.clear();
527 outputs.clear();
528 }
529
clear()530 void GraphAnalyzer::CapturedInfo::GraphInputs::clear() {
531 args.clear();
532 globals.clear();
533 vargs = nullptr;
534 kwargs = nullptr;
535 }
536
clear()537 void GraphAnalyzer::CapturedInfo::clear() {
538 captured_.clear();
539 interpret_.clear();
540 graph_inputs_.clear();
541 }
542
ToString()543 std::string GraphAnalyzer::CapturedInfo::Info::ToString() {
544 std::stringstream s;
545 s << "values: ";
546 for (auto i : values) {
547 s << i->ToString() << "\n";
548 }
549 s << "inputs: \n";
550 for (auto i : inputs) {
551 s << i->ToString() << "\n";
552 }
553 s << "operations: \n";
554 for (auto i : operations) {
555 s << i->ToString() << "\n";
556 }
557 s << "outputs: \n";
558 for (auto i : outputs) {
559 s << i->ToString() << "\n";
560 }
561 return s.str();
562 }
563
ToString()564 std::string GraphAnalyzer::CapturedInfo::GraphInputs::ToString() {
565 std::stringstream s;
566 s << "globals: ";
567 for (auto i : globals) {
568 s << i->ToString() << "\n";
569 }
570 s << "args: \n";
571 for (auto i : args) {
572 s << i->ToString() << "\n";
573 }
574 s << "vargs: ";
575 if (vargs != nullptr) {
576 s << vargs->ToString();
577 }
578 s << "\n";
579 s << "kwargs: ";
580 if (kwargs != nullptr) {
581 s << kwargs->ToString();
582 }
583 s << "\n";
584 return s.str();
585 }
586
ToString()587 std::string GraphAnalyzer::CapturedInfo::ToString() {
588 std::stringstream s;
589 s << "1. captured_ info: \n";
590 s << captured_.ToString();
591 s << "2. interpret_ info: \n";
592 s << interpret_.ToString();
593 s << "3. graph_inputs_: \n";
594 s << graph_inputs_.ToString();
595 s << "4. has_grad_: " << has_grad_ << "\n";
596 return s.str();
597 }
598
CleanCapturedValue()599 void GraphAnalyzer::CleanCapturedValue() {
600 auto &locals = info_.interpret_.values;
601 for (auto i : info_.captured_.operations) {
602 if (locals.find(i) == locals.end()) {
603 locals.insert(i);
604 info_.interpret_.operations.emplace_back(i);
605 }
606 }
607 info_.captured_.values.clear();
608 info_.captured_.operations.clear();
609 }
610
CollectGraphOutputs(const mindspore::CompactSet<ValueNode * > & interpret,const std::vector<ValueNode * > & alive)611 static std::vector<ValueNode *> CollectGraphOutputs(const mindspore::CompactSet<ValueNode *> &interpret,
612 const std::vector<ValueNode *> &alive) {
613 std::vector<ValueNode *> outputs;
614 for (auto i : alive) {
615 if (interpret.find(i) == interpret.end() && !IsNonLocalValue(i)) {
616 outputs.push_back(i);
617 }
618 }
619 return outputs;
620 }
621
CollectCapturedAndInterpret()622 void GraphAnalyzer::CollectCapturedAndInterpret() {
623 CollectCapturedInputs();
624 int break_bci = graph_->GetStopTraceBci();
625 std::vector<ValueNode *> alive_nodes = graph_->CollectAliveNode(break_bci, &alive_locals_);
626
627 GetCaptureInfo().captured_.outputs = CollectGraphOutputs(GetCaptureInfo().interpret_.values, alive_nodes);
628 GetCaptureInfo().interpret_.inputs = graph_->GetFrame(0).GetLocals();
629 GetCaptureInfo().interpret_.outputs = std::move(alive_nodes);
630 }
631
CollectGraphInputs()632 void GraphAnalyzer::CollectGraphInputs() {
633 PyCodeObject *co_ = graph_->GetCodeObj();
634 auto &interpret_ = GetCaptureInfo().interpret_;
635 auto &captured_ = GetCaptureInfo().captured_;
636 auto &graph_inputs = GetCaptureInfo().graph_inputs_;
637
638 // NOTE: if *vargs is cell variable, it is not parameter node
639 MS_EXCEPTION_IF_CHECK_FAIL(co_->co_nlocals == static_cast<int>(interpret_.inputs.size()),
640 "interpret inputs must be same as locals");
641
642 ValueNode *vargs = nullptr;
643 ValueNode *kwargs = nullptr;
644 int arg_index = co_->co_argcount + co_->co_kwonlyargcount;
645 if ((co_->co_flags & CO_VARARGS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
646 vargs = interpret_.inputs[arg_index];
647 }
648 arg_index += (IntToSize(co_->co_flags) & CO_VARARGS) != 0;
649 if ((IntToSize(co_->co_flags) & CO_VARKEYWORDS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
650 kwargs = interpret_.inputs[arg_index];
651 }
652
653 // Identify parameters and global variables
654 for (auto input : captured_.inputs) {
655 if (input == graph_inputs.vargs) {
656 graph_inputs.vargs = vargs;
657 } else if (input == graph_inputs.kwargs) {
658 graph_inputs.kwargs = kwargs;
659 } else if (ValidateGraphParameters(input)) {
660 graph_inputs.args.push_back(input);
661 } else {
662 graph_inputs.globals.push_back(input);
663 }
664 }
665
666 size_t inputs_count = captured_.inputs.size();
667 captured_.inputs = graph_inputs.args;
668 if (graph_inputs.vargs != nullptr) {
669 captured_.inputs.push_back(graph_inputs.vargs);
670 }
671 if (graph_inputs.kwargs != nullptr) {
672 captured_.inputs.push_back(graph_inputs.kwargs);
673 }
674 captured_.inputs.insert(captured_.inputs.end(), graph_inputs.globals.begin(), graph_inputs.globals.end());
675 MS_EXCEPTION_IF_CHECK_FAIL(inputs_count == captured_.inputs.size(), "error parameters");
676 }
677
CollectCapturedInputs()678 void MindGraphAnalyzer::CollectCapturedInputs() {
679 auto &inputs = GetCaptureInfo().captured_.inputs;
680 const FrameStates &enter_frame = graph_->GetFrame(0);
681 PyCodeObject *co = graph_->GetCodeObj();
682 int argc = co->co_argcount + co->co_kwonlyargcount;
683 argc += SizeToInt(IntToSize(co->co_flags) & CO_VARARGS) ? 1 : 0;
684 argc += SizeToInt(IntToSize(co->co_flags) & CO_VARKEYWORDS) ? 1 : 0;
685 for (Py_ssize_t m = 0; m < argc; ++m) {
686 const auto &local = enter_frame.Local(m);
687 if (local == &ValueNode::kUnboundLocal) {
688 continue;
689 }
690 inputs.push_back(local);
691 }
692 }
693
Analyze()694 void MindGraphAnalyzer::Analyze() {
695 OptimizeSideEffectRecord();
696
697 auto origin_stop_bci = graph_->GetStopTraceBci();
698 UseDefAnalyze();
699
700 const FrameStates &enter_frame = graph_->GetFrame(0);
701 GetCaptureInfo().interpret_.values.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
702
703 auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
704 MS_EXCEPTION_IF_NULL(mind_graph_builder);
705 auto func_graph_builder = mind_graph_builder->FGBuilder();
706 if (func_graph_builder->graph() == nullptr) {
707 // Graph build failed, add all nodes to ordered_escaped_locals.
708 MS_LOG(DEBUG) << "Failed to build graph";
709 GetCaptureInfo().interpret_.operations.clear();
710 for (const auto &traced_node : graph_->GetTracedNodes()) {
711 if (origin_stop_bci != -1 && traced_node->bci() >= origin_stop_bci) {
712 break;
713 }
714 AddToEscaped(traced_node);
715 }
716 graph_->StopTraceAt(origin_stop_bci, StopTraceReason::kStopTraceDataDependsOnGraphOut);
717 ResetSideEffectRecord();
718
719 need_interpret_ = true;
720 GetCaptureInfo().captured_.clear();
721 CollectCapturedAndInterpret();
722 return;
723 }
724 ResetSideEffectRecord();
725
726 CollectCapturedAndInterpret();
727 CollectGraphInputs();
728
729 need_interpret_ = true;
730 if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().interpret_.operations.empty()) {
731 return;
732 }
733 bool support_ret = graph_->GetRetVal()->GetVobj() && graph_->GetRetVal()->GetVobj()->IsMindSporeSupportedType();
734 if (!support_ret) {
735 return;
736 }
737 need_interpret_ = !graph_->GetSideEffect()->IsEmpty();
738 }
739
AnalyzeAliveLocals(std::vector<ValueNode * > aliveNodes)740 bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) {
741 bool isAllNodesSupportOutput = true;
742 auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
743 MS_EXCEPTION_IF_NULL(mind_graph_builder);
744 auto func_graph_builder = mind_graph_builder->FGBuilder();
745 MS_EXCEPTION_IF_NULL(func_graph_builder);
746 func_graph_builder->ClearOutputNodes();
747 GetCaptureInfo().captured_.outputs.clear();
748 for (auto node : aliveNodes) {
749 // If the value can get from local, no need to add to graph output.
750 if (IsNonLocalValue(node)) {
751 MS_LOG(DEBUG) << "Skip non local value used as graph return.";
752 continue;
753 }
754 auto capturedLocals = info_.captured_.operations;
755 if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
756 continue;
757 }
758 AObject *o = node->GetVobj();
759 auto out_py_obj = o->GetPyObject();
760 if (func_graph_builder->AddOutput(out_py_obj, true)) {
761 MS_LOG(INFO) << "Add output success for node: " << node->ToString();
762 GetCaptureInfo().captured_.outputs.push_back(node);
763 continue;
764 }
765 MS_LOG(INFO) << "Add output failed for node: " << node->ToString();
766 GetCaptureInfo().captured_.outputs.clear();
767 // reset break graph point
768 isAllNodesSupportOutput = false;
769 int new_break_point = node->bci();
770 auto curNode = node;
771 if (new_break_point == -1) {
772 // No node is unsupported output since no node in captured output.
773 isAllNodesSupportOutput = true;
774 break;
775 }
776 MS_EXCEPTION_IF_NULL(curNode->GetGraph());
777 if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
778 GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
779 }
780 this->graph_->StopTraceAt(new_break_point, StopTraceReason::kStopTraceDataDependsOnGraphOut);
781 break;
782 }
783 return isAllNodesSupportOutput;
784 }
785
UpdateCapturedOrder()786 void MindGraphAnalyzer::UpdateCapturedOrder() {
787 const auto &traced_nodes = graph_->GetTracedNodes();
788 auto stop_bci = graph_->GetStopTraceBci();
789 if (stop_bci == -1) {
790 GetCaptureInfo().captured_.operations = traced_nodes;
791 } else {
792 GetCaptureInfo().captured_.operations.clear();
793 for (const auto &traced_node : traced_nodes) {
794 if (traced_node->bci() >= stop_bci) {
795 break;
796 }
797 GetCaptureInfo().captured_.operations.push_back(traced_node);
798 }
799 }
800 const auto &captured_local_order = GetCaptureInfo().captured_.operations;
801 mindspore::CompactSet<ValueNode *> new_capture_local_values;
802 for (auto val : captured_local_order) {
803 new_capture_local_values.insert(val);
804 }
805 GetCaptureInfo().captured_.values = new_capture_local_values;
806 }
807
CollectCapturedAndInterpret()808 void MindGraphAnalyzer::CollectCapturedAndInterpret() {
809 CollectCapturedInputs();
810 int break_bci = graph_->GetStopTraceBci();
811 std::vector<ValueNode *> alive_nodes = graph_->CollectAliveNode(break_bci, &alive_locals_);
812
813 GetCaptureInfo().interpret_.inputs = graph_->GetFrame(0).GetLocals();
814 GetCaptureInfo().interpret_.outputs = std::move(alive_nodes);
815
816 // remove side-effect node
817 auto is_remove = [this](ValueNode *node) { return this->graph_->GetSideEffect()->IsRecord(node); };
818 auto *ops = &GetCaptureInfo().captured_.operations;
819 ops->erase(std::remove_if(ops->begin(), ops->end(), is_remove), ops->end());
820 ops = &GetCaptureInfo().interpret_.operations;
821 ops->erase(std::remove_if(ops->begin(), ops->end(), is_remove), ops->end());
822 }
823
UseDefAnalyze()824 void MindGraphAnalyzer::UseDefAnalyze() {
825 // UD analyze: alive nodes analysis
826 std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
827 if (!aliveLocals.empty()) {
828 bool stop_analyze = false;
829 while (!stop_analyze) {
830 UpdateCapturedOrder();
831 // Add graph output according to leaf nodes.
832 stop_analyze = AnalyzeAliveLocals(aliveLocals);
833 if (!stop_analyze) {
834 aliveLocals = GetAliveLocals(graph_);
835 }
836 }
837 }
838 }
839
CollectGraphInputs()840 void MindGraphAnalyzer::CollectGraphInputs() {
841 PyCodeObject *co_ = graph_->GetCodeObj();
842 auto &interpret_ = GetCaptureInfo().interpret_;
843 auto &captured_ = GetCaptureInfo().captured_;
844 auto &graph_inputs = GetCaptureInfo().graph_inputs_;
845
846 // NOTE: if *vargs is cell variable, it is not parameter node
847 MS_EXCEPTION_IF_CHECK_FAIL(co_->co_nlocals == static_cast<int>(interpret_.inputs.size()),
848 "interpret inputs must be same as locals");
849
850 ValueNode *vargs = nullptr;
851 ValueNode *kwargs = nullptr;
852 int arg_index = co_->co_argcount + co_->co_kwonlyargcount;
853 if ((IntToSize(co_->co_flags) & CO_VARARGS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
854 vargs = interpret_.inputs[arg_index];
855 }
856 arg_index += (IntToSize(co_->co_flags) & CO_VARARGS) != 0;
857 if ((IntToSize(co_->co_flags) & CO_VARKEYWORDS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
858 kwargs = interpret_.inputs[arg_index];
859 }
860
861 // Identify parameters and global variables
862 for (auto input : captured_.inputs) {
863 if (input == graph_inputs.vargs) {
864 graph_inputs.vargs = vargs;
865 } else if (input == graph_inputs.kwargs) {
866 graph_inputs.kwargs = kwargs;
867 } else {
868 graph_inputs.args.push_back(input);
869 }
870 }
871
872 size_t inputs_count = captured_.inputs.size();
873 captured_.inputs = graph_inputs.args;
874 if (graph_inputs.vargs != nullptr) {
875 captured_.inputs.push_back(graph_inputs.vargs);
876 }
877 if (graph_inputs.kwargs != nullptr) {
878 captured_.inputs.push_back(graph_inputs.kwargs);
879 }
880 captured_.inputs.insert(captured_.inputs.end(), graph_inputs.globals.begin(), graph_inputs.globals.end());
881 MS_EXCEPTION_IF_CHECK_FAIL(inputs_count == captured_.inputs.size(), "error parameters");
882 }
883
884 } // namespace pijit
885 } // namespace mindspore
886