• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/control_flow/stack_actor.h"
18 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
19 #include "runtime/graph_scheduler/control_node_parser.h"
20 
21 namespace mindspore {
22 namespace runtime {
StackActor(const std::string & name,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters)23 StackActor::StackActor(const std::string &name, const AID &memory_manager_aid,
24                        const std::vector<KernelWithIndex> &parameters)
25     : ControlActor(name, KernelTransformType::kStackActor, memory_manager_aid, parameters, nullptr) {
26   input_device_tensors_.resize(parameters.size());
27 }
28 
Init()29 void StackActor::Init() {
30   ControlActor::Init();
31   // The stack actor has 6 parts of input :
32   // 1. Directly input data.
33   // 2. Direct input partial.
34   // 3. Weight.
35   // 4. Local tensor.
36   // 5. Call input data.
37   // 6. Call input partial.
38   input_datas_num_ = formal_parameters_.size() - input_stack_data_num_ - input_stack_partials_num_;
39   if (input_stack_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) {
40     MS_LOG(EXCEPTION) << "Invalid input stack data num:" << input_stack_data_num_
41                       << " device store num:" << device_tensor_store_keys_.size()
42                       << " local device tensor num:" << local_device_tensors_.size()
43                       << " input stack data num:" << input_stack_data_num_
44                       << " input stack partial num:" << input_stack_partials_num_ << " for actor:" << GetAID();
45   }
46 
47   // Fetch the total number of input partial.
48   size_t total_partials_num = 0;
49   for (const auto &formal_parameter : formal_parameters_) {
50     MS_EXCEPTION_IF_NULL(formal_parameter.first);
51     const auto &abstract = formal_parameter.first->abstract();
52     MS_EXCEPTION_IF_NULL(abstract);
53     const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
54     MS_EXCEPTION_IF_NULL(real_abstract);
55     if (real_abstract->isa<abstract::AbstractFunction>()) {
56       total_partials_num++;
57     }
58   }
59 
60   // Fetch call input data num.
61   input_datas_num_ = formal_parameters_.size() - total_partials_num - input_stack_data_num_;
62   input_partials_num_ = total_partials_num - input_stack_partials_num_;
63   // Fetch call input partial num.
64   input_stack_data_num_ -= (device_tensor_store_keys_.size() + local_device_tensors_.size());
65   // Check if the input num is valid.
66   if (input_stack_data_num_ + input_stack_partials_num_ + input_datas_num_ + input_partials_num_ +
67         device_tensor_store_keys_.size() + local_device_tensors_.size() !=
68       formal_parameters_.size()) {
69     MS_LOG(EXCEPTION) << "Invalid input num, input stack data num:" << input_stack_data_num_
70                       << " input stack partial num:" << input_stack_partials_num_
71                       << " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_
72                       << " device tensor store size:" << device_tensor_store_keys_.size()
73                       << " need total size:" << formal_parameters_.size() << " for actor:" << GetAID();
74   }
75   MS_LOG(DEBUG) << "Stack actor input stack data num:" << input_stack_data_num_
76                 << " stack partial num:" << input_stack_partials_num_ << " input data num:" << input_datas_num_
77                 << " input partial num:" << input_partials_num_
78                 << " device tensor store num:" << device_tensor_store_keys_.size()
79                 << " local tensor num:" << local_device_tensors_.size()
80                 << " formal parameter num:" << formal_parameters_.size();
81 }
82 
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)83 void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
84   MS_EXCEPTION_IF_NULL(context);
85   MS_EXCEPTION_IF_NULL(input_data);
86   MS_EXCEPTION_IF_NULL(input_data->data_);
87   MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input data:" << input_data->data_
88                 << " input index:" << input_data->index_ << ", size:" << input_data->data_->GetSize()
89                 << " ptr:" << input_data->data_->GetMutablePtr()
90                 << ", origin ref count:" << input_data->data_->original_ref_count()
91                 << ", current ref count:" << input_data->data_->ref_count()
92                 << ", dynamic ref count:" << input_data->data_->dynamic_ref_count()
93                 << ", flag:" << input_data->data_->flag() << " user data:" << input_data->data_->user_data();
94   // The parameters from the inside of the subgraph need to be put into the stack.
95   if (IntToSize(input_data->index_) < input_stack_data_num_ + device_tensor_store_keys_.size() +
96                                         input_stack_partials_num_ + local_device_tensors_.size()) {
97     input_stack_data_[context->sequential_num_][input_data->index_].push(input_data->data_);
98   } else {
99     // The outputs of call nodes are placed directly in the input data.
100     (void)input_op_datas_[context->sequential_num_].emplace_back(input_data);
101   }
102 
103   auto is_run = CheckRunningCondition(context);
104   MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data and check running condition:" << is_run;
105   if (is_run) {
106     Run(context);
107   }
108 }
109 
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)110 void StackActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
111   MS_EXCEPTION_IF_NULL(context);
112   auto &sequential_num = context->sequential_num_;
113   if (control_aid_to_indexs_.find(*input_control) != control_aid_to_indexs_.end()) {
114     if ((input_stack_controls_.find(sequential_num) == input_stack_controls_.end()) ||
115         (input_stack_controls_[sequential_num].find(control_aid_to_indexs_[*input_control]) ==
116          input_stack_controls_[sequential_num].end())) {
117       input_stack_controls_[sequential_num][control_aid_to_indexs_[*input_control]] = 1;
118     } else {
119       input_stack_controls_[sequential_num][control_aid_to_indexs_[*input_control]]++;
120     }
121   } else {
122     (void)input_op_controls_[sequential_num].emplace_back(input_control);
123   }
124 
125   if (CheckRunningCondition(context)) {
126     Run(context);
127   }
128 }
129 
RunOpPartial(const OpPartialPtr & partial,size_t position,OpContext<DeviceTensor> * const context)130 void StackActor::RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) {
131   MS_EXCEPTION_IF_NULL(context);
132   auto self_partial = std::make_shared<OpPartial>();
133   *self_partial = *partial;
134   // The parameters from the inside of the subgraph need to be put into the stack.
135   if (position < input_stack_data_num_ + device_tensor_store_keys_.size() + input_stack_partials_num_ +
136                    local_device_tensors_.size()) {
137     input_stack_partials_[context->sequential_num_][position].push(self_partial);
138   } else {
139     (void)input_op_partials_[context->sequential_num_].emplace_back(position, self_partial);
140   }
141 
142   auto is_run = CheckRunningCondition(context);
143   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
144                 << ") receive the input op partial and check running condition:" << is_run;
145   if (is_run) {
146     Run(context);
147   }
148 }
149 
CheckRunningCondition(const OpContext<DeviceTensor> * context) const150 bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
151   MS_EXCEPTION_IF_NULL(context);
152   if (!ControlActor::CheckRunningCondition(context)) {
153     return false;
154   }
155 
156   if (CheckStackDataRunningCondition(context) && CheckStackPartialRunningCondition(context) &&
157       CheckStackControlRunningCondition(context)) {
158     return true;
159   }
160   return false;
161 }
162 
CheckStackDataRunningCondition(const OpContext<DeviceTensor> * context) const163 bool StackActor::CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const {
164   MS_EXCEPTION_IF_NULL(context);
165   auto iter = input_branch_ids_.find(context->sequential_num_);
166   bool is_branch_id_invalid = (is_branch_id_enable_ && (iter == input_branch_ids_.end() || iter->second.empty()));
167 
168   if (input_stack_data_num_ != 0) {
169     const auto &data_iter = input_stack_data_.find(context->sequential_num_);
170     if (data_iter == input_stack_data_.end()) {
171       return false;
172     }
173     if (data_iter->second.size() < input_stack_data_num_) {
174       return false;
175     } else if (data_iter->second.size() > input_stack_data_num_) {
176       MS_LOG(ERROR) << "Invalid input stack data num:" << data_iter->second.size() << " need:" << input_stack_data_num_
177                     << " for actor:" << GetAID();
178       return false;
179     }
180 
181     if (is_branch_id_invalid) {
182       MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID().Name();
183       return false;
184     }
185     size_t branch_id_size = 1;
186     if (is_branch_id_enable_) {
187       branch_id_size = iter->second.size();
188     }
189     for (const auto &one_stack : data_iter->second) {
190       if (one_stack.second.size() < branch_id_size) {
191         return false;
192       } else if (one_stack.second.size() > branch_id_size) {
193         MS_LOG(ERROR) << "Invalid input stack data num:" << one_stack.second.size()
194                       << " for input index:" << one_stack.first << " need:" << branch_id_size
195                       << " for actor:" << GetAID();
196         return false;
197       }
198     }
199   }
200   return true;
201 }
202 
CheckStackPartialRunningCondition(const OpContext<DeviceTensor> * context) const203 bool StackActor::CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const {
204   MS_EXCEPTION_IF_NULL(context);
205   auto iter = input_branch_ids_.find(context->sequential_num_);
206   bool is_branch_id_invalid = (is_branch_id_enable_ && (iter == input_branch_ids_.end() || iter->second.empty()));
207 
208   if (input_stack_partials_num_ != 0) {
209     const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
210     if (partial_iter == input_stack_partials_.end()) {
211       return false;
212     }
213     if (partial_iter->second.size() < input_stack_partials_num_) {
214       return false;
215     } else if (partial_iter->second.size() > input_stack_partials_num_) {
216       MS_LOG(ERROR) << "Invalid input stack partial num:" << partial_iter->second.size()
217                     << " need:" << input_stack_partials_num_ << " for actor:" << GetAID();
218       return false;
219     }
220 
221     if (is_branch_id_invalid) {
222       MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID().Name();
223       return false;
224     }
225     size_t branch_id_size = 1;
226     if (is_branch_id_enable_) {
227       branch_id_size = iter->second.size();
228     }
229     for (const auto &one_stack : partial_iter->second) {
230       if (one_stack.second.size() < branch_id_size) {
231         return false;
232       } else if (one_stack.second.size() > branch_id_size) {
233         MS_LOG(ERROR) << "Invalid input stack partial num:" << one_stack.second.size()
234                       << " for input index:" << one_stack.first << " need:" << branch_id_size
235                       << " for actor:" << GetAID();
236         return false;
237       }
238     }
239   }
240   return true;
241 }
242 
CheckStackControlRunningCondition(const OpContext<DeviceTensor> * context) const243 bool StackActor::CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const {
244   MS_EXCEPTION_IF_NULL(context);
245   auto iter = input_branch_ids_.find(context->sequential_num_);
246   bool is_branch_id_invalid = (is_branch_id_enable_ && (iter == input_branch_ids_.end() || iter->second.empty()));
247 
248   if (input_stack_controls_num_ != 0) {
249     const auto &control_iter = input_stack_controls_.find(context->sequential_num_);
250     if (control_iter == input_stack_controls_.end()) {
251       return false;
252     }
253     if (control_iter->second.size() < input_stack_controls_num_) {
254       return false;
255     } else if (control_iter->second.size() > input_stack_controls_num_) {
256       MS_LOG(ERROR) << "Invalid input stack control num:" << control_iter->second.size()
257                     << " need:" << input_stack_controls_num_ << " for actor:" << GetAID();
258       return false;
259     }
260 
261     if (is_branch_id_invalid) {
262       MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID().Name();
263       return false;
264     }
265     size_t branch_id_size = 1;
266     if (is_branch_id_enable_) {
267       branch_id_size = iter->second.size();
268     }
269     for (const auto &one_stack : control_iter->second) {
270       if (one_stack.second < branch_id_size) {
271         return false;
272       } else if (one_stack.second > branch_id_size) {
273         MS_LOG(ERROR) << "Invalid input stack control num:" << one_stack.second
274                       << " for input actor index:" << one_stack.first << " need:" << branch_id_size
275                       << " for actor:" << GetAID();
276         return false;
277       }
278     }
279   }
280   return true;
281 }
282 
FetchInput(OpContext<DeviceTensor> * const context)283 void StackActor::FetchInput(OpContext<DeviceTensor> *const context) {
284   MS_EXCEPTION_IF_NULL(context);
285   if (input_stack_data_num_ != 0) {
286     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
287     const auto &data_iter = input_stack_data_.find(context->sequential_num_);
288     if (data_iter == input_stack_data_.end()) {
289       std::string error_info = "Invalid input for actor:" + GetAID().Name();
290       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
291     }
292     for (const auto &one_stack : data_iter->second) {
293       if (one_stack.first >= input_stack_data_num_ + device_tensor_store_keys_.size() + local_device_tensors_.size() +
294                                input_stack_partials_num_) {
295         std::string error_info = "Invalid input index:" + std::to_string(one_stack.first) +
296                                  " need:" + std::to_string(input_stack_data_num_) + " for actor:" + GetAID().Name();
297         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
298       }
299       MS_EXCEPTION_IF_NULL(one_stack.second.top());
300       input_device_tensors_[one_stack.first] = one_stack.second.top();
301     }
302   }
303 
304   if (input_stack_partials_num_ != 0) {
305     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
306     const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
307     if (partial_iter == input_stack_partials_.end()) {
308       std::string error_info = "Invalid input for actor:" + GetAID().Name();
309       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
310     }
311     for (const auto &one_stack : partial_iter->second) {
312       if (one_stack.first >= input_stack_data_num_ + device_tensor_store_keys_.size() + local_device_tensors_.size() +
313                                input_stack_partials_num_) {
314         std::string error_info = "Invalid input index:" + std::to_string(one_stack.first) +
315                                  " need:" + std::to_string(input_stack_partials_num_) + " for actor:" + GetAID().Name();
316         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
317       }
318       input_partials_[one_stack.first] = one_stack.second.top();
319     }
320   }
321   ControlActor::FetchInput(context);
322 }
323 
EraseInput(const OpContext<DeviceTensor> * const context)324 void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) {
325   MS_EXCEPTION_IF_NULL(context);
326   ControlActor::EraseInput(context);
327 
328   if (input_stack_data_num_ != 0) {
329     const auto &data_iter = input_stack_data_.find(context->sequential_num_);
330     if (data_iter == input_stack_data_.end()) {
331       MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
332       return;
333     }
334 
335     for (auto &one_stack : data_iter->second) {
336       if (one_stack.second.empty()) {
337         MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
338         return;
339       }
340       one_stack.second.pop();
341     }
342   }
343 
344   if (input_stack_partials_num_ != 0) {
345     const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
346     if (partial_iter == input_stack_partials_.end()) {
347       MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
348       return;
349     }
350 
351     for (auto &one_stack : partial_iter->second) {
352       if (one_stack.second.empty()) {
353         MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
354         return;
355       }
356       one_stack.second.pop();
357     }
358   }
359 
360   if (input_stack_controls_num_ != 0) {
361     const auto &control_iter = input_stack_controls_.find(context->sequential_num_);
362     if (control_iter == input_stack_controls_.end()) {
363       MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
364       return;
365     }
366 
367     mindspore::HashMap<size_t, size_t> tmp_stack_controls;
368     for (auto stack_iter = control_iter->second.begin(); stack_iter != control_iter->second.end(); ++stack_iter) {
369       if (stack_iter->second == 0) {
370         MS_LOG(ERROR) << "Input stack control aid:" << stack_iter->first << " is null in actor:" << GetAID();
371         return;
372       } else if (stack_iter->second == 1) {
373         continue;
374       } else {
375         tmp_stack_controls[stack_iter->first] = stack_iter->second - 1;
376       }
377     }
378     if (tmp_stack_controls.empty()) {
379       (void)input_stack_controls_.erase(control_iter);
380     } else {
381       control_iter->second.swap(tmp_stack_controls);
382     }
383   }
384 }
385 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)386 void StackActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
387   MS_EXCEPTION_IF_NULL(context);
388   const auto &sequential_num = context->sequential_num_;
389 
390   // Collect the input device tensors.
391   std::vector<DeviceTensor *> memory_free_list;
392   if (input_op_datas_.find(sequential_num) != input_op_datas_.end()) {
393     for (auto &input_data : input_op_datas_[sequential_num]) {
394       MS_EXCEPTION_IF_NULL(input_data);
395       MS_EXCEPTION_IF_NULL(input_data->data_);
396       (void)memory_free_list.emplace_back(input_data->data_);
397     }
398   }
399 
400   if (input_op_partials_.find(sequential_num) != input_op_partials_.end()) {
401     for (auto &input_partial_pair : input_op_partials_[sequential_num]) {
402       GetAllDeviceTensors(input_partial_pair.second, &memory_free_list);
403     }
404   }
405 
406   if ((input_stack_data_num_ != 0) && (input_stack_data_.count(sequential_num) > 0)) {
407     for (auto &stack_data_pair : input_stack_data_[sequential_num]) {
408       if (!stack_data_pair.second.empty()) {
409         (void)memory_free_list.emplace_back(stack_data_pair.second.top());
410       }
411     }
412   }
413 
414   if ((input_stack_partials_num_ != 0) && (input_stack_partials_.count(sequential_num) > 0)) {
415     for (auto &stack_partial_pair : input_stack_partials_[sequential_num]) {
416       if (!stack_partial_pair.second.empty()) {
417         GetAllDeviceTensors(stack_partial_pair.second.top(), &memory_free_list);
418       }
419     }
420   }
421 
422   if (memory_free_list.size() > 0) {
423     memory_free_lists_.push(memory_free_list);
424     if (ActorDispatcher::is_memory_free_sync()) {
425       ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
426                                 device_contexts_[0], context, GetAID());
427     } else {
428       ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
429                             device_contexts_[0], context, GetAID());
430     }
431   }
432 }
433 }  // namespace runtime
434 }  // namespace mindspore
435