• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_CONTROL_FLOW_CONTROL_FLOW_SCHEDULER_H_
18 #define MINDSPORE_LITE_SRC_CONTROL_FLOW_CONTROL_FLOW_SCHEDULER_H_
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <utility>
23 #include <algorithm>
24 #include <queue>
25 #include <set>
26 #include <unordered_map>
27 #include "src/common/utils.h"
28 #include "src/common/log_util.h"
29 #include "nnacl/op_base.h"
30 #include "src/litert/inner_context.h"
31 #include "src/tensor.h"
32 #include "src/executor/sub_graph_kernel.h"
33 #include "include/model.h"
34 
35 namespace mindspore::lite {
36 #ifndef CONTROLFLOW_TENSORLIST_CLIP
37 class ControlFlowScheduler {
38  public:
ControlFlowScheduler(InnerContext * ctx,const mindspore::Context *,std::vector<Tensor * > * src_tensors)39   ControlFlowScheduler(InnerContext *ctx, const mindspore::Context *, std::vector<Tensor *> *src_tensors)
40       : context_(ctx), src_tensors_(src_tensors) {}
41   ~ControlFlowScheduler() = default;
42   int Schedule(std::vector<kernel::KernelExec *> *dst_kernels);
43   void SetSubgraphForPartialNode(std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map,
44                                  std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map);
GetNonTailCalls()45   std::vector<kernel::KernelExec *> GetNonTailCalls() const { return non_tail_calls_; }
46   void RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node);
47 
48  protected:
49   int SplitNonTailCallSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels);
50   // We insert entrance subgraph kernel and exit subgraph kernel define the boundary of the subgraph.
51   int BuildBoundaryForMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels);
52   // When graph output is switch call node, output tensors not fixed, we need output subgraph holds the output tensors.
53   int IsolateOutputForCallOutputGraph(std::vector<kernel::KernelExec *> *dst_kernels);
54   // Partial nodes which have same input, need isolate partial's input. For creating actor form this kind of
55   // graph, actor's input will be graph's input tensors, and actor's output will be partial's input tensors. So in this
56   // case, actor input will be same as output. And we can not link the actors in the right order in this situation.
57   int IsolateSameInputPartials(std::vector<kernel::KernelExec *> *dst_kernels);
58   int RecordLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels);
59   int IsolateInputOfMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels);
60 
61  private:
62   int SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel *subgraph_kernel,
63                                      std::vector<kernel::KernelExec *> *subgraph_kernels);
64   int SplitSubGraphNodesIntoTwoParts(kernel::SubGraphKernel *subgraph_kernel,
65                                      std::vector<kernel::KernelExec *> *first_part_nodes,
66                                      std::vector<kernel::KernelExec *> *second_part_nodes);
67   int AdjustNodesForTailCallSubGraph(std::vector<kernel::KernelExec *> *first_part_nodes,
68                                      std::vector<kernel::KernelExec *> *second_part_nodes);
69   std::set<kernel::KernelExec *> GetNonTailCallSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels);
70   void RemoveUselessKernels(std::vector<kernel::KernelExec *> *dst_kernels,
71                             std::set<kernel::KernelExec *> *useless_kernels);
72   void AppendToProcessQ(std::vector<kernel::KernelExec *> *new_subgraphs,
73                         std::set<kernel::KernelExec *> *all_non_tail_subgraphs);
74   kernel::SubGraphKernel *CreateEntranceSubGraph(kernel::SubGraphKernel *subgraph, lite::Tensor *link_tensor);
75   kernel::SubGraphKernel *CreateExitSubGraph(kernel::SubGraphKernel *subgraph, lite::Tensor *link_tensor);
76   kernel::SubGraphKernel *AddOutputKernel(kernel::SubGraphKernel *subgraph);
77   int GetTailCallFinalSubgraphs(std::queue<kernel::KernelExec *> *tail_call_q,
78                                 std::vector<kernel::KernelExec *> *final_graphs,
79                                 std::set<kernel::KernelExec *> reviewed_graphs);
80   kernel::SubGraphKernel *IsolatePartialInputs(kernel::SubGraphKernel *subgraph, kernel::KernelExec *partial);
81   std::set<kernel::KernelExec *> GetSameInputPartials();
82   void UpdateSubGraphMap(kernel::KernelExec *new_subgraph, kernel::KernelExec *old_subgraph);
83   int GetSubGraphsWhichNeedBoundary();
84   // link partial inputs to partial's corresponding subgraph's inputs.
85   int RecordPartialInputLinkInfo();
86   // link partial's corresponding final subgraph's outputs to tail call's outputs.
87   int RecordAllTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels);
88   int RecordTailCallLinkInfo(kernel::KernelExec *tail_call);
89   // link partial's corresponding final subgraph's outputs to non-tail call's outputs.
90   int RecordAllNonTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels);
91   int RecordNonTailCallLinkInfo(kernel::KernelExec *non_tail_call);
92 
93   InnerContext *context_ = nullptr;
94   int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
95   std::vector<Tensor *> *src_tensors_ = nullptr;
96   std::queue<kernel::KernelExec *> to_process_q_{};
97   std::vector<kernel::KernelExec *> non_tail_calls_{};
98   // key is subgraph index, value is the corresponding partial nodes.
99   std::unordered_map<size_t, std::set<kernel::KernelExec *>> more_than_once_called_partial_nodes_{};
100   // record partial nodes which corresponding subgraph need build boundary, key is subgraph, value is corresponding
101   // partial nodes
102   std::unordered_map<kernel::SubGraphKernel *, std::set<kernel::KernelExec *>> subgraphs_need_boundary_{};
103   std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map_{};
104   std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map_{};
105 };
106 
107 #else
108 
109 class ControlFlowScheduler {
110  public:
111   ControlFlowScheduler(InnerContext *ctx, const mindspore::Context *ms_ctx, std::vector<Tensor *> *src_tensors)
112       : context_(ctx), src_tensors_(src_tensors) {}
113   ~ControlFlowScheduler() = default;
114   int Schedule(std::vector<kernel::KernelExec *> *dst_kernels);
115   void SetSubgraphForPartialNode(std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map,
116                                  std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map);
117   std::vector<kernel::KernelExec *> GetNonTailCalls() const { return {}; }
118   void RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node);
119 
120  private:
121   InnerContext *context_ = nullptr;
122   int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
123   std::vector<Tensor *> *src_tensors_ = nullptr;
124 };
125 #endif
126 
127 using ControlFlowSchedulerPtr = std::shared_ptr<ControlFlowScheduler>;
128 }  // namespace mindspore::lite
129 #endif  // MINDSPORE_LITE_SRC_CONTROL_FLOW_CONTROL_FLOW_SCHEDULER_H_
130