1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/abstract_actor.h"
18 #include "runtime/graph_scheduler/actor/output_actor.h"
19 #include "utils/log_adapter.h"
20
21 namespace mindspore {
22 namespace runtime {
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)23 void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
24 MS_EXCEPTION_IF_NULL(input_data);
25 MS_EXCEPTION_IF_NULL(input_data->data_);
26 // The unused data may be invalid ptr.
27 if (!ActorDispatcher::enable_async_launch_kernel() && !input_data->data_->IsPtrValid() &&
28 (!TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNotUsed) &&
29 !TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNullptr))) {
30 std::string error_info = "The input_data does not have a valid ptr of actor:" + GetAID().Name() +
31 " with index:" + std::to_string(input_data->index_) +
32 ", flag:" + std::to_string(input_data->data_->flag()) +
33 " device address:" + std::to_string((int64_t)(input_data->data_)) +
34 " ref count:" + std::to_string(input_data->data_->ref_count()) +
35 " dynamic ref count:" + std::to_string(input_data->data_->dynamic_ref_count()) +
36 " origin ref count:" + std::to_string(input_data->data_->original_ref_count());
37 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
38 }
39 auto &sequential_num = context->sequential_num_;
40 (void)input_op_datas_[sequential_num].emplace_back(input_data);
41
42 auto is_run = CheckRunningCondition(context);
43 MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data and check running condition:" << is_run
44 << ", sequential num:" << sequential_num << ", the input data:" << input_data->data_
45 << " input index:" << input_data->index_ << ", size:" << input_data->data_->GetSize()
46 << " ptr:" << input_data->data_->GetMutablePtr()
47 << ", origin ref count:" << input_data->data_->original_ref_count()
48 << ", current ref count:" << input_data->data_->ref_count()
49 << ", dynamic ref count:" << input_data->data_->dynamic_ref_count()
50 << ", flag:" << input_data->data_->flag() << " user data:" << input_data->data_->user_data()
51 << " from memory pool:" << input_data->data_->from_mem_pool();
52
53 if (is_run) {
54 Run(context);
55 }
56 }
57
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)58 void AbstractActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
59 auto &sequential_num = context->sequential_num_;
60 (void)input_op_controls_[sequential_num].emplace_back(input_control);
61
62 auto is_run = CheckRunningCondition(context);
63 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
64 << ") receive the input op control from:" << (input_control == nullptr ? "null" : input_control->Name())
65 << " and check running condition:" << is_run << ", sequential num:" << sequential_num;
66 if (is_run) {
67 Run(context);
68 }
69 }
70
RunBatchOpData(std::vector<OpData<DeviceTensor> * > * const batch_input_data,OpContext<DeviceTensor> * const context)71 void AbstractActor::RunBatchOpData(std::vector<OpData<DeviceTensor> *> *const batch_input_data,
72 OpContext<DeviceTensor> *const context) {
73 MS_EXCEPTION_IF_NULL(context);
74 MS_EXCEPTION_IF_NULL(batch_input_data);
75 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
76 << ") receive the batch input op data, sequential num:" << context->sequential_num_;
77 for (auto &input_data : *batch_input_data) {
78 RunOpData(input_data, context);
79 }
80 }
81
CheckRunningCondition(const OpContext<DeviceTensor> * context) const82 bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
83 MS_EXCEPTION_IF_NULL(context);
84 if (input_datas_num_ != 0) {
85 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
86 if (data_iter == input_op_datas_.end()) {
87 return false;
88 }
89 if (data_iter->second.size() < input_datas_num_) {
90 return false;
91 } else if (data_iter->second.size() > input_datas_num_) {
92 MS_LOG(ERROR) << "Invalid input data num:" << data_iter->second.size() << " need:" << input_datas_num_
93 << " for actor:" << GetAID() << ", sequential num:" << context->sequential_num_;
94 return false;
95 }
96 }
97
98 if (input_controls_num_ != 0) {
99 const auto &control_iter = input_op_controls_.find(context->sequential_num_);
100 if (control_iter == input_op_controls_.end()) {
101 return false;
102 }
103 if (control_iter->second.size() < input_controls_num_) {
104 return false;
105 } else if (control_iter->second.size() > input_controls_num_) {
106 MS_LOG(ERROR) << "Invalid input control num:" << control_iter->second.size() << " need:" << input_controls_num_
107 << " for actor:" << GetAID() << ", sequential num:" << context->sequential_num_;
108 return false;
109 }
110 }
111 return true;
112 }
113
EraseInput(const OpContext<DeviceTensor> * context)114 void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
115 (void)input_op_datas_.erase(context->sequential_num_);
116 (void)input_op_controls_.erase(context->sequential_num_);
117 }
118
FetchInputByTensorStore(std::vector<DeviceTensor * > * const input_device_tensors,std::vector<KernelTensor * > * const input_kernel_tensors,std::vector<abstract::AbstractBasePtr> * const input_kernel_tensors_for_infer,std::vector<DeviceTensor * > * const memory_free_tensors,OpContext<DeviceTensor> * const context) const119 void AbstractActor::FetchInputByTensorStore(
120 std::vector<DeviceTensor *> *const input_device_tensors, std::vector<KernelTensor *> *const input_kernel_tensors,
121 std::vector<abstract::AbstractBasePtr> *const input_kernel_tensors_for_infer,
122 std::vector<DeviceTensor *> *const memory_free_tensors, OpContext<DeviceTensor> *const context) const {
123 for (auto &device_tensor_store_key : device_tensor_store_keys_) {
124 auto device_tensor = DeviceTensorStore::GetInstance()
125 .Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
126 .get();
127 if (device_tensor == nullptr) {
128 std::string error_info =
129 GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
130 ", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
131 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
132 }
133 if ((*input_device_tensors)[device_tensor_store_key.first] != device_tensor) {
134 (*input_device_tensors)[device_tensor_store_key.first] = device_tensor;
135 (*memory_free_tensors)[device_tensor_store_key.first] = device_tensor;
136 }
137 // Collect the input kernel tensor.
138 const auto &kernel_tensor = (*input_device_tensors)[device_tensor_store_key.first]->kernel_tensor();
139 if (input_kernel_tensors && input_kernel_tensors_for_infer &&
140 ((*input_kernel_tensors)[device_tensor_store_key.first] != kernel_tensor.get())) {
141 (*input_kernel_tensors)[device_tensor_store_key.first] = kernel_tensor.get();
142 (*input_kernel_tensors_for_infer)[device_tensor_store_key.first] = kernel_tensor;
143 }
144 }
145 }
146
InitOutputData()147 void AbstractActor::InitOutputData() {
148 mindspore::HashMap<std::string, size_t> batch_op_count;
149 for (auto &data_arrow : output_data_arrows_) {
150 MS_EXCEPTION_IF_NULL(data_arrow);
151 auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
152 auto &to_op_name = data_arrow->to_op_id_.Name();
153
154 // Identify whether the output data flag is kOutputDataFlagToStack.
155 bool is_to_stack = (to_op_name.find(kStackActorNameSuffix) != std::string::npos);
156 size_t output_data_flag = is_to_stack ? kOutputDataFlagToStack : kOutputDataFlagInit;
157
158 // Add the batch output data.
159 if (TEST_FLAG(data_arrow->flag_, kOutputDataFlagBatch)) {
160 if (is_to_stack) {
161 MS_LOG(EXCEPTION) << "Not support the batch output data to stack actor.";
162 }
163 (void)batch_output_data_[to_op_name].emplace_back(data.get());
164
165 SET_FLAG(output_data_flag, kOutputDataFlagBatch);
166 // Identify whether the output data flag is kOutputDataFlagLastBatch.
167 ++(batch_op_count[to_op_name]);
168 if (batch_op_count[to_op_name] == batch_output_data_arrows_[to_op_name].size()) {
169 SET_FLAG(output_data_flag, kOutputDataFlagLastBatch);
170 }
171 }
172
173 // Add the internal fusion flag.
174 if (TEST_FLAG(data_arrow->flag_, kOutputDataFlagBetweenFusion)) {
175 SET_FLAG(output_data_flag, kOutputDataFlagBetweenFusion);
176 }
177
178 // Add the fusion flag.
179 if (TEST_FLAG(data_arrow->flag_, kOutputDataFlagToFusion)) {
180 SET_FLAG(output_data_flag, kOutputDataFlagToFusion);
181 }
182
183 // Add the output data.
184 (void)output_data_.emplace_back(std::make_pair(std::move(data), output_data_flag));
185 }
186 }
187
SendOutputData(OpContext<DeviceTensor> * const context,const std::vector<AnfNodePtr> & output_data_nodes,const std::vector<DataArrowPtr> & output_data_arrows,const std::vector<std::pair<OpDataUniquePtr<DeviceTensor>,size_t>> & output_data_list,const mindspore::HashMap<DataArrow *,size_t> & data_arrow_to_fusion_actor_indexs,mindspore::HashMap<std::string,std::vector<OpData<DeviceTensor> * >> * batch_output_data)188 void AbstractActor::SendOutputData(
189 OpContext<DeviceTensor> *const context, const std::vector<AnfNodePtr> &output_data_nodes,
190 const std::vector<DataArrowPtr> &output_data_arrows,
191 const std::vector<std::pair<OpDataUniquePtr<DeviceTensor>, size_t>> &output_data_list,
192 const mindspore::HashMap<DataArrow *, size_t> &data_arrow_to_fusion_actor_indexs,
193 mindspore::HashMap<std::string, std::vector<OpData<DeviceTensor> *>> *batch_output_data) {
194 for (size_t i = 0; i < output_data_list.size(); ++i) {
195 auto &output_data = output_data_list[i];
196 auto &to_op_id = output_data.first->op_id_;
197 auto &output_data_arrow = output_data_arrows[i];
198 UpdateOutputData(output_data.first.get(), output_data_arrow, output_data_nodes[i], context);
199 // The index of output data will be modified the real actor input index in the fusion actor, so need recovery the
200 // fusion actor index before sending output data to the fusion actor.
201 if (TEST_FLAG(output_data.second, kOutputDataFlagToFusion)) {
202 output_data.first->index_ = SizeToInt(data_arrow_to_fusion_actor_indexs.at(output_data_arrow.get()));
203 }
204
205 if (TEST_FLAG(output_data.second, kOutputDataFlagLastBatch)) {
206 // Send batch output data. As the data need update, so all data must be collected completely before sending.
207 if (TEST_FLAG(output_data.second, kOutputDataFlagBetweenFusion)) {
208 const auto &to_actor = FetchSubActorInFusionActor(to_op_id.Name());
209 MS_EXCEPTION_IF_NULL(to_actor);
210 ActorDispatcher::SendSync(to_actor, &AbstractActor::RunBatchOpData, &((*batch_output_data)[to_op_id.Name()]),
211 context);
212 } else {
213 ActorDispatcher::Send(to_op_id, &AbstractActor::RunBatchOpData, &((*batch_output_data)[to_op_id.Name()]),
214 context);
215 }
216 } else if (TEST_FLAG(output_data.second, kOutputDataFlagToStack)) {
217 // Create a new op data for stack actor.
218 auto to_stack_data =
219 std::make_unique<OpData<DeviceTensor>>(to_op_id, output_data.first->data_, output_data.first->index_);
220 (void)to_stack_data_.emplace_back(std::move(to_stack_data));
221 if (TEST_FLAG(output_data.second, kOutputDataFlagBetweenFusion)) {
222 const auto &to_actor = FetchSubActorInFusionActor(to_op_id.Name());
223 MS_EXCEPTION_IF_NULL(to_actor);
224 ActorDispatcher::SendSync(to_actor, &OpActor::RunOpData, to_stack_data_.back().get(), context);
225 } else {
226 ActorDispatcher::Send(to_op_id, &OpActor::RunOpData, to_stack_data_.back().get(), context);
227 }
228 } else if (!TEST_FLAG(output_data.second, kOutputDataFlagBatch)) {
229 // The batch output data only send when the output flag is kOutputDataFlagLastBatch.
230 if (TEST_FLAG(output_data.second, kOutputDataFlagBetweenFusion)) {
231 const auto &to_actor = FetchSubActorInFusionActor(to_op_id.Name());
232 if (to_actor == nullptr) {
233 MS_LOG(EXCEPTION) << "Failed to fetch to actor:" << to_op_id << " in actor:" << GetAID();
234 }
235 ActorDispatcher::SendSync(to_actor, &OpActor::RunOpData, output_data.first.get(), context);
236 } else {
237 ActorDispatcher::Send(to_op_id, &OpActor::RunOpData, output_data.first.get(), context);
238 }
239 }
240 }
241 }
242
SendOutput(OpContext<DeviceTensor> * const context)243 void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
244 // Must be the execution order: send data --> send control, avoid the illegal timing problem.
245 // 1.Send output data.
246 SendOutputData(context, output_data_nodes_, output_data_arrows_, output_data_, data_arrow_to_fusion_actor_indexs_,
247 &batch_output_data_);
248
249 // 2.Send output control.
250 if (output_control_arrows_.size() > 0) {
251 auto from_aid = const_cast<AID *>(&GetAID());
252 for (auto &output_control : output_control_arrows_) {
253 if (TEST_FLAG(output_control->flag_, kOutputDataFlagBetweenFusion)) {
254 const auto &to_actor = FetchSubActorInFusionActor(output_control->to_op_id_.Name());
255 ActorDispatcher::SendSync(to_actor, &OpActor::RunOpControl, from_aid, context);
256 } else {
257 ActorDispatcher::Send(output_control->to_op_id_, &OpActor::RunOpControl, from_aid, context);
258 }
259 }
260 }
261
262 // 3.Send recorder info.
263 SendRecorderInfo(context);
264 }
265
FetchSubActorInFusionActor(const std::string & sub_actor_name) const266 AbstractActor *AbstractActor::FetchSubActorInFusionActor(const std::string &sub_actor_name) const {
267 if (parent_fusion_actor_ == nullptr) {
268 return nullptr;
269 }
270 return (parent_fusion_actor_->sub_actors_[sub_actor_name]).get();
271 }
272
IsOutputAddressPersisted(const DeviceTensor * output_device_tensor,const KernelWithIndex & output_node)273 bool AbstractActor::IsOutputAddressPersisted(const DeviceTensor *output_device_tensor,
274 const KernelWithIndex &output_node) {
275 MS_EXCEPTION_IF_NULL(output_node.first);
276 MS_EXCEPTION_IF_NULL(output_device_tensor);
277 // The persisted address can't be replaced.
278 if (output_device_tensor->is_ptr_persisted()) {
279 return true;
280 }
281
282 if (output_node.first->isa<ValueNode>()) {
283 return true;
284 }
285
286 // The device address of parameter may come from the device address of input tensor.
287 // In order to avoid mistakenly cleaning up the device data of input tensor, return it as persisted address.
288 if (output_node.first->isa<Parameter>()) {
289 return true;
290 }
291
292 // Ref node need check the origin node.
293 const auto &graph = AnfAlgo::FetchKernelGraph(output_node.first.get());
294 if ((graph != nullptr) && graph->IsInRefOutputMap(output_node)) {
295 const auto &origin_node = graph->GetRefNodeRecursive(output_node).first;
296 MS_EXCEPTION_IF_NULL(origin_node);
297 if (origin_node->isa<ValueNode>() || origin_node->isa<Parameter>()) {
298 return true;
299 }
300 }
301
302 return false;
303 }
304 } // namespace runtime
305 } // namespace mindspore
306