• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/control_flow/actor/exit_actor.h"
18 #include <algorithm>
19 #include "src/control_flow/kernel/exit_subgraph_kernel.h"
20 #include "src/litert/kernel_exec_util.h"
21 #include "src/common/tensor_util.h"
22 
23 namespace {
24 const constexpr int kEntranceTensorIndex = 0;
25 }
26 namespace mindspore::lite {
RunOpData(OpData<Tensor> * inputs,OpContext<Tensor> * context)27 void LiteExitOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context) {
28   auto op_uuid = context->sequential_num_;
29   input_op_datas_[op_uuid].push_back(inputs);
30   inputs_data_[inputs->index_] = inputs->data_;
31   SetEntranceInputAID(inputs);
32   if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
33     return;
34   }
35 
36   auto ret = InitInputData();
37   (void)input_op_datas_.erase(op_uuid);
38   if (ret != RET_OK) {
39     context->SetFailed(ret);
40     return;
41   }
42   AsyncOutput(context);
43   return;
44 }
45 
InitInputData()46 int LiteExitOpActor::InitInputData() {
47   auto ret = SetInputShape();
48 
49   for (size_t i = 1; i < inputs_data_.size(); ++i) {
50     auto dst_tensor = kernel_->out_tensors()[i - 1];
51     auto src_tensor = inputs_data_[i];
52     dst_tensor->set_data_type(src_tensor->data_type());
53     if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
54       (void)SetTensorData(dst_tensor, src_tensor);
55     } else {
56       (void)MoveTensorData(dst_tensor, src_tensor);
57     }
58   }
59   return ret;
60 }
61 
SetInputShape()62 int LiteExitOpActor::SetInputShape() {
63   auto ret = RET_OK;
64   for (size_t i = 1; i < inputs_data_.size(); ++i) {
65     auto &output_tensor = kernel_->out_tensors()[i - 1];
66     if (output_tensor->shape() == inputs_data_[i]->shape()) {
67       continue;
68     }
69     ret = SetTensorShape(output_tensor, inputs_data_[i]);
70     MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "set input shape failed.");
71   }
72   return RET_OK;
73 }
74 
SetEntranceInputAID(const OpData<Tensor> * inputs)75 void LiteExitOpActor::SetEntranceInputAID(const OpData<Tensor> *inputs) {
76   if (inputs->index_ == kEntranceTensorIndex) {
77     entrance_input_aid_ = inputs->op_id_;
78   }
79 }
80 
PrepareOutputData()81 int LiteExitOpActor::PrepareOutputData() {
82   // exit actor has not calculating, so send input directly.
83   outputs_data_.resize(output_data_arrows_.size());
84   for (size_t i = 0; i < output_data_arrows_.size(); i++) {
85     auto &arrow = output_data_arrows_[i];
86     auto data = std::make_shared<OpData<Tensor>>(this->GetAID(), (kernel_->out_tensors()).at(arrow->from_output_index_),
87                                                  static_cast<int>(arrow->to_input_index_));
88     if (data == nullptr) {
89       MS_LOG(ERROR) << "new output_data failed.";
90       return RET_NULL_PTR;
91     }
92     outputs_data_.at(i) = data;
93   }
94   return RET_OK;
95 }
96 
AsyncOutput(OpContext<Tensor> * context)97 void LiteExitOpActor::AsyncOutput(OpContext<Tensor> *context) {
98   AID to_op_id;
99   bool find_to_op_aid = false;
100   for (auto info : all_mapping_info_) {
101     if (info.partial_input_aid == entrance_input_aid_) {
102       find_to_op_aid = true;
103       to_op_id = info.call_output_aid;
104     }
105   }
106 
107   if (!find_to_op_aid) {
108     MS_LOG(ERROR) << "exit actor can not find output actor.";
109     context->SetFailed(RET_ERROR);
110     return;
111   }
112 
113   if (to_op_id.Name() == "") {
114     SetOutputData(context);
115   }
116 
117   for (size_t i = 0; i < output_data_arrows_.size(); i++) {
118     if (output_data_arrows_[i]->to_op_id_ != to_op_id && output_data_arrows_[i]->to_op_id_.Name() != "") {
119       continue;
120     }
121     auto data = outputs_data_.at(i);
122     Async(to_op_id, get_actor_mgr(), &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
123   }
124 }
125 
PreInit(std::vector<std::shared_ptr<LiteOpActor>> * actors,std::unordered_map<Tensor *,Tensor * > * input_map)126 int LiteExitOpActor::PreInit(std::vector<std::shared_ptr<LiteOpActor>> *actors,
127                              std::unordered_map<Tensor *, Tensor *> *input_map) {
128   auto ret = IsolateInputData(actors, input_map);
129   if (ret != RET_OK) {
130     MS_LOG(ERROR) << "isolate input data failed.";
131     return ret;
132   }
133 
134   ret = CreateMappingInfo();
135   if (ret != RET_OK) {
136     MS_LOG(ERROR) << "create partial call pairs failed.";
137     return ret;
138   }
139 
140   ret = RecordCallNodeOutputActor(actors);
141   if (ret != RET_OK) {
142     MS_LOG(ERROR) << "record call node outputs AIDs failed";
143     return ret;
144   }
145   return RET_OK;
146 }
147 
IsSubSet(const std::vector<lite::Tensor * > & all_set,const std::vector<lite::Tensor * > & sub_set)148 bool LiteExitOpActor::IsSubSet(const std::vector<lite::Tensor *> &all_set, const std::vector<lite::Tensor *> &sub_set) {
149   if (sub_set.size() > all_set.size()) {
150     return false;
151   }
152   for (auto &sub_item : sub_set) {
153     if (std::find(all_set.begin(), all_set.end(), sub_item) == all_set.end()) {
154       return false;
155     }
156   }
157   return true;
158 }
159 
RecordCallNodeOutputActor(std::vector<std::shared_ptr<LiteOpActor>> * actors)160 int LiteExitOpActor::RecordCallNodeOutputActor(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
161   actors_ = actors;
162   for (auto actor : *actors_) {
163     auto actor_in_tensors = actor->GetKernel()->in_tensors();
164     for (auto &info : all_mapping_info_) {
165       auto &call = info.call_node;
166       if (IsSubSet(actor_in_tensors, call->out_tensors())) {
167         info.call_output_aid = actor->GetAID();
168       }
169     }
170   }
171   return RET_OK;
172 }
173 
CreateMappingInfo()174 int LiteExitOpActor::CreateMappingInfo() {
175   auto exit_subgraph_kernel = reinterpret_cast<kernel::ExitSubGraphKernel *>(kernel_);
176   if (exit_subgraph_kernel == nullptr) {
177     MS_LOG(ERROR) << "cast to exit kernel failed.";
178     return RET_ERROR;
179   }
180   auto partial_set = exit_subgraph_kernel->GetPartials();
181   for (auto partial : partial_set) {
182     auto call_node = kernel::KernelExecUtil::GetPartialOutputCall(partial);
183     if (call_node == nullptr) {
184       MS_LOG(ERROR) << "get partial node: " << partial->name() << " 's call output node failed.";
185       return RET_ERROR;
186     }
187     MappingInfo info(partial, call_node);
188     (void)all_mapping_info_.emplace_back(info);
189   }
190   return RET_OK;
191 }
192 
PostInit()193 int LiteExitOpActor::PostInit() {
194   auto ret = PrepareOutputData();
195   if (ret != RET_OK) {
196     MS_LOG(ERROR) << "prepare output data failed.";
197     return ret;
198   }
199 
200   RecordPartialNodeInputActor();
201   return RET_OK;
202 }
203 
RecordPartialNodeInputActor()204 void LiteExitOpActor::RecordPartialNodeInputActor() {
205   for (auto actor : *actors_) {
206     auto actor_partial_nodes = actor->GetPartialKernels();
207     if (actor_partial_nodes.empty()) {
208       continue;
209     }
210     for (auto &info : all_mapping_info_) {
211       auto partial = info.partial_node;
212       if (actor_partial_nodes.find(partial) == actor_partial_nodes.end()) {
213         continue;
214       }
215       info.partial_input_aid = actor->GetAID();
216     }
217   }
218 }
219 }  // namespace mindspore::lite
220