1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/actor/switch_actor.h"
18 #include "runtime/framework/actor/output_actor.h"
19 #include "runtime/framework/actor/gather_actor.h"
20 #include "runtime/framework/actor/memory_manager_actor.h"
21 #include "mindrt/include/async/async.h"
22 #include "abstract/utils.h"
23 #include "utils/log_adapter.h"
24
25 namespace mindspore {
26 namespace runtime {
Init()27 void SwitchActor::Init() {
28 // Init output data.
29 output_data_.resize(output_branch_arrows_.size());
30 for (size_t i = 0; i < output_branch_arrows_.size(); ++i) {
31 auto &output_branch_arrow = output_branch_arrows_[i];
32 auto &output_data = output_data_[i];
33 for (auto &data_arrow : output_branch_arrow) {
34 MS_EXCEPTION_IF_NULL(data_arrow);
35 auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
36 (void)output_data.emplace_back(std::move(data));
37 }
38 }
39 }
40
RunOpData(OpData<DeviceTensor> * input_data,OpContext<DeviceTensor> * const context)41 void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context) {
42 MS_EXCEPTION_IF_NULL(context);
43 const auto &sequential_num = context->sequential_num_;
44 auto &input_datas = input_data_[sequential_num];
45 input_datas[input_data->index_].push(input_data->data_);
46
47 if (CheckLaunchCondition(context)) {
48 FetchInputDeviceTensor(context);
49 EraseInput(context);
50 SendOutput(context);
51 }
52 }
53
RunOpControl(AID * input_control,OpContext<DeviceTensor> * context)54 void SwitchActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
55 MS_EXCEPTION_IF_NULL(context);
56 auto &sequential_num = context->sequential_num_;
57 if (input_controls_[sequential_num].find(input_control) == input_controls_[sequential_num].end()) {
58 input_controls_[sequential_num][input_control] = 0;
59 }
60 input_controls_[sequential_num][input_control]++;
61
62 if (CheckLaunchCondition(context)) {
63 FetchInputDeviceTensor(context);
64 EraseInput(context);
65 SendOutput(context);
66 }
67 }
68
CollectBranchId(const int branch_id,OpContext<DeviceTensor> * const context)69 void SwitchActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context) {
70 MS_EXCEPTION_IF_NULL(context);
71 auto &sequential_num = context->sequential_num_;
72 input_branch_ids_[sequential_num].push(branch_id);
73 }
74
ParseInput(const ControlNodeParserPtr & parser)75 void SwitchActor::ParseInput(const ControlNodeParserPtr &parser) {
76 std::vector<AnfNodePtr> inputs = node_->inputs();
77
78 if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
79 ParseSwitchInput();
80 } else if (IsPrimitive(inputs[0], prim::kPrimReturn)) {
81 ParseReturnInput(parser);
82 } else {
83 ParseSwitchLayerInput();
84 }
85 backend_parameters_.resize(input_nodes_.size());
86 }
87
ParsePartialInput(const AnfNodePtr & node,const size_t branch_id)88 void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_id) {
89 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
90 CNodePtr cnode = node->cast<CNodePtr>();
91
92 // The inputs of the Partial node is:
93 // [0] ValueNode<Primitive> kPartial.
94 // [1] ValueNode<FuncGraphPtr>.
95 // [2..] Inputs.
96 auto partial_inputs = cnode->inputs();
97 if (partial_inputs.size() <= kPartialFuncGraphPos) {
98 MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode);
99 }
100
101 auto func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
102
103 branch_func_graph_[branch_id] = func_graph;
104 for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
105 AddInput(partial_inputs[j], branch_id);
106 }
107 } else if (IsValueNode<FuncGraph>(node)) {
108 const auto func_graph = GetValueNode<FuncGraphPtr>(node);
109 branch_func_graph_[branch_id] = func_graph;
110 } else {
111 AddInput(node, branch_id);
112 }
113 }
114
InitVectorSize(const size_t num)115 void SwitchActor::InitVectorSize(const size_t num) {
116 branch_inputs_pos_.resize(num);
117 branch_func_graph_.resize(num);
118 output_branch_arrows_.resize(num);
119 output_branch_result_arrows_.resize(num);
120 output_branch_control_arrows_.resize(num);
121 output_branch_branch_arrows_.resize(num);
122 }
123
ParseReturnInput(const ControlNodeParserPtr & parser)124 void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) {
125 const auto &func_graph = node_->func_graph();
126 MS_EXCEPTION_IF_NULL(func_graph);
127 const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
128 InitVectorSize(call_num);
129
130 AddCommonInput(func_graph->output());
131 }
132
ParseSwitchInput()133 void SwitchActor::ParseSwitchInput() {
134 // The inputs of the switch node:
135 // [0] ValueNode<Primitive> kSwitch.
136 // [1] switch condition.
137 // [2] Partial node: true branch.
138 // [3] Partial node: false branch.
139 std::vector<AnfNodePtr> inputs = node_->inputs();
140 if (inputs.size() != kSwitchInputNum) {
141 MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4";
142 }
143
144 InitVectorSize(kSwitchPartialNum);
145
146 const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0);
147 input_nodes_.push_back(cond_node);
148 input_datas_num_++;
149 // Init the two branches of switch node.
150 ParsePartialInput(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
151 ParsePartialInput(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
152 }
153
ParseSwitchLayerInput()154 void SwitchActor::ParseSwitchLayerInput() {
155 // The inputs of the switch node:
156 // [0] ValueNode<Primitive> kSwitchLayer.
157 // [1] switchLayer index.
158 // [2] MakeTuple node: tuple of branches.
159 std::vector<AnfNodePtr> inputs = node_->inputs();
160 if (inputs.size() != kSwitchLayerInputNum) {
161 MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3";
162 }
163
164 const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchLayerCondPos], 0);
165 input_nodes_.push_back(cond_node);
166 input_datas_num_++;
167
168 // The second input of SwitchLayer is maketuple node, which includes all branches.
169 auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
170 InitVectorSize(branch_nodes.size() - 1);
171
172 // Parse all branches.
173 for (size_t i = kMakeTupleInputStartPos; i < branch_nodes.size(); ++i) {
174 if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
175 ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos);
176 } else if (branch_nodes[i]->isa<ValueNode>()) {
177 branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
178 }
179 }
180 }
181
AddCommonInput(const AnfNodePtr & node)182 void SwitchActor::AddCommonInput(const AnfNodePtr &node) {
183 for (size_t i = 0; i < branch_inputs_pos_.size(); ++i) {
184 AddInput(node, i);
185 }
186 }
187
FetchDataNodePosition(const AnfNodePtr & data_node) const188 size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
189 const auto data_node_with_index = AnfAlgo::VisitKernelWithReturnType(data_node, 0);
190 const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node_with_index);
191 if (iter == input_nodes_.end()) {
192 MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
193 << " is not exist in switch actor:" << GetAID();
194 }
195 return iter - input_nodes_.begin();
196 }
197
AddInput(const KernelWithIndex node_with_index,const size_t branch)198 void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) {
199 const auto &node = node_with_index.first;
200
201 // The value node and weight node need to be placed in the device store. The switch actor has three inputs:
202 // 1) The input of the switch is the value node.
203 // 2) There is a weight node or value node in the return of the sub funcgraph.
204 if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
205 node->isa<ValueNode>()) {
206 const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
207 if (iter != input_nodes_.end()) {
208 branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
209 return;
210 }
211 (void)device_tensor_store_keys_.emplace_back(input_nodes_.size(), node.get());
212 branch_inputs_pos_[branch].push_back(input_nodes_.size());
213 input_nodes_.push_back(node_with_index);
214 return;
215 }
216
217 // Output of updatestate node is U, need to be skipped.
218 if (node->isa<Parameter>() && HasAbstractRef(node)) {
219 return;
220 }
221
222 // Add parameter.
223 auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
224 if (iter == input_nodes_.end()) {
225 branch_inputs_pos_[branch].push_back(input_nodes_.size());
226 input_nodes_.push_back(node_with_index);
227 ++input_datas_num_;
228 } else {
229 branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
230 }
231 }
232
AddInput(const AnfNodePtr & node,const size_t branch)233 void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
234 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) || HasAbstractMonad(node)) {
235 return;
236 }
237
238 const auto &real_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
239
240 if (AnfAlgo::CheckPrimitiveType(real_input.first, prim::kPrimMakeTuple)) {
241 const auto &inputs = real_input.first->cast<CNodePtr>()->inputs();
242 for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
243 AddInput(inputs[i], branch);
244 }
245 } else if (IsCallNode(real_input.first)) {
246 std::vector<AnfNodePtr> call_nodes;
247 const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
248 if (call_output_num == 0) {
249 MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
250 }
251 for (size_t i = 0; i < call_output_num; ++i) {
252 AddInput({real_input.first, i}, branch);
253 }
254 } else if (real_input.first->isa<ValueNode>() && real_input.first->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
255 const auto &value = real_input.first->cast<ValueNodePtr>()->value();
256 const auto &tuple_value = value->cast<ValueTuplePtr>();
257 for (size_t i = 0; i < tuple_value->value().size(); ++i) {
258 AddInput({real_input.first, i}, branch);
259 }
260 } else {
261 AddInput(real_input, branch);
262 }
263 }
264
GetIndex(const OpContext<DeviceTensor> * const context)265 size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) {
266 if (need_branch_id_input_) {
267 if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() ||
268 input_branch_ids_[context->sequential_num_].empty()) {
269 MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name();
270 }
271 auto branch_id = input_branch_ids_[context->sequential_num_].top();
272 input_branch_ids_[context->sequential_num_].pop();
273 if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
274 MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() +
275 " branch id:" + std::to_string(branch_id);
276 }
277 return branch_id_to_index_[branch_id];
278 }
279
280 DeviceTensor *device_tensor = input_device_tensors_[0];
281 MS_EXCEPTION_IF_NULL(device_tensor);
282
283 auto inputs = node_->inputs();
284 TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
285 size_t size = abstract::TypeIdSize(type_id);
286 if (size > sizeof(int64_t)) {
287 MS_LOG(ERROR) << "Index must be Int type.";
288 }
289
290 int64_t index = 0;
291 char buf[kMaxSwitchCondSize] = {0};
292 ShapeVector host_shape;
293 if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
294 MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id)
295 << ", device type:" << std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
296 }
297
298 if (type_id == TypeId::kNumberTypeInt32) {
299 index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
300 } else if (type_id == TypeId::kNumberTypeInt64) {
301 index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
302 } else if (type_id == TypeId::kNumberTypeBool) {
303 bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
304 index = static_cast<int64_t>(cond ? 1 : 0);
305 } else {
306 MS_LOG(ERROR) << "Index must be Int type.";
307 }
308
309 // SwitchLayer node support negative index range [-size, -1].
310 if (index < 0) {
311 index += SizeToInt(branch_func_graph_.size());
312 }
313 return static_cast<size_t>(index);
314 }
315
CheckLaunchCondition(OpContext<DeviceTensor> * const context) const316 bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
317 MS_EXCEPTION_IF_NULL(context);
318 if (input_datas_num_ != 0) {
319 auto data_iter = input_data_.find(context->sequential_num_);
320 if (data_iter == input_data_.end()) {
321 return false;
322 }
323 if (data_iter->second.size() != input_datas_num_) {
324 return false;
325 }
326 if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
327 [](const auto &input_stack) { return input_stack.second.empty(); })) {
328 return false;
329 }
330 }
331
332 if (input_controls_num_ != 0) {
333 auto data_iter = input_controls_.find(context->sequential_num_);
334 if (data_iter == input_controls_.end()) {
335 return false;
336 }
337 if (data_iter->second.size() != input_controls_num_) {
338 return false;
339 }
340 if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
341 [](const auto &input_stack) { return input_stack.second == 0; })) {
342 return false;
343 }
344 }
345
346 return true;
347 }
348
FetchInputDeviceTensor(OpContext<DeviceTensor> * const context)349 void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
350 MS_EXCEPTION_IF_NULL(context);
351 input_device_tensors_.resize(input_nodes_.size());
352 auto data_iter = input_data_.find(context->sequential_num_);
353 if (data_iter != input_data_.end()) {
354 for (auto &input_data : data_iter->second) {
355 input_device_tensors_[input_data.first] = input_data.second.top();
356 input_data.second.pop();
357 }
358 }
359
360 for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
361 auto device_tensor =
362 DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
363 if (device_tensor == nullptr) {
364 std::string error_info =
365 GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
366 ", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
367 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
368 }
369 input_device_tensors_[device_tensor_store_key.first] = device_tensor;
370 }
371
372 auto control_iter = input_controls_.find(context->sequential_num_);
373 if (control_iter != input_controls_.end()) {
374 (void)for_each(control_iter->second.begin(), control_iter->second.end(),
375 [](auto &input_control) { input_control.second--; });
376 }
377 }
378
SendOutput(OpContext<DeviceTensor> * context)379 void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
380 MS_EXCEPTION_IF_NULL(context);
381 auto index = GetIndex(context);
382 if (index >= output_branch_arrows_.size()) {
383 std::string error_info = "Switch actor:" + GetAID().Name() + " invalid index:" + std::to_string(index);
384 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
385 }
386
387 // Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal
388 // timing problem.
389 // 1.Send branch id.
390 if (local_branch_id_ >= 0) {
391 const auto &branch_arrows = output_branch_branch_arrows_[index];
392 for (const auto &branch_arrow : branch_arrows) {
393 Async(branch_arrow, &GatherActor::CollectBranchId, local_branch_id_, context);
394 }
395 }
396
397 // 2.Send result.
398 auto &output_branch_result_arrow = output_branch_result_arrows_[index];
399 for (size_t i = 0; i < output_branch_result_arrow.size(); ++i) {
400 auto &result_arrow = output_branch_result_arrow[i];
401 MS_EXCEPTION_IF_NULL(result_arrow);
402 if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) {
403 std::string error_info =
404 "Invalid from index in switch actor, from index:" + std::to_string(result_arrow->from_output_index_) +
405 " total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name();
406 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
407 }
408 size_t from_index = branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)];
409
410 bool is_send = false;
411 for (const auto &backend_node : backend_parameters_[from_index]) {
412 for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(backend_node.first); ++j) {
413 if (backend_node.first->kernel_info() != nullptr && AnfAlgo::OutputAddrExist(backend_node.first, j, false) &&
414 AnfAlgo::GetMutableOutputAddr(backend_node.first, j, false).get() == input_device_tensors_[from_index]) {
415 auto output_index = j;
416 Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, output_index,
417 result_arrow->to_input_index_, context);
418 is_send = true;
419 break;
420 }
421 }
422 }
423 if (!is_send) {
424 std::string error_info = "Failed to get backend node of switch actor output, actor:" + GetAID().Name() +
425 " branch:" + std::to_string(index) +
426 " index:" + std::to_string(result_arrow->from_output_index_) + " output pos" +
427 std::to_string(branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)]) +
428 " output index" + std::to_string(result_arrow->to_input_index_);
429 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
430 }
431 }
432
433 // 3.Send Data.
434 auto &output_branch_arrow = output_branch_arrows_[index];
435 auto &output_data = output_data_[index];
436 for (size_t i = 0; i < output_branch_arrow.size(); ++i) {
437 auto &data_arrow = output_branch_arrow[i];
438 auto &data = output_data[i];
439 MS_EXCEPTION_IF_NULL(data_arrow);
440 MS_EXCEPTION_IF_NULL(data);
441 data->data_ = input_device_tensors_[IntToSize(data_arrow->from_output_index_)];
442 Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
443 }
444
445 // 4.Send output control.
446 auto source_aid = const_cast<AID *>(&GetAID());
447 for (auto &output_control : output_branch_control_arrows_[index]) {
448 Async(output_control, &OpActor::RunOpControl, source_aid, context);
449 }
450 }
451
EraseInput(OpContext<DeviceTensor> * const context)452 void SwitchActor::EraseInput(OpContext<DeviceTensor> *const context) {
453 MS_EXCEPTION_IF_NULL(context);
454 auto data_iter = input_data_.find(context->sequential_num_);
455 if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(),
456 [](const auto &input_data) { return input_data.second.empty(); })) {
457 auto ret = input_data_.erase(context->sequential_num_);
458 if (ret == 0) {
459 MS_LOG(WARNING) << "Erase input data failed for switch actor: " << GetAID();
460 }
461 }
462
463 if (input_controls_num_ != 0) {
464 auto control_iter = input_controls_.find(context->sequential_num_);
465 if (control_iter != input_controls_.end() &&
466 std::all_of(control_iter->second.begin(), control_iter->second.end(),
467 [](const auto &input_control) { return input_control.second == 0; })) {
468 auto ret = input_controls_.erase(context->sequential_num_);
469 if (ret == 0) {
470 std::string error_info = "Erase input control failed: " + GetAID().Name();
471 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
472 }
473 }
474 }
475 }
476
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)477 void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
478 Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
479 }
480
FetchInputNode(const ControlNodeParserPtr & parser)481 void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) {
482 for (size_t i = 0; i < input_nodes_.size(); ++i) {
483 const auto &input_node = input_nodes_[i].first;
484 if (!(input_node->isa<Parameter>() && HasAbstractRef(input_node))) {
485 backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node);
486 continue;
487 }
488
489 const auto &backend_weight = parser->FetchBackendNodebyWeightNode(input_node);
490 if (backend_weight == nullptr) {
491 MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node);
492 }
493 (void)backend_parameters_[i].emplace(backend_weight, 0);
494 }
495 }
496 } // namespace runtime
497 } // namespace mindspore
498