1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_ANALYZER_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_ANALYZER_H 18 19 #include <set> 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include "pipeline/jit/pi/graph_capture/cfg.h" 24 #include "pipeline/jit/pi/graph_capture/abstract_object.h" 25 #include "pipeline/jit/pi/graph_capture/graph_build.h" 26 #include "utils/convert_utils_base.h" 27 namespace mindspore { 28 namespace pijit { 29 30 class Graph; 31 class AbstractNode; 32 class ValueNode; 33 class CallNode; 34 class GraphAnalyzer; 35 class MindGraphAnalyzer; 36 using GraphAnalyzerPtr = std::shared_ptr<GraphAnalyzer>; 37 using MindGraphAnalyzerPtr = std::shared_ptr<MindGraphAnalyzer>; 38 39 class GraphAnalyzer { 40 public: 41 // escaped_locals and captured.values do not intersect 42 struct CapturedInfo { 43 struct Info { 44 // contains inputs and operations, used to find 45 mindspore::CompactSet<ValueNode *> values; 46 // the inputs of operations 47 std::vector<ValueNode *> inputs; 48 // bytecode operations 49 std::vector<ValueNode *> operations; 50 // ordered outputs, used to restore stack and locals 51 std::vector<ValueNode *> outputs; 52 53 void clear(); 54 std::string ToString(); 55 }; 56 57 struct GraphInputs { 58 std::vector<ValueNode *> args; 59 std::vector<ValueNode *> globals; 60 ValueNode *vargs = nullptr; 61 ValueNode *kwargs = nullptr; 62 63 void clear(); 64 std::string ToString(); 65 }; 66 67 /** 68 * for captured inputs, it's parameters, maybe unordered. 69 * for captured outputs, it's ordered by stack values and alive locals. 70 */ 71 Info captured_; 72 73 /** 74 * for interpret inputs, it's ordered and same as original function. 75 * if not break graph, outputs is return value, else outputs is ordered by stack values and alive locals. 76 */ 77 Info interpret_; 78 79 /** 80 * Store all collected graph inputs. 81 * If no graph is generated, graph_inputs_ should be empty. 82 */ 83 GraphInputs graph_inputs_; 84 85 bool has_grad_ = false; 86 87 void clear(); 88 std::string ToString(); 89 }; 90 GraphAnalyzer(Graph * g)91 explicit GraphAnalyzer(Graph *g) : graph_(g) {} Creator(const GraphBuilderPtr & g)92 static GraphAnalyzerPtr Creator(const GraphBuilderPtr &g) { 93 return g->trace_flag() ? std::static_pointer_cast<GraphAnalyzer>(std::make_shared<MindGraphAnalyzer>(g)) 94 : std::make_shared<GraphAnalyzer>(g->GetGraph()); 95 } GetCaptureInfo()96 auto &GetCaptureInfo() { return info_; } GetCaptureInfo()97 const auto &GetCaptureInfo() const { return info_; } 98 virtual void Analyze(); 99 bool HasTensorOperation() const; NeedInterpret()100 virtual bool NeedInterpret() const { return need_interpret_; } 101 alive_locals()102 const auto &alive_locals() const { return alive_locals_; } 103 104 protected: 105 // optimize 106 void OptimizeSideEffectRecord() const; 107 108 // rollback 109 void ResetSideEffectRecord() const; 110 111 void AddToEscaped(ValueNode *value); 112 // UD analyze 113 virtual void UseDefAnalyze(); 114 std::vector<ValueNode *> GetAliveLocals(Graph *g); 115 virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes); 116 virtual void CollectCapturedInputs(); 117 virtual void CollectCapturedAndInterpret(); 118 virtual void CollectGraphInputs(); 119 bool need_interpret_; 120 Graph *graph_; 121 CapturedInfo info_; 122 std::vector<int> alive_locals_; 123 124 private: 125 bool AnalyzeRecursive(Graph *g); 126 bool AnalyzeCall(CallNode *); 127 bool TryToCapture(AbstractNode *value); 128 bool HandleSideEffectNodeForCapture(AbstractNode *capture_node); 129 bool AddToCaptured(ValueNode *value); 130 bool HandleCallableToGraph(AObject *f); 131 bool ProduceInterpretValue(ValueNode *v); 132 void CleanCapturedValue(); 133 }; 134 135 class MindGraphAnalyzer : public GraphAnalyzer { 136 public: MindGraphAnalyzer(const GraphBuilderPtr & g)137 explicit MindGraphAnalyzer(const GraphBuilderPtr &g) : GraphAnalyzer(g->GetGraph()), graph_builder_(g) {} 138 void Analyze() override; 139 140 protected: 141 // UD analyze 142 void UseDefAnalyze() override; 143 void CollectCapturedInputs() override; 144 void CollectGraphInputs() override; 145 void UpdateCapturedOrder(); 146 void CollectCapturedAndInterpret() override; 147 bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) override; 148 GraphBuilderPtr graph_builder_ = nullptr; 149 }; 150 } // namespace pijit 151 } // namespace mindspore 152 153 #endif // MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_ANALYZER_H 154