• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/control_flow/control_flow_scheduler.h"
18 #ifndef CONTROLFLOW_TENSORLIST_CLIP
19 #include <algorithm>
20 #include <set>
21 #include "src/litert/kernel_exec_util.h"
22 #include "src/litert/kernel/cpu/base/partial_fusion.h"
23 #include "nnacl/call_parameter.h"
24 #include "src/control_flow/kernel/exit_subgraph_kernel.h"
25 #include "src/control_flow/kernel/identity_kernel.h"
26 #include "src/tensorlist.h"
27 #include "src/common/prim_inner.h"
28 
29 namespace {
30 const constexpr int kMinNonTailCallCount = 2;
31 }
32 #endif
33 
34 namespace mindspore::lite {
35 #ifndef CONTROLFLOW_TENSORLIST_CLIP
Schedule(std::vector<kernel::KernelExec * > * dst_kernels)36 int ControlFlowScheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
37   auto ret = this->IsolateSameInputPartials(dst_kernels);
38   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
39   ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
40   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
41   ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
42   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
43   ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
44   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "BuildBoundaryForMultipleCalledGraph failed.");
45   ret = this->RecordLinkInfo(dst_kernels);
46   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordLinkInfo failed.");
47   ret = this->SplitNonTailCallSubGraphs(dst_kernels);
48   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
49   return ret;
50 }
51 
SplitNonTailCallSubGraphs(std::vector<kernel::KernelExec * > * dst_kernels)52 int ControlFlowScheduler::SplitNonTailCallSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels) {
53   std::set<kernel::KernelExec *> all_non_tail_subgraphs = GetNonTailCallSubGraphs(dst_kernels);
54   for (auto item : all_non_tail_subgraphs) {
55     to_process_q_.push(item);
56   }
57 
58   while (!to_process_q_.empty()) {
59     auto cur = to_process_q_.front();
60     to_process_q_.pop();
61     auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(cur);
62     if (subgraph_kernel == nullptr) {
63       MS_LOG(ERROR) << "kernel is not a subgraph kernel";
64       return RET_ERROR;
65     }
66     std::vector<kernel::KernelExec *> new_subgraphs{};
67     auto ret = SplitSingleNonTailCallSubGraph(subgraph_kernel, &new_subgraphs);
68     if (ret != RET_OK) {
69       MS_LOG(ERROR) << "SplitSingleNonTailCallSubGraph failed, ret: " << ret;
70       return ret;
71     }
72     // append dst_kernels
73     (void)std::copy(new_subgraphs.begin(), new_subgraphs.end(), std::back_inserter(*dst_kernels));
74     // update partial_kernel_map
75     for (auto &item : *partial_kernel_subgraph_index_map_) {
76       auto &partial_node = item.first;
77       auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
78       MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
79       auto subgraphs = partial_kernel->subgraph_kernels();
80       auto iter = std::find(subgraphs.begin(), subgraphs.end(), subgraph_kernel);
81       if (iter == subgraphs.end()) {
82         continue;
83       }
84       (void)subgraphs.erase(iter);
85       for (auto &new_subgraph : new_subgraphs) {
86         (void)subgraphs.insert(iter, new_subgraph);
87       }
88       partial_kernel->set_subgraph_kernels(subgraphs);
89     }
90     AppendToProcessQ(&new_subgraphs, &all_non_tail_subgraphs);
91   }
92 
93   RemoveUselessKernels(dst_kernels, &all_non_tail_subgraphs);
94 
95   return RET_OK;
96 }
97 
GetNonTailCallSubGraphs(std::vector<kernel::KernelExec * > * dst_kernels)98 std::set<kernel::KernelExec *> ControlFlowScheduler::GetNonTailCallSubGraphs(
99   std::vector<kernel::KernelExec *> *dst_kernels) {
100   std::set<kernel::KernelExec *> non_tail_subgraph_kernels{};
101 
102   // found non-tail call subgraph
103   for (auto &kernel : *dst_kernels) {
104     if (kernel->desc().arch == kernel::kDelegate) {
105       continue;
106     }
107     auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
108     if (subgraph_kernel == nullptr) {
109       continue;
110     }
111     if (!kernel::KernelExecUtil::IsNonTailCallSubGraph(subgraph_kernel)) {
112       continue;
113     }
114     (void)non_tail_subgraph_kernels.insert(kernel);
115   }
116   return non_tail_subgraph_kernels;
117 }
118 
AdjustNodesForTailCallSubGraph(std::vector<kernel::KernelExec * > * first_part_nodes,std::vector<kernel::KernelExec * > * second_part_nodes)119 int ControlFlowScheduler::AdjustNodesForTailCallSubGraph(std::vector<kernel::KernelExec *> *first_part_nodes,
120                                                          std::vector<kernel::KernelExec *> *second_part_nodes) {
121   auto tail_call = second_part_nodes->back();
122   std::vector<kernel::KernelExec *> all_need_nodes{};
123   (void)std::copy(tail_call->in_kernels().begin(), tail_call->in_kernels().end(), std::back_inserter(all_need_nodes));
124   auto partials = kernel::KernelExecUtil::GetCallInputPartials(tail_call);
125   (void)std::copy(partials.begin(), partials.end(), std::back_inserter(all_need_nodes));
126   for (auto partial : partials) {
127     for (auto input : partial->in_kernels()) {
128       MS_CHECK_TRUE_MSG(input != nullptr, RET_ERROR, "input is nullptr");
129       auto parameter = input->op_parameter();
130       MS_CHECK_TRUE_MSG(parameter != nullptr, RET_ERROR, "parameter is nullptr");
131       if (parameter->type_ == static_cast<int>(PRIM_IDENTITY)) {
132         all_need_nodes.push_back(input);
133       }
134     }
135   }
136 
137   for (auto need : all_need_nodes) {
138     if (IsContain(*second_part_nodes, need)) {
139       continue;
140     }
141     auto is_need = [&need](const kernel::KernelExec *node) { return node == need; };
142     auto iter = std::find_if(first_part_nodes->begin(), first_part_nodes->end(), is_need);
143     MS_CHECK_TRUE_MSG(iter != first_part_nodes->end(), RET_ERROR, "graph is not right");
144     (void)second_part_nodes->insert(second_part_nodes->begin(), *iter);
145     (void)first_part_nodes->erase(iter);
146   }
147   return RET_OK;
148 }
149 
SplitSubGraphNodesIntoTwoParts(kernel::SubGraphKernel * subgraph_kernel,std::vector<kernel::KernelExec * > * first_part_nodes,std::vector<kernel::KernelExec * > * second_part_nodes)150 int ControlFlowScheduler::SplitSubGraphNodesIntoTwoParts(kernel::SubGraphKernel *subgraph_kernel,
151                                                          std::vector<kernel::KernelExec *> *first_part_nodes,
152                                                          std::vector<kernel::KernelExec *> *second_part_nodes) {
153   auto nodes = subgraph_kernel->nodes();
154 
155   // get the position of the last non-tail call op.
156   auto is_non_tail_call = [](const kernel::KernelExec *node) { return kernel::KernelExecUtil::IsNonTailCall(node); };
157   auto last_non_tail_call_iter = std::find_if(nodes.rbegin(), nodes.rend(), is_non_tail_call);
158   auto distance = nodes.rend() - last_non_tail_call_iter;
159   if (distance == 0) {
160     MS_LOG(ERROR) << "not is a non tail call subgraph.";
161     return RET_ERROR;
162   }
163 
164   // change last non-tail call property as is tail call
165   MS_CHECK_TRUE_MSG(*last_non_tail_call_iter != nullptr, RET_ERROR, "last_non_tail_call_iter is nullptr");
166   auto parameter = reinterpret_cast<CallParameter *>((*last_non_tail_call_iter)->op_parameter());
167   if (parameter == nullptr) {
168     MS_LOG(ERROR) << "parameter is nullptr";
169     return RET_ERROR;
170   }
171   parameter->is_tail_call = true;
172 
173   for (auto iter = nodes.begin(); iter != nodes.begin() + distance; ++iter) {
174     first_part_nodes->push_back(*iter);
175   }
176 
177   for (auto iter = nodes.begin() + distance; iter != nodes.end(); ++iter) {
178     second_part_nodes->push_back(*iter);
179   }
180 
181   // if second part nodes contains call node, we need call node input partials and partials' inputs.
182   if (kernel::KernelExecUtil::IsTailCall(second_part_nodes->back())) {
183     auto ret = AdjustNodesForTailCallSubGraph(first_part_nodes, second_part_nodes);
184     MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "AdjustNodesForTailCallSubGraph failed.");
185   }
186   return RET_OK;
187 }
188 
SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel * subgraph_kernel,std::vector<kernel::KernelExec * > * subgraph_kernels)189 int ControlFlowScheduler::SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel *subgraph_kernel,
190                                                          std::vector<kernel::KernelExec *> *subgraph_kernels) {
191   std::vector<kernel::KernelExec *> first_part_nodes{};
192   std::vector<kernel::KernelExec *> second_part_nodes{};
193 
194   auto ret = SplitSubGraphNodesIntoTwoParts(subgraph_kernel, &first_part_nodes, &second_part_nodes);
195   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitSubGraphNodesIntoTwoParts failed.");
196 
197   auto cur_subgraph_type = subgraph_kernel->subgraph_type();
198   auto first_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel(first_part_nodes, nullptr, nullptr,
199                                                                      cur_subgraph_type, *context_, schema_version_);
200   subgraph_kernels->push_back(first_subgraph);
201 
202   auto second_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel(second_part_nodes, nullptr, nullptr,
203                                                                       cur_subgraph_type, *context_, schema_version_);
204   subgraph_kernels->push_back(second_subgraph);
205   return RET_OK;
206 }
207 
RemoveUselessKernels(std::vector<kernel::KernelExec * > * dst_kernels,std::set<kernel::KernelExec * > * useless_kernels)208 void ControlFlowScheduler::RemoveUselessKernels(std::vector<kernel::KernelExec *> *dst_kernels,
209                                                 std::set<kernel::KernelExec *> *useless_kernels) {
210   for (auto iter = dst_kernels->begin(); iter != dst_kernels->end();) {
211     if (useless_kernels->find(*iter) != useless_kernels->end()) {
212       iter = dst_kernels->erase(iter);
213     } else {
214       iter++;
215     }
216   }
217 
218   for (auto &kernel : *useless_kernels) {
219     auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
220     if (subgraph_kernel == nullptr) {
221       continue;
222     }
223     subgraph_kernel->set_nodes({});
224     delete subgraph_kernel;
225   }
226   useless_kernels->clear();
227 
228   return;
229 }
230 
AppendToProcessQ(std::vector<kernel::KernelExec * > * new_subgraphs,std::set<kernel::KernelExec * > * all_non_tail_subgraphs)231 void ControlFlowScheduler::AppendToProcessQ(std::vector<kernel::KernelExec *> *new_subgraphs,
232                                             std::set<kernel::KernelExec *> *all_non_tail_subgraphs) {
233   auto new_non_tail_call_subgraphs = GetNonTailCallSubGraphs(new_subgraphs);
234   for (auto &item : new_non_tail_call_subgraphs) {
235     if (all_non_tail_subgraphs->find(item) == all_non_tail_subgraphs->end()) {
236       to_process_q_.push(item);
237       (void)all_non_tail_subgraphs->insert(item);
238     }
239   }
240   return;
241 }
242 
RecordNonTailCallLinkInfo(kernel::KernelExec * non_tail_call)243 int ControlFlowScheduler::RecordNonTailCallLinkInfo(kernel::KernelExec *non_tail_call) {
244   size_t non_tail_call_output_size = non_tail_call->out_tensors().size();
245   auto partial_nodes = kernel::KernelExecUtil::GetCallInputPartials(non_tail_call);
246   for (auto node : partial_nodes) {
247     auto partial_node = reinterpret_cast<kernel::PartialFusionKernel *>(node->kernel());
248     MS_CHECK_TRUE_MSG(partial_node != nullptr, RET_ERROR, "node cast to partial node failed.");
249     auto kernels = partial_node->subgraph_kernels();
250     MS_CHECK_TRUE_MSG(!kernels.empty(), RET_ERROR, "partial subgraph kernels empty.");
251     auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernels.back());
252     MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "partial node's subgraph kernel is nullptr.");
253     if (kernel::KernelExecUtil::IsTailCallSubGraph(subgraph)) {
254       std::queue<kernel::KernelExec *> tail_call_q{};
255       tail_call_q.push(subgraph->out_nodes().front());
256       std::vector<kernel::KernelExec *> final_graphs{};
257       std::set<kernel::KernelExec *> reviewed_graphs{};
258       auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs);
259       MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetTailCallFinalSubgraphs failed.");
260       for (auto item : final_graphs) {
261         MS_CHECK_TRUE_MSG(item->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
262                           "subgraph outputs and corresponding call outputs size not same.");
263         for (size_t i = 0; i < non_tail_call_output_size; ++i) {
264           context_->SetLinkInfo(item->out_tensors()[i], non_tail_call->out_tensors()[i]);
265         }
266       }
267     } else {
268       MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
269                         "partial inputs and corresponding call outputs size not same.");
270       for (size_t i = 0; i < non_tail_call_output_size; ++i) {
271         context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]);
272       }
273     }
274   }
275   return RET_OK;
276 }
277 
RecordAllNonTailCallLinkInfo(std::vector<kernel::KernelExec * > * dst_kernels)278 int ControlFlowScheduler::RecordAllNonTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels) {
279   for (auto dst_kernel : *dst_kernels) {
280     auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(dst_kernel);
281     MS_CHECK_TRUE_MSG(subgraph_kernel != nullptr, RET_ERROR, "node cast to subgraph kernel failed.");
282     for (auto node : subgraph_kernel->nodes()) {
283       if (kernel::KernelExecUtil::IsNonTailCall(node)) {
284         non_tail_calls_.push_back(node);
285       }
286     }
287   }
288 
289   for (auto non_tail_call : non_tail_calls_) {
290     auto ret = RecordNonTailCallLinkInfo(non_tail_call);
291     MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordNonTailCallLinkInfo, failed");
292   }
293   return RET_OK;
294 }
295 
RecordSubgraphCaller(const size_t & subgraph_index,kernel::KernelExec * partial_node)296 void ControlFlowScheduler::RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node) {
297   if (more_than_once_called_partial_nodes_.find(subgraph_index) == more_than_once_called_partial_nodes_.end()) {
298     std::set<kernel::KernelExec *> tmp_set{partial_node};
299     (void)more_than_once_called_partial_nodes_.insert(
300       std::pair<size_t, std::set<kernel::KernelExec *>>{subgraph_index, tmp_set});
301   } else {
302     (void)more_than_once_called_partial_nodes_[subgraph_index].insert(partial_node);
303   }
304 }
305 
CreateEntranceSubGraph(kernel::SubGraphKernel * subgraph,lite::Tensor * link_tensor)306 kernel::SubGraphKernel *ControlFlowScheduler::CreateEntranceSubGraph(kernel::SubGraphKernel *subgraph,
307                                                                      lite::Tensor *link_tensor) {
308   if (subgraph == nullptr || link_tensor == nullptr) {
309     MS_LOG(ERROR) << "input is nullptr.";
310     return nullptr;
311   }
312   size_t in_tensor_size = subgraph->in_tensors().size();
313   std::vector<Tensor *> old_input_tensors{};
314   // entrance subgraph kernel first output tensor is the first input of the corresponding exit subgraph kernel.
315   std::vector<Tensor *> new_input_tensors{link_tensor};
316   for (size_t i = 0; i < in_tensor_size; i++) {
317     Tensor *old_tensor = subgraph->in_tensors()[i];
318     old_input_tensors.push_back(old_tensor);
319     auto allocator = old_tensor->allocator();
320     auto new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
321     if (new_tensor == nullptr) {
322       MS_LOG(ERROR) << "new Tensor failed.";
323       return nullptr;
324     }
325     src_tensors_->push_back(new_tensor);
326     new_input_tensors.push_back(new_tensor);
327     auto ret = kernel::KernelExecUtil::ReplaceSubGraphNodesInTensor(subgraph, old_tensor, new_tensor);
328     MS_CHECK_FALSE_MSG(ret != RET_OK, nullptr, "ReplaceSubGraphNodesInTensor failed.");
329     subgraph->set_in_tensor(new_tensor, i);
330   }
331   auto entrance_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel(
332     {}, &old_input_tensors, &new_input_tensors, kernel::kEntranceSubGraph, *context_, schema_version_);
333   return entrance_subgraph;
334 }
335 
CreateExitSubGraph(kernel::SubGraphKernel * subgraph,lite::Tensor * link_tensor)336 kernel::SubGraphKernel *ControlFlowScheduler::CreateExitSubGraph(kernel::SubGraphKernel *subgraph,
337                                                                  lite::Tensor *link_tensor) {
338   if (subgraph == nullptr || link_tensor == nullptr) {
339     MS_LOG(ERROR) << "input is nullptr.";
340     return nullptr;
341   }
342   size_t out_tensor_size = subgraph->out_tensors().size();
343   std::vector<Tensor *> old_output_tensors{};
344   // exit subgraph kernel first input tensor is the first output of the corresponding entrance subgraph kernel.
345   std::vector<Tensor *> new_output_tensors{link_tensor};
346   for (size_t i = 0; i < out_tensor_size; i++) {
347     Tensor *old_tensor = subgraph->out_tensors()[i];
348     old_output_tensors.push_back(old_tensor);
349     auto allocator = old_tensor->allocator();
350     auto new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
351     if (new_tensor == nullptr) {
352       MS_LOG(ERROR) << "new Tensor failed.";
353       return nullptr;
354     }
355     src_tensors_->push_back(new_tensor);
356     new_output_tensors.push_back(new_tensor);
357     (void)kernel::KernelExecUtil::ReplaceSubGraphNodesOutTensor(subgraph, old_tensor, new_tensor);
358     subgraph->set_out_tensor(new_tensor, i);
359   }
360   auto exit_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel({}, &new_output_tensors, &old_output_tensors,
361                                                                     kernel::kExitSubGraph, *context_, schema_version_);
362   return exit_subgraph;
363 }
364 
AddOutputKernel(kernel::SubGraphKernel * subgraph)365 kernel::SubGraphKernel *ControlFlowScheduler::AddOutputKernel(kernel::SubGraphKernel *subgraph) {
366   auto inputs = subgraph->in_tensors();
367   auto outputs = subgraph->out_tensors();
368   auto nodes = subgraph->nodes();
369 
370   auto call_node = subgraph->out_nodes().front();
371   reinterpret_cast<CallParameter *>(call_node->op_parameter())->is_tail_call = false;
372 
373   size_t out_tensor_size = call_node->out_tensors().size();
374   std::vector<Tensor *> old_output_tensors{};
375   std::vector<Tensor *> new_output_tensors{};
376   for (size_t i = 0; i < out_tensor_size; i++) {
377     Tensor *old_tensor = subgraph->out_tensors()[i];
378     old_output_tensors.push_back(old_tensor);
379     auto allocator = old_tensor->allocator();
380     auto new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
381     if (new_tensor == nullptr) {
382       MS_LOG(ERROR) << "new Tensor failed.";
383       return nullptr;
384     }
385     src_tensors_->push_back(new_tensor);
386     new_output_tensors.push_back(new_tensor);
387     (void)kernel::KernelExecUtil::ReplaceSubGraphNodesOutTensor(subgraph, old_tensor, new_tensor);
388     call_node->set_out_tensor(new_tensor, i);
389     context_->ReplaceLinkInfoReceiverWithNewOne(new_tensor, old_tensor);
390   }
391   auto output_node = kernel::IdentityKernel::Create(new_output_tensors, old_output_tensors, this->context_);
392   MS_CHECK_FALSE_MSG(output_node == nullptr, nullptr, "Create Identity failed.");
393   output_node->set_name(call_node->name() + "_output");
394   kernel::KernelKey output_desc = call_node->desc();
395   output_desc.type = PrimType_Inner_Identity;
396   output_node->set_desc(output_desc);
397   output_node->AddInKernel(call_node);
398   call_node->AddOutKernel(output_node);
399   nodes.push_back(output_node);
400   auto subgraph_type = subgraph->subgraph_type();
401   auto new_subgraph =
402     kernel::KernelExecUtil::CreateSubGraphKernel(nodes, &inputs, &outputs, subgraph_type, *context_, schema_version_);
403   return new_subgraph;
404 }
405 
GetSubGraphsWhichNeedBoundary()406 int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() {
407   // among the more than once call subgraphs, if one of it's corresponding partial nodes' call node is non-tail call.
408   for (auto item : more_than_once_called_partial_nodes_) {
409     if (item.second.size() == 1) {
410       MS_LOG(DEBUG) << "subgraph call only once.";
411       continue;
412     }
413     auto node = item.second.begin();
414     kernel::PartialFusionKernel *partial = reinterpret_cast<kernel::PartialFusionKernel *>((*node)->kernel());
415     MS_CHECK_TRUE_MSG(partial != nullptr, RET_ERROR, "cast to partial node failed.");
416     auto aim_kernels = partial->subgraph_kernels();
417     MS_CHECK_TRUE_MSG(aim_kernels.size() == 1, RET_ERROR, "partial subgraph kernels size not right.");
418     auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(aim_kernels.front());
419     MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "subgraph is nullptr");
420 
421     std::vector<kernel::KernelExec *> all_call_nodes{};
422     for (auto partial_node : item.second) {
423       auto call_node = kernel::KernelExecUtil::GetPartialOutputCall(partial_node);
424       if (call_node == nullptr) {
425         MS_LOG(ERROR) << "call_node is nullptr.";
426         return RET_ERROR;
427       }
428       all_call_nodes.push_back(call_node);
429     }
430 
431     // non-tail call size less than 2, continue
432     int non_tail_call_size = 0;
433     for (auto call_node : all_call_nodes) {
434       if (kernel::KernelExecUtil::IsNonTailCall(call_node)) {
435         non_tail_call_size++;
436       }
437     }
438     if (non_tail_call_size < kMinNonTailCallCount) {
439       MS_LOG(DEBUG) << "no need to build boundary.";
440       continue;
441     }
442     for (auto partial_node : item.second) {
443       (void)subgraphs_need_boundary_[subgraph].insert(partial_node);
444     }
445   }
446   return RET_OK;
447 }
448 
BuildBoundaryForMultipleCalledGraph(std::vector<kernel::KernelExec * > * dst_kernels)449 int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels) {
450   for (auto &item : subgraphs_need_boundary_) {
451     auto subgraph = item.first;
452     // new link tensor
453     auto link_tensor = new Tensor(kNumberTypeFloat32, {1});
454     if (link_tensor == nullptr) {
455       MS_LOG(ERROR) << "";
456       return RET_NULL_PTR;
457     }
458     link_tensor->set_tensor_name(subgraph->name() + "_link_tensor");
459     link_tensor->set_category(Category::CONST_TENSOR);
460     src_tensors_->push_back(link_tensor);
461 
462     auto entrance_subgraph = CreateEntranceSubGraph(subgraph, link_tensor);
463     if (entrance_subgraph == nullptr) {
464       MS_LOG(ERROR) << "create entrance subgraph failed.";
465       return RET_NULL_PTR;
466     }
467     entrance_subgraph->set_name(subgraph->name() + "_entrance");
468     dst_kernels->push_back(entrance_subgraph);
469 
470     auto exit_subgraph = CreateExitSubGraph(subgraph, link_tensor);
471     if (exit_subgraph == nullptr) {
472       MS_LOG(ERROR) << "create exit subgraph failed.";
473       return RET_NULL_PTR;
474     }
475     exit_subgraph->set_name(subgraph->name() + "_exit");
476     dst_kernels->push_back(exit_subgraph);
477 
478     // update partial's subgraph kernels
479     std::vector<kernel::KernelExec *> subgraph_kernels{};
480     subgraph_kernels.push_back(entrance_subgraph);
481     subgraph_kernels.push_back(subgraph);
482     subgraph_kernels.push_back(exit_subgraph);
483 
484     // record partial nodes of this subgraph.
485     auto exit_subgraph_kernel = reinterpret_cast<kernel::ExitSubGraphKernel *>(exit_subgraph);
486     for (auto partial_node : item.second) {
487       exit_subgraph_kernel->SetPartial(partial_node);
488       auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
489       MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
490       partial_kernel->set_subgraph_kernels(subgraph_kernels);
491     }
492   }
493   return RET_OK;
494 }
495 
IsolateOutputForCallOutputGraph(std::vector<kernel::KernelExec * > * dst_kernels)496 int ControlFlowScheduler::IsolateOutputForCallOutputGraph(std::vector<kernel::KernelExec *> *dst_kernels) {
497   kernel::KernelExec *main_graph_kernel = dst_kernels->front();
498   if (!kernel::KernelExecUtil::IsOutputSubGraph(main_graph_kernel)) {
499     MS_LOG(DEBUG) << "Not is output graph.";
500     return RET_OK;
501   }
502 
503   auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(main_graph_kernel);
504   MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "cast to subgraph failed.");
505   if (!(subgraph->out_nodes().size() == 1 && subgraph->out_nodes().front()->type() == schema::PrimitiveType_Call)) {
506     MS_LOG(DEBUG) << "main graph output is not call node.";
507     return RET_OK;
508   }
509 
510   auto new_subgraph = AddOutputKernel(subgraph);
511   MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create output subgraph failed.");
512   new_subgraph->set_name(subgraph->name());
513   std::replace(dst_kernels->begin(), dst_kernels->end(), subgraph, new_subgraph);
514 
515   subgraph->set_nodes({});
516   delete subgraph;
517   return RET_OK;
518 }
519 
GetTailCallFinalSubgraphs(std::queue<kernel::KernelExec * > * tail_call_q,std::vector<kernel::KernelExec * > * final_graphs,std::set<kernel::KernelExec * > reviewed_graphs)520 int ControlFlowScheduler::GetTailCallFinalSubgraphs(std::queue<kernel::KernelExec *> *tail_call_q,
521                                                     std::vector<kernel::KernelExec *> *final_graphs,
522                                                     std::set<kernel::KernelExec *> reviewed_graphs) {
523   if (tail_call_q->empty()) {
524     return RET_OK;
525   }
526   auto tail_call = tail_call_q->front();
527   tail_call_q->pop();
528   auto partials = kernel::KernelExecUtil::GetCallInputPartials(tail_call);
529   for (auto partial : partials) {
530     auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial->kernel());
531     MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
532     // only get the output subgraph, the last subgraph is the output subgraph.
533     auto subgraphs = partial_kernel->subgraph_kernels();
534     auto subgraph = subgraphs.back();
535     auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
536     MS_CHECK_TRUE_MSG(subgraph_kernel != nullptr, RET_ERROR, "cast to subgraph kernel failed.");
537     if (kernel::KernelExecUtil::IsTailCallSubGraph(subgraph_kernel)) {
538       if (reviewed_graphs.find(subgraph_kernel) == reviewed_graphs.end()) {
539         tail_call_q->push(subgraph_kernel->out_nodes().front());
540       }
541     } else {
542       final_graphs->push_back(subgraph);
543     }
544     (void)reviewed_graphs.insert(subgraph);
545   }
546   return GetTailCallFinalSubgraphs(tail_call_q, final_graphs, reviewed_graphs);
547 }
548 
RecordTailCallLinkInfo(kernel::KernelExec * tail_call)549 int ControlFlowScheduler::RecordTailCallLinkInfo(kernel::KernelExec *tail_call) {
550   std::queue<kernel::KernelExec *> tail_call_q{};
551   tail_call_q.push(tail_call);
552   std::vector<kernel::KernelExec *> final_graphs{};
553   std::set<kernel::KernelExec *> reviewed_graphs{};
554   auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs);
555   if (ret != RET_OK) {
556     MS_LOG(ERROR) << "GetTailCallFinalSubgraphs failed.";
557     return ret;
558   }
559 
560   if (std::any_of(final_graphs.begin(), final_graphs.end(), [&tail_call](const kernel::KernelExec *item) {
561         return item->out_tensors().size() != tail_call->out_tensors().size();
562       })) {
563     MS_LOG(DEBUG) << "not is mindir model, return ok.";
564     return RET_OK;
565   }
566 
567   for (auto final_graph : final_graphs) {
568     for (size_t i = 0; i < final_graph->out_tensors().size(); ++i) {
569       context_->SetLinkInfo(final_graph->out_tensors()[i], tail_call->out_tensors()[i]);
570     }
571   }
572   return RET_OK;
573 }
574 
RecordAllTailCallLinkInfo(std::vector<kernel::KernelExec * > * dst_kernels)575 int ControlFlowScheduler::RecordAllTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels) {
576   std::vector<kernel::KernelExec *> all_tail_calls{};
577   for (auto dst_kernel : *dst_kernels) {
578     auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(dst_kernel);
579     if (kernel::KernelExecUtil::IsTailCallSubGraph(subgraph_kernel)) {
580       all_tail_calls.push_back(subgraph_kernel->out_nodes().front());
581     }
582   }
583 
584   for (auto tail_call : all_tail_calls) {
585     auto ret = RecordTailCallLinkInfo(tail_call);
586     MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordTailCallLinkInfo, failed");
587   }
588   return RET_OK;
589 }
590 
IsolatePartialInputs(kernel::SubGraphKernel * subgraph,kernel::KernelExec * partial)591 kernel::SubGraphKernel *ControlFlowScheduler::IsolatePartialInputs(kernel::SubGraphKernel *subgraph,
592                                                                    kernel::KernelExec *partial) {
593   auto inputs = subgraph->in_tensors();
594   auto outputs = subgraph->out_tensors();
595   auto nodes = subgraph->nodes();
596 
597   auto old_partial_inputs = partial->in_tensors();
598 
599   std::vector<Tensor *> new_partial_inputs{};
600   for (size_t i = 0; i < old_partial_inputs.size(); i++) {
601     Tensor *old_tensor = old_partial_inputs[i];
602     auto allocator = old_tensor->allocator();
603     Tensor *new_tensor = nullptr;
604     if (old_tensor->data_type() == kObjectTypeTensorType) {
605       auto old_tensor_list = reinterpret_cast<TensorList *>(old_tensor);
606       new_tensor = TensorList::CopyTensorList(*old_tensor_list, false, allocator);
607     } else {
608       new_tensor = Tensor::CopyTensor(*old_tensor, false, allocator);
609     }
610     MS_CHECK_TRUE_MSG(new_tensor != nullptr, nullptr, "new tensor failed.");
611     new_tensor->set_category(VAR);
612     partial->set_in_tensor(new_tensor, i);
613     src_tensors_->push_back(new_tensor);
614     new_partial_inputs.push_back(new_tensor);
615   }
616   auto identity_node = kernel::IdentityKernel::Create(old_partial_inputs, new_partial_inputs, this->context_);
617   MS_CHECK_TRUE_MSG(identity_node != nullptr, nullptr, "Create Identity kernel failed.");
618   identity_node->set_name(partial->name() + "_input_identity");
619   kernel::KernelKey identity_desc = partial->desc();
620   identity_desc.type = PrimType_Inner_Identity;
621   identity_node->set_desc(identity_desc);
622   // update identity and partial in kernels and out kernels
623   for (auto partial_in_kernel : partial->in_kernels()) {
624     auto output_kernels = partial_in_kernel->out_kernels();
625     std::replace(output_kernels.begin(), output_kernels.end(), partial, identity_node);
626     partial_in_kernel->set_out_kernels(output_kernels);
627     identity_node->AddInKernel(partial_in_kernel);
628   }
629   identity_node->AddOutKernel(partial);
630   partial->set_in_kernels({identity_node});
631   auto partial_iter = std::find(nodes.begin(), nodes.end(), partial);
632   (void)nodes.insert(partial_iter, identity_node);
633   auto subgraph_type = subgraph->subgraph_type();
634   auto new_subgraph =
635     kernel::KernelExecUtil::CreateSubGraphKernel(nodes, &inputs, &outputs, subgraph_type, *context_, schema_version_);
636   return new_subgraph;
637 }
638 
GetSameInputPartials()639 std::set<kernel::KernelExec *> ControlFlowScheduler::GetSameInputPartials() {
640   std::unordered_map<Tensor *, std::set<kernel::KernelExec *>> input_partial_pairs{};
641   for (auto item : *partial_kernel_subgraph_index_map_) {
642     for (auto input : item.first->in_tensors()) {
643       if (input_partial_pairs.find(input) == input_partial_pairs.end()) {
644         std::set<kernel::KernelExec *> partials{};
645         (void)partials.insert(item.first);
646         input_partial_pairs[input] = partials;
647       } else {
648         (void)input_partial_pairs[input].insert(item.first);
649       }
650     }
651   }
652 
653   std::set<kernel::KernelExec *> same_input_partials{};
654   for (auto item : input_partial_pairs) {
655     if (item.second.size() > 1) {
656       for (auto partial : item.second) {
657         (void)same_input_partials.insert(partial);
658       }
659     }
660   }
661   return same_input_partials;
662 }
663 
IsolateSameInputPartials(std::vector<kernel::KernelExec * > * dst_kernels)664 int ControlFlowScheduler::IsolateSameInputPartials(std::vector<kernel::KernelExec *> *dst_kernels) {
665   auto same_input_partials = GetSameInputPartials();
666 
667   for (auto partial : same_input_partials) {
668     auto subgraph = kernel::KernelExecUtil::BelongToWhichSubGraph(*dst_kernels, partial);
669     MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "can not find belong graph.");
670     kernel::SubGraphKernel *new_subgraph = IsolatePartialInputs(subgraph, partial);
671     MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create new subgraph failed.");
672     new_subgraph->set_name(subgraph->name());
673 
674     std::replace(dst_kernels->begin(), dst_kernels->end(), subgraph, new_subgraph);
675     UpdateSubGraphMap(new_subgraph, subgraph);
676 
677     subgraph->set_nodes({});
678     delete subgraph;
679   }
680 
681   SetSubgraphForPartialNode(partial_kernel_subgraph_index_map_, subgraph_index_subgraph_kernel_map_);
682   return RET_OK;
683 }
684 
IsolateInputOfMultipleCalledGraph(std::vector<kernel::KernelExec * > * dst_kernels)685 int ControlFlowScheduler::IsolateInputOfMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels) {
686   auto ret = GetSubGraphsWhichNeedBoundary();
687   MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetSubGraphsWhichNeedBoundary failed.");
688   std::unordered_map<kernel::SubGraphKernel *, kernel::SubGraphKernel *> replace_pair{};
689 
690   for (auto &item : subgraphs_need_boundary_) {
691     auto subgraph = item.first;
692     std::vector<kernel::KernelExec *> input_partials{};
693     for (auto input : subgraph->in_nodes()) {
694       MS_CHECK_TRUE_MSG(input->op_parameter() != nullptr, RET_ERROR, "op_parameter is nullptr.");
695       if (input->op_parameter()->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion)) {
696         input_partials.push_back(input);
697       }
698     }
699     kernel::SubGraphKernel *new_subgraph = nullptr;
700     kernel::SubGraphKernel *cur_subgraph = subgraph;
701     for (auto cur_partial : input_partials) {
702       new_subgraph = IsolatePartialInputs(cur_subgraph, cur_partial);
703       MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create new subgraph failed.");
704       new_subgraph->set_name(cur_subgraph->name());
705 
706       cur_subgraph->set_nodes({});
707       delete cur_subgraph;
708       cur_subgraph = new_subgraph;
709     }
710 
711     if (new_subgraph != nullptr) {
712       replace_pair[subgraph] = new_subgraph;
713     }
714   }
715 
716   // update all partial nodes' subgraph
717   for (auto item : replace_pair) {
718     auto old_subgrpah = item.first;
719     auto new_subgraph = item.second;
720     for (auto partial_node : subgraphs_need_boundary_[old_subgrpah]) {
721       auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
722       MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
723       partial_kernel->set_subgraph_kernels({new_subgraph});
724       (void)subgraphs_need_boundary_[new_subgraph].insert(partial_node);
725     }
726   }
727 
728   for (auto item : replace_pair) {
729     auto old_subgrpah = item.first;
730     (void)subgraphs_need_boundary_.erase(old_subgrpah);
731   }
732 
733   // update all dst_kernels
734   for (auto item : replace_pair) {
735     auto old_subgrpah = item.first;
736     auto new_subgraph = item.second;
737     std::replace(dst_kernels->begin(), dst_kernels->end(), old_subgrpah, new_subgraph);
738   }
739 
740   return RET_OK;
741 }
742 
SetSubgraphForPartialNode(std::unordered_map<kernel::KernelExec *,size_t> * partial_kernel_subgraph_index_map,std::unordered_map<size_t,kernel::KernelExec * > * subgraph_index_subgraph_kernel_map)743 void ControlFlowScheduler::SetSubgraphForPartialNode(
744   std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map,
745   std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map) {
746   partial_kernel_subgraph_index_map_ = partial_kernel_subgraph_index_map;
747   subgraph_index_subgraph_kernel_map_ = subgraph_index_subgraph_kernel_map;
748 
749   for (auto &pair : *partial_kernel_subgraph_index_map) {
750     auto partial_kernel = static_cast<kernel::PartialFusionKernel *>((pair.first)->kernel());
751     auto &subgraph_index = pair.second;
752     partial_kernel->set_subgraph_kernels({subgraph_index_subgraph_kernel_map->at(subgraph_index)});
753   }
754 }
755 
UpdateSubGraphMap(kernel::KernelExec * new_subgraph,kernel::KernelExec * old_subgraph)756 void ControlFlowScheduler::UpdateSubGraphMap(kernel::KernelExec *new_subgraph, kernel::KernelExec *old_subgraph) {
757   for (auto &item : *subgraph_index_subgraph_kernel_map_) {
758     if (item.second == old_subgraph) {
759       item.second = new_subgraph;
760     }
761   }
762   return;
763 }
764 
RecordLinkInfo(std::vector<kernel::KernelExec * > * dst_kernels)765 int ControlFlowScheduler::RecordLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels) {
766   auto ret = RecordPartialInputLinkInfo();
767   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordPartialInputLinkInfo failed.");
768   ret = this->RecordAllTailCallLinkInfo(dst_kernels);
769   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordAllTailCallLinkInfo failed");
770   ret = this->RecordAllNonTailCallLinkInfo(dst_kernels);
771   MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordAllNonTailCallLinkInfo failed");
772   return RET_OK;
773 }
774 
RecordPartialInputLinkInfo()775 int ControlFlowScheduler::RecordPartialInputLinkInfo() {
776   for (auto &pair : *partial_kernel_subgraph_index_map_) {
777     auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>((pair.first)->kernel());
778     MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
779     auto subgraph_kernels = partial_kernel->subgraph_kernels();
780     MS_CHECK_TRUE_MSG(!subgraph_kernels.empty(), RET_ERROR, "partial corresponding subgraph kernels empty.");
781     auto subgraph_kernel = subgraph_kernels.front();
782     MS_CHECK_TRUE_MSG(partial_kernel->in_tensors().size() == subgraph_kernel->in_tensors().size(), RET_ERROR,
783                       "partial inputs and corresponding subgraph inputs size not same.");
784     for (size_t i = 0; i < partial_kernel->in_tensors().size(); ++i) {
785       context_->SetLinkInfo(partial_kernel->in_tensors()[i], subgraph_kernel->in_tensors()[i]);
786     }
787   }
788   return RET_OK;
789 }
790 
791 #else
792 int ControlFlowScheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) { return RET_OK; }
793 void ControlFlowScheduler::SetSubgraphForPartialNode(
794   std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map,
795   std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map) {
796   return;
797 }
798 void ControlFlowScheduler::RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node) {
799   return;
800 }
801 #endif
802 }  // namespace mindspore::lite
803