1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_CONTROL_FLOW_CONTROL_FLOW_SCHEDULER_H_ 18 #define MINDSPORE_LITE_SRC_CONTROL_FLOW_CONTROL_FLOW_SCHEDULER_H_ 19 #include <string> 20 #include <vector> 21 #include <memory> 22 #include <utility> 23 #include <algorithm> 24 #include <queue> 25 #include <set> 26 #include <unordered_map> 27 #include "src/common/utils.h" 28 #include "src/common/log_util.h" 29 #include "nnacl/op_base.h" 30 #include "src/litert/inner_context.h" 31 #include "src/tensor.h" 32 #include "src/executor/sub_graph_kernel.h" 33 #include "include/model.h" 34 35 namespace mindspore::lite { 36 #ifndef CONTROLFLOW_TENSORLIST_CLIP 37 class ControlFlowScheduler { 38 public: ControlFlowScheduler(InnerContext * ctx,const mindspore::Context *,std::vector<Tensor * > * src_tensors)39 ControlFlowScheduler(InnerContext *ctx, const mindspore::Context *, std::vector<Tensor *> *src_tensors) 40 : context_(ctx), src_tensors_(src_tensors) {} 41 ~ControlFlowScheduler() = default; 42 int Schedule(std::vector<kernel::KernelExec *> *dst_kernels); 43 void SetSubgraphForPartialNode(std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map, 44 std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map); GetNonTailCalls()45 std::vector<kernel::KernelExec *> GetNonTailCalls() const { return non_tail_calls_; } 46 void RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node); 47 48 protected: 49 int SplitNonTailCallSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels); 50 // We insert entrance subgraph kernel and exit subgraph kernel define the boundary of the subgraph. 51 int BuildBoundaryForMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels); 52 // When graph output is switch call node, output tensors not fixed, we need output subgraph holds the output tensors. 53 int IsolateOutputForCallOutputGraph(std::vector<kernel::KernelExec *> *dst_kernels); 54 // Partial nodes which have same input, need isolate partial's input. For creating actor form this kind of 55 // graph, actor's input will be graph's input tensors, and actor's output will be partial's input tensors. So in this 56 // case, actor input will be same as output. And we can not link the actors in the right order in this situation. 57 int IsolateSameInputPartials(std::vector<kernel::KernelExec *> *dst_kernels); 58 int RecordLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels); 59 int IsolateInputOfMultipleCalledGraph(std::vector<kernel::KernelExec *> *dst_kernels); 60 61 private: 62 int SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel *subgraph_kernel, 63 std::vector<kernel::KernelExec *> *subgraph_kernels); 64 int SplitSubGraphNodesIntoTwoParts(kernel::SubGraphKernel *subgraph_kernel, 65 std::vector<kernel::KernelExec *> *first_part_nodes, 66 std::vector<kernel::KernelExec *> *second_part_nodes); 67 int AdjustNodesForTailCallSubGraph(std::vector<kernel::KernelExec *> *first_part_nodes, 68 std::vector<kernel::KernelExec *> *second_part_nodes); 69 std::set<kernel::KernelExec *> GetNonTailCallSubGraphs(std::vector<kernel::KernelExec *> *dst_kernels); 70 void RemoveUselessKernels(std::vector<kernel::KernelExec *> *dst_kernels, 71 std::set<kernel::KernelExec *> *useless_kernels); 72 void AppendToProcessQ(std::vector<kernel::KernelExec *> *new_subgraphs, 73 std::set<kernel::KernelExec *> *all_non_tail_subgraphs); 74 kernel::SubGraphKernel *CreateEntranceSubGraph(kernel::SubGraphKernel *subgraph, lite::Tensor *link_tensor); 75 kernel::SubGraphKernel *CreateExitSubGraph(kernel::SubGraphKernel *subgraph, lite::Tensor *link_tensor); 76 kernel::SubGraphKernel *AddOutputKernel(kernel::SubGraphKernel *subgraph); 77 int GetTailCallFinalSubgraphs(std::queue<kernel::KernelExec *> *tail_call_q, 78 std::vector<kernel::KernelExec *> *final_graphs, 79 std::set<kernel::KernelExec *> reviewed_graphs); 80 kernel::SubGraphKernel *IsolatePartialInputs(kernel::SubGraphKernel *subgraph, kernel::KernelExec *partial); 81 std::set<kernel::KernelExec *> GetSameInputPartials(); 82 void UpdateSubGraphMap(kernel::KernelExec *new_subgraph, kernel::KernelExec *old_subgraph); 83 int GetSubGraphsWhichNeedBoundary(); 84 // link partial inputs to partial's corresponding subgraph's inputs. 85 int RecordPartialInputLinkInfo(); 86 // link partial's corresponding final subgraph's outputs to tail call's outputs. 87 int RecordAllTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels); 88 int RecordTailCallLinkInfo(kernel::KernelExec *tail_call); 89 // link partial's corresponding final subgraph's outputs to non-tail call's outputs. 90 int RecordAllNonTailCallLinkInfo(std::vector<kernel::KernelExec *> *dst_kernels); 91 int RecordNonTailCallLinkInfo(kernel::KernelExec *non_tail_call); 92 93 InnerContext *context_ = nullptr; 94 int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; 95 std::vector<Tensor *> *src_tensors_ = nullptr; 96 std::queue<kernel::KernelExec *> to_process_q_{}; 97 std::vector<kernel::KernelExec *> non_tail_calls_{}; 98 // key is subgraph index, value is the corresponding partial nodes. 99 std::unordered_map<size_t, std::set<kernel::KernelExec *>> more_than_once_called_partial_nodes_{}; 100 // record partial nodes which corresponding subgraph need build boundary, key is subgraph, value is corresponding 101 // partial nodes 102 std::unordered_map<kernel::SubGraphKernel *, std::set<kernel::KernelExec *>> subgraphs_need_boundary_{}; 103 std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map_{}; 104 std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map_{}; 105 }; 106 107 #else 108 109 class ControlFlowScheduler { 110 public: 111 ControlFlowScheduler(InnerContext *ctx, const mindspore::Context *ms_ctx, std::vector<Tensor *> *src_tensors) 112 : context_(ctx), src_tensors_(src_tensors) {} 113 ~ControlFlowScheduler() = default; 114 int Schedule(std::vector<kernel::KernelExec *> *dst_kernels); 115 void SetSubgraphForPartialNode(std::unordered_map<kernel::KernelExec *, size_t> *partial_kernel_subgraph_index_map, 116 std::unordered_map<size_t, kernel::KernelExec *> *subgraph_index_subgraph_kernel_map); 117 std::vector<kernel::KernelExec *> GetNonTailCalls() const { return {}; } 118 void RecordSubgraphCaller(const size_t &subgraph_index, kernel::KernelExec *partial_node); 119 120 private: 121 InnerContext *context_ = nullptr; 122 int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; 123 std::vector<Tensor *> *src_tensors_ = nullptr; 124 }; 125 #endif 126 127 using ControlFlowSchedulerPtr = std::shared_ptr<ControlFlowScheduler>; 128 } // namespace mindspore::lite 129 #endif // MINDSPORE_LITE_SRC_CONTROL_FLOW_CONTROL_FLOW_SCHEDULER_H_ 130