1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/lite_kernel_util.h"
18 #include <queue>
19 #include <unordered_map>
20 #include <set>
21 #include "src/sub_graph_kernel.h"
22
23 namespace mindspore::kernel {
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26
AllOutTensor(const std::vector<kernel::LiteKernel * > & kernels)27 std::set<lite::Tensor *> LiteKernelUtil::AllOutTensor(const std::vector<kernel::LiteKernel *> &kernels) {
28 std::set<lite::Tensor *> all_out_tensors{};
29 for (const auto &kernel_in_subgraph : kernels) {
30 for (auto *tensor : kernel_in_subgraph->out_tensors()) {
31 all_out_tensors.insert(tensor);
32 }
33 }
34 return all_out_tensors;
35 }
36
SubgraphInputNodes(const std::vector<kernel::LiteKernel * > & kernels)37 std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputNodes(const std::vector<kernel::LiteKernel *> &kernels) {
38 std::vector<kernel::LiteKernel *> input_nodes;
39 std::set<lite::Tensor *> all_out_tensors = AllOutTensor(kernels);
40 for (const auto &kernel : kernels) {
41 MS_ASSERT(kernel != nullptr);
42 bool kernel_is_input = false;
43 auto all_input_tensors = kernel->in_tensors();
44 for (auto input : kernel->in_tensors()) {
45 if (input->IsConst()) {
46 continue;
47 }
48 if (all_out_tensors.find(input) != all_out_tensors.end()) {
49 continue;
50 }
51 kernel_is_input = true;
52 break;
53 }
54 if (kernel_is_input && !lite::IsContain(input_nodes, kernel)) {
55 input_nodes.push_back(kernel);
56 }
57 }
58 return input_nodes;
59 }
60
SubgraphOutputNodes(const std::vector<kernel::LiteKernel * > & kernels)61 std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputNodes(
62 const std::vector<kernel::LiteKernel *> &kernels) {
63 std::set<kernel::LiteKernel *> all_kernels{};
64 for (const auto &kernel : kernels) {
65 all_kernels.insert(kernel);
66 }
67 std::vector<kernel::LiteKernel *> output_nodes;
68 // if kernel has no post-kernel, kernel is a graph output, it must be a subgraph output
69 for (const auto &kernel : kernels) {
70 MS_ASSERT(kernel != nullptr);
71 if (kernel->is_model_output() || (kernel->out_kernels().empty() && !kernel->out_tensors().empty())) {
72 if (!lite::IsContain(output_nodes, kernel)) {
73 output_nodes.push_back(kernel);
74 }
75 continue;
76 }
77 if (std::any_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
78 [&all_kernels](kernel::LiteKernel *tmp) { return all_kernels.find(tmp) == all_kernels.end(); }) &&
79 !lite::IsContain(output_nodes, kernel)) {
80 output_nodes.push_back(kernel);
81 }
82 }
83 return output_nodes;
84 }
85
SubgraphInputTensors(const std::vector<kernel::LiteKernel * > & kernels)86 std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
87 std::vector<lite::Tensor *> input_tensors;
88 std::vector<kernel::LiteKernel *> input_nodes = SubgraphInputNodes(kernels);
89 for (const auto &input_node : input_nodes) {
90 auto &in_node_in_kernels = input_node->in_kernels();
91 auto &in_node_in_tensors = input_node->in_tensors();
92 for (auto &in_node_in_tensor : in_node_in_tensors) {
93 if (in_node_in_tensor->IsGraphInput()) {
94 if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
95 input_tensors.push_back(in_node_in_tensor);
96 }
97 }
98 }
99 for (auto in_node_in_kernel : in_node_in_kernels) {
100 auto iter = std::find(kernels.begin(), kernels.end(), in_node_in_kernel);
101 if (iter != kernels.end()) {
102 continue;
103 }
104 auto &outer_in_kernel_out_tensors = in_node_in_kernel->out_tensors();
105 for (auto in_node_in_tensor : in_node_in_tensors) {
106 auto outer_in_kernel_out_tensors_iter =
107 std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_node_in_tensor);
108 if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
109 if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
110 input_tensors.push_back(in_node_in_tensor);
111 }
112 }
113 }
114 }
115 }
116 return input_tensors;
117 }
118
SubgraphOutputTensors(const std::vector<kernel::LiteKernel * > & kernels)119 std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
120 std::vector<lite::Tensor *> output_tensors;
121 std::vector<kernel::LiteKernel *> output_nodes = SubgraphOutputNodes(kernels);
122 for (const auto &output_kernel : output_nodes) {
123 auto &outer_out_kernels = output_kernel->out_kernels();
124 auto &out_kernel_out_tensors = output_kernel->out_tensors();
125 for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
126 if (out_kernel_out_tensor->IsGraphOutput()) {
127 if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
128 output_tensors.push_back(out_kernel_out_tensor);
129 }
130 }
131 }
132 if (!outer_out_kernels.empty()) {
133 for (auto outer_out_kernel : outer_out_kernels) {
134 auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel);
135 if (iter != kernels.end()) {
136 continue;
137 }
138 auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors();
139 for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
140 auto outer_out_kernel_in_tensors_iter =
141 std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
142 if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
143 if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
144 output_tensors.push_back(out_kernel_out_tensor);
145 }
146 }
147 }
148 }
149 }
150 }
151 return output_tensors;
152 }
153
TopologicalSortKernels(std::vector<kernel::LiteKernel * > * kernels)154 int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels) {
155 auto old_kernels = *kernels;
156 kernels->clear();
157 std::queue<kernel::LiteKernel *> kernel_queue;
158 for (auto kernel : old_kernels) {
159 if (kernel->in_kernels().empty()) {
160 kernel_queue.push(kernel);
161 kernels->emplace_back(kernel);
162 }
163 }
164 while (!kernel_queue.empty()) {
165 auto cur_kernel = kernel_queue.front();
166 kernel_queue.pop();
167 MS_ASSERT(cur_kernel != nullptr);
168 auto next_kernels = cur_kernel->out_kernels();
169 for (auto next_kernel : next_kernels) {
170 auto in_kernels = next_kernel->in_kernels();
171 if (lite::IsContain(*kernels, const_cast<kernel::LiteKernel *>(next_kernel))) {
172 MS_LOG(ERROR) << "TopologicalSortKernels failed, loop exist";
173 return RET_ERROR;
174 }
175 if (std::all_of(in_kernels.begin(), in_kernels.end(), [&](const kernel::LiteKernel *in_kernel) {
176 return lite::IsContain(*kernels, const_cast<kernel::LiteKernel *>(in_kernel));
177 })) {
178 kernel_queue.push(next_kernel);
179 }
180 }
181 }
182 if (kernels->size() != old_kernels.size()) {
183 MS_LOG(ERROR) << "TopologicalSortKernels failed, kernels size before sort: " << old_kernels.size()
184 << ", kernels size after sort: " << kernels->size();
185 return RET_ERROR;
186 }
187 return RET_OK;
188 }
189
InitTensorInitRefCount(const std::vector<kernel::LiteKernel * > & kernels)190 void LiteKernelUtil::InitTensorInitRefCount(const std::vector<kernel::LiteKernel *> &kernels) {
191 for (auto *kernel : kernels) {
192 kernel->InitOutTensorInitRefCount(&kernels);
193 }
194 }
195
SetInput(const LiteKernel & kernelMod,const std::vector<lite::Tensor * > & inputs)196 int LiteKernelUtil::SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs) { return -1; }
197
198 #ifndef CONTROLFLOW_TENSORLIST_CLIP
IsSwitchCall(kernel::LiteKernel * kernel)199 bool LiteKernelUtil::IsSwitchCall(kernel::LiteKernel *kernel) {
200 #ifndef DELEGATE_CLIP
201 if (kernel->desc().arch == kernel::kDelegate) {
202 return false;
203 }
204 #endif
205 auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
206 if (subgraph_kernel == nullptr) {
207 return false;
208 }
209 for (auto &node : subgraph_kernel->nodes()) {
210 if (node->type() == schema::PrimitiveType_Switch &&
211 InputsContainsSpecificNode(node, schema::PrimitiveType_PartialFusion) && node->out_kernels().size() == 1 &&
212 node->out_kernels().front()->type() == schema::PrimitiveType_Call) {
213 return true;
214 }
215 }
216
217 return false;
218 }
219 #endif
220
GetInputsSpecificNode(const kernel::LiteKernel * kernel,const schema::PrimitiveType & primitive_type)221 kernel::LiteKernel *LiteKernelUtil::GetInputsSpecificNode(const kernel::LiteKernel *kernel,
222 const schema::PrimitiveType &primitive_type) {
223 for (auto input : kernel->in_kernels()) {
224 if (input->type() == primitive_type) {
225 return input;
226 }
227 }
228 return nullptr;
229 }
230
InputsContainsSpecificNode(const kernel::LiteKernel * kernel,const schema::PrimitiveType & primitive_type)231 bool LiteKernelUtil::InputsContainsSpecificNode(const kernel::LiteKernel *kernel,
232 const schema::PrimitiveType &primitive_type) {
233 if (GetInputsSpecificNode(kernel, primitive_type)) {
234 return true;
235 }
236 return false;
237 }
238
FindAllInoutKernels(const std::vector<kernel::LiteKernel * > & kernels)239 void LiteKernelUtil::FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels) {
240 std::unordered_map<lite::Tensor *, kernel::LiteKernel *> tensor_pre_kernel;
241 std::unordered_map<lite::Tensor *, std::vector<kernel::LiteKernel *>> tensor_post_kernels;
242 for (auto *kernel : kernels) {
243 for (auto *tensor : kernel->out_tensors()) {
244 tensor_pre_kernel[tensor] = kernel;
245 }
246 for (auto *tensor : kernel->in_tensors()) {
247 (tensor_post_kernels[tensor]).push_back(kernel);
248 }
249 }
250
251 for (auto *kernel : kernels) {
252 kernel->set_in_kernels({});
253 for (auto *tensor : kernel->in_tensors()) {
254 auto iter = tensor_pre_kernel.find(tensor);
255 if (iter != tensor_pre_kernel.end() && kernel != iter->second) {
256 kernel->AddInKernel(iter->second);
257 }
258 }
259 kernel->set_out_kernels({});
260 for (auto *tensor : kernel->out_tensors()) {
261 auto iter = tensor_post_kernels.find(tensor);
262 if (iter != tensor_post_kernels.end()) {
263 for (auto *find_kernel : iter->second) {
264 if (kernel == find_kernel) {
265 continue;
266 }
267 kernel->AddOutKernel(find_kernel);
268 }
269 }
270 }
271 }
272 }
273
274 } // namespace mindspore::kernel
275