• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
18 #include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
19 
20 namespace mindspore {
21 namespace runtime {
22 constexpr size_t kEntranceInputStartPos = 1;
23 
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)24 void EntranceActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
25   MS_EXCEPTION_IF_NULL(context);
26   auto &sequential_num = context->sequential_num_;
27   if (is_loop_body_execution_) {
28     (void)loop_body_input_op_controls_[sequential_num].emplace_back(input_control);
29   } else {
30     (void)input_op_controls_[sequential_num].emplace_back(input_control);
31   }
32 
33   auto is_run = CheckRunningCondition(context);
34   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
35                 << ") receive the input op control and check running condition:" << is_run
36                 << ", loop body execution:" << is_loop_body_execution_;
37   if (is_run) {
38     Run(context);
39   }
40 }
41 
RunOpRealParameterWithBranchID(const OpRealParameterWithBranchID & real_parameter_with_branch_id,OpContext<DeviceTensor> * const context)42 void EntranceActor::RunOpRealParameterWithBranchID(const OpRealParameterWithBranchID &real_parameter_with_branch_id,
43                                                    OpContext<DeviceTensor> *const context) {
44   MS_EXCEPTION_IF_NULL(context);
45   auto &sequential_num = context->sequential_num_;
46   (void)real_parameters_with_branch_id_[sequential_num].emplace(real_parameter_with_branch_id);
47 
48   auto is_run = CheckRunningCondition(context);
49   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
50                 << ") receive the input op data with branch id and check running condition:" << is_run
51                 << ", loop body execution:" << is_loop_body_execution_;
52   if (is_run) {
53     Run(context);
54   }
55 }
56 
ClearDataOnStepEnd(AID * const input_control,OpContext<DeviceTensor> * const context)57 void EntranceActor::ClearDataOnStepEnd(AID *const input_control, OpContext<DeviceTensor> *const context) {
58   MS_EXCEPTION_IF_NULL(context);
59   MS_EXCEPTION_IF_NULL(input_control);
60   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
61                 << ") receive the message of clearing data from:" << input_control->Name() << ".";
62 
63   is_loop_body_execution_ = false;
64 
65   if (loop_body_input_controls_nums_ != 0) {
66     loop_body_input_op_controls_.clear();
67   }
68 }
69 
Run(OpContext<DeviceTensor> * const context)70 void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
71   // The begin execution of step is false and the others execution of step is true.
72   is_loop_body_execution_ = true;
73 
74   FetchInput(context);
75 
76   // Note that IncreaseDynamicRefCount must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the
77   // dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then
78   // incremented.
79   IncreaseDynamicRefCounts(context);
80   SendMemoryFreeReq(context);
81 
82   EraseInput(context);
83   SendOutput(context);
84 }
85 
FetchInput(OpContext<DeviceTensor> * const context)86 void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
87   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
88   MS_EXCEPTION_IF_NULL(context);
89   auto &sequential_num = context->sequential_num_;
90 
91   // There are two kinds of run conditions for entrance actor:
92   // 1.Data comes from the data source actor, it is in the form of data arrow.
93   const auto &data_iter = input_op_datas_.find(sequential_num);
94   const auto &control_iter = input_op_controls_.find(sequential_num);
95   if (data_iter != input_op_datas_.end() || control_iter != input_op_controls_.end()) {
96     // If the data comes from the data source actor, use the default branch id.
97     output_branch_id_ = 0;
98 
99     if (data_iter == input_op_datas_.end()) {
100       return;
101     }
102 
103     for (auto &input_data : data_iter->second) {
104       MS_EXCEPTION_IF_NULL(input_data);
105       if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
106         std::string error_info = "The input index is out of range, need:" + std::to_string(input_data->index_) +
107                                  " current:" + std::to_string(input_device_tensors_.size()) +
108                                  " for actor:" + GetAID().Name();
109         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
110       }
111       MS_EXCEPTION_IF_NULL(input_data->data_);
112       input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
113     }
114   } else {
115     // 2.Data comes from the gather actor, it is in the form of data with branch id.
116     output_branch_id_ = real_parameters_with_branch_id_[sequential_num].front().branch_id_;
117     const auto &device_tensors = real_parameters_with_branch_id_[sequential_num].front().device_tensors_;
118     const auto &partials = real_parameters_with_branch_id_[sequential_num].front().partials_;
119 
120     // Collect the device tensors.
121     if (device_tensors.size() + partials.size() != formal_parameters_.size()) {
122       std::string error_info = "Invalid input num, need:" + std::to_string(formal_parameters_.size()) +
123                                " device tensor num:" + std::to_string(device_tensors.size()) +
124                                " partial num:" + std::to_string(partials.size()) + " for actor:" + GetAID().Name();
125       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
126     }
127     for (const auto &device_tensor : device_tensors) {
128       if (device_tensor.first >= input_device_tensors_.size()) {
129         std::string error_info = "Invalid device tensor index:" + std::to_string(device_tensor.first) +
130                                  " vector size:" + std::to_string(input_device_tensors_.size()) +
131                                  " for actor:" + GetAID().Name();
132         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
133       }
134       input_device_tensors_[device_tensor.first] = device_tensor.second;
135     }
136 
137     // Collect the partials.
138     for (const auto &partial : partials) {
139       if (partial.first >= input_partials_.size()) {
140         std::string error_info = "Invalid partial index:" + std::to_string(partial.first) +
141                                  " vector size:" + std::to_string(partials.size()) + " for actor:" + GetAID().Name();
142         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
143       }
144       input_partials_[partial.first] = partial.second;
145     }
146   }
147 
148   // Init the device tensor in output data.
149   for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
150     if (output_data_by_output_index_[i].empty()) {
151       continue;
152     }
153     const auto &data = input_device_tensors_[i];
154     if (data == nullptr) {
155       std::string error_info = "Input data index:" + std::to_string(i) + " for actor:" + GetAID().Name() + " is empty!";
156       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
157     }
158     for (auto &output_data : output_data_by_output_index_[i]) {
159       MS_EXCEPTION_IF_NULL(output_data);
160       output_data->data_ = data;
161     }
162   }
163 }
164 
CheckRunningCondition(const OpContext<DeviceTensor> * context) const165 bool EntranceActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
166   MS_EXCEPTION_IF_NULL(context);
167 
168   // Check the running condition in the begin execution of step.
169   // The input controls and input data exist the begin execution of root graph, and there will only be one of the two.
170   if (!is_loop_body_execution_) {
171     if (input_controls_num_ != 0) {
172       const auto &control_iter = input_op_controls_.find(context->sequential_num_);
173       if ((control_iter != input_op_controls_.end()) && (control_iter->second.size() == input_controls_num_)) {
174         return true;
175       }
176     }
177 
178     // Data comes from the data source actor.
179     if (input_datas_num_ != 0) {
180       const auto &data_iter = input_op_datas_.find(context->sequential_num_);
181       if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) {
182         return true;
183       }
184     }
185   }
186 
187   // Check the controls in the loop body execution of step.
188   if (is_loop_body_execution_ && (loop_body_input_controls_nums_ != 0)) {
189     const auto &control_iter = loop_body_input_op_controls_.find(context->sequential_num_);
190     if ((control_iter == loop_body_input_op_controls_.end()) ||
191         (control_iter->second.size() != loop_body_input_controls_nums_)) {
192       return false;
193     }
194   }
195 
196   // Data comes from the gather actor.
197   const auto &iter = real_parameters_with_branch_id_.find(context->sequential_num_);
198   if (iter == real_parameters_with_branch_id_.end() || iter->second.empty()) {
199     return false;
200   }
201   return true;
202 }
203 
EraseInput(const OpContext<DeviceTensor> * const context)204 void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
205   MS_EXCEPTION_IF_NULL(context);
206   auto &sequential_num = context->sequential_num_;
207 
208   const auto &data_iter = input_op_datas_.find(sequential_num);
209   if (data_iter != input_op_datas_.end()) {
210     (void)input_op_datas_.erase(data_iter);
211   }
212 
213   const auto &control_iter = input_op_controls_.find(sequential_num);
214   if (control_iter != input_op_controls_.end()) {
215     (void)input_op_controls_.erase(control_iter);
216   }
217 
218   const auto &loop_body_control_iter = loop_body_input_op_controls_.find(sequential_num);
219   if (loop_body_control_iter != loop_body_input_op_controls_.end()) {
220     (void)loop_body_input_op_controls_.erase(loop_body_control_iter);
221   }
222 
223   const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
224   if (iter != real_parameters_with_branch_id_.end()) {
225     iter->second.pop();
226     if (iter->second.empty()) {
227       (void)real_parameters_with_branch_id_.erase(sequential_num);
228     }
229   }
230 }
231 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)232 void EntranceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
233   MS_EXCEPTION_IF_NULL(context);
234   const auto &sequential_num = context->sequential_num_;
235 
236   // Collect the input device tensors.
237   std::vector<DeviceTensor *> memory_free_list;
238   if (input_op_datas_.count(sequential_num) > 0) {
239     for (auto &input_data : input_op_datas_[sequential_num]) {
240       MS_EXCEPTION_IF_NULL(input_data);
241       MS_EXCEPTION_IF_NULL(input_data->data_);
242       (void)memory_free_list.emplace_back(input_data->data_);
243     }
244   }
245 
246   const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
247   if (iter != real_parameters_with_branch_id_.end()) {
248     if (iter->second.empty()) {
249       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The real parameter with branch id is empty.");
250     }
251     auto &real_parameters_with_branch_id = iter->second.front();
252     GetAllDeviceTensors(real_parameters_with_branch_id, &memory_free_list);
253   }
254 
255   if (memory_free_list.size() > 0) {
256     memory_free_lists_.push(memory_free_list);
257     if (ActorDispatcher::is_memory_free_sync()) {
258       ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
259                                 device_contexts_[0], context, GetAID());
260     } else {
261       ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
262                             device_contexts_[0], context, GetAID());
263     }
264   }
265 }
266 }  // namespace runtime
267 }  // namespace mindspore
268