• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
18 #include "include/backend/mem_reuse/mem_tracker.h"
19 
20 namespace mindspore {
21 namespace runtime {
ConditionGatherActor(const std::string & name,const CNodePtr & kernel,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,GraphExecutionStrategy strategy,const std::set<size_t> & modifiable_ref_input_indexes,const std::set<size_t> & modifiable_ref_output_indexes,const KernelTransformType & type)22 ConditionGatherActor::ConditionGatherActor(const std::string &name, const CNodePtr &kernel,
23                                            const DeviceContext *device_context, const AID &memory_manager_aid,
24                                            const AID *debug_aid, const AID *recorder_aid,
25                                            GraphExecutionStrategy strategy,
26                                            const std::set<size_t> &modifiable_ref_input_indexes,
27                                            const std::set<size_t> &modifiable_ref_output_indexes,
28                                            const KernelTransformType &type)
29     : KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
30                   modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
31 
~ConditionGatherActor()32 ConditionGatherActor::~ConditionGatherActor() {
33   for_each(need_clean_ptr_device_addresses_.begin(), need_clean_ptr_device_addresses_.end(),
34            [](const device::DeviceAddressPtr &device_address) { device_address->set_ptr(nullptr); });
35 }
36 
RunBranchName(const std::string & branch_name,OpContext<DeviceTensor> * const context)37 void ConditionGatherActor::RunBranchName(const std::string &branch_name, OpContext<DeviceTensor> *const context) {
38   MS_LOG(DEBUG) << "Condition gather actor:" << GetAID() << " branch name:" << branch_name;
39   current_branch_name_ = branch_name;
40   if (branch_name_to_input_data_num_.find(current_branch_name_) == branch_name_to_input_data_num_.end()) {
41     input_datas_num_ = 0;
42   } else {
43     input_datas_num_ = branch_name_to_input_data_num_[current_branch_name_];
44   }
45   if (branch_name_to_input_control_num_.find(current_branch_name_) == branch_name_to_input_control_num_.end()) {
46     input_controls_num_ = 0;
47   } else {
48     input_controls_num_ = branch_name_to_input_control_num_[current_branch_name_];
49   }
50   if (input_datas_num_ == 0 && input_controls_num_ == 0) {
51     MS_LOG(EXCEPTION) << "No input data and input control, branch id:" << current_branch_name_
52                       << " for actor:" << GetAID();
53   }
54   MS_LOG(DEBUG) << "Input data num:" << input_datas_num_ << " control num:" << input_controls_num_
55                 << " for actor:" << GetAID();
56 }
57 
Init()58 void ConditionGatherActor::Init() {
59   // Check device contexts number.
60   if (device_contexts_.size() != device::kDeviceContextsNumOne) {
61     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
62   }
63   MS_EXCEPTION_IF_NULL(device_contexts_[0]);
64   input_device_tensors_.resize(branch_output_num_);
65   InitOutputData();
66 
67   kernel_info_ = dynamic_cast<KernelInfo *>(kernel_->kernel_info());
68   MS_EXCEPTION_IF_NULL(kernel_info_);
69   const auto &output_addresses = kernel_info_->output_address_list();
70   const auto &somas_outputs = kernel_info_->somas_output_result();
71   if (output_addresses.size() != somas_outputs.size()) {
72     MS_LOG(DEBUG) << "Invalid output address size:" << output_addresses.size()
73                   << " and somas output size:" << somas_outputs.size() << " for actor:" << GetAID();
74   }
75   for (size_t i = 0; i < output_addresses.size(); ++i) {
76     auto &output_address = output_addresses[i];
77     MS_EXCEPTION_IF_NULL(output_address);
78     if (output_address->stream_id() != kernel_info_->stream_id()) {
79       MS_LOG(DEBUG) << "Output address : " << output_address << " stream id :" << output_address->stream_id()
80                     << " is not equal kernel info stream id : " << kernel_info_->stream_id() << ".";
81     }
82     (void)output_device_tensors_.emplace_back(output_address.get());
83     // The output taken over by soma does not need to allocate memory.
84     if (kernel_info_->IsTensorEnableSomas(somas_outputs, i)) {
85       // Somas outputs use the info of kernelMod, and output address use the info of device address.
86       if (somas_outputs[i].second < output_address->GetSize()) {
87         MS_LOG(DEBUG) << GetAID().Name() << " check somas size warning, output index:" << i
88                       << " somas aligned size:" << somas_outputs[i].second
89                       << " is smaller than address size:" << output_address->GetSize();
90       }
91       // Used to keep graph output address when somas block memory free, and reused by the ref conut in other graphs.
92       if (somas_graph_output_indexes_.count(i) > 0) {
93         MS_LOG(DEBUG) << "Somas keep output device address:" << output_address << " ptr:" << output_address->GetPtr();
94         MS_EXCEPTION_IF_NULL(somas_info_);
95         (void)somas_info_->InsertGraphOutputInfo(output_address.get(), somas_outputs[i].first, somas_outputs[i].second);
96         output_address->set_from_mem_pool(true);
97         need_clean_ptr_device_addresses_.emplace_back(output_address);
98       } else {
99         UpdateRefCount(output_address.get(), true);
100       }
101     }
102   }
103   if (output_device_tensors_.size() != input_device_tensors_.size()) {
104     MS_LOG(EXCEPTION) << "Invalid input tensor size:" << input_device_tensors_.size()
105                       << " and output size:" << output_device_tensors_.size() << " for actor:" << GetAID();
106   }
107 }
108 
FetchInput(OpContext<DeviceTensor> * const context)109 void ConditionGatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
110   MS_EXCEPTION_IF_NULL(context);
111   auto iter = std::find(branch_names_.begin(), branch_names_.end(), current_branch_name_);
112   if (iter == branch_names_.end()) {
113     MS_LOG(EXCEPTION) << "Invalid current branch name:" << current_branch_name_ << " total:" << branch_names_
114                       << " for actor:" << GetAID();
115   }
116   size_t start_index = branch_output_num_ * LongToSize(iter - branch_names_.begin());
117 
118   memory_free_list_.clear();
119   // Fetch input device tensor from input data.
120   const auto &data_iter = input_op_datas_.find(context->sequential_num_);
121   if (data_iter != input_op_datas_.end()) {
122     for (auto &input_data : data_iter->second) {
123       MS_EXCEPTION_IF_NULL(input_data);
124       if (IntToSize(input_data->index_) < start_index ||
125           IntToSize(input_data->index_) - start_index >= input_device_tensors_.size()) {
126         std::string error_info =
127           "Invalid input index:" + std::to_string(input_data->index_) + " start:" + std::to_string(start_index) +
128           " total:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
129         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
130       }
131       MS_EXCEPTION_IF_NULL(input_data->data_);
132       input_device_tensors_[IntToSize(input_data->index_) - start_index] = input_data->data_;
133 
134       memory_free_list_.emplace_back(input_data->data_);
135     }
136   }
137 
138   // Fetch input device tensor from device tensor store.
139   for (auto &device_tensor_store_key : device_tensor_store_keys_) {
140     if (device_tensor_store_key.first < start_index ||
141         device_tensor_store_key.first - start_index >= input_device_tensors_.size()) {
142       continue;
143     }
144     MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
145     auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
146                                                                 device_contexts_[0]->GetDeviceType());
147     if (device_tensor == nullptr) {
148       std::string error_info =
149         GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
150         ", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
151       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
152     }
153     input_device_tensors_[device_tensor_store_key.first - start_index] = device_tensor.get();
154   }
155 
156   if (output_data_.size() != output_data_arrows_.size()) {
157     MS_LOG(EXCEPTION) << "Invalid output data size:" << output_data_.size()
158                       << " and output data arrow size:" << output_data_arrows_.size() << " for actor:" << GetAID();
159   }
160 
161   for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
162     MS_EXCEPTION_IF_NULL(output_data_arrows_[i]);
163     MS_EXCEPTION_IF_NULL(output_data_[i].first);
164     const auto &from_index = output_data_arrows_[i]->from_output_index_;
165     if (IntToSize(from_index) >= input_device_tensors_.size() || from_index < 0) {
166       MS_LOG(EXCEPTION) << "Invalid from index:" << from_index << " to actor:" << output_data_arrows_[i]->to_op_id_
167                         << " to index:" << output_data_arrows_[i]->to_input_index_ << " for actor:" << GetAID();
168     }
169     if (input_device_tensors_[from_index] == nullptr) {
170       std::string error_info =
171         GetAID().Name() + " get input device tensor index:" + std::to_string(from_index) + " failed.";
172       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
173     }
174     output_data_[i].first->data_ = input_device_tensors_[from_index];
175     if (output_device_tensors_[from_index]->from_mem_pool()) {
176       input_device_tensors_[from_index]->set_from_mem_pool(true);
177     }
178   }
179 }
180 
Run(OpContext<DeviceTensor> * const context)181 void ConditionGatherActor::Run(OpContext<DeviceTensor> *const context) {
182   try {
183     MS_EXCEPTION_IF_NULL(kernel_);
184     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), kernel_->fullname_with_scope(),
185                                                    kernel_->func_graph()->ToString());
186     FetchInput(context);
187     if (memory_free_list_.size() > 0) {
188       SendMemoryFreeReq(context);
189     }
190     MS_LOG(DEBUG) << "Launch kernel:" << kernel_->fullname_with_scope();
191     EraseInput(context);
192     for (const auto &device_address : output_device_tensors_) {
193       device_address->set_ptr(nullptr);
194     }
195     SetSomasMemory(context);
196     SendOutput(context);
197   } catch (const std::exception &e) {
198     MsException::Instance().SetException();
199     std::string error_info =
200       "#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
201     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
202   }
203 }
204 }  // namespace runtime
205 }  // namespace mindspore
206