• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_GRAPH_EXECUTOR_H_
17 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_GRAPH_EXECUTOR_H_
18 
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <map>
24 #include <set>
25 #include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
26 #include "runtime/hardware/device_context.h"
27 #include "runtime/device/memory_manager.h"
28 #include "utils/ms_context.h"
29 #include "include/transform/graph_ir/types.h"
30 #include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h"
31 #include "plugin/device/ascend/hal/hardware/ge_device_res_manager.h"
32 #include "plugin/device/ascend/mindio/mindio_adapter.h"
33 
34 namespace mindspore {
35 namespace device {
36 namespace ascend {
37 struct GeInputData {
38   std::vector<GeTensor> ge_inputs;
39   std::vector<DeviceAddress *> device_addrs;
40   std::vector<std::pair<AnfNodeWeakPtr, size_t>> need_update_input;
41 };
42 
43 struct GeOutputData {
44   std::vector<GeTensor> ge_outputs;
45   std::vector<DeviceAddress *> device_addrs;
46   std::vector<std::pair<AnfNodeWeakPtr, size_t>> graph_outputs;
47 };
48 
49 class GeGraphExecutor : public GraphExecutor {
50  public:
51   ~GeGraphExecutor() override = default;
52   bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) override;
53   bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
54                 std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) override;
55 
56   static FuncGraphPtr BuildDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map,
57                                    bool export_air);
58   void PreprocessBeforeRun(const KernelGraphPtr &graph);
59   size_t GetGraphFeatureMemory(const FuncGraphPtr &graph) const override;
60   void InitGraphInfo(const FuncGraphPtr &graph) override;
61 
62  private:
63   bool RunGraphRefMode(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs);
64   void AllocInputHostMemory(const KernelGraphPtr &kernel_graph) const;
65   void AllocOutputHostMemory(const KernelGraphPtr &kernel_graph) const;
66   void AllocConstMemory(const transform::RunOptions &options, const KernelGraphPtr &graph, size_t memory_size) const;
67   void AllocFeatureMemory(const transform::RunOptions &options, size_t memory_size) const;
68   void AllocParameterMemory(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *memo = nullptr) const;
69   void BuildInputDataGeTensor(const KernelGraphPtr &kernel_graph);
70   void BuildOutputDataGeTensor(const KernelGraphPtr &kernel_graph);
71   void AllocOutputMemory(const KernelGraphPtr &kernel_graph) const;
72   bool CompileGraph(const KernelGraphPtr &graph, const std::map<string, string> &compile_options);
73   int64_t CurGraphSinkSize(std::string graph_name);
74   std::vector<GeTensor> GenerateInputGeTensor(const KernelGraphPtr &kernel_graph) const;
75   std::vector<GeTensor> GenerateOutputGeTensor(const KernelGraphPtr &kernel_graph) const;
76   GeDeviceResManager *ResManager() const;
77   void RunInitGraph(const std::string &graph_name);
78   void AddRefCorrespondPairs(const KernelGraphPtr &graph,
79                              const std::vector<std::pair<uint32_t, uint32_t>> &io_indexes) const;
80   bool BuildGraph(const KernelGraphPtr &graph, const transform::TensorOrderMap &tensor_order_map);
81   DeviceAddressPtr CreateOutputDeviceAddress(const KernelGraphPtr &kernel_graph,
82                                              const KernelWithIndex &output_with_index,
83                                              size_t need_alloc_output_cnt) const;
84   void AllocMemory(const KernelGraphPtr &graph);
85   void DoAsyncCkpt(const FuncGraphPtr &graph);
86   bool IsNeedNotifyTTP(const FuncGraphPtr &graph);
87   mindspore::HashMap<session::KernelGraph *, GeInputData> input_datas_;
88   mindspore::HashMap<session::KernelGraph *, GeOutputData> output_datas_;
89   std::map<std::string, int64_t> graph_sink_size_;
90   int64_t pre_sink_size_{-1};
91 };
92 }  // namespace ascend
93 }  // namespace device
94 }  // namespace mindspore
95 
96 #endif  // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_GRAPH_EXECUTOR_H_
97