1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/control_flow/stack_actor.h"
18 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
19 #include "runtime/graph_scheduler/control_node_parser.h"
20
21 namespace mindspore {
22 namespace runtime {
StackActor(const std::string & name,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters)23 StackActor::StackActor(const std::string &name, const AID &memory_manager_aid,
24 const std::vector<KernelWithIndex> ¶meters)
25 : ControlActor(name, KernelTransformType::kStackActor, memory_manager_aid, parameters, nullptr) {
26 input_device_tensors_.resize(parameters.size());
27 }
28
Init()29 void StackActor::Init() {
30 ControlActor::Init();
31 // The stack actor has 6 parts of input :
32 // 1. Directly input data.
33 // 2. Direct input partial.
34 // 3. Weight.
35 // 4. Local tensor.
36 // 5. Call input data.
37 // 6. Call input partial.
38 input_datas_num_ = formal_parameters_.size() - input_stack_data_num_ - input_stack_partials_num_;
39 if (input_stack_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) {
40 MS_LOG(EXCEPTION) << "Invalid input stack data num:" << input_stack_data_num_
41 << " device store num:" << device_tensor_store_keys_.size()
42 << " local device tensor num:" << local_device_tensors_.size()
43 << " input stack data num:" << input_stack_data_num_
44 << " input stack partial num:" << input_stack_partials_num_ << " for actor:" << GetAID();
45 }
46
47 // Fetch the total number of input partial.
48 size_t total_partials_num = 0;
49 for (const auto &formal_parameter : formal_parameters_) {
50 MS_EXCEPTION_IF_NULL(formal_parameter.first);
51 const auto &abstract = formal_parameter.first->abstract();
52 MS_EXCEPTION_IF_NULL(abstract);
53 const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
54 MS_EXCEPTION_IF_NULL(real_abstract);
55 if (real_abstract->isa<abstract::AbstractFunction>()) {
56 total_partials_num++;
57 }
58 }
59
60 // Fetch call input data num.
61 input_datas_num_ = formal_parameters_.size() - total_partials_num - input_stack_data_num_;
62 input_partials_num_ = total_partials_num - input_stack_partials_num_;
63 // Fetch call input partial num.
64 input_stack_data_num_ -= (device_tensor_store_keys_.size() + local_device_tensors_.size());
65 // Check if the input num is valid.
66 if (input_stack_data_num_ + input_stack_partials_num_ + input_datas_num_ + input_partials_num_ +
67 device_tensor_store_keys_.size() + local_device_tensors_.size() !=
68 formal_parameters_.size()) {
69 MS_LOG(EXCEPTION) << "Invalid input num, input stack data num:" << input_stack_data_num_
70 << " input stack partial num:" << input_stack_partials_num_
71 << " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_
72 << " device tensor store size:" << device_tensor_store_keys_.size()
73 << " need total size:" << formal_parameters_.size() << " for actor:" << GetAID();
74 }
75 MS_LOG(DEBUG) << "Stack actor input stack data num:" << input_stack_data_num_
76 << " stack partial num:" << input_stack_partials_num_ << " input data num:" << input_datas_num_
77 << " input partial num:" << input_partials_num_
78 << " device tensor store num:" << device_tensor_store_keys_.size()
79 << " local tensor num:" << local_device_tensors_.size()
80 << " formal parameter num:" << formal_parameters_.size();
81 }
82
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)83 void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
84 MS_EXCEPTION_IF_NULL(context);
85 MS_EXCEPTION_IF_NULL(input_data);
86 MS_EXCEPTION_IF_NULL(input_data->data_);
87 MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input data:" << input_data->data_
88 << " input index:" << input_data->index_ << ", size:" << input_data->data_->GetSize()
89 << " ptr:" << input_data->data_->GetMutablePtr()
90 << ", origin ref count:" << input_data->data_->original_ref_count()
91 << ", current ref count:" << input_data->data_->ref_count()
92 << ", dynamic ref count:" << input_data->data_->dynamic_ref_count()
93 << ", flag:" << input_data->data_->flag() << " user data:" << input_data->data_->user_data();
94 // The parameters from the inside of the subgraph need to be put into the stack.
95 if (IntToSize(input_data->index_) < input_stack_data_num_ + device_tensor_store_keys_.size() +
96 input_stack_partials_num_ + local_device_tensors_.size()) {
97 input_stack_data_[context->sequential_num_][input_data->index_].push(input_data->data_);
98 } else {
99 // The outputs of call nodes are placed directly in the input data.
100 (void)input_op_datas_[context->sequential_num_].emplace_back(input_data);
101 }
102
103 auto is_run = CheckRunningCondition(context);
104 MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data and check running condition:" << is_run;
105 if (is_run) {
106 Run(context);
107 }
108 }
109
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)110 void StackActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
111 MS_EXCEPTION_IF_NULL(context);
112 auto &sequential_num = context->sequential_num_;
113 if (control_aid_to_indexs_.find(*input_control) != control_aid_to_indexs_.end()) {
114 if ((input_stack_controls_.find(sequential_num) == input_stack_controls_.end()) ||
115 (input_stack_controls_[sequential_num].find(control_aid_to_indexs_[*input_control]) ==
116 input_stack_controls_[sequential_num].end())) {
117 input_stack_controls_[sequential_num][control_aid_to_indexs_[*input_control]] = 1;
118 } else {
119 input_stack_controls_[sequential_num][control_aid_to_indexs_[*input_control]]++;
120 }
121 } else {
122 (void)input_op_controls_[sequential_num].emplace_back(input_control);
123 }
124
125 if (CheckRunningCondition(context)) {
126 Run(context);
127 }
128 }
129
RunOpPartial(const OpPartialPtr & partial,size_t position,OpContext<DeviceTensor> * const context)130 void StackActor::RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) {
131 MS_EXCEPTION_IF_NULL(context);
132 auto self_partial = std::make_shared<OpPartial>();
133 *self_partial = *partial;
134 // The parameters from the inside of the subgraph need to be put into the stack.
135 if (position < input_stack_data_num_ + device_tensor_store_keys_.size() + input_stack_partials_num_ +
136 local_device_tensors_.size()) {
137 input_stack_partials_[context->sequential_num_][position].push(self_partial);
138 } else {
139 (void)input_op_partials_[context->sequential_num_].emplace_back(position, self_partial);
140 }
141
142 auto is_run = CheckRunningCondition(context);
143 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
144 << ") receive the input op partial and check running condition:" << is_run;
145 if (is_run) {
146 Run(context);
147 }
148 }
149
CheckRunningCondition(const OpContext<DeviceTensor> * context) const150 bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
151 MS_EXCEPTION_IF_NULL(context);
152 if (!ControlActor::CheckRunningCondition(context)) {
153 return false;
154 }
155
156 if (CheckStackDataRunningCondition(context) && CheckStackPartialRunningCondition(context) &&
157 CheckStackControlRunningCondition(context)) {
158 return true;
159 }
160 return false;
161 }
162
CheckStackDataRunningCondition(const OpContext<DeviceTensor> * context) const163 bool StackActor::CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const {
164 MS_EXCEPTION_IF_NULL(context);
165 auto iter = input_branch_ids_.find(context->sequential_num_);
166 bool is_branch_id_invalid = (is_branch_id_enable_ && (iter == input_branch_ids_.end() || iter->second.empty()));
167
168 if (input_stack_data_num_ != 0) {
169 const auto &data_iter = input_stack_data_.find(context->sequential_num_);
170 if (data_iter == input_stack_data_.end()) {
171 return false;
172 }
173 if (data_iter->second.size() < input_stack_data_num_) {
174 return false;
175 } else if (data_iter->second.size() > input_stack_data_num_) {
176 MS_LOG(ERROR) << "Invalid input stack data num:" << data_iter->second.size() << " need:" << input_stack_data_num_
177 << " for actor:" << GetAID();
178 return false;
179 }
180
181 if (is_branch_id_invalid) {
182 MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID().Name();
183 return false;
184 }
185 size_t branch_id_size = 1;
186 if (is_branch_id_enable_) {
187 branch_id_size = iter->second.size();
188 }
189 for (const auto &one_stack : data_iter->second) {
190 if (one_stack.second.size() < branch_id_size) {
191 return false;
192 } else if (one_stack.second.size() > branch_id_size) {
193 MS_LOG(ERROR) << "Invalid input stack data num:" << one_stack.second.size()
194 << " for input index:" << one_stack.first << " need:" << branch_id_size
195 << " for actor:" << GetAID();
196 return false;
197 }
198 }
199 }
200 return true;
201 }
202
CheckStackPartialRunningCondition(const OpContext<DeviceTensor> * context) const203 bool StackActor::CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const {
204 MS_EXCEPTION_IF_NULL(context);
205 auto iter = input_branch_ids_.find(context->sequential_num_);
206 bool is_branch_id_invalid = (is_branch_id_enable_ && (iter == input_branch_ids_.end() || iter->second.empty()));
207
208 if (input_stack_partials_num_ != 0) {
209 const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
210 if (partial_iter == input_stack_partials_.end()) {
211 return false;
212 }
213 if (partial_iter->second.size() < input_stack_partials_num_) {
214 return false;
215 } else if (partial_iter->second.size() > input_stack_partials_num_) {
216 MS_LOG(ERROR) << "Invalid input stack partial num:" << partial_iter->second.size()
217 << " need:" << input_stack_partials_num_ << " for actor:" << GetAID();
218 return false;
219 }
220
221 if (is_branch_id_invalid) {
222 MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID().Name();
223 return false;
224 }
225 size_t branch_id_size = 1;
226 if (is_branch_id_enable_) {
227 branch_id_size = iter->second.size();
228 }
229 for (const auto &one_stack : partial_iter->second) {
230 if (one_stack.second.size() < branch_id_size) {
231 return false;
232 } else if (one_stack.second.size() > branch_id_size) {
233 MS_LOG(ERROR) << "Invalid input stack partial num:" << one_stack.second.size()
234 << " for input index:" << one_stack.first << " need:" << branch_id_size
235 << " for actor:" << GetAID();
236 return false;
237 }
238 }
239 }
240 return true;
241 }
242
CheckStackControlRunningCondition(const OpContext<DeviceTensor> * context) const243 bool StackActor::CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const {
244 MS_EXCEPTION_IF_NULL(context);
245 auto iter = input_branch_ids_.find(context->sequential_num_);
246 bool is_branch_id_invalid = (is_branch_id_enable_ && (iter == input_branch_ids_.end() || iter->second.empty()));
247
248 if (input_stack_controls_num_ != 0) {
249 const auto &control_iter = input_stack_controls_.find(context->sequential_num_);
250 if (control_iter == input_stack_controls_.end()) {
251 return false;
252 }
253 if (control_iter->second.size() < input_stack_controls_num_) {
254 return false;
255 } else if (control_iter->second.size() > input_stack_controls_num_) {
256 MS_LOG(ERROR) << "Invalid input stack control num:" << control_iter->second.size()
257 << " need:" << input_stack_controls_num_ << " for actor:" << GetAID();
258 return false;
259 }
260
261 if (is_branch_id_invalid) {
262 MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID().Name();
263 return false;
264 }
265 size_t branch_id_size = 1;
266 if (is_branch_id_enable_) {
267 branch_id_size = iter->second.size();
268 }
269 for (const auto &one_stack : control_iter->second) {
270 if (one_stack.second < branch_id_size) {
271 return false;
272 } else if (one_stack.second > branch_id_size) {
273 MS_LOG(ERROR) << "Invalid input stack control num:" << one_stack.second
274 << " for input actor index:" << one_stack.first << " need:" << branch_id_size
275 << " for actor:" << GetAID();
276 return false;
277 }
278 }
279 }
280 return true;
281 }
282
FetchInput(OpContext<DeviceTensor> * const context)283 void StackActor::FetchInput(OpContext<DeviceTensor> *const context) {
284 MS_EXCEPTION_IF_NULL(context);
285 if (input_stack_data_num_ != 0) {
286 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
287 const auto &data_iter = input_stack_data_.find(context->sequential_num_);
288 if (data_iter == input_stack_data_.end()) {
289 std::string error_info = "Invalid input for actor:" + GetAID().Name();
290 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
291 }
292 for (const auto &one_stack : data_iter->second) {
293 if (one_stack.first >= input_stack_data_num_ + device_tensor_store_keys_.size() + local_device_tensors_.size() +
294 input_stack_partials_num_) {
295 std::string error_info = "Invalid input index:" + std::to_string(one_stack.first) +
296 " need:" + std::to_string(input_stack_data_num_) + " for actor:" + GetAID().Name();
297 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
298 }
299 MS_EXCEPTION_IF_NULL(one_stack.second.top());
300 input_device_tensors_[one_stack.first] = one_stack.second.top();
301 }
302 }
303
304 if (input_stack_partials_num_ != 0) {
305 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
306 const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
307 if (partial_iter == input_stack_partials_.end()) {
308 std::string error_info = "Invalid input for actor:" + GetAID().Name();
309 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
310 }
311 for (const auto &one_stack : partial_iter->second) {
312 if (one_stack.first >= input_stack_data_num_ + device_tensor_store_keys_.size() + local_device_tensors_.size() +
313 input_stack_partials_num_) {
314 std::string error_info = "Invalid input index:" + std::to_string(one_stack.first) +
315 " need:" + std::to_string(input_stack_partials_num_) + " for actor:" + GetAID().Name();
316 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
317 }
318 input_partials_[one_stack.first] = one_stack.second.top();
319 }
320 }
321 ControlActor::FetchInput(context);
322 }
323
EraseInput(const OpContext<DeviceTensor> * const context)324 void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) {
325 MS_EXCEPTION_IF_NULL(context);
326 ControlActor::EraseInput(context);
327
328 if (input_stack_data_num_ != 0) {
329 const auto &data_iter = input_stack_data_.find(context->sequential_num_);
330 if (data_iter == input_stack_data_.end()) {
331 MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
332 return;
333 }
334
335 for (auto &one_stack : data_iter->second) {
336 if (one_stack.second.empty()) {
337 MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
338 return;
339 }
340 one_stack.second.pop();
341 }
342 }
343
344 if (input_stack_partials_num_ != 0) {
345 const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
346 if (partial_iter == input_stack_partials_.end()) {
347 MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
348 return;
349 }
350
351 for (auto &one_stack : partial_iter->second) {
352 if (one_stack.second.empty()) {
353 MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
354 return;
355 }
356 one_stack.second.pop();
357 }
358 }
359
360 if (input_stack_controls_num_ != 0) {
361 const auto &control_iter = input_stack_controls_.find(context->sequential_num_);
362 if (control_iter == input_stack_controls_.end()) {
363 MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
364 return;
365 }
366
367 mindspore::HashMap<size_t, size_t> tmp_stack_controls;
368 for (auto stack_iter = control_iter->second.begin(); stack_iter != control_iter->second.end(); ++stack_iter) {
369 if (stack_iter->second == 0) {
370 MS_LOG(ERROR) << "Input stack control aid:" << stack_iter->first << " is null in actor:" << GetAID();
371 return;
372 } else if (stack_iter->second == 1) {
373 continue;
374 } else {
375 tmp_stack_controls[stack_iter->first] = stack_iter->second - 1;
376 }
377 }
378 if (tmp_stack_controls.empty()) {
379 (void)input_stack_controls_.erase(control_iter);
380 } else {
381 control_iter->second.swap(tmp_stack_controls);
382 }
383 }
384 }
385
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)386 void StackActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
387 MS_EXCEPTION_IF_NULL(context);
388 const auto &sequential_num = context->sequential_num_;
389
390 // Collect the input device tensors.
391 std::vector<DeviceTensor *> memory_free_list;
392 if (input_op_datas_.find(sequential_num) != input_op_datas_.end()) {
393 for (auto &input_data : input_op_datas_[sequential_num]) {
394 MS_EXCEPTION_IF_NULL(input_data);
395 MS_EXCEPTION_IF_NULL(input_data->data_);
396 (void)memory_free_list.emplace_back(input_data->data_);
397 }
398 }
399
400 if (input_op_partials_.find(sequential_num) != input_op_partials_.end()) {
401 for (auto &input_partial_pair : input_op_partials_[sequential_num]) {
402 GetAllDeviceTensors(input_partial_pair.second, &memory_free_list);
403 }
404 }
405
406 if ((input_stack_data_num_ != 0) && (input_stack_data_.count(sequential_num) > 0)) {
407 for (auto &stack_data_pair : input_stack_data_[sequential_num]) {
408 if (!stack_data_pair.second.empty()) {
409 (void)memory_free_list.emplace_back(stack_data_pair.second.top());
410 }
411 }
412 }
413
414 if ((input_stack_partials_num_ != 0) && (input_stack_partials_.count(sequential_num) > 0)) {
415 for (auto &stack_partial_pair : input_stack_partials_[sequential_num]) {
416 if (!stack_partial_pair.second.empty()) {
417 GetAllDeviceTensors(stack_partial_pair.second.top(), &memory_free_list);
418 }
419 }
420 }
421
422 if (memory_free_list.size() > 0) {
423 memory_free_lists_.push(memory_free_list);
424 if (ActorDispatcher::is_memory_free_sync()) {
425 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
426 device_contexts_[0], context, GetAID());
427 } else {
428 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
429 device_contexts_[0], context, GetAID());
430 }
431 }
432 }
433 } // namespace runtime
434 } // namespace mindspore
435