• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/ir/bprop_tensor_replace.h"
18 #include <memory>
19 #include "pipeline/pynative/pynative_utils.h"
20 #include "include/backend/device_address.h"
21 #include "runtime/pipeline/pipeline.h"
22 #include "pybind_api/gil_scoped_long_running.h"
23 
24 namespace mindspore {
25 namespace pynative {
26 namespace {
SaveForwardTensorForReplace(const ValuePtr & value,const TensorIdWithOpInfo & id_with_op_info,bool need_save_tensor_info,OpInfoWithTensorObject * op_info_with_tensor_object)27 void SaveForwardTensorForReplace(const ValuePtr &value, const TensorIdWithOpInfo &id_with_op_info,
28                                  bool need_save_tensor_info, OpInfoWithTensorObject *op_info_with_tensor_object) {
29   MS_EXCEPTION_IF_NULL(value);
30   if (value->isa<tensor::Tensor>()) {
31     auto tensor = value->cast<tensor::TensorPtr>();
32     const auto it = id_with_op_info.find(tensor->id());
33     if (it != id_with_op_info.end() && tensor->device_address() != nullptr) {
34       // For release memory
35       tensor->set_is_forward_output(true);
36       if (!need_save_tensor_info) {
37         return;
38       }
39       MS_EXCEPTION_IF_NULL(op_info_with_tensor_object);
40       (void)(*op_info_with_tensor_object)[it->second.first].emplace_back(std::make_pair(it->second.second, tensor));
41       MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
42                     << " device address: " << tensor->device_address() << ", device ptr: "
43                     << std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->GetPtr()
44                     << ", shape and dtype " << tensor->GetShapeAndDataTypeInfo();
45     }
46   } else if (value->isa<ValueSequence>()) {
47     const auto &value_seq = value->cast<ValueSequencePtr>();
48     for (const auto &v : value_seq->value()) {
49       SaveForwardTensorForReplace(v, id_with_op_info, need_save_tensor_info, op_info_with_tensor_object);
50     }
51   }
52 }
53 
SaveForwardTensorForReplace(const ValueNodePtr & value_node,const TensorIdWithOpInfo & id_with_op_info,bool need_save_tensor_info,OpInfoWithTensorObject * op_info_with_tensor_object)54 void SaveForwardTensorForReplace(const ValueNodePtr &value_node, const TensorIdWithOpInfo &id_with_op_info,
55                                  bool need_save_tensor_info, OpInfoWithTensorObject *op_info_with_tensor_object) {
56   MS_EXCEPTION_IF_NULL(value_node);
57   const auto &value = value_node->value();
58   MS_EXCEPTION_IF_NULL(value);
59   if (value->isa<tensor::Tensor>()) {
60     SaveForwardTensorForReplace(value, id_with_op_info, need_save_tensor_info, op_info_with_tensor_object);
61   } else if (value->isa<tensor::BaseTensor>()) {
62     auto tensor = value->cast<tensor::BaseTensorPtr>();
63     auto real_tensor = std::make_shared<tensor::Tensor>(*tensor);
64     if (tensor->device_address() != nullptr) {
65       value_node->set_value(real_tensor);
66     }
67     SaveForwardTensorForReplace(real_tensor, id_with_op_info, need_save_tensor_info, op_info_with_tensor_object);
68   } else {
69     SaveForwardTensorForReplace(value, id_with_op_info, need_save_tensor_info, op_info_with_tensor_object);
70   }
71 }
72 
GetTensorFromOutValue(size_t index,const ValuePtr & v)73 tensor::BaseTensorPtr GetTensorFromOutValue(size_t index, const ValuePtr &v) {
74   MS_EXCEPTION_IF_NULL(v);
75   // Only one outpout
76   if (index == kIndex0) {
77     if (v->isa<tensor::BaseTensor>()) {
78       return v->cast<tensor::BaseTensorPtr>();
79     }
80   }
81   // Multi output
82   const auto &v_seq = v->cast<ValueSequencePtr>();
83   MS_EXCEPTION_IF_NULL(v_seq);
84   if (v_seq->size() < index) {
85     MS_LOG(EXCEPTION) << "Get wrong index " << index << " with multi output size " << v_seq->size();
86   }
87   return v_seq->value()[index - kIndex1]->cast<tensor::BaseTensorPtr>();
88 }
89 
UpdatePreTensorInfo(const tensor::BaseTensorPtr & new_tensor,const tensor::BaseTensorPtr & old_tensor)90 void UpdatePreTensorInfo(const tensor::BaseTensorPtr &new_tensor, const tensor::BaseTensorPtr &old_tensor) {
91   MS_EXCEPTION_IF_NULL(new_tensor);
92   MS_EXCEPTION_IF_NULL(old_tensor);
93   MS_LOG(DEBUG) << "Replace old tensor id " << old_tensor->id() << " device_address: " << old_tensor->device_address()
94                 << " shape and type " << old_tensor->GetShapeAndDataTypeInfo() << " with new tensor id "
95                 << new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
96                 << new_tensor->GetShapeAndDataTypeInfo();
97   (void)old_tensor->set_shape(new_tensor->shape());
98   (void)old_tensor->set_data_type(new_tensor->data_type());
99   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
100   // Like cell CellBackwardHook is first op, its input is input param have but no device address
101   if (device_address == nullptr) {
102     return;
103   }
104   auto forward = PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor();
105   if (forward->device_target() != kCPUDevice && device_address->GetDeviceType() != device::DeviceType::kCPU) {
106     old_tensor->set_device_address(device_address);
107     return;
108   }
109 
110   {
111     GilReleaseWithCheck gil_release;
112     runtime::Pipeline::Get().backend_stage()->Wait();
113   }
114 
115   // Replace data in device address when run in CPU device.
116   if (old_tensor->device_address() != nullptr) {
117     // If tensor is dynamic shape, Just replace device address.
118     if (PyNativeAlgo::Common::ValueHasDynamicShape(old_tensor)) {
119       old_tensor->set_device_address(device_address);
120       return;
121     }
122     auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(old_tensor->device_address());
123     MS_EXCEPTION_IF_NULL(old_device_address);
124 
125     // CPU host tensor data_c is different from device address if the address is from mem_pool.
126     if (device_address->from_mem_pool()) {
127       old_tensor->set_device_address(device_address);
128       return;
129     }
130 
131     auto old_ptr = old_device_address->GetMutablePtr();
132     MS_EXCEPTION_IF_NULL(old_ptr);
133     auto new_ptr = device_address->GetPtr();
134     MS_EXCEPTION_IF_NULL(new_ptr);
135     MS_EXCEPTION_IF_CHECK_FAIL(old_device_address->GetSize() == device_address->GetSize(), "Size not equal");
136     if (old_device_address->GetSize() < SECUREC_MEM_MAX_LEN) {
137       auto ret_code = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, device_address->GetSize());
138       MS_EXCEPTION_IF_CHECK_FAIL(ret_code == EOK, "Memory copy failed, ret code: " + std::to_string(ret_code));
139     } else {
140       auto ret_code = std::memcpy(old_ptr, new_ptr, old_device_address->GetSize());
141       MS_EXCEPTION_IF_CHECK_FAIL(ret_code == old_ptr, "Memory copy failed");
142     }
143   } else {
144     old_tensor->set_device_address(device_address);
145     old_tensor->data_sync();
146     old_tensor->set_device_address(nullptr);
147     old_tensor->set_sync_status(kNeedSyncHostToDevice);
148   }
149 }
150 }  // namespace
151 
SetIdWithOpInfo(const ValuePtr & v,const std::string & op_info,size_t out_index,TensorIdWithOpInfo * id_with_op_info)152 void SetIdWithOpInfo(const ValuePtr &v, const std::string &op_info, size_t out_index,
153                      TensorIdWithOpInfo *id_with_op_info) {
154   MS_EXCEPTION_IF_NULL(v);
155   MS_EXCEPTION_IF_NULL(id_with_op_info);
156   if (v->isa<tensor::BaseTensor>()) {
157     // Only one output, index will be 0
158     const auto t = v->cast<tensor::BaseTensorPtr>();
159     (*id_with_op_info)[t->id()] = std::make_pair(op_info, out_index);
160   } else if (v->isa<ValueSequence>()) {
161     const auto &v_seq = v->cast<ValueSequencePtr>();
162     // Multi output, index will increase from 1
163     for (const auto &item : v_seq->value()) {
164       SetIdWithOpInfo(item, op_info, ++out_index, id_with_op_info);
165     }
166   }
167 }
168 
UpdateForwardOutputTensorInfo(const std::string & op_info,const ValuePtr & v,const TensorReplaceInfo & replace_info)169 void UpdateForwardOutputTensorInfo(const std::string &op_info, const ValuePtr &v,
170                                    const TensorReplaceInfo &replace_info) {
171   const auto it = replace_info.op_info_with_tensor_object.find(op_info);
172   if (it == replace_info.op_info_with_tensor_object.end()) {
173     return;
174   }
175   for (const auto &elem : it->second) {
176     const auto &new_tensor = GetTensorFromOutValue(elem.first, v);
177     UpdatePreTensorInfo(new_tensor, elem.second);
178   }
179 }
180 
UpdatePipelineTopCellFowardTensor(const TensorReplaceInfo & ir_replace_info,const TensorReplaceInfo & cur_replace_info)181 void UpdatePipelineTopCellFowardTensor(const TensorReplaceInfo &ir_replace_info,
182                                        const TensorReplaceInfo &cur_replace_info) {
183   // Do update for ir top cell, and set it for actor running
184   size_t replace_num = 0;
185   for (const auto &[op_info, forward_output] : cur_replace_info.op_info_with_forward_output) {
186     UpdateForwardOutputTensorInfo(op_info, forward_output, ir_replace_info);
187     ++replace_num;
188   }
189   if (replace_num != ir_replace_info.need_replace_size) {
190     MS_LOG(EXCEPTION) << "Get replace forward output num " << replace_num << ", but need replace num is "
191                       << ir_replace_info.need_replace_size;
192   }
193 }
194 
StoreForwardOutputWithOpInfo(const OpInfoWithTensorObject & op_info_with_tensor_object,const std::string & op_info,const ValuePtr & v,TensorReplaceInfo * replace_info)195 void StoreForwardOutputWithOpInfo(const OpInfoWithTensorObject &op_info_with_tensor_object, const std::string &op_info,
196                                   const ValuePtr &v, TensorReplaceInfo *replace_info) {
197   // Use first ir top cell do opinfo replace
198   const auto it = op_info_with_tensor_object.find(op_info);
199   if (it == op_info_with_tensor_object.end()) {
200     MS_LOG(DEBUG) << "Can not find op info " << op_info << " in ir top cell, no need do replace";
201     return;
202   }
203   replace_info->op_info_with_forward_output[op_info] = v;
204 }
205 
SaveForwardOutputTensorInfo(const FuncGraphPtr & func_graph,bool need_save_tensor_info,TensorReplaceInfo * replace_info)206 void SaveForwardOutputTensorInfo(const FuncGraphPtr &func_graph, bool need_save_tensor_info,
207                                  TensorReplaceInfo *replace_info) {
208   // Get all tensors obj in value node of bprop graph
209   MS_EXCEPTION_IF_NULL(func_graph);
210   MS_EXCEPTION_IF_NULL(replace_info);
211   const auto &value_node_list = func_graph->value_nodes();
212   for (const auto &elem : value_node_list) {
213     auto value_node = elem.first->cast<ValueNodePtr>();
214     MS_EXCEPTION_IF_NULL(value_node);
215     SaveForwardTensorForReplace(value_node, replace_info->id_with_op_info, need_save_tensor_info,
216                                 &(replace_info->op_info_with_tensor_object));
217   }
218   replace_info->need_replace_size = replace_info->op_info_with_tensor_object.size();
219 }
220 }  // namespace pynative
221 }  // namespace mindspore
222