• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_capture/graph_analyzer.h"
17 #include <algorithm>
18 #include <unordered_set>
19 #include <utility>
20 #include <string>
21 #include <vector>
22 #include "pipeline/jit/pi/pi_jit_config.h"
23 #include "pipeline/jit/pi/graph_guard/infer.h"
24 #include "pipeline/jit/pi/graph_capture/graph.h"
25 #include "pipeline/jit/pi/graph_capture/special_func_infer.h"
26 #include "pipeline/jit/pi/graph_capture/graph_build.h"
27 #include "pipeline/jit/pi/graph_capture/side_effect.h"
28 
29 namespace mindspore {
30 namespace pijit {
31 
32 extern TracePtr GetTrace(ValueNode *node, bool strict, bool print, int depth, int max_depth);
33 
34 const int kMsFlagSet = AObject::kMsFlagGradFunc | AObject::kMsFlagStandardFunc | AObject::kMsFlagShardFunc |
35                        AObject::kMsFlagVmapFunc | AObject::kMsFlagJitFunc;
36 static bool IsRepeatWithoutSideEffect(ValueNode *v, bool repeat_attr_item_access);
37 
CheckBuildTupleRepeatable(ValueNode * value,bool repeat_attr_item_access)38 static bool CheckBuildTupleRepeatable(ValueNode *value, bool repeat_attr_item_access) {
39   for (auto i : value->getInputs()) {
40     if (i->GetOpcode() == BUILD_TUPLE || !IsRepeatWithoutSideEffect(i, repeat_attr_item_access)) {
41       return false;
42     }
43   }
44   return true;
45 }
46 
CheckBuildSliceRepeatable(const std::vector<ValueNode * > & inputs,bool repeat_attr_item_access)47 static bool CheckBuildSliceRepeatable(const std::vector<ValueNode *> &inputs, bool repeat_attr_item_access) {
48   for (auto i : inputs) {
49     if (i->GetOpcode() != LOAD_CONST) {
50       return false;
51     }
52   }
53   return true;
54 }
55 
56 // These are operations that are repeated and have no side effects.
IsRepeatWithoutSideEffect(ValueNode * v,bool repeat_attr_item_access)57 static bool IsRepeatWithoutSideEffect(ValueNode *v, bool repeat_attr_item_access) {
58   if (IsNonLocalValue(v)) {
59     return true;
60   }
61 
62   AObject::Type type = v->GetVobj() ? v->GetVobj()->GetType() : AObject::kTypeAnyValue;
63   auto opcode = v->GetOpcode();
64   if (opcode == BUILD_TUPLE) {
65     return CheckBuildTupleRepeatable(v, repeat_attr_item_access);
66   } else if (opcode == BUILD_SLICE) {
67     // NOTE: mindspore can't resolve call 'slice' class
68     return CheckBuildSliceRepeatable(v->getInputs(), repeat_attr_item_access);
69   } else if (opcode == BINARY_SUBSCR || opcode == LOAD_ATTR) {
70     return type == AObject::kTypeAnyValue ? false : repeat_attr_item_access;
71   } else if (opcode == BUILD_MAP) {
72     if (type == AObject::kTypeDict) {
73       AbstractDict *d = static_cast<AbstractDict *>(v->GetVobj());
74       return d->size() == 0 || d->KeyType() != AObject::kTypeAnyValue;
75     }
76     return false;
77   }
78   return false;
79 }
80 
81 namespace {
82 /**
83  * mindspore func_graph assume these unsupported value is constant, so it same as global.
84  * avoid parameter unsupported error by global
85  */
ValidateGraphParameters(ValueNode * node)86 bool ValidateGraphParameters(ValueNode *node) {
87   static const std::set<AObject::Type> unsupported_parameter = {
88     AObject::kTypeAnyValue,  AObject::kTypeFunction,      AObject::kTypeBoundMethod,
89     AObject::kTypePrimitive, AObject::kTypeMetaFuncGraph, AObject::kTypeCell,
90   };
91   AObject *info = node->GetVobj();
92   if (info == nullptr) {
93     return false;
94   }
95   return unsupported_parameter.find(info->GetType()) == unsupported_parameter.end();
96 }
97 }  // namespace
98 
ProduceInterpretValue(ValueNode * v)99 bool GraphAnalyzer::ProduceInterpretValue(ValueNode *v) {
100   bool repeat_op = graph_->Config().GetBoolConfig(GraphJitConfig::kEnableOptimizeForAttrItem);
101   auto &locals = GetCaptureInfo().interpret_.values;
102   auto &values = GetCaptureInfo().captured_.values;
103   for (auto i : v->getInputs()) {
104     if (IsNonLocalValue(i) || locals.find(i) != locals.end()) {
105       continue;
106     }
107     if (values.find(i) == values.end()) {
108       MS_LOG(INTERNAL_EXCEPTION) << "capture info can't find the value [" << i->ToString() << "]";
109     }
110     if (!IsRepeatWithoutSideEffect(i, repeat_op)) {
111       return false;
112     }
113     // duplicate some operations if possible
114     if (ProduceInterpretValue(i)) {
115       continue;
116     }
117     return false;
118   }
119   AddToEscaped(v);
120   return true;
121 }
122 
123 // if operation can't be repeated, or block has attr access side effect
124 // can't reorder attr access op, must be interpret all attr, item access operation
CheckAttrItemSupport(ValueNode * v,bool repeat_op)125 static bool CheckAttrItemSupport(ValueNode *v, bool repeat_op) {
126   int op = v->GetOpcode();
127   AObject::Type type = v->input(0)->GetVobj() ? v->input(0)->GetVobj()->GetType() : AObject::kTypeAnyValue;
128   // item access
129   if (op == BINARY_SUBSCR) {
130     return type != AObject::kTypeAnyValue;
131   }
132   // attr access
133   if (type == AObject::kTypeAnyValue || type == AObject::kTypeBoundMethod) {
134     return false;
135   }
136   if (type == AObject::kTypeTensor && !FindTensorName(v->GetName())) {
137     return false;
138   }
139   return true;
140 }
141 
IsSideEffect(ValueNode * v)142 static bool IsSideEffect(ValueNode *v) {
143   static const std::set<std::string> funcs = {"assign", "Assign"};
144   static const std::set<int> unsupported_op = {
145     STORE_DEREF,  DELETE_DEREF,  STORE_GLOBAL, DELETE_GLOBAL, STORE_ATTR, DELETE_ATTR,
146     STORE_SUBSCR, DELETE_SUBSCR, IMPORT_STAR,  RAISE_VARARGS, RERAISE,    FORMAT_VALUE,
147   };
148   Opcode opcode(v->GetOpcode());
149   if (opcode.MayDelete()) {
150     return false;
151   }
152   if (opcode.IsCall()) {
153     AObject *f = v->input(0)->GetVobj();
154     if (f == nullptr) {
155       return true;
156     }
157     if (f->TestMsFlag(AObject::kMsFlagGradFunc)) {
158       return false;
159     }
160     return funcs.find(GetFuncName(f->GetPyObject())) != funcs.end();
161   }
162   return unsupported_op.find(v->GetOpcode()) != unsupported_op.end();
163 }
164 
HandleCallableToGraph(AObject * f)165 bool GraphAnalyzer::HandleCallableToGraph(AObject *f) {
166   static bool known_type[AObject::kTypeCount] = {false};
167   if (known_type[AObject::kTypePrimitive] == false) {
168     known_type[AObject::kTypePrimitive] = true;
169     known_type[AObject::kTypeCell] = true;
170     known_type[AObject::kTypeMetaFuncGraph] = true;
171     known_type[AObject::kTypePrimitiveFunction] = true;
172   }
173   if (f == nullptr) {
174     return false;
175   }
176   // don't pass unknown callable to graph
177   bool is_known_func = known_type[f->GetType()] || CheckJitConstexpr(f->GetPyObject());
178   bool is_ms_support_func = f->TestMsFlag(kMsFlagSet);
179   if (!is_known_func && !is_ms_support_func) {
180     return false;
181   }
182   if (f->GetType() == AObject::kTypePrimitive) {
183     PyTypeObject *tp = f->GetTypeObject();
184     std::string name = (tp && tp->tp_name ? tp->tp_name : "");
185     if (name == "Assign") {
186       return false;
187     }
188   }
189   return true;
190 }
191 
AddToCaptured(ValueNode * v)192 bool GraphAnalyzer::AddToCaptured(ValueNode *v) {
193   if (IsNonLocalValue(v)) {
194     return true;
195   }
196   if (v->GetVobj() && v->GetVobj()->TestMsFlag(AObject::kMsFlagGradFunc)) {
197     GetCaptureInfo().has_grad_ = true;
198     GetCaptureInfo().captured_.values.insert(v);
199     GetCaptureInfo().captured_.operations.push_back(v);
200     return true;
201   }
202 
203   int op = v->GetOpcode();
204   bool repeat_op = graph_->Config().GetBoolConfig(GraphJitConfig::kEnableOptimizeForAttrItem);
205   if ((op == LOAD_ATTR || op == BINARY_SUBSCR) && !CheckAttrItemSupport(v, repeat_op)) {
206     return false;
207   }
208 
209   bool is_call_op = Opcode(v->GetOpcode()).IsCall();
210   if (is_call_op) {
211     AObject *f = v->input(0)->GetVobj();
212     bool can_pass = HandleCallableToGraph(f);
213     if (!can_pass) {
214       return false;
215     }
216     GetCaptureInfo().has_grad_ |= f->TestMsFlag(AObject::kMsFlagGradFunc);
217   }
218 
219   auto &locals = GetCaptureInfo().interpret_.values;  // interpret values
220   auto &values = GetCaptureInfo().captured_.values;   // graph produced values
221   for (auto i : v->getInputs()) {
222     bool produced_in_graph = values.find(i) != values.end() || IsNonLocalValue(i);
223     MS_EXCEPTION_IF_CHECK_FAIL(produced_in_graph || locals.find(i) != locals.end(),
224                                "check values order, all input must be generate before this value " + i->ToString());
225     if (i->GetVobj() == nullptr) {
226       return false;
227     }
228     AObject::Type type = i->GetVobj()->GetType();
229     PyTypeObject *tp = i->GetVobj()->GetTypeObject();
230     if (type == AObject::kTypeAnyValue && !IsMsClass(reinterpret_cast<PyObject *>(tp))) {
231       // don't pass unknown object to graph
232       return false;
233     }
234     if (type == AObject::kTypeCell && !is_call_op) {
235       // don't pass a cell object that not call to graph.
236       return false;
237     }
238   }
239 
240   GetCaptureInfo().captured_.values.insert(v);
241   GetCaptureInfo().captured_.operations.push_back(v);
242   return true;
243 }
244 
AddToEscaped(ValueNode * v)245 void GraphAnalyzer::AddToEscaped(ValueNode *v) {
246   MS_EXCEPTION_IF_CHECK_FAIL(GetCaptureInfo().interpret_.values.find(v) == GetCaptureInfo().interpret_.values.end(),
247                              "duplicate escaped values");
248   GetCaptureInfo().interpret_.values.insert(v);
249   GetCaptureInfo().interpret_.operations.push_back(v);
250 }
251 
252 extern TracePtr GetTrace(ValueNode *node, bool strict, bool print, int depth, int max_depth);
253 
TryToCapture(AbstractNode * n)254 bool GraphAnalyzer::TryToCapture(AbstractNode *n) {
255   ValueNode *v = static_cast<ValueNode *>(n);
256   if (IsNonLocalValue(v)) {
257     return true;
258   }
259   if (graph_->GetSideEffect()->IsRecord(v)) {
260     return true;
261   }
262   bool is_side_effect = IsSideEffect(v);
263   if (!is_side_effect && AddToCaptured(v)) {
264     return true;
265   }
266   if (!GetCaptureInfo().captured_.values.empty() && is_side_effect) {
267     return false;
268   }
269   if (ProduceInterpretValue(v)) {
270     return true;
271   }
272   if (!HasTensorOperation()) {
273     CleanCapturedValue();
274     AddToEscaped(v);
275     return true;
276   }
277 
278   if (v->GetGraph() != nullptr && this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
279     GRAPH_JIT_LOG_F("capture failed, operations is unsupported [%s] at [%U: %d]", v->ToString().c_str(),
280                     v->GetGraph()->GetCodeObj()->co_filename, v->GetLineNo());
281     GRAPH_JIT_LOG_F("parameters");
282     for (auto &i : v->getInputs()) {
283       PyObject *op = i->GetVobj() ? i->GetVobj()->GetPyObject().ptr() : nullptr;
284       GRAPH_JIT_LOG_F("%s", op ? AObject::ToString(op).c_str() : "NULL");
285     }
286   }
287 
288   MS_LOG(DEBUG) << "---operation that depend on the graph outputs, break graph---";
289   return false;
290 }
291 
AnalyzeCall(CallNode * call_node)292 bool GraphAnalyzer::AnalyzeCall(CallNode *call_node) {
293   if (call_node->GetSubGraph() == nullptr) {
294     return false;
295   }
296   if (call_node->GetInlineReason() != InlineReason::kInline) {
297     return false;
298   }
299 
300   Graph *g = call_node->GetGraph();
301 
302   CapturedInfo back_up = info_;
303   const auto &p = call_node->GetParams();
304   // capture parameter handle operations
305   auto iter = std::find_if(p.begin(), p.end(), [this](ValueNode *i) { return !this->TryToCapture(i); });
306   // capture sub-graph
307   if (iter == p.end() && AnalyzeRecursive(call_node->GetSubGraph())) {
308     return true;
309   }
310   info_ = back_up;
311   g->StopTraceAt(call_node->bci(), StopTraceReason::kStopTraceDataDependsOnGraphOut);
312   return false;
313 }
314 
AnalyzeRecursive(Graph * g)315 bool GraphAnalyzer::AnalyzeRecursive(Graph *g) {
316   for (auto n : g->GetTracedNodes()) {
317     int bci = static_cast<ValueNode *>(n)->bci();
318     if (n->GetType() == AbstractNode::Call && AnalyzeCall(static_cast<CallNode *>(n))) {
319       continue;
320     }
321     if (bci != -1 && g->GetStopTraceBci() == bci) {
322       return false;
323     }
324     if (!TryToCapture(n)) {
325       g->StopTraceAt(bci, StopTraceReason::kStopTraceDataDependsOnGraphOut);
326       return false;
327     }
328   }
329   return true;
330 }
331 
CollectCapturedInputs()332 void GraphAnalyzer::CollectCapturedInputs() {
333   auto &locals = GetCaptureInfo().interpret_.values;
334   auto &values = GetCaptureInfo().captured_.values;
335   mindspore::CompactSet<ValueNode *> inputs;
336   for (ValueNode *i : GetCaptureInfo().captured_.operations) {
337     for (auto input : i->getInputs()) {
338       if (values.find(input) != values.end() || IsNonLocalValue(input)) {
339         continue;
340       }
341       MS_EXCEPTION_IF_CHECK_FAIL(locals.find(input) != locals.end(), "check graph input");
342       inputs.insert(input);
343     }
344   }
345   GetCaptureInfo().captured_.inputs = {inputs.begin(), inputs.end()};
346 }
347 
UseDefAnalyze()348 void GraphAnalyzer::UseDefAnalyze() {
349   // UD analyze: alive nodes analysis
350   std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
351   if (!aliveLocals.empty()) {
352     bool isStopAnalyze = false;
353     while (!isStopAnalyze) {
354       isStopAnalyze = AnalyzeAliveLocals(aliveLocals);
355       if (isStopAnalyze) {
356         break;
357       }
358       aliveLocals = GetAliveLocals(graph_);
359     }
360   }
361 }
362 
OptimizeSideEffectRecord() const363 void GraphAnalyzer::OptimizeSideEffectRecord() const {
364   if (graph_->GetSideEffect()->IsEmpty()) {
365     return;
366   }
367   auto alive = graph_->CollectAliveNode(graph_->GetStopTraceBci());
368   auto side_effect_required_size = graph_->GetSideEffect()->GetRequiredNodes().size();
369   auto size = alive.size() - side_effect_required_size;
370   graph_->GetSideEffect()->Optimize({alive.begin(), alive.begin() + size});
371 }
372 
ResetSideEffectRecord() const373 void GraphAnalyzer::ResetSideEffectRecord() const {
374   // if break point is changed, rollback graph nodes(only reset break bci) and side-effect record
375   int break_bci = graph_->GetStopTraceBci();
376   if (break_bci == -1 || graph_->GetSideEffect()->IsEmpty()) {
377     return;
378   }
379   const auto &nodes = graph_->GetTracedNodes();
380   auto iter = std::find_if(nodes.begin(), nodes.end(), [&break_bci](ValueNode *i) { return i->bci() > break_bci; });
381   graph_->GetSideEffect()->ResetRecord({nodes.begin(), iter});
382   OptimizeSideEffectRecord();  // after reset record, rollback side-effect record status
383 }
384 
Analyze()385 void GraphAnalyzer::Analyze() {
386   OptimizeSideEffectRecord();  // first optimize, remove dead local side-effects and it's required nodes
387 
388   const FrameStates &enter_frame = graph_->GetFrame(0);
389   GetCaptureInfo().interpret_.values.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
390   AnalyzeRecursive(graph_);
391   if (!HasTensorOperation()) {
392     CleanCapturedValue();
393   }
394   UseDefAnalyze();
395   ResetSideEffectRecord();  // if rollback nodes, rollback side-effects
396 
397   CollectCapturedAndInterpret();
398   CollectGraphInputs();
399 
400   need_interpret_ = true;
401 
402   if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().interpret_.operations.empty()) {
403     return;
404   }
405   bool support_ret = graph_->GetRetVal()->GetVobj() && graph_->GetRetVal()->GetVobj()->IsMindSporeSupportedType();
406   if (!support_ret) {
407     return;
408   }
409   PyCodeObject *co = graph_->GetCodeObj();
410   const auto &args = enter_frame.GetLocals();
411   int argc = co->co_argcount + co->co_kwonlyargcount;
412   // check all parameters is graph supported, but here not check variable arguments
413   auto end = args.begin() + argc;
414   auto iter = std::find_if(args.begin(), end, [](ValueNode *i) { return !ValidateGraphParameters(i); });
415   if (iter == end) {
416     need_interpret_ = false;
417   }
418   need_interpret_ |= !graph_->GetSideEffect()->IsEmpty();
419 }
420 
buildLastFrame(Graph * g)421 FrameStates buildLastFrame(Graph *g) { return g->GetFrame(g->GetStopTraceBci()); }
422 
GetAliveLocals(Graph * g)423 std::vector<ValueNode *> GraphAnalyzer::GetAliveLocals(Graph *g) {
424   int bci = g->GetStopTraceBci();
425   if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
426     GRAPH_JIT_LOG_F("UD analyze: enter GetAliveLocals bci %d", bci);
427   }
428   std::vector<ValueNode *> outputs = g->CollectAliveNode(bci);
429   mindspore::CompactSet<ValueNode *> uniques;
430   for (auto output : outputs) {
431     uniques.insert(output);
432   }
433   outputs.assign(uniques.begin(), uniques.end());
434   return outputs;
435 }
436 
PrintAliveNodes(std::vector<ValueNode * > aliveNodes)437 void PrintAliveNodes(std::vector<ValueNode *> aliveNodes) {
438   GRAPH_JIT_LOG_F("UD analyze: alive node size : %ld", aliveNodes.size());
439   for (auto node : aliveNodes) {
440     if (node) {
441       GRAPH_JIT_LOG_F("UD analyze: alive node: %s", node->ToString().c_str());
442     }
443   }
444 }
445 
AnalyzeAliveLocals(std::vector<ValueNode * > aliveNodes)446 bool GraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) {
447   if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
448     PrintAliveNodes(aliveNodes);
449   }
450   bool isAllNodesSupportOutput = true;
451   for (auto node : aliveNodes) {
452     AObject *o = node->GetVobj();
453     bool supported_type = o && o->IsMindSporeSupportedType();
454     if (supported_type) {
455       continue;
456     }
457     auto capturedLocals = info_.captured_.operations;
458     if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
459       continue;
460     }
461 
462     if (!HasTensorOperation()) {
463       CleanCapturedValue();
464       break;
465     }
466 
467     //  reset break graph point
468     isAllNodesSupportOutput = false;
469     int new_break_point = node->bci();
470     auto curNode = node;
471     MS_EXCEPTION_IF_CHECK_FAIL(new_break_point != -1, "break point cannot be -1");
472     MS_EXCEPTION_IF_NULL(curNode->GetGraph());
473     if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
474       GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
475     }
476     this->graph_->StopTraceAt(new_break_point, StopTraceReason::kStopTraceDataDependsOnGraphOut);
477 
478     // re-collect captured info
479     info_.clear();
480     const FrameStates &enter_frame = graph_->GetFrame(0);
481     GetCaptureInfo().interpret_.values.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
482     (void)AnalyzeRecursive(graph_);
483     break;
484   }
485   return isAllNodesSupportOutput;
486 }
487 
SkipSpecialFuncOrPrimitive(const py::object & callable)488 static bool SkipSpecialFuncOrPrimitive(const py::object &callable) {
489   if (callable.ptr() == nullptr) {
490     return false;
491   }
492   if (CheckJitConstexpr(callable) || CheckMSConstexpr(callable)) {
493     return true;
494   }
495   if (IsPrimitiveType<true>(Py_TYPE(callable.ptr()))) {
496     std::string name = callable.attr("name").cast<std::string>();
497     return GetSpecialPrimitiveInferFunc().find(name) != GetSpecialPrimitiveInferFunc().end();
498   }
499   return false;
500 }
501 
HasTensorOperation() const502 bool GraphAnalyzer::HasTensorOperation() const {
503   bool has_tensor_cal = false;
504   for (auto i : info_.captured_.values) {
505     AObject *value = i->GetVobj();
506     Opcode op(i->GetOpcode());
507     if (op.IsCall()) {
508       if (SkipSpecialFuncOrPrimitive(i->input(0)->GetVobj()->GetPyObject())) {
509         continue;
510       }
511       if (value->GetType() == AObject::kTypeCFunction) {
512         continue;
513       }
514       return true;
515     }
516     if (op.IsBinaryMath() && value->GetType() == AObject::kTypeTensor) {
517       return true;
518     }
519   }
520   return has_tensor_cal;
521 }
522 
clear()523 void GraphAnalyzer::CapturedInfo::Info::clear() {
524   values.clear();
525   inputs.clear();
526   operations.clear();
527   outputs.clear();
528 }
529 
clear()530 void GraphAnalyzer::CapturedInfo::GraphInputs::clear() {
531   args.clear();
532   globals.clear();
533   vargs = nullptr;
534   kwargs = nullptr;
535 }
536 
clear()537 void GraphAnalyzer::CapturedInfo::clear() {
538   captured_.clear();
539   interpret_.clear();
540   graph_inputs_.clear();
541 }
542 
ToString()543 std::string GraphAnalyzer::CapturedInfo::Info::ToString() {
544   std::stringstream s;
545   s << "values: ";
546   for (auto i : values) {
547     s << i->ToString() << "\n";
548   }
549   s << "inputs: \n";
550   for (auto i : inputs) {
551     s << i->ToString() << "\n";
552   }
553   s << "operations: \n";
554   for (auto i : operations) {
555     s << i->ToString() << "\n";
556   }
557   s << "outputs: \n";
558   for (auto i : outputs) {
559     s << i->ToString() << "\n";
560   }
561   return s.str();
562 }
563 
ToString()564 std::string GraphAnalyzer::CapturedInfo::GraphInputs::ToString() {
565   std::stringstream s;
566   s << "globals: ";
567   for (auto i : globals) {
568     s << i->ToString() << "\n";
569   }
570   s << "args: \n";
571   for (auto i : args) {
572     s << i->ToString() << "\n";
573   }
574   s << "vargs: ";
575   if (vargs != nullptr) {
576     s << vargs->ToString();
577   }
578   s << "\n";
579   s << "kwargs: ";
580   if (kwargs != nullptr) {
581     s << kwargs->ToString();
582   }
583   s << "\n";
584   return s.str();
585 }
586 
ToString()587 std::string GraphAnalyzer::CapturedInfo::ToString() {
588   std::stringstream s;
589   s << "1. captured_ info: \n";
590   s << captured_.ToString();
591   s << "2. interpret_ info: \n";
592   s << interpret_.ToString();
593   s << "3. graph_inputs_: \n";
594   s << graph_inputs_.ToString();
595   s << "4. has_grad_: " << has_grad_ << "\n";
596   return s.str();
597 }
598 
CleanCapturedValue()599 void GraphAnalyzer::CleanCapturedValue() {
600   auto &locals = info_.interpret_.values;
601   for (auto i : info_.captured_.operations) {
602     if (locals.find(i) == locals.end()) {
603       locals.insert(i);
604       info_.interpret_.operations.emplace_back(i);
605     }
606   }
607   info_.captured_.values.clear();
608   info_.captured_.operations.clear();
609 }
610 
CollectGraphOutputs(const mindspore::CompactSet<ValueNode * > & interpret,const std::vector<ValueNode * > & alive)611 static std::vector<ValueNode *> CollectGraphOutputs(const mindspore::CompactSet<ValueNode *> &interpret,
612                                                     const std::vector<ValueNode *> &alive) {
613   std::vector<ValueNode *> outputs;
614   for (auto i : alive) {
615     if (interpret.find(i) == interpret.end() && !IsNonLocalValue(i)) {
616       outputs.push_back(i);
617     }
618   }
619   return outputs;
620 }
621 
CollectCapturedAndInterpret()622 void GraphAnalyzer::CollectCapturedAndInterpret() {
623   CollectCapturedInputs();
624   int break_bci = graph_->GetStopTraceBci();
625   std::vector<ValueNode *> alive_nodes = graph_->CollectAliveNode(break_bci, &alive_locals_);
626 
627   GetCaptureInfo().captured_.outputs = CollectGraphOutputs(GetCaptureInfo().interpret_.values, alive_nodes);
628   GetCaptureInfo().interpret_.inputs = graph_->GetFrame(0).GetLocals();
629   GetCaptureInfo().interpret_.outputs = std::move(alive_nodes);
630 }
631 
CollectGraphInputs()632 void GraphAnalyzer::CollectGraphInputs() {
633   PyCodeObject *co_ = graph_->GetCodeObj();
634   auto &interpret_ = GetCaptureInfo().interpret_;
635   auto &captured_ = GetCaptureInfo().captured_;
636   auto &graph_inputs = GetCaptureInfo().graph_inputs_;
637 
638   // NOTE: if *vargs is cell variable, it is not parameter node
639   MS_EXCEPTION_IF_CHECK_FAIL(co_->co_nlocals == static_cast<int>(interpret_.inputs.size()),
640                              "interpret inputs must be same as locals");
641 
642   ValueNode *vargs = nullptr;
643   ValueNode *kwargs = nullptr;
644   int arg_index = co_->co_argcount + co_->co_kwonlyargcount;
645   if ((co_->co_flags & CO_VARARGS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
646     vargs = interpret_.inputs[arg_index];
647   }
648   arg_index += (IntToSize(co_->co_flags) & CO_VARARGS) != 0;
649   if ((IntToSize(co_->co_flags) & CO_VARKEYWORDS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
650     kwargs = interpret_.inputs[arg_index];
651   }
652 
653   // Identify parameters and global variables
654   for (auto input : captured_.inputs) {
655     if (input == graph_inputs.vargs) {
656       graph_inputs.vargs = vargs;
657     } else if (input == graph_inputs.kwargs) {
658       graph_inputs.kwargs = kwargs;
659     } else if (ValidateGraphParameters(input)) {
660       graph_inputs.args.push_back(input);
661     } else {
662       graph_inputs.globals.push_back(input);
663     }
664   }
665 
666   size_t inputs_count = captured_.inputs.size();
667   captured_.inputs = graph_inputs.args;
668   if (graph_inputs.vargs != nullptr) {
669     captured_.inputs.push_back(graph_inputs.vargs);
670   }
671   if (graph_inputs.kwargs != nullptr) {
672     captured_.inputs.push_back(graph_inputs.kwargs);
673   }
674   captured_.inputs.insert(captured_.inputs.end(), graph_inputs.globals.begin(), graph_inputs.globals.end());
675   MS_EXCEPTION_IF_CHECK_FAIL(inputs_count == captured_.inputs.size(), "error parameters");
676 }
677 
CollectCapturedInputs()678 void MindGraphAnalyzer::CollectCapturedInputs() {
679   auto &inputs = GetCaptureInfo().captured_.inputs;
680   const FrameStates &enter_frame = graph_->GetFrame(0);
681   PyCodeObject *co = graph_->GetCodeObj();
682   int argc = co->co_argcount + co->co_kwonlyargcount;
683   argc += SizeToInt(IntToSize(co->co_flags) & CO_VARARGS) ? 1 : 0;
684   argc += SizeToInt(IntToSize(co->co_flags) & CO_VARKEYWORDS) ? 1 : 0;
685   for (Py_ssize_t m = 0; m < argc; ++m) {
686     const auto &local = enter_frame.Local(m);
687     if (local == &ValueNode::kUnboundLocal) {
688       continue;
689     }
690     inputs.push_back(local);
691   }
692 }
693 
Analyze()694 void MindGraphAnalyzer::Analyze() {
695   OptimizeSideEffectRecord();
696 
697   auto origin_stop_bci = graph_->GetStopTraceBci();
698   UseDefAnalyze();
699 
700   const FrameStates &enter_frame = graph_->GetFrame(0);
701   GetCaptureInfo().interpret_.values.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
702 
703   auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
704   MS_EXCEPTION_IF_NULL(mind_graph_builder);
705   auto func_graph_builder = mind_graph_builder->FGBuilder();
706   if (func_graph_builder->graph() == nullptr) {
707     // Graph build failed, add all nodes to ordered_escaped_locals.
708     MS_LOG(DEBUG) << "Failed to build graph";
709     GetCaptureInfo().interpret_.operations.clear();
710     for (const auto &traced_node : graph_->GetTracedNodes()) {
711       if (origin_stop_bci != -1 && traced_node->bci() >= origin_stop_bci) {
712         break;
713       }
714       AddToEscaped(traced_node);
715     }
716     graph_->StopTraceAt(origin_stop_bci, StopTraceReason::kStopTraceDataDependsOnGraphOut);
717     ResetSideEffectRecord();
718 
719     need_interpret_ = true;
720     GetCaptureInfo().captured_.clear();
721     CollectCapturedAndInterpret();
722     return;
723   }
724   ResetSideEffectRecord();
725 
726   CollectCapturedAndInterpret();
727   CollectGraphInputs();
728 
729   need_interpret_ = true;
730   if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().interpret_.operations.empty()) {
731     return;
732   }
733   bool support_ret = graph_->GetRetVal()->GetVobj() && graph_->GetRetVal()->GetVobj()->IsMindSporeSupportedType();
734   if (!support_ret) {
735     return;
736   }
737   need_interpret_ = !graph_->GetSideEffect()->IsEmpty();
738 }
739 
AnalyzeAliveLocals(std::vector<ValueNode * > aliveNodes)740 bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) {
741   bool isAllNodesSupportOutput = true;
742   auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
743   MS_EXCEPTION_IF_NULL(mind_graph_builder);
744   auto func_graph_builder = mind_graph_builder->FGBuilder();
745   MS_EXCEPTION_IF_NULL(func_graph_builder);
746   func_graph_builder->ClearOutputNodes();
747   GetCaptureInfo().captured_.outputs.clear();
748   for (auto node : aliveNodes) {
749     // If the value can get from local, no need to add to graph output.
750     if (IsNonLocalValue(node)) {
751       MS_LOG(DEBUG) << "Skip non local value used as graph return.";
752       continue;
753     }
754     auto capturedLocals = info_.captured_.operations;
755     if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
756       continue;
757     }
758     AObject *o = node->GetVobj();
759     auto out_py_obj = o->GetPyObject();
760     if (func_graph_builder->AddOutput(out_py_obj, true)) {
761       MS_LOG(INFO) << "Add output success for node: " << node->ToString();
762       GetCaptureInfo().captured_.outputs.push_back(node);
763       continue;
764     }
765     MS_LOG(INFO) << "Add output failed for node: " << node->ToString();
766     GetCaptureInfo().captured_.outputs.clear();
767     //  reset break graph point
768     isAllNodesSupportOutput = false;
769     int new_break_point = node->bci();
770     auto curNode = node;
771     if (new_break_point == -1) {
772       // No node is unsupported output since no node in captured output.
773       isAllNodesSupportOutput = true;
774       break;
775     }
776     MS_EXCEPTION_IF_NULL(curNode->GetGraph());
777     if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
778       GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
779     }
780     this->graph_->StopTraceAt(new_break_point, StopTraceReason::kStopTraceDataDependsOnGraphOut);
781     break;
782   }
783   return isAllNodesSupportOutput;
784 }
785 
UpdateCapturedOrder()786 void MindGraphAnalyzer::UpdateCapturedOrder() {
787   const auto &traced_nodes = graph_->GetTracedNodes();
788   auto stop_bci = graph_->GetStopTraceBci();
789   if (stop_bci == -1) {
790     GetCaptureInfo().captured_.operations = traced_nodes;
791   } else {
792     GetCaptureInfo().captured_.operations.clear();
793     for (const auto &traced_node : traced_nodes) {
794       if (traced_node->bci() >= stop_bci) {
795         break;
796       }
797       GetCaptureInfo().captured_.operations.push_back(traced_node);
798     }
799   }
800   const auto &captured_local_order = GetCaptureInfo().captured_.operations;
801   mindspore::CompactSet<ValueNode *> new_capture_local_values;
802   for (auto val : captured_local_order) {
803     new_capture_local_values.insert(val);
804   }
805   GetCaptureInfo().captured_.values = new_capture_local_values;
806 }
807 
CollectCapturedAndInterpret()808 void MindGraphAnalyzer::CollectCapturedAndInterpret() {
809   CollectCapturedInputs();
810   int break_bci = graph_->GetStopTraceBci();
811   std::vector<ValueNode *> alive_nodes = graph_->CollectAliveNode(break_bci, &alive_locals_);
812 
813   GetCaptureInfo().interpret_.inputs = graph_->GetFrame(0).GetLocals();
814   GetCaptureInfo().interpret_.outputs = std::move(alive_nodes);
815 
816   // remove side-effect node
817   auto is_remove = [this](ValueNode *node) { return this->graph_->GetSideEffect()->IsRecord(node); };
818   auto *ops = &GetCaptureInfo().captured_.operations;
819   ops->erase(std::remove_if(ops->begin(), ops->end(), is_remove), ops->end());
820   ops = &GetCaptureInfo().interpret_.operations;
821   ops->erase(std::remove_if(ops->begin(), ops->end(), is_remove), ops->end());
822 }
823 
UseDefAnalyze()824 void MindGraphAnalyzer::UseDefAnalyze() {
825   // UD analyze: alive nodes analysis
826   std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
827   if (!aliveLocals.empty()) {
828     bool stop_analyze = false;
829     while (!stop_analyze) {
830       UpdateCapturedOrder();
831       // Add graph output according to leaf nodes.
832       stop_analyze = AnalyzeAliveLocals(aliveLocals);
833       if (!stop_analyze) {
834         aliveLocals = GetAliveLocals(graph_);
835       }
836     }
837   }
838 }
839 
CollectGraphInputs()840 void MindGraphAnalyzer::CollectGraphInputs() {
841   PyCodeObject *co_ = graph_->GetCodeObj();
842   auto &interpret_ = GetCaptureInfo().interpret_;
843   auto &captured_ = GetCaptureInfo().captured_;
844   auto &graph_inputs = GetCaptureInfo().graph_inputs_;
845 
846   // NOTE: if *vargs is cell variable, it is not parameter node
847   MS_EXCEPTION_IF_CHECK_FAIL(co_->co_nlocals == static_cast<int>(interpret_.inputs.size()),
848                              "interpret inputs must be same as locals");
849 
850   ValueNode *vargs = nullptr;
851   ValueNode *kwargs = nullptr;
852   int arg_index = co_->co_argcount + co_->co_kwonlyargcount;
853   if ((IntToSize(co_->co_flags) & CO_VARARGS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
854     vargs = interpret_.inputs[arg_index];
855   }
856   arg_index += (IntToSize(co_->co_flags) & CO_VARARGS) != 0;
857   if ((IntToSize(co_->co_flags) & CO_VARKEYWORDS) && interpret_.inputs[arg_index] != &ValueNode::kUnboundLocal) {
858     kwargs = interpret_.inputs[arg_index];
859   }
860 
861   // Identify parameters and global variables
862   for (auto input : captured_.inputs) {
863     if (input == graph_inputs.vargs) {
864       graph_inputs.vargs = vargs;
865     } else if (input == graph_inputs.kwargs) {
866       graph_inputs.kwargs = kwargs;
867     } else {
868       graph_inputs.args.push_back(input);
869     }
870   }
871 
872   size_t inputs_count = captured_.inputs.size();
873   captured_.inputs = graph_inputs.args;
874   if (graph_inputs.vargs != nullptr) {
875     captured_.inputs.push_back(graph_inputs.vargs);
876   }
877   if (graph_inputs.kwargs != nullptr) {
878     captured_.inputs.push_back(graph_inputs.kwargs);
879   }
880   captured_.inputs.insert(captured_.inputs.end(), graph_inputs.globals.begin(), graph_inputs.globals.end());
881   MS_EXCEPTION_IF_CHECK_FAIL(inputs_count == captured_.inputs.size(), "error parameters");
882 }
883 
884 }  // namespace pijit
885 }  // namespace mindspore
886