1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/pynative/grad/top_cell.h"
17 #include "pipeline/pynative/pynative_utils.h"
18 #include "ir/tensor.h"
19 #include "include/backend/device_address.h"
20 #include "include/common/profiler.h"
21
22 namespace mindspore {
23 namespace pynative {
RecordCellBackwardHookOp(const std::string & cell_order,const AnfNodePtr & hook_op)24 void TopCellInfo::RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op) {
25 MS_EXCEPTION_IF_NULL(hook_op);
26 (void)cell_backward_hook_op_[cell_order].emplace_back(hook_op);
27 }
28
GetOpInfo(const FrontendOpRunInfoPtr & op_run_info,bool is_jit_graph) const29 void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info, bool is_jit_graph) const {
30 // Dynamic shape no need do value node replace
31 if (use_dynamic_shape_process() && !is_jit_graph) {
32 return;
33 }
34 MS_EXCEPTION_IF_NULL(op_run_info);
35 op_run_info->op_info.clear();
36 op_run_info->op_info += op_run_info->base_op_run_info.op_name + "-" + std::to_string(op_index_);
37 }
38
UpdateTopCellInfo(bool forward_already_run,bool need_compile_graph,bool vm_compile)39 void TopCellInfo::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile) {
40 need_compile_graph_ = need_compile_graph;
41 forward_already_run_ = forward_already_run;
42 vm_compile_ = vm_compile;
43 }
44
ClearDeviceMemory() const45 void TopCellInfo::ClearDeviceMemory() const {
46 MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
47 auto ms_context = MsContext::GetInstance();
48 MS_EXCEPTION_IF_NULL(ms_context);
49 const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
50 if (device_target == kCPUDevice) {
51 MS_LOG(DEBUG) << "No need to clear device address when run in CPU device.";
52 return;
53 }
54 // Top cell has already call Clear(), this maybe happen in no need compile grad scenario
55 if (resource_ == nullptr) {
56 MS_LOG(DEBUG) << "This top cell " << this << " has already been clear";
57 return;
58 }
59
60 const auto &bprop_graph = resource_->func_graph();
61 if (bprop_graph == nullptr) {
62 return;
63 }
64 const auto &value_node_list = bprop_graph->value_nodes();
65 // Get all tensors obj in value node of running graph
66 std::vector<tensor::BaseTensorPtr> tensors_in_bprop_graph;
67 for (const auto &elem : value_node_list) {
68 auto &node = elem.first;
69 MS_EXCEPTION_IF_NULL(node);
70 auto value_node = node->cast<ValueNodePtr>();
71 MS_EXCEPTION_IF_NULL(value_node);
72 TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
73 }
74 for (const auto &tensor : tensors_in_bprop_graph) {
75 MS_EXCEPTION_IF_NULL(tensor);
76 auto device_sync = tensor->device_address();
77 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
78 if (device_address == nullptr) {
79 continue;
80 }
81 if (!device_address->from_persistent_mem() && !tensor->is_parameter() && !IsOutputTensor(tensor)) {
82 // Parameters can not be cleaned up. In the case of Parameter(Tensor(xxx).view(xxx), requires_grad=False),
83 // the param will be converted to value node into bprop graph. Tensor will be zero after cleaning.
84 MS_LOG(DEBUG) << "Clear device address for tensor: " << tensor->id() << ", device address " << device_address
85 << ", device ptr " << device_address->GetPtr();
86 tensor->set_device_address(nullptr);
87 }
88 }
89 }
90
AddMetaGradInfo(const tensor::BaseTensorPtr & tensor,const AutoGradMetaDataPtr & auto_grad_meta_data)91 void TopCellInfo::AddMetaGradInfo(const tensor::BaseTensorPtr &tensor, const AutoGradMetaDataPtr &auto_grad_meta_data) {
92 meta_grad_info_[tensor] = auto_grad_meta_data;
93 }
94
BackUpValueMetaGradInfo(const ValuePtr & value)95 void TopCellInfo::BackUpValueMetaGradInfo(const ValuePtr &value) {
96 MS_EXCEPTION_IF_NULL(value);
97 if (value->isa<tensor::BaseTensor>()) {
98 auto tensor_value = value->cast<tensor::BaseTensorPtr>();
99 auto auto_grad_meta_data = tensor_value->auto_grad_meta_data();
100 if (auto_grad_meta_data != nullptr) {
101 meta_grad_info_[tensor_value] = auto_grad_meta_data;
102 }
103 } else if (value->isa<ValueSequence>()) {
104 const auto &value_seq = value->cast<ValueSequencePtr>();
105 for (const auto &elem : value_seq->value()) {
106 BackUpValueMetaGradInfo(elem);
107 }
108 } else if (value->isa<stub::StubNode>()) {
109 auto stub_node = value->cast<stub::StubNodePtr>();
110 MS_EXCEPTION_IF_NULL(stub_node);
111 BackUpValueMetaGradInfo(stub_node->WaitValue());
112 }
113 }
114
ClearValueMetaGradInfo(const ValuePtr & value)115 void TopCellInfo::ClearValueMetaGradInfo(const ValuePtr &value) {
116 MS_EXCEPTION_IF_NULL(value);
117 if (value->isa<tensor::BaseTensor>()) {
118 auto tensor_value = value->cast<tensor::BaseTensorPtr>();
119 // Hook register before op run
120 if (tensor_value->auto_grad_meta_data() != nullptr && tensor_value->auto_grad_meta_data()->is_register_hook()) {
121 return;
122 }
123 tensor_value->set_auto_grad_meta_data(nullptr);
124 } else if (value->isa<ValueSequence>()) {
125 const auto &value_seq = value->cast<ValueSequencePtr>();
126 for (const auto &elem : value_seq->value()) {
127 ClearValueMetaGradInfo(elem);
128 }
129 } else if (value->isa<stub::StubNode>()) {
130 auto stub_node = value->cast<stub::StubNodePtr>();
131 MS_EXCEPTION_IF_NULL(stub_node);
132 ClearValueMetaGradInfo(stub_node->WaitValue());
133 }
134 }
135
ResetMetaGradInfo()136 void TopCellInfo::ResetMetaGradInfo() {
137 if (meta_grad_info_.empty()) {
138 return;
139 }
140 for (auto &item : meta_grad_info_) {
141 item.first->set_auto_grad_meta_data(nullptr);
142 }
143 need_resume_meta_grad_ = true;
144 }
145
ResumeMetaGradInfo()146 void TopCellInfo::ResumeMetaGradInfo() {
147 if (!need_resume_meta_grad_ || meta_grad_info_.empty()) {
148 return;
149 }
150
151 for (auto &item : meta_grad_info_) {
152 item.first->set_auto_grad_meta_data(item.second);
153 }
154 need_resume_meta_grad_ = false;
155 }
156
ClearMetaGradInfo()157 void TopCellInfo::ClearMetaGradInfo() {
158 for (auto &item : meta_grad_info_) {
159 item.first->set_auto_grad_meta_data(nullptr);
160 }
161 meta_grad_info_.clear();
162 }
163
Clear()164 void TopCellInfo::Clear() {
165 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
166 runtime::ProfilerEvent::kPyNativeGradClearTopCell,
167 runtime::ProfilerRecorder::kNoName, true);
168 MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
169 auto_grad_cell_ptr_ = nullptr;
170 hook_changed_ = false;
171 is_init_kpynative_ = false;
172 need_compile_graph_ = false;
173 forward_already_run_ = false;
174 vm_compile_ = false;
175 op_index_ = 0;
176 resource_ = nullptr;
177 fg_ = nullptr;
178 shadow_top_cell_ = nullptr;
179 graph_info_map_.clear();
180 replace_info_.clear();
181 input_args_info_ = nullptr;
182 }
183
DeleteParamNodeInfo(const FuncGraphPtr & g,const std::string & id) const184 void TopCellInfo::DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) const {
185 auto &graph_info = graph_info_map().at(g);
186 MS_EXCEPTION_IF_NULL(graph_info);
187 (void)graph_info->input_params.erase(id);
188 }
189
SetParamNodeMapInGraphInfoMap(const std::string & id,const ParameterPtr & param,bool is_weight) const190 void TopCellInfo::SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m,
191 bool is_weight) const {
192 if (id.find('T') == std::string::npos) {
193 return;
194 }
195 auto &graph_info = graph_info_map().at(fg());
196 MS_EXCEPTION_IF_NULL(graph_info);
197 if (is_weight) {
198 graph_info->weight_params[id] = param;
199 } else {
200 graph_info->input_params[id] = param;
201 }
202 }
203
SetNodeMapInGraphInfoMap(const std::string & id,const AnfNodePtr & node,int64_t index,bool need_save_sub_id) const204 void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index,
205 bool need_save_sub_id) const {
206 auto &graph_info = graph_info_map().at(fg());
207 MS_EXCEPTION_IF_NULL(graph_info);
208 if (id.find('T') == std::string::npos) {
209 return;
210 }
211 graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
212 // For example, set id of ((A,B),C) = {CNode, -1}
213 if (need_save_sub_id) {
214 SetMultipleOutputToGraphInfoMap(id, node);
215 }
216 }
217
SetMultipleOutputToGraphInfoMap(const string & id,const AnfNodePtr & node) const218 void TopCellInfo::SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const {
219 if (id.find("Tuple") == std::string::npos && id.find("List") == std::string::npos) {
220 return;
221 }
222 std::vector<std::string> id_vec;
223 PyNativeAlgo::Common::SplitString(id, &id_vec);
224 auto tuple_size = static_cast<int64_t>(id_vec.size());
225 for (int64_t i = 0; i < tuple_size; ++i) {
226 // Set id of (A,B) = {CNode, 0}; Set id of C = {CNode, 1}
227 SetNodeMapInGraphInfoMap(id_vec[i], node, i, false);
228 SetNestedMultipleOutputToGraphInfoMap(id_vec[i], node, std::vector<int64_t>{i});
229 }
230 }
231
SetNestedMultipleOutputToGraphInfoMap(const string & id,const AnfNodePtr & node,const std::vector<int64_t> & index_sequence) const232 void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node,
233 const std::vector<int64_t> &index_sequence) const {
234 if (id.find("Tuple") == std::string::npos && id.find("List") == std::string::npos) {
235 return;
236 }
237 MS_EXCEPTION_IF_NULL(node);
238 std::vector<std::string> id_vec;
239 PyNativeAlgo::Common::SplitString(id, &id_vec);
240 auto tuple_size = static_cast<int64_t>(id_vec.size());
241 for (int64_t i = 0; i < tuple_size; ++i) {
242 std::vector<int64_t> tmp = index_sequence;
243 (void)tmp.emplace_back(i);
244 // Set id of A = {CNode, [0, 0]}; Set id of B = {CNode, [0, 1]};
245 SetUnpackOutputToGraphInfoMap(id_vec[i], node, tmp);
246 // If output have more nested tuple or list
247 SetNestedMultipleOutputToGraphInfoMap(id_vec[i], node, tmp);
248 }
249 }
250
SetUnpackOutputToGraphInfoMap(const std::string & id,const AnfNodePtr & node,const std::vector<int64_t> & index) const251 void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
252 const std::vector<int64_t> &index) const {
253 if (id.find('T') == std::string::npos) {
254 return;
255 }
256 auto &graph_info = graph_info_map().at(fg());
257 MS_EXCEPTION_IF_NULL(graph_info);
258 graph_info->node_map[id] = std::make_pair(node, index);
259 }
260
SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr & func_graph)261 void TopCellInfo::SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph) {
262 initial_graph_param_size_ = func_graph->parameters().size();
263 if (has_bprop_cut_op()) {
264 MS_LOG(DEBUG) << "Top cell has bprop cut, no need to save forward output tensor info";
265 return;
266 }
267 MS_LOG(DEBUG) << "Save top cell forward output tensor info";
268 SaveForwardOutputTensorInfo(func_graph, !use_dynamic_shape_process_, &replace_info_);
269 }
270
SetLastOutputValueForwardOutputFlag(const ValuePtr & value)271 void TopCellInfo::SetLastOutputValueForwardOutputFlag(const ValuePtr &value) {
272 MS_EXCEPTION_IF_NULL(value);
273 if (value->isa<tensor::BaseTensor>()) {
274 auto tensor = value->cast<tensor::BaseTensorPtr>();
275 const auto it = replace_info_.id_with_op_info.find(tensor->id());
276 if (it != replace_info_.id_with_op_info.end()) {
277 tensor->set_is_forward_output(true);
278 }
279 } else if (value->isa<ValueSequence>()) {
280 const auto &value_seq = value->cast<ValueSequencePtr>();
281 for (const auto &v : value_seq->value()) {
282 SetLastOutputValueForwardOutputFlag(v);
283 }
284 }
285 }
286
ChangeTopCellInfo(const std::vector<BaseShapePtr> & args_new_shape)287 void TopCellInfo::ChangeTopCellInfo(const std::vector<BaseShapePtr> &args_new_shape) {
288 input_args_info_->input_arg_base_shape_vec = args_new_shape;
289 // Update cell id
290 const auto &new_cell_id = PyNativeAlgo::Common::GetCellId(
291 input_args_info_->obj_id, input_args_info_->input_arg_id_vec, input_args_info_->input_arg_value_vec);
292 MS_LOG(DEBUG) << "Change top cell " << this->cell_id() << " to be unknown shape " << new_cell_id;
293 cell_id_ = new_cell_id;
294 input_args_info_->cell_id = new_cell_id;
295 already_run_cell_id_ = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->GetAlreadyRunCellId(new_cell_id);
296 MS_LOG(DEBUG) << "Get new already run top cell id " << already_run_cell_id_;
297 input_args_info_->already_run_cell_id = already_run_cell_id_;
298 is_unknown_shape_ = true;
299 }
300
IsOutputTensor(const tensor::BaseTensorPtr & tensor) const301 bool TopCellInfo::IsOutputTensor(const tensor::BaseTensorPtr &tensor) const {
302 return std::any_of(output_ids().begin(), output_ids().end(),
303 [&tensor](const std::string &output_id) { return tensor->id() == output_id; });
304 }
305 } // namespace pynative
306 } // namespace mindspore
307