1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/control_flow/condition_switch_actor.h"
18 #include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
19
20 namespace mindspore {
21 namespace runtime {
ConditionSwitchActor(const std::string & name,const CNodePtr & kernel,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,GraphExecutionStrategy strategy,const std::set<size_t> & modifiable_ref_input_indexes,const std::set<size_t> & modifiable_ref_output_indexes,const KernelTransformType & type)22 ConditionSwitchActor::ConditionSwitchActor(const std::string &name, const CNodePtr &kernel,
23 const DeviceContext *device_context, const AID &memory_manager_aid,
24 const AID *debug_aid, const AID *recorder_aid,
25 GraphExecutionStrategy strategy,
26 const std::set<size_t> &modifiable_ref_input_indexes,
27 const std::set<size_t> &modifiable_ref_output_indexes,
28 const KernelTransformType &type)
29 : KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
30 modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
31
Init()32 void ConditionSwitchActor::Init() {
33 // Check device contexts number.
34 if (device_contexts_.size() != device::kDeviceContextsNumOne) {
35 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
36 }
37 MS_EXCEPTION_IF_NULL(device_contexts_[0]);
38 input_device_tensors_.resize(common::AnfAlgo::GetInputTensorNum(kernel_));
39
40 InitOutputData();
41 output_data_by_output_index_.resize(AnfAlgo::GetOutputTensorNum(kernel_));
42 if (output_data_.size() != output_data_arrows_.size()) {
43 MS_LOG(EXCEPTION) << "The output data size is wrong: " << GetAID().Name();
44 }
45
46 for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
47 const auto &output_data = output_data_[i].first;
48 const auto &data_arrow = output_data_arrows_[i];
49 MS_EXCEPTION_IF_NULL(output_data);
50 MS_EXCEPTION_IF_NULL(data_arrow);
51 const auto &from_index = data_arrow->from_output_index_;
52 if (IntToSize(from_index) >= output_data_by_output_index_.size()) {
53 MS_LOG(EXCEPTION) << "Invalid from index:" << from_index
54 << " and output size:" << output_data_by_output_index_.size() << " for actor:" << GetAID();
55 }
56 output_data_by_output_index_[from_index].emplace_back(output_data.get());
57 }
58 MS_LOG(DEBUG) << "Condition switch actor:" << GetAID() << " init: branch name:" << branch_names_
59 << " branch origin ref count:" << branch_origin_ref_count_
60 << " output data branch index:" << output_data_branch_indexes_
61 << " output control branch index:" << output_control_branch_indexes_;
62 }
63
SendOutput(OpContext<DeviceTensor> * const context,size_t index)64 void ConditionSwitchActor::SendOutput(OpContext<DeviceTensor> *const context, size_t index) {
65 MS_EXCEPTION_IF_NULL(gather_aid_);
66 MS_LOG(DEBUG) << "condition actor run for index:" << index << " branch name:" << branch_names_[index]
67 << " for actor:" << GetAID();
68 ActorDispatcher::Send(*gather_aid_, &ConditionGatherActor::RunBranchName, branch_names_[index], context);
69
70 if (output_data_arrows_.size() != output_data_nodes_.size() || output_data_nodes_.size() != output_data_.size() ||
71 output_data_.size() != output_data_branch_indexes_.size()) {
72 MS_LOG(EXCEPTION) << "Invalid data arrow size:" << output_data_arrows_.size()
73 << " node size:" << output_data_nodes_.size() << " data size:" << output_data_.size()
74 << " index size:" << output_data_branch_indexes_.size() << " for actor:" << GetAID();
75 }
76 for (size_t i = 0; i < output_data_branch_indexes_.size(); ++i) {
77 if (TEST_FLAG(output_data_[i].second, kOutputDataFlagToFusion)) {
78 if (data_arrow_to_fusion_actor_indexs_.find(output_data_arrows_[i].get()) ==
79 data_arrow_to_fusion_actor_indexs_.end()) {
80 MS_LOG(EXCEPTION) << "Failed to get real from index by output data arrow from index:"
81 << output_data_arrows_[i]->from_output_index_ << " to " << output_data_arrows_[i]->to_op_id_
82 << " by actor:" << GetAID();
83 }
84 output_data_[i].first->index_ = SizeToInt(data_arrow_to_fusion_actor_indexs_.at(output_data_arrows_[i].get()));
85 }
86 if (output_data_branch_indexes_[i] == index) {
87 ActorDispatcher::Send(output_data_arrows_[i]->to_op_id_, &OpActor::RunOpData, output_data_[i].first.get(),
88 context);
89 }
90 }
91
92 if (output_control_arrows_.size() != output_control_branch_indexes_.size()) {
93 MS_LOG(EXCEPTION) << "Invalid control arrow size:" << output_control_arrows_.size()
94 << output_control_branch_indexes_.size() << " for actor:" << GetAID();
95 }
96 for (size_t i = 0; i < output_control_branch_indexes_.size(); ++i) {
97 MS_EXCEPTION_IF_NULL(output_control_arrows_[i]);
98 if (output_control_branch_indexes_[i] == index) {
99 ActorDispatcher::Send(output_control_arrows_[i]->to_op_id_, &OpActor::RunOpControl, const_cast<AID *>(&GetAID()),
100 context);
101 }
102 }
103 }
104
Run(OpContext<DeviceTensor> * const context)105 void ConditionSwitchActor::Run(OpContext<DeviceTensor> *const context) {
106 try {
107 if (!WaitRuntimePipelineFinish(context)) {
108 MS_LOG(INFO) << "Run failed and early stop.";
109 return;
110 }
111 FetchInput(context);
112 MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);
113 MS_EXCEPTION_IF_NULL(input_device_tensors_[0]->kernel_tensor());
114 bool index = input_device_tensors_[0]->kernel_tensor()->GetValueWithCheck<bool>();
115 if (common::IsNeedProfileMemory()) {
116 index = true;
117 }
118 MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
119 if (index >= branch_names_.size()) {
120 std::string error_info = "Invalid index:" + std::to_string(index) +
121 " and branch size:" + std::to_string(branch_names_.size()) +
122 " for actor:" + GetAID().Name();
123 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
124 }
125 EraseInput(context);
126 CollectMemoryFreeList(index);
127 if (memory_free_list_.size() > 0) {
128 SendMemoryFreeReq(context);
129 }
130 MS_LOG(DEBUG) << "Launch kernel:" << kernel_->fullname_with_scope() << " by index:" << index;
131 SendOutput(context, index);
132 } catch (const std::exception &e) {
133 MsException::Instance().SetException();
134 std::string error_info =
135 "#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
136 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
137 }
138 }
139
CollectMemoryFreeList(size_t index)140 void ConditionSwitchActor::CollectMemoryFreeList(size_t index) {
141 memory_free_list_.clear();
142 memory_free_list_.insert(memory_free_list_.end(), input_device_tensors_.begin(), input_device_tensors_.end());
143 memory_free_list_.insert(memory_free_list_.end(), input_device_tensors_.begin() + 1, input_device_tensors_.end());
144 for (size_t i = 0; i < branch_origin_ref_count_.size(); ++i) {
145 if (i == index) {
146 continue;
147 }
148 if (branch_origin_ref_count_[i].size() + 1 != input_device_tensors_.size()) {
149 MS_LOG(EXCEPTION) << "Invalid origin ref count size:" << branch_origin_ref_count_[i]
150 << " and input size:" << input_device_tensors_.size() << " for actor:" << GetAID();
151 }
152 MS_LOG(DEBUG) << "Free memory for branch:" << i << " for actor:" << GetAID();
153 for (size_t j = 0; j < branch_origin_ref_count_[i].size(); ++j) {
154 std::fill_n(back_inserter(memory_free_list_), branch_origin_ref_count_[i][j], input_device_tensors_[j + 1]);
155 }
156 }
157 }
158
FetchInput(OpContext<DeviceTensor> * const context)159 void ConditionSwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
160 MS_EXCEPTION_IF_NULL(context);
161
162 // Fetch input device tensor from input data.
163 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
164 if (data_iter != input_op_datas_.end()) {
165 for (auto &input_data : data_iter->second) {
166 MS_EXCEPTION_IF_NULL(input_data);
167 if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
168 std::string error_info = "Invalid input index, need:" + std::to_string(input_data->index_) +
169 " current:" + std::to_string(input_device_tensors_.size()) +
170 " for actor:" + GetAID().Name();
171 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
172 }
173 MS_EXCEPTION_IF_NULL(input_data->data_);
174 input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
175 }
176 }
177
178 // Fetch input device tensor from device tensor store.
179 for (auto &device_tensor_store_key : device_tensor_store_keys_) {
180 MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
181 auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
182 device_contexts_[0]->GetDeviceType());
183 if (device_tensor == nullptr) {
184 std::string error_info =
185 GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
186 ", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
187 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
188 }
189
190 if (device_tensor_store_key.first >= input_device_tensors_.size()) {
191 std::string error_info =
192 "The input index is out of range, need:" + std::to_string(device_tensor_store_key.first) +
193 " current:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
194 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
195 }
196 MS_EXCEPTION_IF_NULL(device_tensor);
197 input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
198 }
199
200 if (output_data_by_output_index_.size() + 1 != input_device_tensors_.size()) {
201 MS_LOG(EXCEPTION) << "Invalid output size:" << output_data_by_output_index_.size()
202 << " and input device tensor size:" << input_device_tensors_.size() << " for actor:" << GetAID();
203 }
204
205 for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
206 if (output_data_by_output_index_[i].empty()) {
207 continue;
208 }
209 const auto &data = input_device_tensors_[i + 1];
210 MS_EXCEPTION_IF_NULL(data);
211 for (auto &output_data : output_data_by_output_index_[i]) {
212 MS_EXCEPTION_IF_NULL(output_data);
213 output_data->data_ = data;
214 }
215 }
216 }
217 } // namespace runtime
218 } // namespace mindspore
219