1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
18 #include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
19
20 namespace mindspore {
21 namespace runtime {
22 constexpr size_t kEntranceInputStartPos = 1;
23
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)24 void EntranceActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
25 MS_EXCEPTION_IF_NULL(context);
26 auto &sequential_num = context->sequential_num_;
27 if (is_loop_body_execution_) {
28 (void)loop_body_input_op_controls_[sequential_num].emplace_back(input_control);
29 } else {
30 (void)input_op_controls_[sequential_num].emplace_back(input_control);
31 }
32
33 auto is_run = CheckRunningCondition(context);
34 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
35 << ") receive the input op control and check running condition:" << is_run
36 << ", loop body execution:" << is_loop_body_execution_;
37 if (is_run) {
38 Run(context);
39 }
40 }
41
RunOpRealParameterWithBranchID(const OpRealParameterWithBranchID & real_parameter_with_branch_id,OpContext<DeviceTensor> * const context)42 void EntranceActor::RunOpRealParameterWithBranchID(const OpRealParameterWithBranchID &real_parameter_with_branch_id,
43 OpContext<DeviceTensor> *const context) {
44 MS_EXCEPTION_IF_NULL(context);
45 auto &sequential_num = context->sequential_num_;
46 (void)real_parameters_with_branch_id_[sequential_num].emplace(real_parameter_with_branch_id);
47
48 auto is_run = CheckRunningCondition(context);
49 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
50 << ") receive the input op data with branch id and check running condition:" << is_run
51 << ", loop body execution:" << is_loop_body_execution_;
52 if (is_run) {
53 Run(context);
54 }
55 }
56
ClearDataOnStepEnd(AID * const input_control,OpContext<DeviceTensor> * const context)57 void EntranceActor::ClearDataOnStepEnd(AID *const input_control, OpContext<DeviceTensor> *const context) {
58 MS_EXCEPTION_IF_NULL(context);
59 MS_EXCEPTION_IF_NULL(input_control);
60 MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
61 << ") receive the message of clearing data from:" << input_control->Name() << ".";
62
63 is_loop_body_execution_ = false;
64
65 if (loop_body_input_controls_nums_ != 0) {
66 loop_body_input_op_controls_.clear();
67 }
68 }
69
Run(OpContext<DeviceTensor> * const context)70 void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
71 // The begin execution of step is false and the others execution of step is true.
72 is_loop_body_execution_ = true;
73
74 FetchInput(context);
75
76 // Note that IncreaseDynamicRefCount must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the
77 // dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then
78 // incremented.
79 IncreaseDynamicRefCounts(context);
80 SendMemoryFreeReq(context);
81
82 EraseInput(context);
83 SendOutput(context);
84 }
85
FetchInput(OpContext<DeviceTensor> * const context)86 void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
87 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
88 MS_EXCEPTION_IF_NULL(context);
89 auto &sequential_num = context->sequential_num_;
90
91 // There are two kinds of run conditions for entrance actor:
92 // 1.Data comes from the data source actor, it is in the form of data arrow.
93 const auto &data_iter = input_op_datas_.find(sequential_num);
94 const auto &control_iter = input_op_controls_.find(sequential_num);
95 if (data_iter != input_op_datas_.end() || control_iter != input_op_controls_.end()) {
96 // If the data comes from the data source actor, use the default branch id.
97 output_branch_id_ = 0;
98
99 if (data_iter == input_op_datas_.end()) {
100 return;
101 }
102
103 for (auto &input_data : data_iter->second) {
104 MS_EXCEPTION_IF_NULL(input_data);
105 if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
106 std::string error_info = "The input index is out of range, need:" + std::to_string(input_data->index_) +
107 " current:" + std::to_string(input_device_tensors_.size()) +
108 " for actor:" + GetAID().Name();
109 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
110 }
111 MS_EXCEPTION_IF_NULL(input_data->data_);
112 input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
113 }
114 } else {
115 // 2.Data comes from the gather actor, it is in the form of data with branch id.
116 output_branch_id_ = real_parameters_with_branch_id_[sequential_num].front().branch_id_;
117 const auto &device_tensors = real_parameters_with_branch_id_[sequential_num].front().device_tensors_;
118 const auto &partials = real_parameters_with_branch_id_[sequential_num].front().partials_;
119
120 // Collect the device tensors.
121 if (device_tensors.size() + partials.size() != formal_parameters_.size()) {
122 std::string error_info = "Invalid input num, need:" + std::to_string(formal_parameters_.size()) +
123 " device tensor num:" + std::to_string(device_tensors.size()) +
124 " partial num:" + std::to_string(partials.size()) + " for actor:" + GetAID().Name();
125 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
126 }
127 for (const auto &device_tensor : device_tensors) {
128 if (device_tensor.first >= input_device_tensors_.size()) {
129 std::string error_info = "Invalid device tensor index:" + std::to_string(device_tensor.first) +
130 " vector size:" + std::to_string(input_device_tensors_.size()) +
131 " for actor:" + GetAID().Name();
132 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
133 }
134 input_device_tensors_[device_tensor.first] = device_tensor.second;
135 }
136
137 // Collect the partials.
138 for (const auto &partial : partials) {
139 if (partial.first >= input_partials_.size()) {
140 std::string error_info = "Invalid partial index:" + std::to_string(partial.first) +
141 " vector size:" + std::to_string(partials.size()) + " for actor:" + GetAID().Name();
142 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
143 }
144 input_partials_[partial.first] = partial.second;
145 }
146 }
147
148 // Init the device tensor in output data.
149 for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
150 if (output_data_by_output_index_[i].empty()) {
151 continue;
152 }
153 const auto &data = input_device_tensors_[i];
154 if (data == nullptr) {
155 std::string error_info = "Input data index:" + std::to_string(i) + " for actor:" + GetAID().Name() + " is empty!";
156 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
157 }
158 for (auto &output_data : output_data_by_output_index_[i]) {
159 MS_EXCEPTION_IF_NULL(output_data);
160 output_data->data_ = data;
161 }
162 }
163 }
164
CheckRunningCondition(const OpContext<DeviceTensor> * context) const165 bool EntranceActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
166 MS_EXCEPTION_IF_NULL(context);
167
168 // Check the running condition in the begin execution of step.
169 // The input controls and input data exist the begin execution of root graph, and there will only be one of the two.
170 if (!is_loop_body_execution_) {
171 if (input_controls_num_ != 0) {
172 const auto &control_iter = input_op_controls_.find(context->sequential_num_);
173 if ((control_iter != input_op_controls_.end()) && (control_iter->second.size() == input_controls_num_)) {
174 return true;
175 }
176 }
177
178 // Data comes from the data source actor.
179 if (input_datas_num_ != 0) {
180 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
181 if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) {
182 return true;
183 }
184 }
185 }
186
187 // Check the controls in the loop body execution of step.
188 if (is_loop_body_execution_ && (loop_body_input_controls_nums_ != 0)) {
189 const auto &control_iter = loop_body_input_op_controls_.find(context->sequential_num_);
190 if ((control_iter == loop_body_input_op_controls_.end()) ||
191 (control_iter->second.size() != loop_body_input_controls_nums_)) {
192 return false;
193 }
194 }
195
196 // Data comes from the gather actor.
197 const auto &iter = real_parameters_with_branch_id_.find(context->sequential_num_);
198 if (iter == real_parameters_with_branch_id_.end() || iter->second.empty()) {
199 return false;
200 }
201 return true;
202 }
203
EraseInput(const OpContext<DeviceTensor> * const context)204 void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
205 MS_EXCEPTION_IF_NULL(context);
206 auto &sequential_num = context->sequential_num_;
207
208 const auto &data_iter = input_op_datas_.find(sequential_num);
209 if (data_iter != input_op_datas_.end()) {
210 (void)input_op_datas_.erase(data_iter);
211 }
212
213 const auto &control_iter = input_op_controls_.find(sequential_num);
214 if (control_iter != input_op_controls_.end()) {
215 (void)input_op_controls_.erase(control_iter);
216 }
217
218 const auto &loop_body_control_iter = loop_body_input_op_controls_.find(sequential_num);
219 if (loop_body_control_iter != loop_body_input_op_controls_.end()) {
220 (void)loop_body_input_op_controls_.erase(loop_body_control_iter);
221 }
222
223 const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
224 if (iter != real_parameters_with_branch_id_.end()) {
225 iter->second.pop();
226 if (iter->second.empty()) {
227 (void)real_parameters_with_branch_id_.erase(sequential_num);
228 }
229 }
230 }
231
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)232 void EntranceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
233 MS_EXCEPTION_IF_NULL(context);
234 const auto &sequential_num = context->sequential_num_;
235
236 // Collect the input device tensors.
237 std::vector<DeviceTensor *> memory_free_list;
238 if (input_op_datas_.count(sequential_num) > 0) {
239 for (auto &input_data : input_op_datas_[sequential_num]) {
240 MS_EXCEPTION_IF_NULL(input_data);
241 MS_EXCEPTION_IF_NULL(input_data->data_);
242 (void)memory_free_list.emplace_back(input_data->data_);
243 }
244 }
245
246 const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
247 if (iter != real_parameters_with_branch_id_.end()) {
248 if (iter->second.empty()) {
249 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The real parameter with branch id is empty.");
250 }
251 auto &real_parameters_with_branch_id = iter->second.front();
252 GetAllDeviceTensors(real_parameters_with_branch_id, &memory_free_list);
253 }
254
255 if (memory_free_list.size() > 0) {
256 memory_free_lists_.push(memory_free_list);
257 if (ActorDispatcher::is_memory_free_sync()) {
258 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
259 device_contexts_[0], context, GetAID());
260 } else {
261 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
262 device_contexts_[0], context, GetAID());
263 }
264 }
265 }
266 } // namespace runtime
267 } // namespace mindspore
268