1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/ascend/hal/hardware/ge_utils.h"
18
19 #include <tuple>
20 #include <utility>
21 #include <nlohmann/json.hpp>
22 #include "include/common/utils/anfalgo.h"
23 #include "include/transform/graph_ir/types.h"
24 #include "include/transform/graph_ir/utils.h"
25 #include "include/common/debug/anf_ir_dump.h"
26 #include "include/common/utils/scoped_long_running.h"
27 #include "abstract/abstract_value.h"
28 #include "include/backend/kernel_graph.h"
29 #include "include/backend/anf_runtime_algorithm.h"
30 #include "plugin/device/ascend/hal/common/ascend_utils.h"
31 #include "transform/symbol/symbol_utils.h"
32 #include "transform/symbol/acl_rt_symbol.h"
33 namespace mindspore {
34 namespace device {
35 namespace ascend {
36 using mindspore::transform::OptionMap;
37
ShapesToString(const ShapeArray & shapes)38 std::string ShapesToString(const ShapeArray &shapes) {
39 std::stringstream buffer;
40 for (size_t i = 0; i < shapes.size(); ++i) {
41 if (i != 0) {
42 buffer << ",";
43 }
44 buffer << "[";
45 const auto &shape = shapes[i];
46 for (size_t j = 0; j < shape.size(); ++j) {
47 if (j != 0) {
48 buffer << ",";
49 }
50 buffer << shape[j];
51 }
52 buffer << "]";
53 }
54 return buffer.str();
55 }
56
IsGeTrain()57 bool IsGeTrain() {
58 auto context = MsContext::GetInstance();
59 MS_EXCEPTION_IF_NULL(context);
60 bool enable_ge = context->backend_policy() == "ge";
61 bool enable_training = GetPhasePrefix() == "train";
62 if (enable_ge && enable_training) {
63 return true;
64 }
65 return false;
66 }
67
GetGraphName(const FuncGraphPtr & graph)68 std::string GetGraphName(const FuncGraphPtr &graph) {
69 MS_EXCEPTION_IF_NULL(graph);
70 if (IsEnableRefMode()) {
71 return graph->ToString();
72 } else {
73 KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
74 std::string name;
75 if (kg == nullptr) {
76 name = graph->ToString();
77 } else {
78 FuncGraphPtr origin_graph = kg->GetFuncGraph();
79 MS_EXCEPTION_IF_NULL(origin_graph);
80 name = origin_graph->ToString();
81 }
82 return name;
83 }
84 }
85
GetComputeGraphOptions(const ShapeArray & input_shapes,bool is_dynamic_shape)86 OptionMap GetComputeGraphOptions(const ShapeArray &input_shapes, bool is_dynamic_shape) {
87 OptionMap options{};
88 if (IsGeTrain() && GetPhasePrefix() == "train") {
89 (void)options.emplace("ge.exec.variable_acc", "1");
90 }
91 auto ms_context = MsContext::GetInstance();
92 MS_EXCEPTION_IF_NULL(ms_context);
93 auto max_threshold = ms_context->get_param<std::string>(MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD);
94 if (!max_threshold.empty()) {
95 (void)options.emplace("ge.exec.hostSchedulingMaxThreshold", max_threshold);
96 }
97 if (!is_dynamic_shape) {
98 return options;
99 }
100 (void)options.emplace("ge.exec.dynamicGraphExecuteMode", "dynamic_execute");
101 (void)options.emplace("ge.exec.dataInputsShapeRange", ShapesToString(input_shapes));
102 return options;
103 }
104
GetComputeGraphReuseOptions(const FuncGraphPtr & graph,OptionMap * option)105 void GetComputeGraphReuseOptions(const FuncGraphPtr &graph, OptionMap *option) {
106 MS_EXCEPTION_IF_NULL(graph);
107 MS_EXCEPTION_IF_NULL(option);
108 auto enable_io_reuse = common::GetEnv("MS_ENABLE_IO_REUSE");
109 MS_LOG(INFO) << "Enable io reuse: " << enable_io_reuse;
110 if (enable_io_reuse != "1" || !IsEnableRefMode()) {
111 return;
112 }
113 auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
114 if (!outputs.empty()) {
115 std::string value;
116 for (size_t i = 0; i < outputs.size(); ++i) {
117 auto output = outputs[i];
118 const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
119 auto &output_node = output_with_index.first;
120 MS_EXCEPTION_IF_NULL(output_node);
121 // Parameter and value can not been reused.
122 if (output_node->isa<Parameter>() || output_node->isa<ValueNode>()) {
123 MS_LOG(INFO) << "Output is parameter or value node, not support reuse, index is: " << i;
124 continue;
125 }
126 (void)value.append(std::to_string(i));
127 (void)value.append(",");
128 }
129 if (!value.empty()) {
130 value.pop_back();
131 MS_LOG(INFO) << "key: ge.exec.outputReuseMemIndexes, value: " << value << ",Graph name: " << graph->ToString();
132 (void)option->insert(std::make_pair("ge.exec.outputReuseMemIndexes", value));
133 }
134 }
135
136 auto ms_context = MsContext::GetInstance();
137 MS_EXCEPTION_IF_NULL(ms_context);
138 if (graph->has_flag(transform::kGraphFlagHasGetNext) && !graph->has_flag(transform::kGraphNeedIteration)) {
139 MS_LOG(INFO) << "key: ge.exec.inputReuseMemIndexes, value: 0."
140 << ", Graph name: " << graph->ToString();
141 (void)option->insert(std::make_pair("ge.exec.inputReuseMemIndexes", "0"));
142 }
143 }
144
SetPassthroughGeOptions(bool is_global,OptionMap * options)145 void SetPassthroughGeOptions(bool is_global, OptionMap *options) {
146 auto context = MsContext::GetInstance();
147 MS_EXCEPTION_IF_NULL(context);
148
149 const auto &ge_options_str = context->get_param<std::string>(MS_CTX_GE_OPTIONS);
150 if (ge_options_str.empty()) {
151 MS_LOG(DEBUG) << "The ge option for passthrough is not set.";
152 return;
153 }
154
155 string level = is_global ? "global" : "session";
156 nlohmann::json options_json = nlohmann::json::parse(ge_options_str);
157 auto options_iter = options_json.find(level);
158 if (options_iter == options_json.end()) {
159 MS_LOG(INFO) << "GE " << level << " option is not set.";
160 return;
161 }
162
163 const auto &new_options = *options_iter;
164 for (auto &[key, value] : new_options.items()) {
165 (*options)[key] = value;
166 MS_LOG(INFO) << "Set ge " << level << " option: {" << key << ", " << value << "}";
167 }
168 }
169
170 namespace {
UpdateTopoOrderOptions(const string & graph_name,OptionMap * option)171 void UpdateTopoOrderOptions(const string &graph_name, OptionMap *option) {
172 auto context = MsContext::GetInstance();
173 MS_EXCEPTION_IF_NULL(context);
174
175 const auto &topo_order = context->get_param<std::string>(MS_CTX_TOPO_ORDER);
176 if (topo_order.empty()) {
177 return;
178 }
179
180 nlohmann::json topo_order_json = nlohmann::json::parse(topo_order);
181 auto topo_order_iter = topo_order_json.find(graph_name);
182 if (topo_order_iter == topo_order_json.end()) {
183 return;
184 }
185 MS_LOG(INFO) << "Update topo order for graph " << graph_name << " to " << topo_order_iter.value();
186 std::string topo_sorting_mode = "1";
187 if (topo_order_iter.value() == "bfs") {
188 topo_sorting_mode = "0";
189 } else if (topo_order_iter.value() == "dfs") {
190 topo_sorting_mode = "1";
191 } else if (topo_order_iter.value() == "rdfs") {
192 topo_sorting_mode = "2";
193 }
194 (*option)["ge.topoSortingMode"] = topo_sorting_mode;
195 }
196 } // namespace
197
AddFakeGraph(const FuncGraphPtr & anf_graph)198 bool AddFakeGraph(const FuncGraphPtr &anf_graph) {
199 MS_EXCEPTION_IF_NULL(anf_graph);
200 auto converter = transform::NewConverter(anf_graph, GetPhasePrefix());
201 transform::GenFakeGraph(anf_graph->ToString(), converter);
202 auto graph_name = GetGraphName(anf_graph);
203 std::string init_graph = "init_subgraph." + graph_name;
204 std::string checkpoint_name = "save." + graph_name;
205 ShapeArray shape_array;
206 bool dynamic_shape_inputs = false;
207 auto options = GetComputeGraphOptions(shape_array, dynamic_shape_inputs);
208 GetComputeGraphReuseOptions(anf_graph, &options);
209 UpdateTopoOrderOptions(graph_name, &options);
210 MS_LOG(INFO) << "Set options of compute graph: " << graph_name << " to " << MapToString(options);
211 (void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter));
212 (void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
213 (void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
214
215 if (!IsEnableRefMode()) {
216 transform::Status ret = transform::AddGraph(checkpoint_name, transform::GetSaveCheckpointGraph(converter));
217 if (ret == transform::Status::SUCCESS) {
218 transform::SetAnfGraph(checkpoint_name, anf_graph);
219 }
220 }
221 return true;
222 }
223
AddDFGraph(const FuncGraphPtr & anf_graph,const transform::TensorOrderMap & init_inputs_map,bool export_air)224 bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
225 MS_EXCEPTION_IF_NULL(anf_graph);
226 auto converter = transform::NewConverter(anf_graph, GetPhasePrefix());
227 bool is_cloud = true;
228 bool need_aoe = false;
229 if (export_air) {
230 MS_LOG(INFO) << "Set DfGraphConvertor training : false";
231 transform::SetTraining(converter, false);
232 transform::SetExportAir(converter, true);
233 is_cloud = false;
234 }
235 transform::BuildGraph(anf_graph->ToString(), converter, init_inputs_map);
236 transform::GenerateBroadcastGraph(converter, init_inputs_map);
237 transform::GenerateCheckpointGraph(converter);
238 auto err_code = transform::ErrCode(converter);
239 if (err_code != 0) {
240 transform::ClearGraph();
241 MS_LOG(ERROR) << "Convert df graph failed, err:" << err_code;
242 return false;
243 }
244 if (MsContext::GetInstance()->EnableAoeOnline()) {
245 need_aoe = true;
246 }
247 auto graph_name = GetGraphName(anf_graph);
248 std::string init_graph = "init_subgraph." + graph_name;
249 std::string checkpoint_name = "save." + graph_name;
250 auto options = GetComputeGraphOptions(converter->input_shapes(), converter->dynamic_shape_inputs());
251 GetComputeGraphReuseOptions(anf_graph, &options);
252 UpdateTopoOrderOptions(graph_name, &options);
253 MS_LOG(INFO) << "Set options of compute graph: " << graph_name << " to " << MapToString(options);
254 (void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), options, is_cloud, need_aoe);
255 if (IsEnableRefMode()) {
256 (void)transform::AddGraph(init_graph, converter->GetInitGraph());
257 } else {
258 (void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
259 }
260 (void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
261
262 if (!IsEnableRefMode()) {
263 transform::Status ret = transform::AddGraph(checkpoint_name, transform::GetSaveCheckpointGraph(converter));
264 if (ret == transform::Status::SUCCESS) {
265 transform::SetAnfGraph(checkpoint_name, anf_graph);
266 }
267 }
268
269 return true;
270 }
271
SyncCopyStream(aclrtStream stream)272 void SyncCopyStream(aclrtStream stream) {
273 MS_LOG(INFO) << "Start sync copy data stream";
274 if (CALL_ASCEND_API(aclrtSynchronizeStreamWithTimeout, stream, -1) != ACL_SUCCESS) {
275 MS_LOG(EXCEPTION) << "Exec aclrtSynchronizeStreamWithTimeout failed";
276 }
277 MS_LOG(INFO) << "End sync copy data stream";
278 }
279
SavePrevStepWeight(const std::vector<AnfNodePtr> & weights,aclrtStream stream)280 void SavePrevStepWeight(const std::vector<AnfNodePtr> &weights, aclrtStream stream) {
281 for (const auto &node : weights) {
282 if (!node->isa<Parameter>()) {
283 continue;
284 }
285 auto param = node->cast<ParameterPtr>();
286 MS_EXCEPTION_IF_NULL(param);
287 if (common::AnfAlgo::IsParameterWeight(param)) {
288 auto tensor = param->default_param()->cast<tensor::TensorPtr>();
289 MS_EXCEPTION_IF_NULL(tensor);
290 auto out_addr = AnfAlgo::GetMutableOutputAddr(param, 0, false);
291 if (out_addr == nullptr || out_addr->GetPtr() == nullptr || IsOneOfHWSpecialFormat(out_addr->format())) {
292 // skip async copy if addr is nullptr.
293 // special format need convert to default format at host, so skip async copy if format is a special format.
294 continue;
295 }
296 auto size = tensor->Size();
297 auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, tensor->data_c(), size, out_addr->GetMutablePtr(), size,
298 ACL_MEMCPY_DEVICE_TO_HOST, stream);
299 if (ret != ACL_ERROR_NONE) {
300 MS_LOG(EXCEPTION) << "Call aclrtMemcpyAsync failed, param: " << param->DebugString();
301 }
302 tensor->set_copy_done_flag(true);
303 }
304 }
305 }
306 } // namespace ascend
307 } // namespace device
308 } // namespace mindspore
309