• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/control_flow/condition_switch_actor.h"
18 #include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
19 
20 namespace mindspore {
21 namespace runtime {
ConditionSwitchActor(const std::string & name,const CNodePtr & kernel,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,GraphExecutionStrategy strategy,const std::set<size_t> & modifiable_ref_input_indexes,const std::set<size_t> & modifiable_ref_output_indexes,const KernelTransformType & type)22 ConditionSwitchActor::ConditionSwitchActor(const std::string &name, const CNodePtr &kernel,
23                                            const DeviceContext *device_context, const AID &memory_manager_aid,
24                                            const AID *debug_aid, const AID *recorder_aid,
25                                            GraphExecutionStrategy strategy,
26                                            const std::set<size_t> &modifiable_ref_input_indexes,
27                                            const std::set<size_t> &modifiable_ref_output_indexes,
28                                            const KernelTransformType &type)
29     : KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
30                   modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
31 
Init()32 void ConditionSwitchActor::Init() {
33   // Check device contexts number.
34   if (device_contexts_.size() != device::kDeviceContextsNumOne) {
35     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
36   }
37   MS_EXCEPTION_IF_NULL(device_contexts_[0]);
38   input_device_tensors_.resize(common::AnfAlgo::GetInputTensorNum(kernel_));
39 
40   InitOutputData();
41   output_data_by_output_index_.resize(AnfAlgo::GetOutputTensorNum(kernel_));
42   if (output_data_.size() != output_data_arrows_.size()) {
43     MS_LOG(EXCEPTION) << "The output data size is wrong: " << GetAID().Name();
44   }
45 
46   for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
47     const auto &output_data = output_data_[i].first;
48     const auto &data_arrow = output_data_arrows_[i];
49     MS_EXCEPTION_IF_NULL(output_data);
50     MS_EXCEPTION_IF_NULL(data_arrow);
51     const auto &from_index = data_arrow->from_output_index_;
52     if (IntToSize(from_index) >= output_data_by_output_index_.size()) {
53       MS_LOG(EXCEPTION) << "Invalid from index:" << from_index
54                         << " and output size:" << output_data_by_output_index_.size() << " for actor:" << GetAID();
55     }
56     output_data_by_output_index_[from_index].emplace_back(output_data.get());
57   }
58   MS_LOG(DEBUG) << "Condition switch actor:" << GetAID() << " init: branch name:" << branch_names_
59                 << " branch origin ref count:" << branch_origin_ref_count_
60                 << " output data branch index:" << output_data_branch_indexes_
61                 << " output control branch index:" << output_control_branch_indexes_;
62 }
63 
SendOutput(OpContext<DeviceTensor> * const context,size_t index)64 void ConditionSwitchActor::SendOutput(OpContext<DeviceTensor> *const context, size_t index) {
65   MS_EXCEPTION_IF_NULL(gather_aid_);
66   MS_LOG(DEBUG) << "condition actor run for index:" << index << " branch name:" << branch_names_[index]
67                 << " for actor:" << GetAID();
68   ActorDispatcher::Send(*gather_aid_, &ConditionGatherActor::RunBranchName, branch_names_[index], context);
69 
70   if (output_data_arrows_.size() != output_data_nodes_.size() || output_data_nodes_.size() != output_data_.size() ||
71       output_data_.size() != output_data_branch_indexes_.size()) {
72     MS_LOG(EXCEPTION) << "Invalid data arrow size:" << output_data_arrows_.size()
73                       << " node size:" << output_data_nodes_.size() << " data size:" << output_data_.size()
74                       << " index size:" << output_data_branch_indexes_.size() << " for actor:" << GetAID();
75   }
76   for (size_t i = 0; i < output_data_branch_indexes_.size(); ++i) {
77     if (TEST_FLAG(output_data_[i].second, kOutputDataFlagToFusion)) {
78       if (data_arrow_to_fusion_actor_indexs_.find(output_data_arrows_[i].get()) ==
79           data_arrow_to_fusion_actor_indexs_.end()) {
80         MS_LOG(EXCEPTION) << "Failed to get real from index by output data arrow from index:"
81                           << output_data_arrows_[i]->from_output_index_ << " to " << output_data_arrows_[i]->to_op_id_
82                           << " by actor:" << GetAID();
83       }
84       output_data_[i].first->index_ = SizeToInt(data_arrow_to_fusion_actor_indexs_.at(output_data_arrows_[i].get()));
85     }
86     if (output_data_branch_indexes_[i] == index) {
87       ActorDispatcher::Send(output_data_arrows_[i]->to_op_id_, &OpActor::RunOpData, output_data_[i].first.get(),
88                             context);
89     }
90   }
91 
92   if (output_control_arrows_.size() != output_control_branch_indexes_.size()) {
93     MS_LOG(EXCEPTION) << "Invalid control arrow size:" << output_control_arrows_.size()
94                       << output_control_branch_indexes_.size() << " for actor:" << GetAID();
95   }
96   for (size_t i = 0; i < output_control_branch_indexes_.size(); ++i) {
97     MS_EXCEPTION_IF_NULL(output_control_arrows_[i]);
98     if (output_control_branch_indexes_[i] == index) {
99       ActorDispatcher::Send(output_control_arrows_[i]->to_op_id_, &OpActor::RunOpControl, const_cast<AID *>(&GetAID()),
100                             context);
101     }
102   }
103 }
104 
Run(OpContext<DeviceTensor> * const context)105 void ConditionSwitchActor::Run(OpContext<DeviceTensor> *const context) {
106   try {
107     if (!WaitRuntimePipelineFinish(context)) {
108       MS_LOG(INFO) << "Run failed and early stop.";
109       return;
110     }
111     FetchInput(context);
112     MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);
113     MS_EXCEPTION_IF_NULL(input_device_tensors_[0]->kernel_tensor());
114     bool index = input_device_tensors_[0]->kernel_tensor()->GetValueWithCheck<bool>();
115     if (common::IsNeedProfileMemory()) {
116       index = true;
117     }
118     MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
119     if (index >= branch_names_.size()) {
120       std::string error_info = "Invalid index:" + std::to_string(index) +
121                                " and branch size:" + std::to_string(branch_names_.size()) +
122                                " for actor:" + GetAID().Name();
123       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
124     }
125     EraseInput(context);
126     CollectMemoryFreeList(index);
127     if (memory_free_list_.size() > 0) {
128       SendMemoryFreeReq(context);
129     }
130     MS_LOG(DEBUG) << "Launch kernel:" << kernel_->fullname_with_scope() << " by index:" << index;
131     SendOutput(context, index);
132   } catch (const std::exception &e) {
133     MsException::Instance().SetException();
134     std::string error_info =
135       "#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
136     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
137   }
138 }
139 
CollectMemoryFreeList(size_t index)140 void ConditionSwitchActor::CollectMemoryFreeList(size_t index) {
141   memory_free_list_.clear();
142   memory_free_list_.insert(memory_free_list_.end(), input_device_tensors_.begin(), input_device_tensors_.end());
143   memory_free_list_.insert(memory_free_list_.end(), input_device_tensors_.begin() + 1, input_device_tensors_.end());
144   for (size_t i = 0; i < branch_origin_ref_count_.size(); ++i) {
145     if (i == index) {
146       continue;
147     }
148     if (branch_origin_ref_count_[i].size() + 1 != input_device_tensors_.size()) {
149       MS_LOG(EXCEPTION) << "Invalid origin ref count size:" << branch_origin_ref_count_[i]
150                         << " and input size:" << input_device_tensors_.size() << " for actor:" << GetAID();
151     }
152     MS_LOG(DEBUG) << "Free memory for branch:" << i << " for actor:" << GetAID();
153     for (size_t j = 0; j < branch_origin_ref_count_[i].size(); ++j) {
154       std::fill_n(back_inserter(memory_free_list_), branch_origin_ref_count_[i][j], input_device_tensors_[j + 1]);
155     }
156   }
157 }
158 
FetchInput(OpContext<DeviceTensor> * const context)159 void ConditionSwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
160   MS_EXCEPTION_IF_NULL(context);
161 
162   // Fetch input device tensor from input data.
163   const auto &data_iter = input_op_datas_.find(context->sequential_num_);
164   if (data_iter != input_op_datas_.end()) {
165     for (auto &input_data : data_iter->second) {
166       MS_EXCEPTION_IF_NULL(input_data);
167       if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
168         std::string error_info = "Invalid input index, need:" + std::to_string(input_data->index_) +
169                                  " current:" + std::to_string(input_device_tensors_.size()) +
170                                  " for actor:" + GetAID().Name();
171         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
172       }
173       MS_EXCEPTION_IF_NULL(input_data->data_);
174       input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
175     }
176   }
177 
178   // Fetch input device tensor from device tensor store.
179   for (auto &device_tensor_store_key : device_tensor_store_keys_) {
180     MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
181     auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
182                                                                 device_contexts_[0]->GetDeviceType());
183     if (device_tensor == nullptr) {
184       std::string error_info =
185         GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
186         ", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
187       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
188     }
189 
190     if (device_tensor_store_key.first >= input_device_tensors_.size()) {
191       std::string error_info =
192         "The input index is out of range, need:" + std::to_string(device_tensor_store_key.first) +
193         " current:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
194       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
195     }
196     MS_EXCEPTION_IF_NULL(device_tensor);
197     input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
198   }
199 
200   if (output_data_by_output_index_.size() + 1 != input_device_tensors_.size()) {
201     MS_LOG(EXCEPTION) << "Invalid output size:" << output_data_by_output_index_.size()
202                       << " and input device tensor size:" << input_device_tensors_.size() << " for actor:" << GetAID();
203   }
204 
205   for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
206     if (output_data_by_output_index_[i].empty()) {
207       continue;
208     }
209     const auto &data = input_device_tensors_[i + 1];
210     MS_EXCEPTION_IF_NULL(data);
211     for (auto &output_data : output_data_by_output_index_[i]) {
212       MS_EXCEPTION_IF_NULL(output_data);
213       output_data->data_ = data;
214     }
215   }
216 }
217 }  // namespace runtime
218 }  // namespace mindspore
219