1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/control_flow/control_flow_scheduler.h"
18 #ifndef CONTROLFLOW_TENSORLIST_CLIP
19 #include <algorithm>
20 #include <set>
21 #include "src/litert/kernel_exec_util.h"
22 #include "src/litert/kernel/cpu/base/partial_fusion.h"
23 #include "nnacl/call_parameter.h"
24 #include "src/control_flow/kernel/exit_subgraph_kernel.h"
25 #include "src/control_flow/kernel/identity_kernel.h"
26 #include "src/tensorlist.h"
27 #include "src/common/prim_inner.h"
28
29 namespace {
30 const constexpr int kMinNonTailCallCount = 2;
31 }
32 #endif
33
34 namespace mindspore::lite {
35 #ifndef CONTROLFLOW_TENSORLIST_CLIP
Schedule(std::vector<kernel::KernelExec * > * dst_kernels)36 int ControlFlowScheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
37 auto ret = this->IsolateSameInputPartials(dst_kernels);
38 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
39 ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
40 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
41 ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
42 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
43 ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
44 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "BuildBoundaryForMultipleCalledGraph failed.");
45 ret = this->RecordLinkInfo(dst_kernels);
46 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordLinkInfo failed.");
47 ret = this->SplitNonTailCallSubGraphs(dst_kernels);
48 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
49 return ret;
50 }
51
SplitNonTailCallSubGraphs(std::vector<kernel::KernelExec * > * dst_kernels)52 int ControlFlowScheduler::SplitNonTailCallSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels) {
53 std::set<kernel::KernelExec *> all_non_tail_subgraphs = GetNonTailCallSubGraphs(dst_kernels);
54 for (auto item : all_non_tail_subgraphs) {
55 to_process_q_.push(item);
56 }
57
58 while (!to_process_q_.empty()) {
59 auto cur = to_process_q_.front();
60 to_process_q_.pop();
61 auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(cur);
62 if (subgraph_kernel == nullptr) {
63 MS_LOG(ERROR) << "kernel is not a subgraph kernel";
64 return RET_ERROR;
65 }
66 std::vector<kernel::KernelExec *> new_subgraphs{};
67 auto ret = SplitSingleNonTailCallSubGraph(subgraph_kernel, &new_subgraphs);
68 if (ret != RET_OK) {
69 MS_LOG(ERROR) << "SplitSingleNonTailCallSubGraph failed, ret: " << ret;
70 return ret;
71 }
72 // append dst_kernels
73 (void)std::copy(new_subgraphs.begin(), new_subgraphs.end(), std::back_inserter(*dst_kernels));
74 // update partial_kernel_map
75 for (auto &item : *partial_kernel_subgraph_index_map_) {
76 auto &partial_node = item.first;
77 auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
78 MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
79 auto subgraphs = partial_kernel->subgraph_kernels();
80 auto iter = std::find(subgraphs.begin(), subgraphs.end(), subgraph_kernel);
81 if (iter == subgraphs.end()) {
82 continue;
83 }
84 (void)subgraphs.erase(iter);
85 for (auto &new_subgraph : new_subgraphs) {
86 (void)subgraphs.insert(iter, new_subgraph);
87 }
88 partial_kernel->set_subgraph_kernels(subgraphs);
89 }
90 AppendToProcessQ(&new_subgraphs, &all_non_tail_subgraphs);
91 }
92
93 RemoveUselessKernels(dst_kernels, &all_non_tail_subgraphs);
94
95 return RET_OK;
96 }
97
GetNonTailCallSubGraphs(std::vector<kernel::KernelExec * > * dst_kernels)98 std::set<kernel::KernelExec *> ControlFlowScheduler::GetNonTailCallSubGraphs(
99 std::vector<kernel::KernelExec *> *dst_kernels) {
100 std::set<kernel::KernelExec *> non_tail_subgraph_kernels{};
101
102 // found non-tail call subgraph
103 for (auto &kernel : *dst_kernels) {
104 if (kernel->desc().arch == kernel::kDelegate) {
105 continue;
106 }
107 auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
108 if (subgraph_kernel == nullptr) {
109 continue;
110 }
111 if (!kernel::KernelExecUtil::IsNonTailCallSubGraph(subgraph_kernel)) {
112 continue;
113 }
114 (void)non_tail_subgraph_kernels.insert(kernel);
115 }
116 return non_tail_subgraph_kernels;
117 }
118
AdjustNodesForTailCallSubGraph(std::vector<kernel::KernelExec * > * first_part_nodes,std::vector<kernel::KernelExec * > * second_part_nodes)119 int ControlFlowScheduler::AdjustNodesForTailCallSubGraph(std::vector<kernel::KernelExec *> *first_part_nodes,
120 std::vector<kernel::KernelExec *> *second_part_nodes) {
121 auto tail_call = second_part_nodes->back();
122 std::vector<kernel::KernelExec *> all_need_nodes{};
123 (void)std::copy(tail_call->in_kernels().begin(), tail_call->in_kernels().end(), std::back_inserter(all_need_nodes));
124 auto partials = kernel::KernelExecUtil::GetCallInputPartials(tail_call);
125 (void)std::copy(partials.begin(), partials.end(), std::back_inserter(all_need_nodes));
126 for (auto partial : partials) {
127 for (auto input : partial->in_kernels()) {
128 MS_CHECK_TRUE_MSG(input != nullptr, RET_ERROR, "input is nullptr");
129 auto parameter = input->op_parameter();
130 MS_CHECK_TRUE_MSG(parameter != nullptr, RET_ERROR, "parameter is nullptr");
131 if (parameter->type_ == static_cast<int>(PRIM_IDENTITY)) {
132 all_need_nodes.push_back(input);
133 }
134 }
135 }
136
137 for (auto need : all_need_nodes) {
138 if (IsContain(*second_part_nodes, need)) {
139 continue;
140 }
141 auto is_need = [&need](const kernel::KernelExec *node) { return node == need; };
142 auto iter = std::find_if(first_part_nodes->begin(), first_part_nodes->end(), is_need);
143 MS_CHECK_TRUE_MSG(iter != first_part_nodes->end(), RET_ERROR, "graph is not right");
144 (void)second_part_nodes->insert(second_part_nodes->begin(), *iter);
145 (void)first_part_nodes->erase(iter);
146 }
147 return RET_OK;
148 }
149
SplitSubGraphNodesIntoTwoParts(kernel::SubGraphKernel * subgraph_kernel,std::vector<kernel::KernelExec * > * first_part_nodes,std::vector<kernel::KernelExec * > * second_part_nodes)150 int ControlFlowScheduler::SplitSubGraphNodesIntoTwoParts(kernel::SubGraphKernel *subgraph_kernel,
151 std::vector<kernel::KernelExec *> *first_part_nodes,
152 std::vector<kernel::KernelExec *> *second_part_nodes) {
153 auto nodes = subgraph_kernel->nodes();
154
155 // get the position of the last non-tail call op.
156 auto is_non_tail_call = [](const kernel::KernelExec *node) { return kernel::KernelExecUtil::IsNonTailCall(node); };
157 auto last_non_tail_call_iter = std::find_if(nodes.rbegin(), nodes.rend(), is_non_tail_call);
158 auto distance = nodes.rend() - last_non_tail_call_iter;
159 if (distance == 0) {
160 MS_LOG(ERROR) << "not is a non tail call subgraph.";
161 return RET_ERROR;
162 }
163
164 // change last non-tail call property as is tail call
165 MS_CHECK_TRUE_MSG(*last_non_tail_call_iter != nullptr, RET_ERROR, "last_non_tail_call_iter is nullptr");
166 auto parameter = reinterpret_cast<CallParameter *>((*last_non_tail_call_iter)->op_parameter());
167 if (parameter == nullptr) {
168 MS_LOG(ERROR) << "parameter is nullptr";
169 return RET_ERROR;
170 }
171 parameter->is_tail_call = true;
172
173 for (auto iter = nodes.begin(); iter != nodes.begin() + distance; ++iter) {
174 first_part_nodes->push_back(*iter);
175 }
176
177 for (auto iter = nodes.begin() + distance; iter != nodes.end(); ++iter) {
178 second_part_nodes->push_back(*iter);
179 }
180
181 // if second part nodes contains call node, we need call node input partials and partials' inputs.
182 if (kernel::KernelExecUtil::IsTailCall(second_part_nodes->back())) {
183 auto ret = AdjustNodesForTailCallSubGraph(first_part_nodes, second_part_nodes);
184 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "AdjustNodesForTailCallSubGraph failed.");
185 }
186 return RET_OK;
187 }
188
SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel * subgraph_kernel,std::vector<kernel::KernelExec * > * subgraph_kernels)189 int ControlFlowScheduler::SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel *subgraph_kernel,
190 std::vector<kernel::KernelExec *> *subgraph_kernels) {
191 std::vector<kernel::KernelExec *> first_part_nodes{};
192 std::vector<kernel::KernelExec *> second_part_nodes{};
193
194 auto ret = SplitSubGraphNodesIntoTwoParts(subgraph_kernel, &first_part_nodes, &second_part_nodes);
195 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitSubGraphNodesIntoTwoParts failed.");
196
197 auto cur_subgraph_type = subgraph_kernel->subgraph_type();
198 auto first_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel(first_part_nodes, nullptr, nullptr,
199 cur_subgraph_type, *context_, schema_version_);
200 subgraph_kernels->push_back(first_subgraph);
201
202 auto second_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel(second_part_nodes, nullptr, nullptr,
203 cur_subgraph_type, *context_, schema_version_);
204 subgraph_kernels->push_back(second_subgraph);
205 return RET_OK;
206 }
207
RemoveUselessKernels(std::vector<kernel::KernelExec * > * dst_kernels,std::set<kernel::KernelExec * > * useless_kernels)208 void ControlFlowScheduler::RemoveUselessKernels(std::vector<kernel::KernelExec *> *dst_kernels,
209 std::set<kernel::KernelExec *> *useless_kernels) {
210 for (auto iter = dst_kernels->begin(); iter != dst_kernels->end();) {
211 if (useless_kernels->find(*iter) != useless_kernels->end()) {
212 iter = dst_kernels->erase(iter);
213 } else {
214 iter++;
215 }
216 }
217
218 for (auto &kernel : *useless_kernels) {
219 auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
220 if (subgraph_kernel == nullptr) {
221 continue;
222 }
223 subgraph_kernel->set_nodes({});
224 delete subgraph_kernel;
225 }
226 useless_kernels->clear();
227
228 return;
229 }
230
AppendToProcessQ(std::vector<kernel::KernelExec * > * new_subgraphs,std::set<kernel::KernelExec * > * all_non_tail_subgraphs)231 void ControlFlowScheduler::AppendToProcessQ(std::vector<kernel::KernelExec *> *new_subgraphs,
232 std::set<kernel::KernelExec *> *all_non_tail_subgraphs) {
233 auto new_non_tail_call_subgraphs = GetNonTailCallSubGraphs(new_subgraphs);
234 for (auto &item : new_non_tail_call_subgraphs) {
235 if (all_non_tail_subgraphs->find(item) == all_non_tail_subgraphs->end()) {
236 to_process_q_.push(item);
237 (void)all_non_tail_subgraphs->insert(item);
238 }
239 }
240 return;
241 }
242
RecordNonTailCallLinkInfo(kernel::KernelExec * non_tail_call)243 int ControlFlowScheduler::RecordNonTailCallLinkInfo(kernel::KernelExec *non_tail_call) {
244 size_t non_tail_call_output_size = non_tail_call->out_tensors().size();
245 auto partial_nodes = kernel::KernelExecUtil::GetCallInputPartials(non_tail_call);
246 for (auto node : partial_nodes) {
247 auto partial_node = reinterpret_cast<kernel::PartialFusionKernel *>(node->kernel());
248 MS_CHECK_TRUE_MSG(partial_node != nullptr, RET_ERROR, "node cast to partial node failed.");
249 auto kernels = partial_node->subgraph_kernels();
250 MS_CHECK_TRUE_MSG(!kernels.empty(), RET_ERROR, "partial subgraph kernels empty.");
251 auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernels.back());
252 MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "partial node's subgraph kernel is nullptr.");
253 if (kernel::KernelExecUtil::IsTailCallSubGraph(subgraph)) {
254 std::queue<kernel::KernelExec *> tail_call_q{};
255 tail_call_q.push(subgraph->out_nodes().front());
256 std::vector<kernel::KernelExec *> final_graphs{};
257 std::set<kernel::KernelExec *> reviewed_graphs{};
258 auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs);
259 MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetTailCallFinalSubgraphs failed.");
260 for (auto item : final_graphs) {
261 MS_CHECK_TRUE_MSG(item->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
262 "subgraph outputs and corresponding call outputs size not same.");
263 for (size_t i = 0; i < non_tail_call_output_size; ++i) {
264 context_->SetLinkInfo(item->out_tensors()[i], non_tail_call->out_tensors()[i]);
265 }
266 }
267 } else {
268 MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
269 "partial inputs and corresponding call outputs size not same.");
270 for (size_t i = 0; i < non_tail_call_output_size; ++i) {
271 context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]);
272 }
273 }
274 }
275 return RET_OK;
276 }
277
RecordAllNonTailCallLinkInfo(std::vector<kernel::KernelExec * > * dst_kernels)278 int ControlFlowScheduler::RecordAllNonTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels) {
279 for (auto dst_kernel : *dst_kernels) {
280 auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(dst_kernel);
281 MS_CHECK_TRUE_MSG(subgraph_kernel != nullptr, RET_ERROR, "node cast to subgraph kernel failed.");
282 for (auto node : subgraph_kernel->nodes()) {
283 if (kernel::KernelExecUtil::IsNonTailCall(node)) {
284 non_tail_calls_.push_back(node);
285 }
286 }
287 }
288
289 for (auto non_tail_call : non_tail_calls_) {
290 auto ret = RecordNonTailCallLinkInfo(non_tail_call);
291 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordNonTailCallLinkInfo, failed");
292 }
293 return RET_OK;
294 }
295
RecordSubgraphCaller(const size_t & subgraph_index,kernel::KernelExec * partial_node)296 void ControlFlowScheduler::RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node) {
297 if (more_than_once_called_partial_nodes_.find(subgraph_index) == more_than_once_called_partial_nodes_.end()) {
298 std::set<kernel::KernelExec *> tmp_set{partial_node};
299 (void)more_than_once_called_partial_nodes_.insert(
300 std::pair<size_t, std::set<kernel::KernelExec *>>{subgraph_index, tmp_set});
301 } else {
302 (void)more_than_once_called_partial_nodes_[subgraph_index].insert(partial_node);
303 }
304 }
305
CreateEntranceSubGraph(kernel::SubGraphKernel * subgraph,lite::Tensor * link_tensor)306 kernel::SubGraphKernel *ControlFlowScheduler::CreateEntranceSubGraph(kernel::SubGraphKernel *subgraph,
307 lite::Tensor *link_tensor) {
308 if (subgraph == nullptr || link_tensor == nullptr) {
309 MS_LOG(ERROR) << "input is nullptr.";
310 return nullptr;
311 }
312 size_t in_tensor_size = subgraph->in_tensors().size();
313 std::vector<Tensor *> old_input_tensors{};
314 // entrance subgraph kernel first output tensor is the first input of the corresponding exit subgraph kernel.
315 std::vector<Tensor *> new_input_tensors{link_tensor};
316 for (size_t i = 0; i < in_tensor_size; i++) {
317 Tensor *old_tensor = subgraph->in_tensors()[i];
318 old_input_tensors.push_back(old_tensor);
319 auto allocator = old_tensor->allocator();
320 auto new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
321 if (new_tensor == nullptr) {
322 MS_LOG(ERROR) << "new Tensor failed.";
323 return nullptr;
324 }
325 src_tensors_->push_back(new_tensor);
326 new_input_tensors.push_back(new_tensor);
327 auto ret = kernel::KernelExecUtil::ReplaceSubGraphNodesInTensor(subgraph, old_tensor, new_tensor);
328 MS_CHECK_FALSE_MSG(ret != RET_OK, nullptr, "ReplaceSubGraphNodesInTensor failed.");
329 subgraph->set_in_tensor(new_tensor, i);
330 }
331 auto entrance_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel(
332 {}, &old_input_tensors, &new_input_tensors, kernel::kEntranceSubGraph, *context_, schema_version_);
333 return entrance_subgraph;
334 }
335
CreateExitSubGraph(kernel::SubGraphKernel * subgraph,lite::Tensor * link_tensor)336 kernel::SubGraphKernel *ControlFlowScheduler::CreateExitSubGraph(kernel::SubGraphKernel *subgraph,
337 lite::Tensor *link_tensor) {
338 if (subgraph == nullptr || link_tensor == nullptr) {
339 MS_LOG(ERROR) << "input is nullptr.";
340 return nullptr;
341 }
342 size_t out_tensor_size = subgraph->out_tensors().size();
343 std::vector<Tensor *> old_output_tensors{};
344 // exit subgraph kernel first input tensor is the first output of the corresponding entrance subgraph kernel.
345 std::vector<Tensor *> new_output_tensors{link_tensor};
346 for (size_t i = 0; i < out_tensor_size; i++) {
347 Tensor *old_tensor = subgraph->out_tensors()[i];
348 old_output_tensors.push_back(old_tensor);
349 auto allocator = old_tensor->allocator();
350 auto new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
351 if (new_tensor == nullptr) {
352 MS_LOG(ERROR) << "new Tensor failed.";
353 return nullptr;
354 }
355 src_tensors_->push_back(new_tensor);
356 new_output_tensors.push_back(new_tensor);
357 (void)kernel::KernelExecUtil::ReplaceSubGraphNodesOutTensor(subgraph, old_tensor, new_tensor);
358 subgraph->set_out_tensor(new_tensor, i);
359 }
360 auto exit_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel({}, &new_output_tensors, &old_output_tensors,
361 kernel::kExitSubGraph, *context_, schema_version_);
362 return exit_subgraph;
363 }
364
AddOutputKernel(kernel::SubGraphKernel * subgraph)365 kernel::SubGraphKernel *ControlFlowScheduler::AddOutputKernel(kernel::SubGraphKernel *subgraph) {
366 auto inputs = subgraph->in_tensors();
367 auto outputs = subgraph->out_tensors();
368 auto nodes = subgraph->nodes();
369
370 auto call_node = subgraph->out_nodes().front();
371 reinterpret_cast<CallParameter *>(call_node->op_parameter())->is_tail_call = false;
372
373 size_t out_tensor_size = call_node->out_tensors().size();
374 std::vector<Tensor *> old_output_tensors{};
375 std::vector<Tensor *> new_output_tensors{};
376 for (size_t i = 0; i < out_tensor_size; i++) {
377 Tensor *old_tensor = subgraph->out_tensors()[i];
378 old_output_tensors.push_back(old_tensor);
379 auto allocator = old_tensor->allocator();
380 auto new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
381 if (new_tensor == nullptr) {
382 MS_LOG(ERROR) << "new Tensor failed.";
383 return nullptr;
384 }
385 src_tensors_->push_back(new_tensor);
386 new_output_tensors.push_back(new_tensor);
387 (void)kernel::KernelExecUtil::ReplaceSubGraphNodesOutTensor(subgraph, old_tensor, new_tensor);
388 call_node->set_out_tensor(new_tensor, i);
389 context_->ReplaceLinkInfoReceiverWithNewOne(new_tensor, old_tensor);
390 }
391 auto output_node = kernel::IdentityKernel::Create(new_output_tensors, old_output_tensors, this->context_);
392 MS_CHECK_FALSE_MSG(output_node == nullptr, nullptr, "Create Identity failed.");
393 output_node->set_name(call_node->name() + "_output");
394 kernel::KernelKey output_desc = call_node->desc();
395 output_desc.type = PrimType_Inner_Identity;
396 output_node->set_desc(output_desc);
397 output_node->AddInKernel(call_node);
398 call_node->AddOutKernel(output_node);
399 nodes.push_back(output_node);
400 auto subgraph_type = subgraph->subgraph_type();
401 auto new_subgraph =
402 kernel::KernelExecUtil::CreateSubGraphKernel(nodes, &inputs, &outputs, subgraph_type, *context_, schema_version_);
403 return new_subgraph;
404 }
405
GetSubGraphsWhichNeedBoundary()406 int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() {
407 // among the more than once call subgraphs, if one of it's corresponding partial nodes' call node is non-tail call.
408 for (auto item : more_than_once_called_partial_nodes_) {
409 if (item.second.size() == 1) {
410 MS_LOG(DEBUG) << "subgraph call only once.";
411 continue;
412 }
413 auto node = item.second.begin();
414 kernel::PartialFusionKernel *partial = reinterpret_cast<kernel::PartialFusionKernel *>((*node)->kernel());
415 MS_CHECK_TRUE_MSG(partial != nullptr, RET_ERROR, "cast to partial node failed.");
416 auto aim_kernels = partial->subgraph_kernels();
417 MS_CHECK_TRUE_MSG(aim_kernels.size() == 1, RET_ERROR, "partial subgraph kernels size not right.");
418 auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(aim_kernels.front());
419 MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "subgraph is nullptr");
420
421 std::vector<kernel::KernelExec *> all_call_nodes{};
422 for (auto partial_node : item.second) {
423 auto call_node = kernel::KernelExecUtil::GetPartialOutputCall(partial_node);
424 if (call_node == nullptr) {
425 MS_LOG(ERROR) << "call_node is nullptr.";
426 return RET_ERROR;
427 }
428 all_call_nodes.push_back(call_node);
429 }
430
431 // non-tail call size less than 2, continue
432 int non_tail_call_size = 0;
433 for (auto call_node : all_call_nodes) {
434 if (kernel::KernelExecUtil::IsNonTailCall(call_node)) {
435 non_tail_call_size++;
436 }
437 }
438 if (non_tail_call_size < kMinNonTailCallCount) {
439 MS_LOG(DEBUG) << "no need to build boundary.";
440 continue;
441 }
442 for (auto partial_node : item.second) {
443 (void)subgraphs_need_boundary_[subgraph].insert(partial_node);
444 }
445 }
446 return RET_OK;
447 }
448
BuildBoundaryForMultipleCalledGraph(std::vector<kernel::KernelExec * > * dst_kernels)449 int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels) {
450 for (auto &item : subgraphs_need_boundary_) {
451 auto subgraph = item.first;
452 // new link tensor
453 auto link_tensor = new Tensor(kNumberTypeFloat32, {1});
454 if (link_tensor == nullptr) {
455 MS_LOG(ERROR) << "";
456 return RET_NULL_PTR;
457 }
458 link_tensor->set_tensor_name(subgraph->name() + "_link_tensor");
459 link_tensor->set_category(Category::CONST_TENSOR);
460 src_tensors_->push_back(link_tensor);
461
462 auto entrance_subgraph = CreateEntranceSubGraph(subgraph, link_tensor);
463 if (entrance_subgraph == nullptr) {
464 MS_LOG(ERROR) << "create entrance subgraph failed.";
465 return RET_NULL_PTR;
466 }
467 entrance_subgraph->set_name(subgraph->name() + "_entrance");
468 dst_kernels->push_back(entrance_subgraph);
469
470 auto exit_subgraph = CreateExitSubGraph(subgraph, link_tensor);
471 if (exit_subgraph == nullptr) {
472 MS_LOG(ERROR) << "create exit subgraph failed.";
473 return RET_NULL_PTR;
474 }
475 exit_subgraph->set_name(subgraph->name() + "_exit");
476 dst_kernels->push_back(exit_subgraph);
477
478 // update partial's subgraph kernels
479 std::vector<kernel::KernelExec *> subgraph_kernels{};
480 subgraph_kernels.push_back(entrance_subgraph);
481 subgraph_kernels.push_back(subgraph);
482 subgraph_kernels.push_back(exit_subgraph);
483
484 // record partial nodes of this subgraph.
485 auto exit_subgraph_kernel = reinterpret_cast<kernel::ExitSubGraphKernel *>(exit_subgraph);
486 for (auto partial_node : item.second) {
487 exit_subgraph_kernel->SetPartial(partial_node);
488 auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
489 MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
490 partial_kernel->set_subgraph_kernels(subgraph_kernels);
491 }
492 }
493 return RET_OK;
494 }
495
IsolateOutputForCallOutputGraph(std::vector<kernel::KernelExec * > * dst_kernels)496 int ControlFlowScheduler::IsolateOutputForCallOutputGraph(std::vector<kernel::KernelExec *> *dst_kernels) {
497 kernel::KernelExec *main_graph_kernel = dst_kernels->front();
498 if (!kernel::KernelExecUtil::IsOutputSubGraph(main_graph_kernel)) {
499 MS_LOG(DEBUG) << "Not is output graph.";
500 return RET_OK;
501 }
502
503 auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(main_graph_kernel);
504 MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "cast to subgraph failed.");
505 if (!(subgraph->out_nodes().size() == 1 && subgraph->out_nodes().front()->type() == schema::PrimitiveType_Call)) {
506 MS_LOG(DEBUG) << "main graph output is not call node.";
507 return RET_OK;
508 }
509
510 auto new_subgraph = AddOutputKernel(subgraph);
511 MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create output subgraph failed.");
512 new_subgraph->set_name(subgraph->name());
513 std::replace(dst_kernels->begin(), dst_kernels->end(), subgraph, new_subgraph);
514
515 subgraph->set_nodes({});
516 delete subgraph;
517 return RET_OK;
518 }
519
GetTailCallFinalSubgraphs(std::queue<kernel::KernelExec * > * tail_call_q,std::vector<kernel::KernelExec * > * final_graphs,std::set<kernel::KernelExec * > reviewed_graphs)520 int ControlFlowScheduler::GetTailCallFinalSubgraphs(std::queue<kernel::KernelExec *> *tail_call_q,
521 std::vector<kernel::KernelExec *> *final_graphs,
522 std::set<kernel::KernelExec *> reviewed_graphs) {
523 if (tail_call_q->empty()) {
524 return RET_OK;
525 }
526 auto tail_call = tail_call_q->front();
527 tail_call_q->pop();
528 auto partials = kernel::KernelExecUtil::GetCallInputPartials(tail_call);
529 for (auto partial : partials) {
530 auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial->kernel());
531 MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
532 // only get the output subgraph, the last subgraph is the output subgraph.
533 auto subgraphs = partial_kernel->subgraph_kernels();
534 auto subgraph = subgraphs.back();
535 auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
536 MS_CHECK_TRUE_MSG(subgraph_kernel != nullptr, RET_ERROR, "cast to subgraph kernel failed.");
537 if (kernel::KernelExecUtil::IsTailCallSubGraph(subgraph_kernel)) {
538 if (reviewed_graphs.find(subgraph_kernel) == reviewed_graphs.end()) {
539 tail_call_q->push(subgraph_kernel->out_nodes().front());
540 }
541 } else {
542 final_graphs->push_back(subgraph);
543 }
544 (void)reviewed_graphs.insert(subgraph);
545 }
546 return GetTailCallFinalSubgraphs(tail_call_q, final_graphs, reviewed_graphs);
547 }
548
RecordTailCallLinkInfo(kernel::KernelExec * tail_call)549 int ControlFlowScheduler::RecordTailCallLinkInfo(kernel::KernelExec *tail_call) {
550 std::queue<kernel::KernelExec *> tail_call_q{};
551 tail_call_q.push(tail_call);
552 std::vector<kernel::KernelExec *> final_graphs{};
553 std::set<kernel::KernelExec *> reviewed_graphs{};
554 auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs);
555 if (ret != RET_OK) {
556 MS_LOG(ERROR) << "GetTailCallFinalSubgraphs failed.";
557 return ret;
558 }
559
560 if (std::any_of(final_graphs.begin(), final_graphs.end(), [&tail_call](const kernel::KernelExec *item) {
561 return item->out_tensors().size() != tail_call->out_tensors().size();
562 })) {
563 MS_LOG(DEBUG) << "not is mindir model, return ok.";
564 return RET_OK;
565 }
566
567 for (auto final_graph : final_graphs) {
568 for (size_t i = 0; i < final_graph->out_tensors().size(); ++i) {
569 context_->SetLinkInfo(final_graph->out_tensors()[i], tail_call->out_tensors()[i]);
570 }
571 }
572 return RET_OK;
573 }
574
RecordAllTailCallLinkInfo(std::vector<kernel::KernelExec * > * dst_kernels)575 int ControlFlowScheduler::RecordAllTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels) {
576 std::vector<kernel::KernelExec *> all_tail_calls{};
577 for (auto dst_kernel : *dst_kernels) {
578 auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(dst_kernel);
579 if (kernel::KernelExecUtil::IsTailCallSubGraph(subgraph_kernel)) {
580 all_tail_calls.push_back(subgraph_kernel->out_nodes().front());
581 }
582 }
583
584 for (auto tail_call : all_tail_calls) {
585 auto ret = RecordTailCallLinkInfo(tail_call);
586 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordTailCallLinkInfo, failed");
587 }
588 return RET_OK;
589 }
590
IsolatePartialInputs(kernel::SubGraphKernel * subgraph,kernel::KernelExec * partial)591 kernel::SubGraphKernel *ControlFlowScheduler::IsolatePartialInputs(kernel::SubGraphKernel *subgraph,
592 kernel::KernelExec *partial) {
593 auto inputs = subgraph->in_tensors();
594 auto outputs = subgraph->out_tensors();
595 auto nodes = subgraph->nodes();
596
597 auto old_partial_inputs = partial->in_tensors();
598
599 std::vector<Tensor *> new_partial_inputs{};
600 for (size_t i = 0; i < old_partial_inputs.size(); i++) {
601 Tensor *old_tensor = old_partial_inputs[i];
602 auto allocator = old_tensor->allocator();
603 Tensor *new_tensor = nullptr;
604 if (old_tensor->data_type() == kObjectTypeTensorType) {
605 auto old_tensor_list = reinterpret_cast<TensorList *>(old_tensor);
606 new_tensor = TensorList::CopyTensorList(*old_tensor_list, false, allocator);
607 } else {
608 new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
609 }
610 MS_CHECK_TRUE_MSG(new_tensor != nullptr, nullptr, "new tensor failed.");
611 new_tensor->set_category(VAR);
612 partial->set_in_tensor(new_tensor, i);
613 src_tensors_->push_back(new_tensor);
614 new_partial_inputs.push_back(new_tensor);
615 }
616 auto identity_node = kernel::IdentityKernel::Create(old_partial_inputs, new_partial_inputs, this->context_);
617 MS_CHECK_TRUE_MSG(identity_node != nullptr, nullptr, "Create Identity kernel failed.");
618 identity_node->set_name(partial->name() + "_input_identity");
619 kernel::KernelKey identity_desc = partial->desc();
620 identity_desc.type = PrimType_Inner_Identity;
621 identity_node->set_desc(identity_desc);
622 // update identity and partial in kernels and out kernels
623 for (auto partial_in_kernel : partial->in_kernels()) {
624 auto output_kernels = partial_in_kernel->out_kernels();
625 std::replace(output_kernels.begin(), output_kernels.end(), partial, identity_node);
626 partial_in_kernel->set_out_kernels(output_kernels);
627 identity_node->AddInKernel(partial_in_kernel);
628 }
629 identity_node->AddOutKernel(partial);
630 partial->set_in_kernels({identity_node});
631 auto partial_iter = std::find(nodes.begin(), nodes.end(), partial);
632 (void)nodes.insert(partial_iter, identity_node);
633 auto subgraph_type = subgraph->subgraph_type();
634 auto new_subgraph =
635 kernel::KernelExecUtil::CreateSubGraphKernel(nodes, &inputs, &outputs, subgraph_type, *context_, schema_version_);
636 return new_subgraph;
637 }
638
GetSameInputPartials()639 std::set<kernel::KernelExec *> ControlFlowScheduler::GetSameInputPartials() {
640 std::unordered_map<Tensor *, std::set<kernel::KernelExec *>> input_partial_pairs{};
641 for (auto item : *partial_kernel_subgraph_index_map_) {
642 for (auto input : item.first->in_tensors()) {
643 if (input_partial_pairs.find(input) == input_partial_pairs.end()) {
644 std::set<kernel::KernelExec *> partials{};
645 (void)partials.insert(item.first);
646 input_partial_pairs[input] = partials;
647 } else {
648 (void)input_partial_pairs[input].insert(item.first);
649 }
650 }
651 }
652
653 std::set<kernel::KernelExec *> same_input_partials{};
654 for (auto item : input_partial_pairs) {
655 if (item.second.size() > 1) {
656 for (auto partial : item.second) {
657 (void)same_input_partials.insert(partial);
658 }
659 }
660 }
661 return same_input_partials;
662 }
663
IsolateSameInputPartials(std::vector<kernel::KernelExec * > * dst_kernels)664 int ControlFlowScheduler::IsolateSameInputPartials(std::vector<kernel::KernelExec *> *dst_kernels) {
665 auto same_input_partials = GetSameInputPartials();
666
667 for (auto partial : same_input_partials) {
668 auto subgraph = kernel::KernelExecUtil::BelongToWhichSubGraph(*dst_kernels, partial);
669 MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "can not find belong graph.");
670 kernel::SubGraphKernel *new_subgraph = IsolatePartialInputs(subgraph, partial);
671 MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create new subgraph failed.");
672 new_subgraph->set_name(subgraph->name());
673
674 std::replace(dst_kernels->begin(), dst_kernels->end(), subgraph, new_subgraph);
675 UpdateSubGraphMap(new_subgraph, subgraph);
676
677 subgraph->set_nodes({});
678 delete subgraph;
679 }
680
681 SetSubgraphForPartialNode(partial_kernel_subgraph_index_map_, subgraph_index_subgraph_kernel_map_);
682 return RET_OK;
683 }
684
IsolateInputOfMultipleCalledGraph(std::vector<kernel::KernelExec * > * dst_kernels)685 int ControlFlowScheduler::IsolateInputOfMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels) {
686 auto ret = GetSubGraphsWhichNeedBoundary();
687 MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetSubGraphsWhichNeedBoundary failed.");
688 std::unordered_map<kernel::SubGraphKernel *, kernel::SubGraphKernel *> replace_pair{};
689
690 for (auto &item : subgraphs_need_boundary_) {
691 auto subgraph = item.first;
692 std::vector<kernel::KernelExec *> input_partials{};
693 for (auto input : subgraph->in_nodes()) {
694 MS_CHECK_TRUE_MSG(input->op_parameter() != nullptr, RET_ERROR, "op_parameter is nullptr.");
695 if (input->op_parameter()->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion)) {
696 input_partials.push_back(input);
697 }
698 }
699 kernel::SubGraphKernel *new_subgraph = nullptr;
700 kernel::SubGraphKernel *cur_subgraph = subgraph;
701 for (auto cur_partial : input_partials) {
702 new_subgraph = IsolatePartialInputs(cur_subgraph, cur_partial);
703 MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create new subgraph failed.");
704 new_subgraph->set_name(cur_subgraph->name());
705
706 cur_subgraph->set_nodes({});
707 delete cur_subgraph;
708 cur_subgraph = new_subgraph;
709 }
710
711 if (new_subgraph != nullptr) {
712 replace_pair[subgraph] = new_subgraph;
713 }
714 }
715
716 // update all partial nodes' subgraph
717 for (auto item : replace_pair) {
718 auto old_subgrpah = item.first;
719 auto new_subgraph = item.second;
720 for (auto partial_node : subgraphs_need_boundary_[old_subgrpah]) {
721 auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
722 MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
723 partial_kernel->set_subgraph_kernels({new_subgraph});
724 (void)subgraphs_need_boundary_[new_subgraph].insert(partial_node);
725 }
726 }
727
728 for (auto item : replace_pair) {
729 auto old_subgrpah = item.first;
730 (void)subgraphs_need_boundary_.erase(old_subgrpah);
731 }
732
733 // update all dst_kernels
734 for (auto item : replace_pair) {
735 auto old_subgrpah = item.first;
736 auto new_subgraph = item.second;
737 std::replace(dst_kernels->begin(), dst_kernels->end(), old_subgrpah, new_subgraph);
738 }
739
740 return RET_OK;
741 }
742
SetSubgraphForPartialNode(std::unordered_map<kernel::KernelExec *,size_t> * partial_kernel_subgraph_index_map,std::unordered_map<size_t,kernel::KernelExec * > * subgraph_index_subgraph_kernel_map)743 void ControlFlowScheduler::SetSubgraphForPartialNode(
744 std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map,
745 std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map) {
746 partial_kernel_subgraph_index_map_ = partial_kernel_subgraph_index_map;
747 subgraph_index_subgraph_kernel_map_ = subgraph_index_subgraph_kernel_map;
748
749 for (auto &pair : *partial_kernel_subgraph_index_map) {
750 auto partial_kernel = static_cast<kernel::PartialFusionKernel *>((pair.first)->kernel());
751 auto &subgraph_index = pair.second;
752 partial_kernel->set_subgraph_kernels({subgraph_index_subgraph_kernel_map->at(subgraph_index)});
753 }
754 }
755
UpdateSubGraphMap(kernel::KernelExec * new_subgraph,kernel::KernelExec * old_subgraph)756 void ControlFlowScheduler::UpdateSubGraphMap(kernel::KernelExec *new_subgraph, kernel::KernelExec *old_subgraph) {
757 for (auto &item : *subgraph_index_subgraph_kernel_map_) {
758 if (item.second == old_subgraph) {
759 item.second = new_subgraph;
760 }
761 }
762 return;
763 }
764
RecordLinkInfo(std::vector<kernel::KernelExec * > * dst_kernels)765 int ControlFlowScheduler::RecordLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels) {
766 auto ret = RecordPartialInputLinkInfo();
767 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordPartialInputLinkInfo failed.");
768 ret = this->RecordAllTailCallLinkInfo(dst_kernels);
769 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordAllTailCallLinkInfo failed");
770 ret = this->RecordAllNonTailCallLinkInfo(dst_kernels);
771 MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordAllNonTailCallLinkInfo failed");
772 return RET_OK;
773 }
774
RecordPartialInputLinkInfo()775 int ControlFlowScheduler::RecordPartialInputLinkInfo() {
776 for (auto &pair : *partial_kernel_subgraph_index_map_) {
777 auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>((pair.first)->kernel());
778 MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
779 auto subgraph_kernels = partial_kernel->subgraph_kernels();
780 MS_CHECK_TRUE_MSG(!subgraph_kernels.empty(), RET_ERROR, "partial corresponding subgraph kernels empty.");
781 auto subgraph_kernel = subgraph_kernels.front();
782 MS_CHECK_TRUE_MSG(partial_kernel->in_tensors().size() == subgraph_kernel->in_tensors().size(), RET_ERROR,
783 "partial inputs and corresponding subgraph inputs size not same.");
784 for (size_t i = 0; i < partial_kernel->in_tensors().size(); ++i) {
785 context_->SetLinkInfo(partial_kernel->in_tensors()[i], subgraph_kernel->in_tensors()[i]);
786 }
787 }
788 return RET_OK;
789 }
790
791 #else
792 int ControlFlowScheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) { return RET_OK; }
793 void ControlFlowScheduler::SetSubgraphForPartialNode(
794 std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map,
795 std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map) {
796 return;
797 }
798 void ControlFlowScheduler::RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node) {
799 return;
800 }
801 #endif
802 } // namespace mindspore::lite
803