• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/lite_kernel_util.h"
18 #include <queue>
19 #include <unordered_map>
20 #include <set>
21 #include "src/sub_graph_kernel.h"
22 
23 namespace mindspore::kernel {
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 
AllOutTensor(const std::vector<kernel::LiteKernel * > & kernels)27 std::set<lite::Tensor *> LiteKernelUtil::AllOutTensor(const std::vector<kernel::LiteKernel *> &kernels) {
28   std::set<lite::Tensor *> all_out_tensors{};
29   for (const auto &kernel_in_subgraph : kernels) {
30     for (auto *tensor : kernel_in_subgraph->out_tensors()) {
31       all_out_tensors.insert(tensor);
32     }
33   }
34   return all_out_tensors;
35 }
36 
SubgraphInputNodes(const std::vector<kernel::LiteKernel * > & kernels)37 std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputNodes(const std::vector<kernel::LiteKernel *> &kernels) {
38   std::vector<kernel::LiteKernel *> input_nodes;
39   std::set<lite::Tensor *> all_out_tensors = AllOutTensor(kernels);
40   for (const auto &kernel : kernels) {
41     MS_ASSERT(kernel != nullptr);
42     bool kernel_is_input = false;
43     auto all_input_tensors = kernel->in_tensors();
44     for (auto input : kernel->in_tensors()) {
45       if (input->IsConst()) {
46         continue;
47       }
48       if (all_out_tensors.find(input) != all_out_tensors.end()) {
49         continue;
50       }
51       kernel_is_input = true;
52       break;
53     }
54     if (kernel_is_input && !lite::IsContain(input_nodes, kernel)) {
55       input_nodes.push_back(kernel);
56     }
57   }
58   return input_nodes;
59 }
60 
SubgraphOutputNodes(const std::vector<kernel::LiteKernel * > & kernels)61 std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputNodes(
62   const std::vector<kernel::LiteKernel *> &kernels) {
63   std::set<kernel::LiteKernel *> all_kernels{};
64   for (const auto &kernel : kernels) {
65     all_kernels.insert(kernel);
66   }
67   std::vector<kernel::LiteKernel *> output_nodes;
68   // if kernel has no post-kernel, kernel is a graph output, it must be a subgraph output
69   for (const auto &kernel : kernels) {
70     MS_ASSERT(kernel != nullptr);
71     if (kernel->is_model_output() || (kernel->out_kernels().empty() && !kernel->out_tensors().empty())) {
72       if (!lite::IsContain(output_nodes, kernel)) {
73         output_nodes.push_back(kernel);
74       }
75       continue;
76     }
77     if (std::any_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
78                     [&all_kernels](kernel::LiteKernel *tmp) { return all_kernels.find(tmp) == all_kernels.end(); }) &&
79         !lite::IsContain(output_nodes, kernel)) {
80       output_nodes.push_back(kernel);
81     }
82   }
83   return output_nodes;
84 }
85 
SubgraphInputTensors(const std::vector<kernel::LiteKernel * > & kernels)86 std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
87   std::vector<lite::Tensor *> input_tensors;
88   std::vector<kernel::LiteKernel *> input_nodes = SubgraphInputNodes(kernels);
89   for (const auto &input_node : input_nodes) {
90     auto &in_node_in_kernels = input_node->in_kernels();
91     auto &in_node_in_tensors = input_node->in_tensors();
92     for (auto &in_node_in_tensor : in_node_in_tensors) {
93       if (in_node_in_tensor->IsGraphInput()) {
94         if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
95           input_tensors.push_back(in_node_in_tensor);
96         }
97       }
98     }
99     for (auto in_node_in_kernel : in_node_in_kernels) {
100       auto iter = std::find(kernels.begin(), kernels.end(), in_node_in_kernel);
101       if (iter != kernels.end()) {
102         continue;
103       }
104       auto &outer_in_kernel_out_tensors = in_node_in_kernel->out_tensors();
105       for (auto in_node_in_tensor : in_node_in_tensors) {
106         auto outer_in_kernel_out_tensors_iter =
107           std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_node_in_tensor);
108         if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
109           if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
110             input_tensors.push_back(in_node_in_tensor);
111           }
112         }
113       }
114     }
115   }
116   return input_tensors;
117 }
118 
SubgraphOutputTensors(const std::vector<kernel::LiteKernel * > & kernels)119 std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
120   std::vector<lite::Tensor *> output_tensors;
121   std::vector<kernel::LiteKernel *> output_nodes = SubgraphOutputNodes(kernels);
122   for (const auto &output_kernel : output_nodes) {
123     auto &outer_out_kernels = output_kernel->out_kernels();
124     auto &out_kernel_out_tensors = output_kernel->out_tensors();
125     for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
126       if (out_kernel_out_tensor->IsGraphOutput()) {
127         if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
128           output_tensors.push_back(out_kernel_out_tensor);
129         }
130       }
131     }
132     if (!outer_out_kernels.empty()) {
133       for (auto outer_out_kernel : outer_out_kernels) {
134         auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel);
135         if (iter != kernels.end()) {
136           continue;
137         }
138         auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors();
139         for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
140           auto outer_out_kernel_in_tensors_iter =
141             std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
142           if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
143             if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
144               output_tensors.push_back(out_kernel_out_tensor);
145             }
146           }
147         }
148       }
149     }
150   }
151   return output_tensors;
152 }
153 
TopologicalSortKernels(std::vector<kernel::LiteKernel * > * kernels)154 int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels) {
155   auto old_kernels = *kernels;
156   kernels->clear();
157   std::queue<kernel::LiteKernel *> kernel_queue;
158   for (auto kernel : old_kernels) {
159     if (kernel->in_kernels().empty()) {
160       kernel_queue.push(kernel);
161       kernels->emplace_back(kernel);
162     }
163   }
164   while (!kernel_queue.empty()) {
165     auto cur_kernel = kernel_queue.front();
166     kernel_queue.pop();
167     MS_ASSERT(cur_kernel != nullptr);
168     auto next_kernels = cur_kernel->out_kernels();
169     for (auto next_kernel : next_kernels) {
170       auto in_kernels = next_kernel->in_kernels();
171       if (lite::IsContain(*kernels, const_cast<kernel::LiteKernel *>(next_kernel))) {
172         MS_LOG(ERROR) << "TopologicalSortKernels failed, loop exist";
173         return RET_ERROR;
174       }
175       if (std::all_of(in_kernels.begin(), in_kernels.end(), [&](const kernel::LiteKernel *in_kernel) {
176             return lite::IsContain(*kernels, const_cast<kernel::LiteKernel *>(in_kernel));
177           })) {
178         kernel_queue.push(next_kernel);
179       }
180     }
181   }
182   if (kernels->size() != old_kernels.size()) {
183     MS_LOG(ERROR) << "TopologicalSortKernels failed, kernels size before sort: " << old_kernels.size()
184                   << ", kernels size after sort: " << kernels->size();
185     return RET_ERROR;
186   }
187   return RET_OK;
188 }
189 
InitTensorInitRefCount(const std::vector<kernel::LiteKernel * > & kernels)190 void LiteKernelUtil::InitTensorInitRefCount(const std::vector<kernel::LiteKernel *> &kernels) {
191   for (auto *kernel : kernels) {
192     kernel->InitOutTensorInitRefCount(&kernels);
193   }
194 }
195 
SetInput(const LiteKernel & kernelMod,const std::vector<lite::Tensor * > & inputs)196 int LiteKernelUtil::SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs) { return -1; }
197 
198 #ifndef CONTROLFLOW_TENSORLIST_CLIP
IsSwitchCall(kernel::LiteKernel * kernel)199 bool LiteKernelUtil::IsSwitchCall(kernel::LiteKernel *kernel) {
200 #ifndef DELEGATE_CLIP
201   if (kernel->desc().arch == kernel::kDelegate) {
202     return false;
203   }
204 #endif
205   auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
206   if (subgraph_kernel == nullptr) {
207     return false;
208   }
209   for (auto &node : subgraph_kernel->nodes()) {
210     if (node->type() == schema::PrimitiveType_Switch &&
211         InputsContainsSpecificNode(node, schema::PrimitiveType_PartialFusion) && node->out_kernels().size() == 1 &&
212         node->out_kernels().front()->type() == schema::PrimitiveType_Call) {
213       return true;
214     }
215   }
216 
217   return false;
218 }
219 #endif
220 
GetInputsSpecificNode(const kernel::LiteKernel * kernel,const schema::PrimitiveType & primitive_type)221 kernel::LiteKernel *LiteKernelUtil::GetInputsSpecificNode(const kernel::LiteKernel *kernel,
222                                                           const schema::PrimitiveType &primitive_type) {
223   for (auto input : kernel->in_kernels()) {
224     if (input->type() == primitive_type) {
225       return input;
226     }
227   }
228   return nullptr;
229 }
230 
InputsContainsSpecificNode(const kernel::LiteKernel * kernel,const schema::PrimitiveType & primitive_type)231 bool LiteKernelUtil::InputsContainsSpecificNode(const kernel::LiteKernel *kernel,
232                                                 const schema::PrimitiveType &primitive_type) {
233   if (GetInputsSpecificNode(kernel, primitive_type)) {
234     return true;
235   }
236   return false;
237 }
238 
FindAllInoutKernels(const std::vector<kernel::LiteKernel * > & kernels)239 void LiteKernelUtil::FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels) {
240   std::unordered_map<lite::Tensor *, kernel::LiteKernel *> tensor_pre_kernel;
241   std::unordered_map<lite::Tensor *, std::vector<kernel::LiteKernel *>> tensor_post_kernels;
242   for (auto *kernel : kernels) {
243     for (auto *tensor : kernel->out_tensors()) {
244       tensor_pre_kernel[tensor] = kernel;
245     }
246     for (auto *tensor : kernel->in_tensors()) {
247       (tensor_post_kernels[tensor]).push_back(kernel);
248     }
249   }
250 
251   for (auto *kernel : kernels) {
252     kernel->set_in_kernels({});
253     for (auto *tensor : kernel->in_tensors()) {
254       auto iter = tensor_pre_kernel.find(tensor);
255       if (iter != tensor_pre_kernel.end() && kernel != iter->second) {
256         kernel->AddInKernel(iter->second);
257       }
258     }
259     kernel->set_out_kernels({});
260     for (auto *tensor : kernel->out_tensors()) {
261       auto iter = tensor_post_kernels.find(tensor);
262       if (iter != tensor_post_kernels.end()) {
263         for (auto *find_kernel : iter->second) {
264           if (kernel == find_kernel) {
265             continue;
266           }
267           kernel->AddOutKernel(find_kernel);
268         }
269       }
270     }
271   }
272 }
273 
274 }  // namespace mindspore::kernel
275