1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_GRAPH_EXECUTOR_H_ 17 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_GRAPH_EXECUTOR_H_ 18 19 #include <vector> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <map> 24 #include <set> 25 #include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h" 26 #include "runtime/hardware/device_context.h" 27 #include "runtime/device/memory_manager.h" 28 #include "utils/ms_context.h" 29 #include "include/transform/graph_ir/types.h" 30 #include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h" 31 #include "plugin/device/ascend/hal/hardware/ge_device_res_manager.h" 32 #include "plugin/device/ascend/mindio/mindio_adapter.h" 33 34 namespace mindspore { 35 namespace device { 36 namespace ascend { 37 struct GeInputData { 38 std::vector<GeTensor> ge_inputs; 39 std::vector<DeviceAddress *> device_addrs; 40 std::vector<std::pair<AnfNodeWeakPtr, size_t>> need_update_input; 41 }; 42 43 struct GeOutputData { 44 std::vector<GeTensor> ge_outputs; 45 std::vector<DeviceAddress *> device_addrs; 46 std::vector<std::pair<AnfNodeWeakPtr, size_t>> graph_outputs; 47 }; 48 49 class GeGraphExecutor : public GraphExecutor { 50 public: 51 ~GeGraphExecutor() override = default; 52 bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) override; 53 bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs, 54 std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) override; 55 56 static FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, 57 bool export_air); 58 void PreprocessBeforeRun(const KernelGraphPtr &graph); 59 size_t GetGraphFeatureMemory(const FuncGraphPtr &graph) const override; 60 void InitGraphInfo(const FuncGraphPtr &graph) override; 61 62 private: 63 bool RunGraphRefMode(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs); 64 void AllocInputHostMemory(const KernelGraphPtr &kernel_graph) const; 65 void AllocOutputHostMemory(const KernelGraphPtr &kernel_graph) const; 66 void AllocConstMemory(const transform::RunOptions &options, const KernelGraphPtr &graph, size_t memory_size) const; 67 void AllocFeatureMemory(const transform::RunOptions &options, size_t memory_size) const; 68 void AllocParameterMemory(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *memo = nullptr) const; 69 void BuildInputDataGeTensor(const KernelGraphPtr &kernel_graph); 70 void BuildOutputDataGeTensor(const KernelGraphPtr &kernel_graph); 71 void AllocOutputMemory(const KernelGraphPtr &kernel_graph) const; 72 bool CompileGraph(const KernelGraphPtr &graph, const std::map<string, string> &compile_options); 73 int64_t CurGraphSinkSize(std::string graph_name); 74 std::vector<GeTensor> GenerateInputGeTensor(const KernelGraphPtr &kernel_graph) const; 75 std::vector<GeTensor> GenerateOutputGeTensor(const KernelGraphPtr &kernel_graph) const; 76 GeDeviceResManager *ResManager() const; 77 void RunInitGraph(const std::string &graph_name); 78 void AddRefCorrespondPairs(const KernelGraphPtr &graph, 79 const std::vector<std::pair<uint32_t, uint32_t>> &io_indexes) const; 80 bool BuildGraph(const KernelGraphPtr &graph, const transform::TensorOrderMap &tensor_order_map); 81 DeviceAddressPtr CreateOutputDeviceAddress(const KernelGraphPtr &kernel_graph, 82 const KernelWithIndex &output_with_index, 83 size_t need_alloc_output_cnt) const; 84 void AllocMemory(const KernelGraphPtr &graph); 85 void DoAsyncCkpt(const FuncGraphPtr &graph); 86 bool IsNeedNotifyTTP(const FuncGraphPtr &graph); 87 mindspore::HashMap<session::KernelGraph *, GeInputData> input_datas_; 88 mindspore::HashMap<session::KernelGraph *, GeOutputData> output_datas_; 89 std::map<std::string, int64_t> graph_sink_size_; 90 int64_t pre_sink_size_{-1}; 91 }; 92 } // namespace ascend 93 } // namespace device 94 } // namespace mindspore 95 96 #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_GRAPH_EXECUTOR_H_ 97