1 /** 2 * Copyright 2022-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_ 17 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_ 18 19 #include <utility> 20 #include <vector> 21 #include <memory> 22 #include <map> 23 #include <string> 24 #include <unordered_map> 25 #include <set> 26 #include "utils/ms_utils.h" 27 #include "include/backend/kernel_graph.h" 28 #include "backend/common/session/session_basic.h" 29 #include "runtime/hardware/device_context.h" 30 #include "runtime/pynative/ir_converter.h" 31 32 namespace mindspore { 33 using device::DeviceContext; 34 using session::KernelWithIndex; 35 namespace pynative { 36 constexpr size_t kAlignSize = 64; 37 struct OpCompilerInfo { OpCompilerInfoOpCompilerInfo38 OpCompilerInfo(GraphInfo graph_info, GraphId graph_id, KernelGraphPtr graph, DeviceContext *device_context, 39 bool need_erase, bool need_refresh_abstract, std::vector<KernelWithIndex> graph_output_nodes, 40 std::vector<size_t> graph_outputs_tensor_num, std::vector<std::string> graph_outputs_padding_type, 41 SimpleGraphPtr simple_graph) 42 : graph_info_(std::move(graph_info)), 43 graph_id_(graph_id), 44 graph_(std::move(graph)), 45 device_context_(device_context), 46 need_erase_(need_erase), 47 need_refresh_abstract_(need_refresh_abstract), 48 graph_output_nodes_(std::move(graph_output_nodes)), 49 graph_outputs_tensor_num_(std::move(graph_outputs_tensor_num)), 50 graph_outputs_padding_type_(std::move(graph_outputs_padding_type)), 51 simple_graph_(std::move(simple_graph)) {} 52 ~OpCompilerInfo() = default; 53 void UpdateStatus(bool ready); 54 void WaitReady() const; 55 const mindspore::GraphInfo graph_info_; 56 const GraphId graph_id_; 57 const KernelGraphPtr graph_; 58 const DeviceContext *device_context_; 59 const bool need_erase_; 60 const bool need_refresh_abstract_; 61 const std::vector<KernelWithIndex> graph_output_nodes_; 62 const std::vector<size_t> graph_outputs_tensor_num_; 63 const std::vector<std::string> graph_outputs_padding_type_; 64 const SimpleGraphPtr simple_graph_; 65 alignas(kAlignSize) std::atomic<bool> ready_{true}; 66 }; 67 using OpCompilerInfoPtr = std::shared_ptr<OpCompilerInfo>; 68 69 // FuncGraph, Backend and GraphCompiler correspond one-to-one, 70 // and GraphCompiler stores the compilation cache of operators. 71 // When the graph structure changes, the front-end will send multiple graphs, 72 // the operators of each graph will be compiled separately, which will result in very poor performance. 73 // Therefore, the OpCompiler class is required to save all operator caches and make them independent of Graph. 74 class BACKEND_EXPORT OpCompiler { 75 public: 76 static OpCompiler &GetInstance(); 77 78 // Compile RunOpInfo into a KernelGraph. 79 OpCompilerInfoPtr Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit, 80 const std::string &device_name, const uint32_t &device_id); 81 82 // Clear op cache in dynamic scenes. 83 // Otherwise, the operator cache will keep growing, resulting in insufficient memory. 84 void ClearOpCache(const mindspore::GraphInfo &graph_info); 85 86 // Accumulate a certain number of operators, 87 // and then compile the operators in parallel to improve compilation efficiency. 88 void KernelBuild(const OpCompilerInfoPtr &op_compiler_info, const DeviceContext *device_context, 89 bool is_dynamic_shape = false) const; 90 91 std::string GetSingleOpGraphInfo(const pynative::BaseOpRunInfo &op_info, const PrimitivePtr &op_prim) const; 92 93 // Clear anf resources before process exit. 94 void ClearAllCache(); 95 96 bool IsInvalidInferResultOp(const std::string &op_name) const; 97 98 static void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph); 99 100 private: 101 OpCompiler(); 102 ~OpCompiler() = default; 103 DISABLE_COPY_AND_ASSIGN(OpCompiler); 104 KernelGraphPtr GenerateKernelGraph(const session::BackendOpRunInfoPtr &op_run_info, 105 const device::DeviceContext *device_context) const; 106 void AssignStreamIdForSingleOpGraph(const KernelGraphPtr &graph, uint32_t stream_id); 107 // All operators shared the same session. 108 session::SessionPtr session_; 109 mindspore::HashMap<mindspore::GraphInfo, OpCompilerInfoPtr> op_compiler_infos_; 110 }; 111 } // namespace pynative 112 using OpCompilerInfoPtr = pynative::OpCompilerInfoPtr; 113 } // namespace mindspore 114 #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_ 115