• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/abstract_actor.h"
18 #include "runtime/graph_scheduler/actor/output_actor.h"
19 #include "utils/log_adapter.h"
20 
21 namespace mindspore {
22 namespace runtime {
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)23 void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
24   MS_EXCEPTION_IF_NULL(input_data);
25   MS_EXCEPTION_IF_NULL(input_data->data_);
26   // The unused data may be invalid ptr.
27   if (!ActorDispatcher::enable_async_launch_kernel() && !input_data->data_->IsPtrValid() &&
28       (!TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNotUsed) &&
29        !TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNullptr))) {
30     std::string error_info = "The input_data does not have a valid ptr of actor:" + GetAID().Name() +
31                              " with index:" + std::to_string(input_data->index_) +
32                              ", flag:" + std::to_string(input_data->data_->flag()) +
33                              " device address:" + std::to_string((int64_t)(input_data->data_)) +
34                              " ref count:" + std::to_string(input_data->data_->ref_count()) +
35                              " dynamic ref count:" + std::to_string(input_data->data_->dynamic_ref_count()) +
36                              " origin ref count:" + std::to_string(input_data->data_->original_ref_count());
37     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
38   }
39   auto &sequential_num = context->sequential_num_;
40   (void)input_op_datas_[sequential_num].emplace_back(input_data);
41 
42   auto is_run = CheckRunningCondition(context);
43   MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data and check running condition:" << is_run
44                 << ", sequential num:" << sequential_num << ", the input data:" << input_data->data_
45                 << " input index:" << input_data->index_ << ", size:" << input_data->data_->GetSize()
46                 << " ptr:" << input_data->data_->GetMutablePtr()
47                 << ", origin ref count:" << input_data->data_->original_ref_count()
48                 << ", current ref count:" << input_data->data_->ref_count()
49                 << ", dynamic ref count:" << input_data->data_->dynamic_ref_count()
50                 << ", flag:" << input_data->data_->flag() << " user data:" << input_data->data_->user_data()
51                 << " from memory pool:" << input_data->data_->from_mem_pool();
52 
53   if (is_run) {
54     Run(context);
55   }
56 }
57 
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)58 void AbstractActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
59   auto &sequential_num = context->sequential_num_;
60   (void)input_op_controls_[sequential_num].emplace_back(input_control);
61 
62   auto is_run = CheckRunningCondition(context);
63   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
64                 << ") receive the input op control from:" << (input_control == nullptr ? "null" : input_control->Name())
65                 << " and check running condition:" << is_run << ", sequential num:" << sequential_num;
66   if (is_run) {
67     Run(context);
68   }
69 }
70 
RunBatchOpData(std::vector<OpData<DeviceTensor> * > * const batch_input_data,OpContext<DeviceTensor> * const context)71 void AbstractActor::RunBatchOpData(std::vector<OpData<DeviceTensor> *> *const batch_input_data,
72                                    OpContext<DeviceTensor> *const context) {
73   MS_EXCEPTION_IF_NULL(context);
74   MS_EXCEPTION_IF_NULL(batch_input_data);
75   MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
76                 << ") receive the batch input op data, sequential num:" << context->sequential_num_;
77   for (auto &input_data : *batch_input_data) {
78     RunOpData(input_data, context);
79   }
80 }
81 
CheckRunningCondition(const OpContext<DeviceTensor> * context) const82 bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
83   MS_EXCEPTION_IF_NULL(context);
84   if (input_datas_num_ != 0) {
85     const auto &data_iter = input_op_datas_.find(context->sequential_num_);
86     if (data_iter == input_op_datas_.end()) {
87       return false;
88     }
89     if (data_iter->second.size() < input_datas_num_) {
90       return false;
91     } else if (data_iter->second.size() > input_datas_num_) {
92       MS_LOG(ERROR) << "Invalid input data num:" << data_iter->second.size() << " need:" << input_datas_num_
93                     << " for actor:" << GetAID() << ", sequential num:" << context->sequential_num_;
94       return false;
95     }
96   }
97 
98   if (input_controls_num_ != 0) {
99     const auto &control_iter = input_op_controls_.find(context->sequential_num_);
100     if (control_iter == input_op_controls_.end()) {
101       return false;
102     }
103     if (control_iter->second.size() < input_controls_num_) {
104       return false;
105     } else if (control_iter->second.size() > input_controls_num_) {
106       MS_LOG(ERROR) << "Invalid input control num:" << control_iter->second.size() << " need:" << input_controls_num_
107                     << " for actor:" << GetAID() << ", sequential num:" << context->sequential_num_;
108       return false;
109     }
110   }
111   return true;
112 }
113 
EraseInput(const OpContext<DeviceTensor> * context)114 void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
115   (void)input_op_datas_.erase(context->sequential_num_);
116   (void)input_op_controls_.erase(context->sequential_num_);
117 }
118 
FetchInputByTensorStore(std::vector<DeviceTensor * > * const input_device_tensors,std::vector<KernelTensor * > * const input_kernel_tensors,std::vector<abstract::AbstractBasePtr> * const input_kernel_tensors_for_infer,std::vector<DeviceTensor * > * const memory_free_tensors,OpContext<DeviceTensor> * const context) const119 void AbstractActor::FetchInputByTensorStore(
120   std::vector<DeviceTensor *> *const input_device_tensors, std::vector<KernelTensor *> *const input_kernel_tensors,
121   std::vector<abstract::AbstractBasePtr> *const input_kernel_tensors_for_infer,
122   std::vector<DeviceTensor *> *const memory_free_tensors, OpContext<DeviceTensor> *const context) const {
123   for (auto &device_tensor_store_key : device_tensor_store_keys_) {
124     auto device_tensor = DeviceTensorStore::GetInstance()
125                            .Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
126                            .get();
127     if (device_tensor == nullptr) {
128       std::string error_info =
129         GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
130         ", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
131       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
132     }
133     if ((*input_device_tensors)[device_tensor_store_key.first] != device_tensor) {
134       (*input_device_tensors)[device_tensor_store_key.first] = device_tensor;
135       (*memory_free_tensors)[device_tensor_store_key.first] = device_tensor;
136     }
137     // Collect the input kernel tensor.
138     const auto &kernel_tensor = (*input_device_tensors)[device_tensor_store_key.first]->kernel_tensor();
139     if (input_kernel_tensors && input_kernel_tensors_for_infer &&
140         ((*input_kernel_tensors)[device_tensor_store_key.first] != kernel_tensor.get())) {
141       (*input_kernel_tensors)[device_tensor_store_key.first] = kernel_tensor.get();
142       (*input_kernel_tensors_for_infer)[device_tensor_store_key.first] = kernel_tensor;
143     }
144   }
145 }
146 
InitOutputData()147 void AbstractActor::InitOutputData() {
148   mindspore::HashMap<std::string, size_t> batch_op_count;
149   for (auto &data_arrow : output_data_arrows_) {
150     MS_EXCEPTION_IF_NULL(data_arrow);
151     auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
152     auto &to_op_name = data_arrow->to_op_id_.Name();
153 
154     // Identify whether the output data flag is kOutputDataFlagToStack.
155     bool is_to_stack = (to_op_name.find(kStackActorNameSuffix) != std::string::npos);
156     size_t output_data_flag = is_to_stack ? kOutputDataFlagToStack : kOutputDataFlagInit;
157 
158     // Add the batch output data.
159     if (TEST_FLAG(data_arrow->flag_, kOutputDataFlagBatch)) {
160       if (is_to_stack) {
161         MS_LOG(EXCEPTION) << "Not support the batch output data to stack actor.";
162       }
163       (void)batch_output_data_[to_op_name].emplace_back(data.get());
164 
165       SET_FLAG(output_data_flag, kOutputDataFlagBatch);
166       // Identify whether the output data flag is kOutputDataFlagLastBatch.
167       ++(batch_op_count[to_op_name]);
168       if (batch_op_count[to_op_name] == batch_output_data_arrows_[to_op_name].size()) {
169         SET_FLAG(output_data_flag, kOutputDataFlagLastBatch);
170       }
171     }
172 
173     // Add the internal fusion flag.
174     if (TEST_FLAG(data_arrow->flag_, kOutputDataFlagBetweenFusion)) {
175       SET_FLAG(output_data_flag, kOutputDataFlagBetweenFusion);
176     }
177 
178     // Add the fusion flag.
179     if (TEST_FLAG(data_arrow->flag_, kOutputDataFlagToFusion)) {
180       SET_FLAG(output_data_flag, kOutputDataFlagToFusion);
181     }
182 
183     // Add the output data.
184     (void)output_data_.emplace_back(std::make_pair(std::move(data), output_data_flag));
185   }
186 }
187 
SendOutputData(OpContext<DeviceTensor> * const context,const std::vector<AnfNodePtr> & output_data_nodes,const std::vector<DataArrowPtr> & output_data_arrows,const std::vector<std::pair<OpDataUniquePtr<DeviceTensor>,size_t>> & output_data_list,const mindspore::HashMap<DataArrow *,size_t> & data_arrow_to_fusion_actor_indexs,mindspore::HashMap<std::string,std::vector<OpData<DeviceTensor> * >> * batch_output_data)188 void AbstractActor::SendOutputData(
189   OpContext<DeviceTensor> *const context, const std::vector<AnfNodePtr> &output_data_nodes,
190   const std::vector<DataArrowPtr> &output_data_arrows,
191   const std::vector<std::pair<OpDataUniquePtr<DeviceTensor>, size_t>> &output_data_list,
192   const mindspore::HashMap<DataArrow *, size_t> &data_arrow_to_fusion_actor_indexs,
193   mindspore::HashMap<std::string, std::vector<OpData<DeviceTensor> *>> *batch_output_data) {
194   for (size_t i = 0; i < output_data_list.size(); ++i) {
195     auto &output_data = output_data_list[i];
196     auto &to_op_id = output_data.first->op_id_;
197     auto &output_data_arrow = output_data_arrows[i];
198     UpdateOutputData(output_data.first.get(), output_data_arrow, output_data_nodes[i], context);
199     // The index of output data will be modified the real actor input index in the fusion actor, so need recovery the
200     // fusion actor index before sending output data to the fusion actor.
201     if (TEST_FLAG(output_data.second, kOutputDataFlagToFusion)) {
202       output_data.first->index_ = SizeToInt(data_arrow_to_fusion_actor_indexs.at(output_data_arrow.get()));
203     }
204 
205     if (TEST_FLAG(output_data.second, kOutputDataFlagLastBatch)) {
206       // Send batch output data. As the data need update, so all data must be collected completely before sending.
207       if (TEST_FLAG(output_data.second, kOutputDataFlagBetweenFusion)) {
208         const auto &to_actor = FetchSubActorInFusionActor(to_op_id.Name());
209         MS_EXCEPTION_IF_NULL(to_actor);
210         ActorDispatcher::SendSync(to_actor, &AbstractActor::RunBatchOpData, &((*batch_output_data)[to_op_id.Name()]),
211                                   context);
212       } else {
213         ActorDispatcher::Send(to_op_id, &AbstractActor::RunBatchOpData, &((*batch_output_data)[to_op_id.Name()]),
214                               context);
215       }
216     } else if (TEST_FLAG(output_data.second, kOutputDataFlagToStack)) {
217       // Create a new op data for stack actor.
218       auto to_stack_data =
219         std::make_unique<OpData<DeviceTensor>>(to_op_id, output_data.first->data_, output_data.first->index_);
220       (void)to_stack_data_.emplace_back(std::move(to_stack_data));
221       if (TEST_FLAG(output_data.second, kOutputDataFlagBetweenFusion)) {
222         const auto &to_actor = FetchSubActorInFusionActor(to_op_id.Name());
223         MS_EXCEPTION_IF_NULL(to_actor);
224         ActorDispatcher::SendSync(to_actor, &OpActor::RunOpData, to_stack_data_.back().get(), context);
225       } else {
226         ActorDispatcher::Send(to_op_id, &OpActor::RunOpData, to_stack_data_.back().get(), context);
227       }
228     } else if (!TEST_FLAG(output_data.second, kOutputDataFlagBatch)) {
229       // The batch output data only send when the output flag is kOutputDataFlagLastBatch.
230       if (TEST_FLAG(output_data.second, kOutputDataFlagBetweenFusion)) {
231         const auto &to_actor = FetchSubActorInFusionActor(to_op_id.Name());
232         if (to_actor == nullptr) {
233           MS_LOG(EXCEPTION) << "Failed to fetch to actor:" << to_op_id << " in actor:" << GetAID();
234         }
235         ActorDispatcher::SendSync(to_actor, &OpActor::RunOpData, output_data.first.get(), context);
236       } else {
237         ActorDispatcher::Send(to_op_id, &OpActor::RunOpData, output_data.first.get(), context);
238       }
239     }
240   }
241 }
242 
SendOutput(OpContext<DeviceTensor> * const context)243 void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
244   // Must be the execution order: send data --> send control, avoid the illegal timing problem.
245   // 1.Send output data.
246   SendOutputData(context, output_data_nodes_, output_data_arrows_, output_data_, data_arrow_to_fusion_actor_indexs_,
247                  &batch_output_data_);
248 
249   // 2.Send output control.
250   if (output_control_arrows_.size() > 0) {
251     auto from_aid = const_cast<AID *>(&GetAID());
252     for (auto &output_control : output_control_arrows_) {
253       if (TEST_FLAG(output_control->flag_, kOutputDataFlagBetweenFusion)) {
254         const auto &to_actor = FetchSubActorInFusionActor(output_control->to_op_id_.Name());
255         ActorDispatcher::SendSync(to_actor, &OpActor::RunOpControl, from_aid, context);
256       } else {
257         ActorDispatcher::Send(output_control->to_op_id_, &OpActor::RunOpControl, from_aid, context);
258       }
259     }
260   }
261 
262   // 3.Send recorder info.
263   SendRecorderInfo(context);
264 }
265 
FetchSubActorInFusionActor(const std::string & sub_actor_name) const266 AbstractActor *AbstractActor::FetchSubActorInFusionActor(const std::string &sub_actor_name) const {
267   if (parent_fusion_actor_ == nullptr) {
268     return nullptr;
269   }
270   return (parent_fusion_actor_->sub_actors_[sub_actor_name]).get();
271 }
272 
IsOutputAddressPersisted(const DeviceTensor * output_device_tensor,const KernelWithIndex & output_node)273 bool AbstractActor::IsOutputAddressPersisted(const DeviceTensor *output_device_tensor,
274                                              const KernelWithIndex &output_node) {
275   MS_EXCEPTION_IF_NULL(output_node.first);
276   MS_EXCEPTION_IF_NULL(output_device_tensor);
277   // The persisted address can't be replaced.
278   if (output_device_tensor->is_ptr_persisted()) {
279     return true;
280   }
281 
282   if (output_node.first->isa<ValueNode>()) {
283     return true;
284   }
285 
286   // The device address of parameter may come from the device address of input tensor.
287   // In order to avoid mistakenly cleaning up the device data of input tensor, return it as persisted address.
288   if (output_node.first->isa<Parameter>()) {
289     return true;
290   }
291 
292   // Ref node need check the origin node.
293   const auto &graph = AnfAlgo::FetchKernelGraph(output_node.first.get());
294   if ((graph != nullptr) && graph->IsInRefOutputMap(output_node)) {
295     const auto &origin_node = graph->GetRefNodeRecursive(output_node).first;
296     MS_EXCEPTION_IF_NULL(origin_node);
297     if (origin_node->isa<ValueNode>() || origin_node->isa<Parameter>()) {
298       return true;
299     }
300   }
301 
302   return false;
303 }
304 }  // namespace runtime
305 }  // namespace mindspore
306