• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/ascend/hal/hardware/ge_utils.h"
18 
19 #include <tuple>
20 #include <utility>
21 #include <nlohmann/json.hpp>
22 #include "include/common/utils/anfalgo.h"
23 #include "include/transform/graph_ir/types.h"
24 #include "include/transform/graph_ir/utils.h"
25 #include "include/common/debug/anf_ir_dump.h"
26 #include "include/common/utils/scoped_long_running.h"
27 #include "abstract/abstract_value.h"
28 #include "include/backend/kernel_graph.h"
29 #include "include/backend/anf_runtime_algorithm.h"
30 #include "plugin/device/ascend/hal/common/ascend_utils.h"
31 #include "transform/symbol/symbol_utils.h"
32 #include "transform/symbol/acl_rt_symbol.h"
33 namespace mindspore {
34 namespace device {
35 namespace ascend {
36 using mindspore::transform::OptionMap;
37 
ShapesToString(const ShapeArray & shapes)38 std::string ShapesToString(const ShapeArray &shapes) {
39   std::stringstream buffer;
40   for (size_t i = 0; i < shapes.size(); ++i) {
41     if (i != 0) {
42       buffer << ",";
43     }
44     buffer << "[";
45     const auto &shape = shapes[i];
46     for (size_t j = 0; j < shape.size(); ++j) {
47       if (j != 0) {
48         buffer << ",";
49       }
50       buffer << shape[j];
51     }
52     buffer << "]";
53   }
54   return buffer.str();
55 }
56 
IsGeTrain()57 bool IsGeTrain() {
58   auto context = MsContext::GetInstance();
59   MS_EXCEPTION_IF_NULL(context);
60   bool enable_ge = context->backend_policy() == "ge";
61   bool enable_training = GetPhasePrefix() == "train";
62   if (enable_ge && enable_training) {
63     return true;
64   }
65   return false;
66 }
67 
GetGraphName(const FuncGraphPtr & graph)68 std::string GetGraphName(const FuncGraphPtr &graph) {
69   MS_EXCEPTION_IF_NULL(graph);
70   if (IsEnableRefMode()) {
71     return graph->ToString();
72   } else {
73     KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
74     std::string name;
75     if (kg == nullptr) {
76       name = graph->ToString();
77     } else {
78       FuncGraphPtr origin_graph = kg->GetFuncGraph();
79       MS_EXCEPTION_IF_NULL(origin_graph);
80       name = origin_graph->ToString();
81     }
82     return name;
83   }
84 }
85 
GetComputeGraphOptions(const ShapeArray & input_shapes,bool is_dynamic_shape)86 OptionMap GetComputeGraphOptions(const ShapeArray &input_shapes, bool is_dynamic_shape) {
87   OptionMap options{};
88   if (IsGeTrain() && GetPhasePrefix() == "train") {
89     (void)options.emplace("ge.exec.variable_acc", "1");
90   }
91   auto ms_context = MsContext::GetInstance();
92   MS_EXCEPTION_IF_NULL(ms_context);
93   auto max_threshold = ms_context->get_param<std::string>(MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD);
94   if (!max_threshold.empty()) {
95     (void)options.emplace("ge.exec.hostSchedulingMaxThreshold", max_threshold);
96   }
97   if (!is_dynamic_shape) {
98     return options;
99   }
100   (void)options.emplace("ge.exec.dynamicGraphExecuteMode", "dynamic_execute");
101   (void)options.emplace("ge.exec.dataInputsShapeRange", ShapesToString(input_shapes));
102   return options;
103 }
104 
GetComputeGraphReuseOptions(const FuncGraphPtr & graph,OptionMap * option)105 void GetComputeGraphReuseOptions(const FuncGraphPtr &graph, OptionMap *option) {
106   MS_EXCEPTION_IF_NULL(graph);
107   MS_EXCEPTION_IF_NULL(option);
108   auto enable_io_reuse = common::GetEnv("MS_ENABLE_IO_REUSE");
109   MS_LOG(INFO) << "Enable io reuse: " << enable_io_reuse;
110   if (enable_io_reuse != "1" || !IsEnableRefMode()) {
111     return;
112   }
113   auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
114   if (!outputs.empty()) {
115     std::string value;
116     for (size_t i = 0; i < outputs.size(); ++i) {
117       auto output = outputs[i];
118       const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
119       auto &output_node = output_with_index.first;
120       MS_EXCEPTION_IF_NULL(output_node);
121       // Parameter and value can not been reused.
122       if (output_node->isa<Parameter>() || output_node->isa<ValueNode>()) {
123         MS_LOG(INFO) << "Output is parameter or value node, not support reuse, index is: " << i;
124         continue;
125       }
126       (void)value.append(std::to_string(i));
127       (void)value.append(",");
128     }
129     if (!value.empty()) {
130       value.pop_back();
131       MS_LOG(INFO) << "key: ge.exec.outputReuseMemIndexes, value: " << value << ",Graph name: " << graph->ToString();
132       (void)option->insert(std::make_pair("ge.exec.outputReuseMemIndexes", value));
133     }
134   }
135 
136   auto ms_context = MsContext::GetInstance();
137   MS_EXCEPTION_IF_NULL(ms_context);
138   if (graph->has_flag(transform::kGraphFlagHasGetNext) && !graph->has_flag(transform::kGraphNeedIteration)) {
139     MS_LOG(INFO) << "key: ge.exec.inputReuseMemIndexes, value: 0."
140                  << ", Graph name: " << graph->ToString();
141     (void)option->insert(std::make_pair("ge.exec.inputReuseMemIndexes", "0"));
142   }
143 }
144 
SetPassthroughGeOptions(bool is_global,OptionMap * options)145 void SetPassthroughGeOptions(bool is_global, OptionMap *options) {
146   auto context = MsContext::GetInstance();
147   MS_EXCEPTION_IF_NULL(context);
148 
149   const auto &ge_options_str = context->get_param<std::string>(MS_CTX_GE_OPTIONS);
150   if (ge_options_str.empty()) {
151     MS_LOG(DEBUG) << "The ge option for passthrough is not set.";
152     return;
153   }
154 
155   string level = is_global ? "global" : "session";
156   nlohmann::json options_json = nlohmann::json::parse(ge_options_str);
157   auto options_iter = options_json.find(level);
158   if (options_iter == options_json.end()) {
159     MS_LOG(INFO) << "GE " << level << " option is not set.";
160     return;
161   }
162 
163   const auto &new_options = *options_iter;
164   for (auto &[key, value] : new_options.items()) {
165     (*options)[key] = value;
166     MS_LOG(INFO) << "Set ge " << level << " option: {" << key << ", " << value << "}";
167   }
168 }
169 
170 namespace {
UpdateTopoOrderOptions(const string & graph_name,OptionMap * option)171 void UpdateTopoOrderOptions(const string &graph_name, OptionMap *option) {
172   auto context = MsContext::GetInstance();
173   MS_EXCEPTION_IF_NULL(context);
174 
175   const auto &topo_order = context->get_param<std::string>(MS_CTX_TOPO_ORDER);
176   if (topo_order.empty()) {
177     return;
178   }
179 
180   nlohmann::json topo_order_json = nlohmann::json::parse(topo_order);
181   auto topo_order_iter = topo_order_json.find(graph_name);
182   if (topo_order_iter == topo_order_json.end()) {
183     return;
184   }
185   MS_LOG(INFO) << "Update topo order for graph " << graph_name << " to " << topo_order_iter.value();
186   std::string topo_sorting_mode = "1";
187   if (topo_order_iter.value() == "bfs") {
188     topo_sorting_mode = "0";
189   } else if (topo_order_iter.value() == "dfs") {
190     topo_sorting_mode = "1";
191   } else if (topo_order_iter.value() == "rdfs") {
192     topo_sorting_mode = "2";
193   }
194   (*option)["ge.topoSortingMode"] = topo_sorting_mode;
195 }
196 }  // namespace
197 
AddFakeGraph(const FuncGraphPtr & anf_graph)198 bool AddFakeGraph(const FuncGraphPtr &anf_graph) {
199   MS_EXCEPTION_IF_NULL(anf_graph);
200   auto converter = transform::NewConverter(anf_graph, GetPhasePrefix());
201   transform::GenFakeGraph(anf_graph->ToString(), converter);
202   auto graph_name = GetGraphName(anf_graph);
203   std::string init_graph = "init_subgraph." + graph_name;
204   std::string checkpoint_name = "save." + graph_name;
205   ShapeArray shape_array;
206   bool dynamic_shape_inputs = false;
207   auto options = GetComputeGraphOptions(shape_array, dynamic_shape_inputs);
208   GetComputeGraphReuseOptions(anf_graph, &options);
209   UpdateTopoOrderOptions(graph_name, &options);
210   MS_LOG(INFO) << "Set options of compute graph: " << graph_name << " to " << MapToString(options);
211   (void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter));
212   (void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
213   (void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
214 
215   if (!IsEnableRefMode()) {
216     transform::Status ret = transform::AddGraph(checkpoint_name, transform::GetSaveCheckpointGraph(converter));
217     if (ret == transform::Status::SUCCESS) {
218       transform::SetAnfGraph(checkpoint_name, anf_graph);
219     }
220   }
221   return true;
222 }
223 
AddDFGraph(const FuncGraphPtr & anf_graph,const transform::TensorOrderMap & init_inputs_map,bool export_air)224 bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
225   MS_EXCEPTION_IF_NULL(anf_graph);
226   auto converter = transform::NewConverter(anf_graph, GetPhasePrefix());
227   bool is_cloud = true;
228   bool need_aoe = false;
229   if (export_air) {
230     MS_LOG(INFO) << "Set DfGraphConvertor training : false";
231     transform::SetTraining(converter, false);
232     transform::SetExportAir(converter, true);
233     is_cloud = false;
234   }
235   transform::BuildGraph(anf_graph->ToString(), converter, init_inputs_map);
236   transform::GenerateBroadcastGraph(converter, init_inputs_map);
237   transform::GenerateCheckpointGraph(converter);
238   auto err_code = transform::ErrCode(converter);
239   if (err_code != 0) {
240     transform::ClearGraph();
241     MS_LOG(ERROR) << "Convert df graph failed, err:" << err_code;
242     return false;
243   }
244   if (MsContext::GetInstance()->EnableAoeOnline()) {
245     need_aoe = true;
246   }
247   auto graph_name = GetGraphName(anf_graph);
248   std::string init_graph = "init_subgraph." + graph_name;
249   std::string checkpoint_name = "save." + graph_name;
250   auto options = GetComputeGraphOptions(converter->input_shapes(), converter->dynamic_shape_inputs());
251   GetComputeGraphReuseOptions(anf_graph, &options);
252   UpdateTopoOrderOptions(graph_name, &options);
253   MS_LOG(INFO) << "Set options of compute graph: " << graph_name << " to " << MapToString(options);
254   (void)transform::AddGraph(graph_name, transform::GetComputeGraph(converter), options, is_cloud, need_aoe);
255   if (IsEnableRefMode()) {
256     (void)transform::AddGraph(init_graph, converter->GetInitGraph());
257   } else {
258     (void)transform::AddGraph(init_graph, transform::GetInitGraph(converter));
259   }
260   (void)transform::AddGraph(BROADCAST_GRAPH_NAME, transform::GetBroadcastGraph(converter));
261 
262   if (!IsEnableRefMode()) {
263     transform::Status ret = transform::AddGraph(checkpoint_name, transform::GetSaveCheckpointGraph(converter));
264     if (ret == transform::Status::SUCCESS) {
265       transform::SetAnfGraph(checkpoint_name, anf_graph);
266     }
267   }
268 
269   return true;
270 }
271 
SyncCopyStream(aclrtStream stream)272 void SyncCopyStream(aclrtStream stream) {
273   MS_LOG(INFO) << "Start sync copy data stream";
274   if (CALL_ASCEND_API(aclrtSynchronizeStreamWithTimeout, stream, -1) != ACL_SUCCESS) {
275     MS_LOG(EXCEPTION) << "Exec aclrtSynchronizeStreamWithTimeout failed";
276   }
277   MS_LOG(INFO) << "End sync copy data stream";
278 }
279 
SavePrevStepWeight(const std::vector<AnfNodePtr> & weights,aclrtStream stream)280 void SavePrevStepWeight(const std::vector<AnfNodePtr> &weights, aclrtStream stream) {
281   for (const auto &node : weights) {
282     if (!node->isa<Parameter>()) {
283       continue;
284     }
285     auto param = node->cast<ParameterPtr>();
286     MS_EXCEPTION_IF_NULL(param);
287     if (common::AnfAlgo::IsParameterWeight(param)) {
288       auto tensor = param->default_param()->cast<tensor::TensorPtr>();
289       MS_EXCEPTION_IF_NULL(tensor);
290       auto out_addr = AnfAlgo::GetMutableOutputAddr(param, 0, false);
291       if (out_addr == nullptr || out_addr->GetPtr() == nullptr || IsOneOfHWSpecialFormat(out_addr->format())) {
292         // skip async copy if addr is nullptr.
293         // special format need convert to default format at host, so skip async copy if format is a special format.
294         continue;
295       }
296       auto size = tensor->Size();
297       auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, tensor->data_c(), size, out_addr->GetMutablePtr(), size,
298                                  ACL_MEMCPY_DEVICE_TO_HOST, stream);
299       if (ret != ACL_ERROR_NONE) {
300         MS_LOG(EXCEPTION) << "Call aclrtMemcpyAsync failed, param: " << param->DebugString();
301       }
302       tensor->set_copy_done_flag(true);
303     }
304   }
305 }
306 }  // namespace ascend
307 }  // namespace device
308 }  // namespace mindspore
309