• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/pynative/grad/top_cell.h"
17 #include "pipeline/pynative/pynative_utils.h"
18 #include "ir/tensor.h"
19 #include "include/backend/device_address.h"
20 #include "include/common/profiler.h"
21 
22 namespace mindspore {
23 namespace pynative {
RecordCellBackwardHookOp(const std::string & cell_order,const AnfNodePtr & hook_op)24 void TopCellInfo::RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op) {
25   MS_EXCEPTION_IF_NULL(hook_op);
26   (void)cell_backward_hook_op_[cell_order].emplace_back(hook_op);
27 }
28 
GetOpInfo(const FrontendOpRunInfoPtr & op_run_info,bool is_jit_graph) const29 void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info, bool is_jit_graph) const {
30   // Dynamic shape no need do value node replace
31   if (use_dynamic_shape_process() && !is_jit_graph) {
32     return;
33   }
34   MS_EXCEPTION_IF_NULL(op_run_info);
35   op_run_info->op_info.clear();
36   op_run_info->op_info += op_run_info->base_op_run_info.op_name + "-" + std::to_string(op_index_);
37 }
38 
UpdateTopCellInfo(bool forward_already_run,bool need_compile_graph,bool vm_compile)39 void TopCellInfo::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile) {
40   need_compile_graph_ = need_compile_graph;
41   forward_already_run_ = forward_already_run;
42   vm_compile_ = vm_compile;
43 }
44 
ClearDeviceMemory() const45 void TopCellInfo::ClearDeviceMemory() const {
46   MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
47   auto ms_context = MsContext::GetInstance();
48   MS_EXCEPTION_IF_NULL(ms_context);
49   const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
50   if (device_target == kCPUDevice) {
51     MS_LOG(DEBUG) << "No need to clear device address when run in CPU device.";
52     return;
53   }
54   // Top cell has already call Clear(), this maybe happen in no need compile grad scenario
55   if (resource_ == nullptr) {
56     MS_LOG(DEBUG) << "This top cell " << this << " has already been clear";
57     return;
58   }
59 
60   const auto &bprop_graph = resource_->func_graph();
61   if (bprop_graph == nullptr) {
62     return;
63   }
64   const auto &value_node_list = bprop_graph->value_nodes();
65   // Get all tensors obj in value node of running graph
66   std::vector<tensor::BaseTensorPtr> tensors_in_bprop_graph;
67   for (const auto &elem : value_node_list) {
68     auto &node = elem.first;
69     MS_EXCEPTION_IF_NULL(node);
70     auto value_node = node->cast<ValueNodePtr>();
71     MS_EXCEPTION_IF_NULL(value_node);
72     TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
73   }
74   for (const auto &tensor : tensors_in_bprop_graph) {
75     MS_EXCEPTION_IF_NULL(tensor);
76     auto device_sync = tensor->device_address();
77     auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
78     if (device_address == nullptr) {
79       continue;
80     }
81     if (!device_address->from_persistent_mem() && !tensor->is_parameter() && !IsOutputTensor(tensor)) {
82       // Parameters can not be cleaned up. In the case of Parameter(Tensor(xxx).view(xxx), requires_grad=False),
83       // the param will be converted to value node into bprop graph. Tensor will be zero after cleaning.
84       MS_LOG(DEBUG) << "Clear device address for tensor: " << tensor->id() << ", device address " << device_address
85                     << ", device ptr " << device_address->GetPtr();
86       tensor->set_device_address(nullptr);
87     }
88   }
89 }
90 
AddMetaGradInfo(const tensor::BaseTensorPtr & tensor,const AutoGradMetaDataPtr & auto_grad_meta_data)91 void TopCellInfo::AddMetaGradInfo(const tensor::BaseTensorPtr &tensor, const AutoGradMetaDataPtr &auto_grad_meta_data) {
92   meta_grad_info_[tensor] = auto_grad_meta_data;
93 }
94 
BackUpValueMetaGradInfo(const ValuePtr & value)95 void TopCellInfo::BackUpValueMetaGradInfo(const ValuePtr &value) {
96   MS_EXCEPTION_IF_NULL(value);
97   if (value->isa<tensor::BaseTensor>()) {
98     auto tensor_value = value->cast<tensor::BaseTensorPtr>();
99     auto auto_grad_meta_data = tensor_value->auto_grad_meta_data();
100     if (auto_grad_meta_data != nullptr) {
101       meta_grad_info_[tensor_value] = auto_grad_meta_data;
102     }
103   } else if (value->isa<ValueSequence>()) {
104     const auto &value_seq = value->cast<ValueSequencePtr>();
105     for (const auto &elem : value_seq->value()) {
106       BackUpValueMetaGradInfo(elem);
107     }
108   } else if (value->isa<stub::StubNode>()) {
109     auto stub_node = value->cast<stub::StubNodePtr>();
110     MS_EXCEPTION_IF_NULL(stub_node);
111     BackUpValueMetaGradInfo(stub_node->WaitValue());
112   }
113 }
114 
ClearValueMetaGradInfo(const ValuePtr & value)115 void TopCellInfo::ClearValueMetaGradInfo(const ValuePtr &value) {
116   MS_EXCEPTION_IF_NULL(value);
117   if (value->isa<tensor::BaseTensor>()) {
118     auto tensor_value = value->cast<tensor::BaseTensorPtr>();
119     // Hook register before op run
120     if (tensor_value->auto_grad_meta_data() != nullptr && tensor_value->auto_grad_meta_data()->is_register_hook()) {
121       return;
122     }
123     tensor_value->set_auto_grad_meta_data(nullptr);
124   } else if (value->isa<ValueSequence>()) {
125     const auto &value_seq = value->cast<ValueSequencePtr>();
126     for (const auto &elem : value_seq->value()) {
127       ClearValueMetaGradInfo(elem);
128     }
129   } else if (value->isa<stub::StubNode>()) {
130     auto stub_node = value->cast<stub::StubNodePtr>();
131     MS_EXCEPTION_IF_NULL(stub_node);
132     ClearValueMetaGradInfo(stub_node->WaitValue());
133   }
134 }
135 
ResetMetaGradInfo()136 void TopCellInfo::ResetMetaGradInfo() {
137   if (meta_grad_info_.empty()) {
138     return;
139   }
140   for (auto &item : meta_grad_info_) {
141     item.first->set_auto_grad_meta_data(nullptr);
142   }
143   need_resume_meta_grad_ = true;
144 }
145 
ResumeMetaGradInfo()146 void TopCellInfo::ResumeMetaGradInfo() {
147   if (!need_resume_meta_grad_ || meta_grad_info_.empty()) {
148     return;
149   }
150 
151   for (auto &item : meta_grad_info_) {
152     item.first->set_auto_grad_meta_data(item.second);
153   }
154   need_resume_meta_grad_ = false;
155 }
156 
ClearMetaGradInfo()157 void TopCellInfo::ClearMetaGradInfo() {
158   for (auto &item : meta_grad_info_) {
159     item.first->set_auto_grad_meta_data(nullptr);
160   }
161   meta_grad_info_.clear();
162 }
163 
Clear()164 void TopCellInfo::Clear() {
165   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
166                                      runtime::ProfilerEvent::kPyNativeGradClearTopCell,
167                                      runtime::ProfilerRecorder::kNoName, true);
168   MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
169   auto_grad_cell_ptr_ = nullptr;
170   hook_changed_ = false;
171   is_init_kpynative_ = false;
172   need_compile_graph_ = false;
173   forward_already_run_ = false;
174   vm_compile_ = false;
175   op_index_ = 0;
176   resource_ = nullptr;
177   fg_ = nullptr;
178   shadow_top_cell_ = nullptr;
179   graph_info_map_.clear();
180   replace_info_.clear();
181   input_args_info_ = nullptr;
182 }
183 
DeleteParamNodeInfo(const FuncGraphPtr & g,const std::string & id) const184 void TopCellInfo::DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) const {
185   auto &graph_info = graph_info_map().at(g);
186   MS_EXCEPTION_IF_NULL(graph_info);
187   (void)graph_info->input_params.erase(id);
188 }
189 
SetParamNodeMapInGraphInfoMap(const std::string & id,const ParameterPtr & param,bool is_weight) const190 void TopCellInfo::SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr &param,
191                                                 bool is_weight) const {
192   if (id.find('T') == std::string::npos) {
193     return;
194   }
195   auto &graph_info = graph_info_map().at(fg());
196   MS_EXCEPTION_IF_NULL(graph_info);
197   if (is_weight) {
198     graph_info->weight_params[id] = param;
199   } else {
200     graph_info->input_params[id] = param;
201   }
202 }
203 
SetNodeMapInGraphInfoMap(const std::string & id,const AnfNodePtr & node,int64_t index,bool need_save_sub_id) const204 void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index,
205                                            bool need_save_sub_id) const {
206   auto &graph_info = graph_info_map().at(fg());
207   MS_EXCEPTION_IF_NULL(graph_info);
208   if (id.find('T') == std::string::npos) {
209     return;
210   }
211   graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
212   // For example, set id of ((A,B),C) = {CNode, -1}
213   if (need_save_sub_id) {
214     SetMultipleOutputToGraphInfoMap(id, node);
215   }
216 }
217 
SetMultipleOutputToGraphInfoMap(const string & id,const AnfNodePtr & node) const218 void TopCellInfo::SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const {
219   if (id.find("Tuple") == std::string::npos && id.find("List") == std::string::npos) {
220     return;
221   }
222   std::vector<std::string> id_vec;
223   PyNativeAlgo::Common::SplitString(id, &id_vec);
224   auto tuple_size = static_cast<int64_t>(id_vec.size());
225   for (int64_t i = 0; i < tuple_size; ++i) {
226     // Set id of (A,B) = {CNode, 0}; Set id of C = {CNode, 1}
227     SetNodeMapInGraphInfoMap(id_vec[i], node, i, false);
228     SetNestedMultipleOutputToGraphInfoMap(id_vec[i], node, std::vector<int64_t>{i});
229   }
230 }
231 
SetNestedMultipleOutputToGraphInfoMap(const string & id,const AnfNodePtr & node,const std::vector<int64_t> & index_sequence) const232 void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node,
233                                                         const std::vector<int64_t> &index_sequence) const {
234   if (id.find("Tuple") == std::string::npos && id.find("List") == std::string::npos) {
235     return;
236   }
237   MS_EXCEPTION_IF_NULL(node);
238   std::vector<std::string> id_vec;
239   PyNativeAlgo::Common::SplitString(id, &id_vec);
240   auto tuple_size = static_cast<int64_t>(id_vec.size());
241   for (int64_t i = 0; i < tuple_size; ++i) {
242     std::vector<int64_t> tmp = index_sequence;
243     (void)tmp.emplace_back(i);
244     // Set id of A = {CNode, [0, 0]}; Set id of B = {CNode, [0, 1]};
245     SetUnpackOutputToGraphInfoMap(id_vec[i], node, tmp);
246     // If output have more nested tuple or list
247     SetNestedMultipleOutputToGraphInfoMap(id_vec[i], node, tmp);
248   }
249 }
250 
SetUnpackOutputToGraphInfoMap(const std::string & id,const AnfNodePtr & node,const std::vector<int64_t> & index) const251 void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
252                                                 const std::vector<int64_t> &index) const {
253   if (id.find('T') == std::string::npos) {
254     return;
255   }
256   auto &graph_info = graph_info_map().at(fg());
257   MS_EXCEPTION_IF_NULL(graph_info);
258   graph_info->node_map[id] = std::make_pair(node, index);
259 }
260 
SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr & func_graph)261 void TopCellInfo::SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph) {
262   initial_graph_param_size_ = func_graph->parameters().size();
263   if (has_bprop_cut_op()) {
264     MS_LOG(DEBUG) << "Top cell has bprop cut, no need to save forward output tensor info";
265     return;
266   }
267   MS_LOG(DEBUG) << "Save top cell forward output tensor info";
268   SaveForwardOutputTensorInfo(func_graph, !use_dynamic_shape_process_, &replace_info_);
269 }
270 
SetLastOutputValueForwardOutputFlag(const ValuePtr & value)271 void TopCellInfo::SetLastOutputValueForwardOutputFlag(const ValuePtr &value) {
272   MS_EXCEPTION_IF_NULL(value);
273   if (value->isa<tensor::BaseTensor>()) {
274     auto tensor = value->cast<tensor::BaseTensorPtr>();
275     const auto it = replace_info_.id_with_op_info.find(tensor->id());
276     if (it != replace_info_.id_with_op_info.end()) {
277       tensor->set_is_forward_output(true);
278     }
279   } else if (value->isa<ValueSequence>()) {
280     const auto &value_seq = value->cast<ValueSequencePtr>();
281     for (const auto &v : value_seq->value()) {
282       SetLastOutputValueForwardOutputFlag(v);
283     }
284   }
285 }
286 
ChangeTopCellInfo(const std::vector<BaseShapePtr> & args_new_shape)287 void TopCellInfo::ChangeTopCellInfo(const std::vector<BaseShapePtr> &args_new_shape) {
288   input_args_info_->input_arg_base_shape_vec = args_new_shape;
289   // Update cell id
290   const auto &new_cell_id = PyNativeAlgo::Common::GetCellId(
291     input_args_info_->obj_id, input_args_info_->input_arg_id_vec, input_args_info_->input_arg_value_vec);
292   MS_LOG(DEBUG) << "Change top cell " << this->cell_id() << " to be unknown shape " << new_cell_id;
293   cell_id_ = new_cell_id;
294   input_args_info_->cell_id = new_cell_id;
295   already_run_cell_id_ = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->GetAlreadyRunCellId(new_cell_id);
296   MS_LOG(DEBUG) << "Get new already run top cell id " << already_run_cell_id_;
297   input_args_info_->already_run_cell_id = already_run_cell_id_;
298   is_unknown_shape_ = true;
299 }
300 
IsOutputTensor(const tensor::BaseTensorPtr & tensor) const301 bool TopCellInfo::IsOutputTensor(const tensor::BaseTensorPtr &tensor) const {
302   return std::any_of(output_ids().begin(), output_ids().end(),
303                      [&tensor](const std::string &output_id) { return tensor->id() == output_id; });
304 }
305 }  // namespace pynative
306 }  // namespace mindspore
307