• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_FUNC_WRAPPER_H_
18 #define MINDSPORE_PI_JIT_FUNC_WRAPPER_H_
19 
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include "pipeline/jit/pi/graph_compiler/pi_ir/ctrl_flow.h"
26 #include "pipeline/jit/pi/graph_compiler/pi_ir/custom_nodes.h"
27 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_visitor.h"
28 #include "pipeline/jit/pi/graph_compiler/pi_ir/value.h"
29 
30 namespace mindspore {
31 namespace pijit {
32 namespace py = pybind11;
33 using ValuePtrList = std::vector<ir::ValuePtr>;
34 
35 class InputsCollector final : public ir::IRVisitor {
36  public:
InputsCollector(const ir::NodePtrList & nodes)37   explicit InputsCollector(const ir::NodePtrList &nodes) : nodes_(nodes) {}
38   virtual ~InputsCollector() = default;
39   const ValuePtrList &GetInputs();
40   void Visit_(const ir::LoadValueNodePtr &node) override;
41   void Visit_(const ir::StoreNodePtr &node) override;
42 
43  private:
44   void AddInput(const ir::NodePtr &input);
45   void AddAssignedVar(const ir::NodePtr &var);
46 
47   const ir::NodePtrList &nodes_;
48   ValuePtrList inputs_;
49   std::set<py::object> input_names_;
50   std::set<py::object> assigned_vars_;
51 };
52 
53 using InputsCollectorPtr = std::shared_ptr<InputsCollector>;
54 
55 class OutputsCollector final : public ir::IRVisitor {
56  public:
OutputsCollector(const ir::NodePtrList & nodes)57   explicit OutputsCollector(const ir::NodePtrList &nodes) : nodes_(nodes) {}
58   virtual ~OutputsCollector() = default;
59   const ValuePtrList &GetOutputs();
60   void Visit_(const ir::StoreNodePtr &node) override;
61 
62  private:
63   void AddOutput(const ir::NodePtr &output);
64 
65   const ir::NodePtrList &nodes_;
66   ValuePtrList outputs_;
67   std::set<py::object> output_names_;
68 };
69 
70 using OutputsCollectorPtr = std::shared_ptr<OutputsCollector>;
71 
72 // FuncInliner to convert ir graph to function graph
73 class FuncWrapper {
74  public:
FuncWrapper(const std::string & func_name,const ir::NodePtrList & nodes)75   explicit FuncWrapper(const std::string &func_name, const ir::NodePtrList &nodes)
76       : func_(std::make_shared<ir::FunctionNode>(func_name, nodes)) {}
77   virtual ~FuncWrapper() = default;
78   ir::FunctionNodePtr Wrapper();
79   const ValuePtrList &GetOutputs();
SpecifyOutputs(const ValuePtrList & outputs)80   void SpecifyOutputs(const ValuePtrList &outputs) { outputs_ = outputs; }
81 
82  private:
83   void GenerateParameters() const;
84   void GenerateReturn() const;
85 
86   const ir::FunctionNodePtr func_;
87   ValuePtrList outputs_;
88 };
89 
90 using FuncWrapperPtr = std::shared_ptr<FuncWrapper>;
91 }  // namespace pijit
92 }  // namespace mindspore
93 
94 #endif  // MINDSPORE_PI_JIT_FUNC_WRAPPER_H_
95