1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_ 17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_ 18 19 #include <vector> 20 #include <string> 21 #include <memory> 22 #include <map> 23 #include <set> 24 #include <utility> 25 26 #include "include/api/context.h" 27 #include "include/model.h" 28 #include "include/transform/graph_ir/types.h" 29 #include "extendrt/session/lite_graph_executor.h" 30 #include "common/config_infos.h" 31 #include "include/transform/graph_ir/utils.h" 32 #include "extendrt/delegate/ascend_ge/ge_device_context.h" 33 #include "extendrt/delegate/ascend_ge/ge_memory_manager.h" 34 #include "extendrt/delegate/ascend_ge/ge_context_manager.h" 35 #include "extendrt/delegate/ascend_ge/update_weight.h" 36 #include "mindspore/lite/src/common/common.h" 37 38 namespace mindspore { 39 struct RefDataInfo { 40 std::string name; 41 ShapeVector shape; 42 ShapeVector dyn_shape; 43 TypeId dtype = kTypeUnknown; 44 tensor::TensorPtr host_data = nullptr; // will be released after device tensor allocated 45 size_t offset = 0; 46 size_t size = 0; 47 GeTensor ge_tensor; 48 }; 49 50 struct InOutBufferInfo { 51 ShapeVector shape; 52 TypeId dtype = kTypeUnknown; 53 void *device_addr = nullptr; 54 size_t max_size = 0; 55 GeTensor ge_tensor; 56 }; 57 58 struct OutputInfo { 59 ShapeVector shape; 60 TypeId dtype = kTypeUnknown; 61 }; 62 63 struct GraphRuntimeInfo { 64 void *const_addr = nullptr; 65 size_t const_size = 0; 66 void *feature_addr = nullptr; 67 size_t feature_size = 0; 68 std::vector<ShapeVector> output_shapes; 69 }; 70 71 struct DynKVCacheInfo { 72 bool dynamic_kv_cache = false; 73 bool batch_size_dyn = false; 74 bool seq_length_dyn = false; 75 bool is_ge_graph_static_ = false; 76 int64_t real_batch_size = -1; 77 int64_t real_seq_len_size = -1; 78 int64_t max_batch_size = 32; 79 int64_t max_seq_len_size = 4096; 80 std::vector<std::vector<int64_t>> dynamic_kv_cache_dims; 81 std::string kv_cache_layout = lite::kKVCacheLayoutBNSD; 82 }; 83 84 class GeGraphExecutor : public LiteGraphExecutor { 85 public: GeGraphExecutor(const std::shared_ptr<mindspore::Context> & context,const ConfigInfos & config_infos)86 GeGraphExecutor(const std::shared_ptr<mindspore::Context> &context, const ConfigInfos &config_infos) 87 : context_(context), config_infos_(config_infos) {} 88 ~GeGraphExecutor(); 89 90 bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options, 91 uint32_t *graph_id) override; 92 93 bool RunGraph(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs, 94 const std::map<string, string> &compile_options) override; 95 Resize(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,const std::vector<ShapeVector> & dims)96 bool Resize(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs, 97 const std::vector<ShapeVector> &dims) override { 98 return true; 99 } 100 101 std::vector<tensor::Tensor> GetInputInfos(uint32_t graph_id) override; 102 std::vector<tensor::Tensor> GetOutputInfos(uint32_t graph_id) override; 103 bool Init(); 104 bool AoeTuning(const FuncGraphPtr &graph); 105 bool OfflineBuildGraph(const FuncGraphPtr &graph); 106 bool UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> &weights) override; 107 108 private: 109 std::shared_ptr<UpdateWeight> update_weight_ptr_ = nullptr; 110 bool enable_update_weight_ = false; 111 const std::shared_ptr<mindspore::Context> context_; 112 ConfigInfos config_infos_; 113 std::shared_ptr<ge::Session> ge_session_ = nullptr; 114 std::map<std::string, std::string> session_options_; 115 int64_t session_id_ = -1; 116 std::vector<uint32_t> init_graph_id_list_; 117 std::vector<uint32_t> compute_graph_id_list_; 118 transform::RefModeFlag ref_mode_flag_ = transform::RefModeFlag::kRefModeNone; 119 std::string cache_mode_; 120 std::vector<RefDataInfo> ref_data_infos_; 121 std::vector<InOutBufferInfo> inputs_buffer_infos_; 122 std::vector<InOutBufferInfo> outputs_buffer_infos_; 123 124 std::shared_ptr<GeMemoryManager> memory_manager_ = nullptr; 125 std::shared_ptr<GeContextManager> context_manager_ = nullptr; 126 127 std::shared_ptr<GeDeviceContext> ge_global_context_ = nullptr; 128 std::string graph_name_; 129 std::string build_cache_dir_; 130 std::string build_cache_relative_dir_; 131 132 std::map<uint32_t, std::vector<tensor::Tensor>> graph_inputs_; 133 std::map<uint32_t, std::vector<tensor::Tensor>> graph_outputs_; 134 std::map<uint32_t, std::vector<tensor::TensorPtr>> original_graph_outputs_; 135 bool is_data_flow_graph_ = false; 136 DynKVCacheInfo dyn_kv_cache_info_; 137 138 std::shared_ptr<AscendDeviceInfo> GetAscendDeviceInfo(); 139 uint32_t GetRankID() const; 140 uint32_t GetDeviceID() const; 141 void GetGeGraphOptions(const FuncGraphPtr &anf_graph, std::map<std::string, std::string> *ge_options); 142 void GetGeSessionOptions(std::map<std::string, std::string> *ge_options); 143 void GetGeSessionOptionsFromAscendContext(const std::map<std::string, std::string> &config, 144 std::map<std::string, std::string> *ge_options_ptr); 145 bool CreateSession(const std::map<std::string, std::string> &extra_options); 146 int64_t GetSessionId(); 147 void GetParams(const FuncGraphPtr &anf_graph, transform::TensorOrderMap *param_tensors); 148 149 bool AddGraph(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &options, 150 uint32_t *graph_id); 151 bool RunGeInitGraph(uint32_t init_graph_id, const std::vector<std::string> &init_data_names, 152 const transform::TensorOrderMap ¶ms_vals); 153 tensor::TensorPtr ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr, uint32_t graph_id, size_t idx); 154 155 bool RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector<GeTensor> &inputs, 156 std::vector<GeTensor> *outputs); 157 bool InitRefDataList(const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors); 158 bool InitRefDataContext(const FuncGraphPtr &func_graph, 159 const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors, 160 std::map<std::string, std::string> *ge_options_ptr); 161 bool InitRefDataDeviceTensor(); 162 bool InitConstantFeatureDeviceMemory(uint32_t graph_id); 163 bool InitInOutDeviceBuffer(const std::string &name, const ShapeVector &shape, TypeId dtype, 164 InOutBufferInfo *buffer_info); 165 bool InitInputDataTensor(const std::vector<tensor::Tensor> &inputs, std::vector<::ge::Tensor> *ge_inputs, 166 std::vector<::ge::Tensor> *ge_outputs); 167 bool InitMemoryContextManager(); 168 169 bool BuildGraphRefMode(const FuncGraphPtr &anf_graph, uint32_t graph_id); 170 bool RunGraphRefMode(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs, 171 std::vector<tensor::Tensor> *outputs); 172 bool SyncDeviceOutputsToHost(std::vector<tensor::Tensor> *outputs, std::vector<::ge::Tensor> *ge_outputs); 173 174 bool UpdateInputShapeOption(const FuncGraphPtr &func_graph, 175 const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors, 176 std::map<std::string, std::string> *ge_options_ptr); 177 178 static std::atomic_uint32_t global_graph_idx_; 179 static uint32_t GetNextGraphIdx(); 180 181 bool RunGeGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs, std::vector<::ge::Tensor> *outputs); 182 bool RunDataFlowGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs, 183 std::vector<::ge::Tensor> *outputs); 184 185 transform::DfGraphPtr CompileGraphCommon(const FuncGraphPtr &graph, 186 std::map<std::string, std::string> *ge_options_ptr); 187 188 transform::DfGraphPtr CreateGeGraphOnline(const FuncGraphPtr &anf_graph, 189 std::map<std::string, std::string> *ge_options_ptr); 190 transform::DfGraphPtr CreateFakeGraph(const std::map<std::string, std::string> &ge_options); 191 192 void SetOptionsIntoOfflineModel(const std::map<std::string, std::string> &graph_options, 193 std::map<std::string, ValuePtr> *attr_map); 194 195 bool LoadOnlineGraph(const FuncGraphPtr &anf_graph, uint32_t *graph_id); 196 bool UpdateGraphInputs(const FuncGraphPtr &graph); 197 198 bool GetOneRealInputs(const FuncGraphPtr &func_graph, std::vector<ge::Tensor> *ge_tensors); 199 bool CreateAsCustomFuncGraph(const FuncGraphPtr &func_graph, const std::map<std::string, std::string> &graph_options); 200 bool SetModelCacheDir(std::map<std::string, std::string> *session_options_ptr); 201 bool SetOfflineBuildModelCacheDir(std::map<std::string, std::string> *session_options_ptr); 202 bool GetConfigOption(const std::string §ion_name, const std::string &option_name, std::string *option_val); 203 204 bool SetGeTensorShape(GeTensor *ge_tensor, ShapeVector shape); 205 void UpdateOutputShapeInfo(std::vector<::ge::Tensor> *ge_outputs); 206 bool SetDynamicKVCache(const FuncGraphPtr &func_graph); 207 bool InitRefModeConfig(); 208 bool InitRealShapeParam(const std::vector<tensor::Tensor> &inputs); 209 bool CheckRefDataInfo(); 210 bool InitMaxShapeParam(); 211 void SetRefShape(std::vector<int64_t> *ref_shape, bool dyn, std::string tensor_name); 212 bool InitInputDeviceTensor(const FuncGraphPtr &anf_graph); 213 bool InitOutputDeviceTensor(const FuncGraphPtr &anf_graph, uint32_t graph_id); 214 }; 215 216 struct GeSessionContext { 217 std::weak_ptr<ge::Session> ge_session; 218 std::map<std::string, std::string> session_options; 219 std::set<std::string> session_variables; 220 std::map<std::string, RefDataInfo> ref_data_map_; 221 std::weak_ptr<GeMemoryManager> memory_manager; 222 std::weak_ptr<GeContextManager> context_manager; 223 std::vector<void *> ref_data_device_memories; 224 void *feature_memory = nullptr; 225 size_t feature_size = 0; 226 std::map<uint32_t, size_t> feature_graph_ids; 227 }; 228 229 class GeSessionManager { 230 public: 231 static std::shared_ptr<ge::Session> CreateGeSession(int64_t session_id, 232 const std::map<std::string, std::string> &session_options); 233 // return new Variables not in session 234 static std::set<std::string> UpdateSessionVariables(int64_t session_id, 235 const std::vector<std::string> &graph_variables); 236 static void TryReleaseGeSessionContext(int64_t session_id); 237 238 static std::shared_ptr<GeSessionContext> GetGeSessionContext(int64_t session_id); 239 240 private: 241 static std::map<int64_t, std::shared_ptr<GeSessionContext>> ge_session_map_; 242 static std::mutex session_mutex_; 243 }; 244 } // namespace mindspore 245 #endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_GE_GRAPH_EXECUTOR_H_ 246