• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_ANALYZER_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_ANALYZER_H
18 
19 #include <set>
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include "pipeline/jit/pi/graph_capture/cfg.h"
24 #include "pipeline/jit/pi/graph_capture/abstract_object.h"
25 #include "pipeline/jit/pi/graph_capture/graph_build.h"
26 #include "utils/convert_utils_base.h"
27 namespace mindspore {
28 namespace pijit {
29 
30 class Graph;
31 class AbstractNode;
32 class ValueNode;
33 class CallNode;
34 class GraphAnalyzer;
35 class MindGraphAnalyzer;
36 using GraphAnalyzerPtr = std::shared_ptr<GraphAnalyzer>;
37 using MindGraphAnalyzerPtr = std::shared_ptr<MindGraphAnalyzer>;
38 
39 class GraphAnalyzer {
40  public:
41   // escaped_locals and captured.values do not intersect
42   struct CapturedInfo {
43     struct Info {
44       // contains inputs and operations, used to find
45       mindspore::CompactSet<ValueNode *> values;
46       // the inputs of operations
47       std::vector<ValueNode *> inputs;
48       // bytecode operations
49       std::vector<ValueNode *> operations;
50       // ordered outputs, used to restore stack and locals
51       std::vector<ValueNode *> outputs;
52 
53       void clear();
54       std::string ToString();
55     };
56 
57     struct GraphInputs {
58       std::vector<ValueNode *> args;
59       std::vector<ValueNode *> globals;
60       ValueNode *vargs = nullptr;
61       ValueNode *kwargs = nullptr;
62 
63       void clear();
64       std::string ToString();
65     };
66 
67     /**
68      * for captured inputs, it's parameters, maybe unordered.
69      * for captured outputs, it's ordered by stack values and alive locals.
70      */
71     Info captured_;
72 
73     /**
74      * for interpret inputs, it's ordered and same as original function.
75      * if not break graph, outputs is return value, else outputs is ordered by stack values and alive locals.
76      */
77     Info interpret_;
78 
79     /**
80      * Store all collected graph inputs.
81      * If no graph is generated, graph_inputs_ should be empty.
82      */
83     GraphInputs graph_inputs_;
84 
85     bool has_grad_ = false;
86 
87     void clear();
88     std::string ToString();
89   };
90 
GraphAnalyzer(Graph * g)91   explicit GraphAnalyzer(Graph *g) : graph_(g) {}
Creator(const GraphBuilderPtr & g)92   static GraphAnalyzerPtr Creator(const GraphBuilderPtr &g) {
93     return g->trace_flag() ? std::static_pointer_cast<GraphAnalyzer>(std::make_shared<MindGraphAnalyzer>(g))
94                            : std::make_shared<GraphAnalyzer>(g->GetGraph());
95   }
GetCaptureInfo()96   auto &GetCaptureInfo() { return info_; }
GetCaptureInfo()97   const auto &GetCaptureInfo() const { return info_; }
98   virtual void Analyze();
99   bool HasTensorOperation() const;
NeedInterpret()100   virtual bool NeedInterpret() const { return need_interpret_; }
101 
alive_locals()102   const auto &alive_locals() const { return alive_locals_; }
103 
104  protected:
105   // optimize
106   void OptimizeSideEffectRecord() const;
107 
108   // rollback
109   void ResetSideEffectRecord() const;
110 
111   void AddToEscaped(ValueNode *value);
112   // UD analyze
113   virtual void UseDefAnalyze();
114   std::vector<ValueNode *> GetAliveLocals(Graph *g);
115   virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes);
116   virtual void CollectCapturedInputs();
117   virtual void CollectCapturedAndInterpret();
118   virtual void CollectGraphInputs();
119   bool need_interpret_;
120   Graph *graph_;
121   CapturedInfo info_;
122   std::vector<int> alive_locals_;
123 
124  private:
125   bool AnalyzeRecursive(Graph *g);
126   bool AnalyzeCall(CallNode *);
127   bool TryToCapture(AbstractNode *value);
128   bool HandleSideEffectNodeForCapture(AbstractNode *capture_node);
129   bool AddToCaptured(ValueNode *value);
130   bool HandleCallableToGraph(AObject *f);
131   bool ProduceInterpretValue(ValueNode *v);
132   void CleanCapturedValue();
133 };
134 
135 class MindGraphAnalyzer : public GraphAnalyzer {
136  public:
MindGraphAnalyzer(const GraphBuilderPtr & g)137   explicit MindGraphAnalyzer(const GraphBuilderPtr &g) : GraphAnalyzer(g->GetGraph()), graph_builder_(g) {}
138   void Analyze() override;
139 
140  protected:
141   // UD analyze
142   void UseDefAnalyze() override;
143   void CollectCapturedInputs() override;
144   void CollectGraphInputs() override;
145   void UpdateCapturedOrder();
146   void CollectCapturedAndInterpret() override;
147   bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) override;
148   GraphBuilderPtr graph_builder_ = nullptr;
149 };
150 }  // namespace pijit
151 }  // namespace mindspore
152 
153 #endif  // MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_ANALYZER_H
154