• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_IR_CONVERTER_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_IR_CONVERTER_H_
19 
20 #include <atomic>
21 #include <cstdint>
22 #include <memory>
23 #include <utility>
24 #include <string>
25 #include <vector>
26 #include "nlohmann/json.hpp"
27 #include "ir/anf.h"
28 #include "ir/tensor.h"
29 #include "include/backend/device_address.h"
30 #include "include/backend/kernel_graph.h"
31 #include "runtime/hardware/device_context.h"
32 
33 namespace mindspore {
34 namespace pynative {
35 enum class EdgeType : uint8_t {
36   kParameterEdge,
37   kValueNodeEdge,
38   kOpOutputEdge,
39 };
40 
41 // A new simpler IR for PyNative runtime.
42 // Edge:
43 //    DeviceAddress
44 //
45 // SingleOp:
46 //    inputs: list[Edge]
47 //    outputs: list[Edge]
48 //
49 // SimpleGrpah:
50 //    inputs: list[Edge]
51 //    outputs: list[Edge]
52 //    SingleOps: list[SingleOp]
53 //
54 // This IR has the following three characteristics:
55 // 1. The same Edge contains the same DeviceAddress,
56 //    and there is no need to sense Ref information at runtime.
57 // 2. The Edges of the IR graph inputs are the same as the Edges of SingleOp inputs.
58 //    The Edges of Graph inputs are refreshed according to the input Tensors,
59 //    and the correct DeviceAddress is naturally obtained when SingleOp is executed.
60 // 3. The output Edges of SimpleGraph are the same as the output Edges of SingleOp.
61 //    After the operator is executed, the output Edges of Graph are automatically updated,
62 //    and there is no need to additionally update the outputs of Graph.
63 struct Edge {
64   Edge(EdgeType type, device::DeviceAddressPtr address, device::DeviceAddressPtr origin_address,
65        session::KernelWithIndex node_with_index);
66   nlohmann::json DebugInfo() const;
67   const EdgeType type_;
68   const uint64_t id_;
69   bool ignore_h2d_;
70   bool is_grad_;
71   device::DeviceAddressPtr address_;
72   // For cloning device address faster.
73   const device::DeviceAddressPtr origin_address_;
74   const session::KernelWithIndex node_with_index_;
75 };
76 using EdgePtr = std::shared_ptr<Edge>;
77 
78 // Edge1 Edge2
79 //   \    /
80 // SingleOp
81 //     |
82 //   Edge3
83 struct SingleOp {
84   SingleOp(PrimitivePtr primitive, CNodePtr kernel, std::vector<EdgePtr> inputs, std::vector<EdgePtr> outputs);
85   nlohmann::json DebugInfo() const;
86   const uint64_t id_;
87   const PrimitivePtr primitive_;
88   const CNodePtr kernel_;
89   const std::vector<EdgePtr> inputs_;
90   const std::vector<EdgePtr> outputs_;
91 };
92 using SingleOpPtr = std::unique_ptr<SingleOp>;
93 
94 // SimpleGraph:
95 //
96 // inputs: Edge1, Edge2
97 //
98 // Edge1  Edge2
99 //    \    /
100 //   SingleOp1
101 //      |
102 //    Edge3
103 //      |
104 //   SingleOp2
105 //     |
106 //   Edge4
107 //
108 // outputs: Edge4
109 struct SimpleGraph {
110   SimpleGraph(std::string name, std::vector<SingleOpPtr> single_ops, std::vector<EdgePtr> inputs,
111               std::vector<EdgePtr> outputs, std::vector<EdgePtr> all_edges);
112   nlohmann::json DebugInfo() const;
113   const std::string name_;
114   const std::vector<SingleOpPtr> single_ops_;
115   const std::vector<EdgePtr> inputs_;
116   const std::vector<EdgePtr> outputs_;
117   const std::vector<EdgePtr> all_edges_;
118 };
119 using SimpleGraphPtr = std::unique_ptr<SimpleGraph>;
120 
121 // Convert ANF IR to a simpler IR
122 class IrConverter {
123  public:
124   static SimpleGraphPtr Convert(const std::string &name, const KernelGraphPtr &graph,
125                                 const device::DeviceContext *device_context);
126 };
127 }  // namespace pynative
128 }  // namespace mindspore
129 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_IR_CONVERTER_H_
130