• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
17 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
18 
19 #include <utility>
20 #include <vector>
21 #include <memory>
22 #include <map>
23 #include <string>
24 #include <unordered_map>
25 #include <set>
26 #include "utils/ms_utils.h"
27 #include "include/backend/kernel_graph.h"
28 #include "backend/common/session/session_basic.h"
29 #include "runtime/hardware/device_context.h"
30 #include "runtime/pynative/ir_converter.h"
31 
32 namespace mindspore {
33 using device::DeviceContext;
34 using session::KernelWithIndex;
35 namespace pynative {
36 constexpr size_t kAlignSize = 64;
37 struct OpCompilerInfo {
OpCompilerInfoOpCompilerInfo38   OpCompilerInfo(GraphInfo graph_info, GraphId graph_id, KernelGraphPtr graph, DeviceContext *device_context,
39                  bool need_erase, bool need_refresh_abstract, std::vector<KernelWithIndex> graph_output_nodes,
40                  std::vector<size_t> graph_outputs_tensor_num, std::vector<std::string> graph_outputs_padding_type,
41                  SimpleGraphPtr simple_graph)
42       : graph_info_(std::move(graph_info)),
43         graph_id_(graph_id),
44         graph_(std::move(graph)),
45         device_context_(device_context),
46         need_erase_(need_erase),
47         need_refresh_abstract_(need_refresh_abstract),
48         graph_output_nodes_(std::move(graph_output_nodes)),
49         graph_outputs_tensor_num_(std::move(graph_outputs_tensor_num)),
50         graph_outputs_padding_type_(std::move(graph_outputs_padding_type)),
51         simple_graph_(std::move(simple_graph)) {}
52   ~OpCompilerInfo() = default;
53   void UpdateStatus(bool ready);
54   void WaitReady() const;
55   const mindspore::GraphInfo graph_info_;
56   const GraphId graph_id_;
57   const KernelGraphPtr graph_;
58   const DeviceContext *device_context_;
59   const bool need_erase_;
60   const bool need_refresh_abstract_;
61   const std::vector<KernelWithIndex> graph_output_nodes_;
62   const std::vector<size_t> graph_outputs_tensor_num_;
63   const std::vector<std::string> graph_outputs_padding_type_;
64   const SimpleGraphPtr simple_graph_;
65   alignas(kAlignSize) std::atomic<bool> ready_{true};
66 };
67 using OpCompilerInfoPtr = std::shared_ptr<OpCompilerInfo>;
68 
69 // FuncGraph, Backend and GraphCompiler correspond one-to-one,
70 // and GraphCompiler stores the compilation cache of operators.
71 // When the graph structure changes, the front-end will send multiple graphs,
72 // the operators of each graph will be compiled separately, which will result in very poor performance.
73 // Therefore, the OpCompiler class is required to save all operator caches and make them independent of Graph.
74 class BACKEND_EXPORT OpCompiler {
75  public:
76   static OpCompiler &GetInstance();
77 
78   // Compile RunOpInfo into a KernelGraph.
79   OpCompilerInfoPtr Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
80                             const std::string &device_name, const uint32_t &device_id);
81 
82   // Clear op cache in dynamic scenes.
83   // Otherwise, the operator cache will keep growing, resulting in insufficient memory.
84   void ClearOpCache(const mindspore::GraphInfo &graph_info);
85 
86   // Accumulate a certain number of operators,
87   // and then compile the operators in parallel to improve compilation efficiency.
88   void KernelBuild(const OpCompilerInfoPtr &op_compiler_info, const DeviceContext *device_context,
89                    bool is_dynamic_shape = false) const;
90 
91   std::string GetSingleOpGraphInfo(const pynative::BaseOpRunInfo &op_info, const PrimitivePtr &op_prim) const;
92 
93   // Clear anf resources before process exit.
94   void ClearAllCache();
95 
96   bool IsInvalidInferResultOp(const std::string &op_name) const;
97 
98   static void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph);
99 
100  private:
101   OpCompiler();
102   ~OpCompiler() = default;
103   DISABLE_COPY_AND_ASSIGN(OpCompiler);
104   KernelGraphPtr GenerateKernelGraph(const session::BackendOpRunInfoPtr &op_run_info,
105                                      const device::DeviceContext *device_context) const;
106   void AssignStreamIdForSingleOpGraph(const KernelGraphPtr &graph, uint32_t stream_id);
107   // All operators shared the same session.
108   session::SessionPtr session_;
109   mindspore::HashMap<mindspore::GraphInfo, OpCompilerInfoPtr> op_compiler_infos_;
110 };
111 }  // namespace pynative
112 using OpCompilerInfoPtr = pynative::OpCompilerInfoPtr;
113 }  // namespace mindspore
114 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
115