1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/optimizer/fusion/gru_fusion_pass.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "src/common/log_adapter.h"
26 #include "include/errorcode.h"
27 #include "nnacl/op_base.h"
28
29 namespace mindspore {
30 namespace lite {
31 namespace {
32 constexpr size_t kSplitOutSize = 3;
33 constexpr uint32_t kAdd0 = 0;
34 constexpr uint32_t kAdd1 = 1;
35 constexpr uint32_t kAdd2 = 2;
36 constexpr uint32_t kAdd3 = 3;
37 constexpr uint32_t kAdd4 = 4;
38 constexpr uint32_t kAdd5 = 5;
39 constexpr uint32_t kSub = 6;
40 constexpr uint32_t kMul0 = 7;
41 constexpr uint32_t kMul1 = 8;
42 constexpr uint32_t kTanh = 9;
43 constexpr uint32_t kSigmoid0 = 10;
44 constexpr uint32_t kSigmoid1 = 11;
45 constexpr uint32_t kSplit0 = 12;
46 constexpr uint32_t kSplit1 = 13;
47 constexpr uint32_t kMatmul0 = 14;
48 constexpr uint32_t kMatmul1 = 15;
49 constexpr uint32_t kInputH = 16;
50 constexpr uint32_t kInputI = 17;
51 constexpr auto kCustomGRU = "CustomGRU";
52
CheckCommon(schema::MetaGraphT * graph,uint32_t node_index,schema::PrimitiveType type,size_t in_nums,size_t out_nums)53 bool CheckCommon(schema::MetaGraphT *graph, uint32_t node_index, schema::PrimitiveType type, size_t in_nums,
54 size_t out_nums) {
55 if (graph->nodes.size() <= node_index) {
56 return false;
57 }
58 const auto &node = graph->nodes[node_index];
59 if (node == nullptr || node->primitive == nullptr) {
60 return false;
61 }
62 const auto &value = node->primitive->value;
63 if (value.type != type) {
64 return false;
65 }
66 if (value.value == nullptr) {
67 return false;
68 }
69 if ((in_nums > 0 && node->inputIndex.size() != in_nums) || node->outputIndex.size() != out_nums) {
70 return false;
71 }
72 return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
73 [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; }) &&
74 std::all_of(node->outputIndex.begin(), node->outputIndex.end(),
75 [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; });
76 }
77
78 template <schema::PrimitiveType T, typename P>
CheckArithmetic(schema::MetaGraphT * graph,uint32_t node_index)79 bool CheckArithmetic(schema::MetaGraphT *graph, uint32_t node_index) {
80 if (!CheckCommon(graph, node_index, T, kInputSize1, 1)) {
81 return false;
82 }
83 const auto &node = graph->nodes[node_index];
84 const auto &value = node->primitive->value;
85 const auto add_attr = static_cast<const P *>(value.value);
86 if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
87 return false;
88 }
89 auto tensor_indexes = node->inputIndex;
90 (void)tensor_indexes.insert(tensor_indexes.end(), node->outputIndex.begin(), node->outputIndex.end());
91 std::vector<int> shape;
92 for (size_t i = 0; i < tensor_indexes.size(); ++i) {
93 if (i == 0) {
94 shape = graph->allTensors[tensor_indexes[i]]->dims;
95 continue;
96 }
97 if (graph->allTensors[tensor_indexes[i]]->dims != shape) {
98 return false;
99 }
100 }
101 return true;
102 }
103
104 template <schema::ActivationType T>
CheckActivation(schema::MetaGraphT * graph,uint32_t node_index)105 bool CheckActivation(schema::MetaGraphT *graph, uint32_t node_index) {
106 if (!CheckCommon(graph, node_index, schema::PrimitiveType_Activation, 1, 1)) {
107 return false;
108 }
109 const auto &value = graph->nodes[node_index]->primitive->value;
110 const auto add_attr = static_cast<const schema::ActivationT *>(value.value);
111 if (add_attr->activation_type != T) {
112 return false;
113 }
114 return true;
115 }
116
CheckBiasAdd(schema::MetaGraphT * graph,uint32_t node_index)117 bool CheckBiasAdd(schema::MetaGraphT *graph, uint32_t node_index) {
118 if (!CheckCommon(graph, node_index, schema::PrimitiveType_AddFusion, kInputSize1, 1) &&
119 !CheckCommon(graph, node_index, schema::PrimitiveType_BiasAdd, kInputSize1, 1)) {
120 return false;
121 }
122 const auto &node = graph->nodes[node_index];
123 const auto &value = node->primitive->value;
124 if (value.type == schema::PrimitiveType_AddFusion) {
125 const auto add_attr = static_cast<const schema::AddFusionT *>(value.value);
126 if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
127 return false;
128 }
129 }
130 auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims;
131 auto in_shape1 = graph->allTensors[node->inputIndex[1]]->dims;
132 if (in_shape1.size() != 1 || in_shape0.empty() || in_shape0.back() != in_shape1.back()) {
133 return false;
134 }
135 return true;
136 }
137
CheckMatmul(schema::MetaGraphT * graph,uint32_t node_index)138 bool CheckMatmul(schema::MetaGraphT *graph, uint32_t node_index) {
139 if (!CheckCommon(graph, node_index, schema::PrimitiveType_MatMulFusion, kInputSize1, 1)) {
140 return false;
141 }
142 const auto &node = graph->nodes[node_index];
143 const auto &value = node->primitive->value;
144 const auto matmul_attr = static_cast<const schema::MatMulFusionT *>(value.value);
145 if (matmul_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
146 return false;
147 }
148 auto out_shape = graph->allTensors[node->outputIndex.front()]->dims;
149 return out_shape.size() == kInputSize1;
150 }
151
CheckSplit(schema::MetaGraphT * graph,uint32_t node_index)152 bool CheckSplit(schema::MetaGraphT *graph, uint32_t node_index) {
153 if (!CheckCommon(graph, node_index, schema::PrimitiveType_Split, 1, kSplitOutSize)) {
154 return false;
155 }
156 const auto &node = graph->nodes[node_index];
157 if (node->inputIndex.size() != 1 || node->outputIndex.size() != kSplitOutSize) {
158 return false;
159 }
160 auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims;
161 auto out_shape0 = graph->allTensors[node->outputIndex[0]]->dims;
162 auto out_shape1 = graph->allTensors[node->outputIndex[1]]->dims;
163 auto out_shape2 = graph->allTensors[node->outputIndex[kInputSize1]]->dims;
164 if (out_shape0 != out_shape1 || out_shape0 != out_shape2) {
165 return false;
166 }
167 if (in_shape0.empty() || out_shape0.empty()) {
168 return false;
169 }
170 if (in_shape0.back() != (out_shape0.back() + out_shape1.back() + out_shape2.back())) {
171 return false;
172 }
173 return true;
174 }
175
CheckStack(schema::MetaGraphT * graph,uint32_t node_index)176 bool CheckStack(schema::MetaGraphT *graph, uint32_t node_index) {
177 if (!CheckCommon(graph, node_index, schema::PrimitiveType_Stack, 0, 1)) {
178 return false;
179 }
180 const auto &node = graph->nodes[node_index];
181 const auto &value = node->primitive->value;
182 const auto stack_attr = static_cast<const schema::StackT *>(value.value);
183 auto out_shape = graph->allTensors[node->outputIndex.front()]->dims;
184 if (out_shape.empty()) {
185 return false;
186 }
187 auto axis = stack_attr->axis;
188 if (axis < 0) {
189 axis += static_cast<int64_t>(out_shape.size());
190 }
191 return axis == 0;
192 }
193
CheckSqueeze(schema::MetaGraphT * graph,uint32_t node_index)194 bool CheckSqueeze(schema::MetaGraphT *graph, uint32_t node_index) {
195 if (!CheckCommon(graph, node_index, schema::PrimitiveType_Squeeze, 0, 1)) {
196 return false;
197 }
198 const auto &node = graph->nodes[node_index];
199 if (node->inputIndex.size() != 1 && node->inputIndex.size() != kInputSize1) {
200 return false;
201 }
202 int axis = 0;
203 if (node->inputIndex.size() == kInputSize1) {
204 const auto &data = graph->allTensors[node->inputIndex[1]]->data;
205 if (data.size() != sizeof(int)) {
206 return false;
207 }
208 axis = *(reinterpret_cast<const int *>(data.data()));
209 } else {
210 const auto &value = node->primitive->value;
211 const auto squeeze_attr = static_cast<const schema::SqueezeT *>(value.value);
212 if (squeeze_attr->axis.size() != 1) {
213 return false;
214 }
215 axis = squeeze_attr->axis.front();
216 }
217 auto in_shape = graph->allTensors[node->inputIndex[0]]->dims;
218 if (in_shape.empty()) {
219 return false;
220 }
221 if (axis < 0) {
222 axis += static_cast<int>(in_shape.size());
223 }
224 return axis == 0;
225 }
226
GetStridedSlicePoints(const schema::TensorT * tensor,int64_t mask)227 std::vector<int> GetStridedSlicePoints(const schema::TensorT *tensor, int64_t mask) {
228 if (tensor->data.empty()) {
229 return {};
230 }
231 auto origin_data = reinterpret_cast<const int *>(tensor->data.data());
232 size_t num = tensor->data.size() / sizeof(int);
233 std::vector<int> data;
234 for (size_t i = 0; i < num; ++i) {
235 bool ineffective = (mask & (1 << i));
236 int cur_point = ineffective ? 0 : origin_data[i];
237 data.push_back(cur_point);
238 }
239 return data;
240 }
241
CheckStridedSlice(schema::MetaGraphT * graph,uint32_t node_index,int batch_position)242 bool CheckStridedSlice(schema::MetaGraphT *graph, uint32_t node_index, int batch_position) {
243 if (!CheckCommon(graph, node_index, schema::PrimitiveType_StridedSlice, C4NUM, 1)) {
244 return false;
245 }
246 const auto &node = graph->nodes[node_index];
247 const auto &step_tensor = graph->allTensors[node->inputIndex.back()];
248 if (!step_tensor->data.empty()) {
249 const auto data = reinterpret_cast<int *>(step_tensor->data.data());
250 auto size = step_tensor->data.size() / sizeof(int);
251 if (std::any_of(data, data + size, [](int val) { return val != 1; })) {
252 return false;
253 }
254 }
255 auto in_shape = graph->allTensors[node->inputIndex.front()]->dims;
256 auto out_shape = graph->allTensors[node->outputIndex.back()]->dims;
257 if (in_shape.size() != out_shape.size() || in_shape.empty()) {
258 return false;
259 }
260 for (size_t i = 1; i < in_shape.size(); ++i) {
261 if (in_shape[i] != out_shape[i]) {
262 return false;
263 }
264 }
265 const auto &value = node->primitive->value;
266 const auto strided_slice_attr = static_cast<const schema::StridedSliceT *>(value.value);
267 if (strided_slice_attr->ellipsis_mask != 0 || strided_slice_attr->new_axis_mask != 0 ||
268 strided_slice_attr->shrink_axis_mask != 0) {
269 return false;
270 }
271 auto begin = GetStridedSlicePoints(graph->allTensors[node->inputIndex[1]].get(), strided_slice_attr->begin_mask);
272 if (begin.empty()) {
273 return false;
274 }
275 return begin.front() == batch_position;
276 }
277
CheckGruCell(schema::MetaGraphT * graph,uint32_t node_index)278 bool CheckGruCell(schema::MetaGraphT *graph, uint32_t node_index) {
279 if (!CheckCommon(graph, node_index, schema::PrimitiveType_Custom, C6NUM, 1)) {
280 return false;
281 }
282 const auto &node = graph->nodes[node_index];
283 const auto &value = node->primitive->value;
284 const auto gru_attr = static_cast<const schema::CustomT *>(value.value);
285 return gru_attr->type == kCustomGRU;
286 }
287
CreateCustom()288 std::unique_ptr<schema::CustomT> CreateCustom() {
289 auto ConvertToAttr = [](const std::string &key, const std::vector<uint8_t> &value) {
290 auto attr = std::make_unique<schema::AttributeT>();
291 attr->name = key;
292 attr->data = value;
293 return attr;
294 };
295 auto attrs = std::make_unique<schema::CustomT>();
296 MS_CHECK_TRUE_MSG(attrs != nullptr, nullptr, "Create CustomT failed.");
297 attrs->type = kCustomGRU;
298 std::vector<uint8_t> transpose_a{false};
299 std::vector<uint8_t> transpose_b{true};
300 std::vector<uint8_t> built_in{true};
301
302 attrs->attr.push_back(ConvertToAttr("transpose_a", transpose_a));
303 attrs->attr.push_back(ConvertToAttr("transpose_b", transpose_b));
304 attrs->attr.push_back(ConvertToAttr("builtin", built_in));
305 return attrs;
306 }
307
308 struct InNodeInfo {
309 int node_index;
310 std::vector<uint32_t> in_indexes;
311 };
312
313 struct OutNodeInfo {
314 int node_index;
315 uint32_t out_index;
316 };
317
318 struct camp {
operator ()mindspore::lite::__anon8248cbd80111::camp319 bool operator()(uint32_t left, uint32_t right) const { return left > right; }
320 };
321 } // namespace
322
323 class LinkInfoManager {
324 public:
LinkInfoManager(schema::MetaGraphT * graph)325 explicit LinkInfoManager(schema::MetaGraphT *graph) : graph_{graph} {
326 auto &all_nodes = graph->nodes;
327 for (int node_index = 0; node_index < static_cast<int>(all_nodes.size()); ++node_index) {
328 auto in_indexes = all_nodes[node_index]->inputIndex;
329 for (uint32_t index = 0; index < static_cast<uint32_t>(in_indexes.size()); ++index) {
330 if (link_info_manager_.find(in_indexes[index]) == link_info_manager_.end()) {
331 link_info_manager_[in_indexes[index]] = std::make_pair(std::vector<InNodeInfo>{}, OutNodeInfo{-1, 0});
332 }
333 auto &in_infos = link_info_manager_[in_indexes[index]].first;
334 auto iter = in_infos.begin();
335 for (; iter != in_infos.end(); ++iter) {
336 if (iter->node_index == node_index) {
337 break;
338 }
339 }
340 if (iter != in_infos.end()) {
341 iter->in_indexes.push_back(index);
342 } else {
343 in_infos.push_back({node_index, {index}});
344 }
345 }
346
347 auto out_indexes = all_nodes[node_index]->outputIndex;
348 for (uint32_t index = 0; index < static_cast<uint32_t>(out_indexes.size()); ++index) {
349 link_info_manager_[out_indexes[index]].second = OutNodeInfo{node_index, index};
350 }
351 }
352 }
353
GetLinkInfos() const354 const auto &GetLinkInfos() const { return link_info_manager_; }
355
Replace(uint32_t node_index,std::unique_ptr<CNodeT> node)356 void Replace(uint32_t node_index, std::unique_ptr<CNodeT> node) { graph_->nodes[node_index].swap(node); }
357
AddDeleteNodes(const std::set<uint32_t> & node_indexes)358 void AddDeleteNodes(const std::set<uint32_t> &node_indexes) {
359 delete_nodes_.insert(node_indexes.begin(), node_indexes.end());
360 }
361
UpdateMetaGraph()362 void UpdateMetaGraph() {
363 auto &main_graph = graph_->subGraph.front();
364 for (auto node_index : delete_nodes_) {
365 graph_->nodes.erase(graph_->nodes.begin() + node_index);
366 }
367 main_graph->nodeIndices.clear();
368 for (uint32_t index = 0; index < static_cast<uint32_t>(graph_->nodes.size()); ++index) {
369 main_graph->nodeIndices.push_back(index);
370 }
371 std::map<uint32_t, uint32_t> tensor_maps;
372 BuildTensorMap(&tensor_maps);
373 auto UpdateTensorIndex = [&tensor_maps](std::vector<uint32_t> *origin) {
374 auto origin_indexes = *origin;
375 origin->clear();
376 (void)std::transform(origin_indexes.begin(), origin_indexes.end(), std::back_inserter(*origin),
377 [&tensor_maps](uint32_t origin_index) { return tensor_maps[origin_index]; });
378 };
379 UpdateTensorIndex(&graph_->inputIndex);
380 for (auto &node : graph_->nodes) {
381 UpdateTensorIndex(&node->inputIndex);
382 UpdateTensorIndex(&node->outputIndex);
383 }
384 UpdateTensorIndex(&graph_->outputIndex);
385 main_graph->inputIndices = graph_->inputIndex;
386 main_graph->outputIndices = graph_->outputIndex;
387 main_graph->tensorIndices.clear();
388 for (uint32_t index = 0; index < static_cast<uint32_t>(tensor_maps.size()); ++index) {
389 main_graph->tensorIndices.push_back(index);
390 }
391 std::vector<std::unique_ptr<TensorT>> tensors;
392 graph_->allTensors.swap(tensors);
393 graph_->allTensors.resize(tensor_maps.size());
394 for (auto &tensor_map : tensor_maps) {
395 graph_->allTensors[tensor_map.second].swap(tensors[tensor_map.first]);
396 }
397 }
398
399 private:
BuildTensorMap(std::map<uint32_t,uint32_t> * tensor_maps)400 void BuildTensorMap(std::map<uint32_t, uint32_t> *tensor_maps) {
401 uint32_t new_index = 0;
402 auto InsertElements = [tensor_maps, &new_index](const std::vector<uint32_t> &indexes) mutable {
403 for (auto index : indexes) {
404 if (tensor_maps->find(index) != tensor_maps->end()) {
405 continue;
406 }
407 (*tensor_maps)[index] = new_index++;
408 }
409 };
410 InsertElements(graph_->inputIndex);
411 for (auto &node : graph_->nodes) {
412 InsertElements(node->inputIndex);
413 InsertElements(node->outputIndex);
414 }
415 InsertElements(graph_->outputIndex);
416 }
417
418 schema::MetaGraphT *graph_{nullptr};
419 std::set<uint32_t, camp> delete_nodes_;
420 // tensor_index, <in_node_infos, out_node_info>
421 std::map<uint32_t, std::pair<std::vector<InNodeInfo>, OutNodeInfo>> link_info_manager_;
422 };
423
424 class GruCellFusion {
425 public:
426 GruCellFusion() = default;
427 ~GruCellFusion() = default;
Run(schema::MetaGraphT * graph)428 STATUS Run(schema::MetaGraphT *graph) {
429 MS_ASSERT(graph != nullptr);
430 MS_ASSERT(graph->subGraph.size() == 1);
431 link_info_manager_ = std::make_shared<LinkInfoManager>(graph);
432 graph_ = graph;
433 DefinePattern();
434 for (uint32_t node_index = 0; node_index < static_cast<uint32_t>(graph->nodes.size()); ++node_index) {
435 if (!MatchPattern(node_index)) {
436 continue;
437 }
438 if (CreateCustomGruCell() != RET_OK) {
439 MS_LOG(ERROR) << "Create Custom-Gru failed.";
440 return RET_ERROR;
441 }
442 }
443 link_info_manager_->UpdateMetaGraph();
444 return RET_OK;
445 }
446
447 private:
448 struct NodeInfo {
449 struct InTensorInfo {
450 bool is_const{false};
451 uint32_t node_index_{0};
452 uint32_t tensor_index_{0};
453 };
454 struct OutTensorInfo {
455 uint32_t node_index_{0};
456 uint32_t tensor_index_{0};
457 };
458 bool (*checker)(schema::MetaGraphT *graph, uint32_t node_index);
459 std::vector<InTensorInfo> in_infos;
460 std::vector<OutTensorInfo> out_infos;
461 };
462
DefinePattern()463 void DefinePattern() {
464 int match_order = 0;
465 pattern_[{match_order++, kAdd0}] = {
466 CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>, {{false, kTanh, 0}, {false, kMul0, 0}}, {}};
467 pattern_[{match_order++, kTanh}] = {
468 CheckActivation<schema::ActivationType_TANH>, {{false, kAdd1, 0}}, {{kSub, 1}, {kAdd0, 0}}};
469 pattern_[{match_order++, kMul0}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>,
470 {{false, kSigmoid0, 0}, {false, kSub, 0}},
471 {{kAdd0, 1}}};
472 pattern_[{match_order++, kAdd1}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
473 {{false, kSplit0, 2}, {false, kMul1, 0}},
474 {{kTanh, 0}}};
475 pattern_[{match_order++, kSub}] = {CheckArithmetic<schema::PrimitiveType_SubFusion, schema::SubFusionT>,
476 {{false, kInputH, 0}, {false, kTanh, 0}},
477 {{kMul0, 1}}};
478 pattern_[{match_order++, kSigmoid0}] = {
479 CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd2, 0}}, {{kMul0, 0}}};
480 pattern_[{match_order++, kSplit0}] = {CheckSplit, {{false, kAdd3, 0}}, {{kAdd4, 0}, {kAdd2, 0}, {kAdd1, 0}}};
481 pattern_[{match_order++, kMul1}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>,
482 {{false, kSigmoid1, 0}, {false, kSplit1, 2}},
483 {{kAdd1, 1}}};
484 pattern_[{match_order++, kAdd2}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
485 {{false, kSplit0, 1}, {false, kSplit1, 1}},
486 {{kSigmoid0, 0}}};
487 pattern_[{match_order++, kSigmoid1}] = {
488 CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd4, 0}}, {{kMul1, 0}}};
489 pattern_[{match_order++, kAdd3}] = {CheckBiasAdd, {{false, kMatmul0, 0}, {true}}, {{kSplit0, 0}}};
490 pattern_[{match_order++, kSplit1}] = {CheckSplit, {{false, kAdd5, 0}}, {{kAdd4, 1}, {kAdd2, 1}, {kMul1, 1}}};
491 pattern_[{match_order++, kAdd4}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
492 {{false, kSplit0, 0}, {false, kSplit1, 0}},
493 {{kSigmoid1, 0}}};
494 pattern_[{match_order++, kAdd5}] = {CheckBiasAdd, {{false, kMatmul1, 0}, {true}}, {{kSplit1, 0}}};
495 pattern_[{match_order++, kMatmul0}] = {CheckMatmul, {{false, kInputI, 0}, {true}}, {{kAdd3, 0}}};
496 pattern_[{match_order++, kMatmul1}] = {CheckMatmul, {{false, kInputH, 0}, {true}}, {{kAdd5, 0}}};
497 }
498
FillRealPattern(uint32_t node_index,std::map<uint32_t,NodeInfo> * real_pattern)499 bool FillRealPattern(uint32_t node_index, std::map<uint32_t, NodeInfo> *real_pattern) {
500 const auto &link_infos = link_info_manager_->GetLinkInfos();
501 if (real_pattern->find(node_index) != real_pattern->end()) {
502 return false;
503 }
504 real_pattern->insert({node_index, {nullptr}});
505 auto in_tensor_indexes = graph_->nodes[node_index]->inputIndex;
506 for (auto tensor_index : in_tensor_indexes) {
507 if (link_infos.find(tensor_index) == link_infos.end()) {
508 return false;
509 }
510 const auto &tensor_out_info = link_infos.at(tensor_index).second;
511 if (tensor_out_info.node_index < 0) {
512 real_pattern->at(node_index).in_infos.push_back({true});
513 } else {
514 real_pattern->at(node_index)
515 .in_infos.push_back({false, static_cast<uint32_t>(tensor_out_info.node_index), tensor_out_info.out_index});
516 }
517 }
518 auto out_tensor_indexes = graph_->nodes[node_index]->outputIndex;
519 for (auto tensor_index : out_tensor_indexes) {
520 if (link_infos.find(tensor_index) == link_infos.end()) {
521 return false;
522 }
523 const auto &in_tensor_out_info = link_infos.at(tensor_index).first;
524 for (const auto &in_node_info : in_tensor_out_info) {
525 for (auto index : in_node_info.in_indexes) {
526 real_pattern->at(node_index).out_infos.push_back({static_cast<uint32_t>(in_node_info.node_index), index});
527 }
528 }
529 }
530 return true;
531 }
532
CheckPattern(const std::map<uint32_t,NodeInfo> & real_pattern,const std::pair<int,uint32_t> & pattern_node_index)533 bool CheckPattern(const std::map<uint32_t, NodeInfo> &real_pattern,
534 const std::pair<int, uint32_t> &pattern_node_index) {
535 const auto &real_in_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).in_infos;
536 const auto &virtual_in_infos = pattern_.at(pattern_node_index).in_infos;
537 if (real_in_infos.size() != virtual_in_infos.size()) {
538 return false;
539 }
540 for (size_t i = 0; i < virtual_in_infos.size(); ++i) {
541 if (virtual_in_infos[i].is_const) {
542 if (!real_in_infos[i].is_const) {
543 return false;
544 }
545 continue;
546 }
547 if (virtual_in_infos[i].tensor_index_ != real_in_infos[i].tensor_index_) {
548 return false;
549 }
550 if (real_node_map_.find(virtual_in_infos[i].node_index_) == real_node_map_.end()) {
551 real_node_map_.insert({virtual_in_infos[i].node_index_, real_in_infos[i].node_index_});
552 } else if (real_node_map_.at(virtual_in_infos[i].node_index_) != real_in_infos[i].node_index_) {
553 return false;
554 }
555 }
556 const auto &real_out_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).out_infos;
557 const auto &virtual_out_infos = pattern_.at(pattern_node_index).out_infos;
558 if (virtual_out_infos.empty()) {
559 return true;
560 }
561 if (real_out_infos.size() != virtual_out_infos.size()) {
562 return false;
563 }
564 for (size_t i = 0; i < virtual_out_infos.size(); ++i) {
565 if (virtual_out_infos[i].tensor_index_ != real_out_infos[i].tensor_index_) {
566 return false;
567 }
568 if (real_node_map_.find(virtual_out_infos[i].node_index_) == real_node_map_.end()) {
569 real_node_map_.insert({virtual_out_infos[i].node_index_, real_out_infos[i].node_index_});
570 } else if (real_node_map_.at(virtual_out_infos[i].node_index_) != real_out_infos[i].node_index_) {
571 return false;
572 }
573 }
574 return true;
575 }
576
CheckClosure(const std::map<uint32_t,uint32_t> & node_map)577 bool CheckClosure(const std::map<uint32_t, uint32_t> &node_map) {
578 std::set<uint32_t> real_nodes;
579 (void)std::for_each(node_map.begin(), node_map.end(),
580 [&real_nodes](std::pair<uint32_t, uint32_t> pair) { real_nodes.insert(pair.second); });
581 if (real_nodes.size() != node_map.size()) {
582 return false;
583 }
584 const auto &link_infos = link_info_manager_->GetLinkInfos();
585 for (uint32_t start = kAdd1; start <= kMatmul1; ++start) {
586 if (node_map.find(start) == node_map.end()) {
587 return false;
588 }
589 const auto &node = graph_->nodes[node_map.at(start)];
590 auto out_tensor_indexes = node->outputIndex;
591 for (auto out_index : out_tensor_indexes) {
592 if (link_infos.find(out_index) == link_infos.end()) {
593 return false;
594 }
595 for (const auto &in_node_info : link_infos.at(out_index).first) {
596 if (real_nodes.find(in_node_info.node_index) == real_nodes.end()) {
597 return false;
598 }
599 }
600 }
601 }
602 return true;
603 }
604
MatchPattern(uint32_t add_index)605 bool MatchPattern(uint32_t add_index) {
606 real_node_map_.clear();
607 real_node_map_[kAdd0] = add_index;
608 std::map<uint32_t, NodeInfo> real_pattern;
609 for (const auto &pair : pattern_) {
610 if (real_node_map_.find(pair.first.second) == real_node_map_.end()) {
611 return false;
612 }
613 auto node_index = real_node_map_[pair.first.second];
614 if (!pair.second.checker(graph_, node_index)) {
615 return false;
616 }
617 if (!FillRealPattern(node_index, &real_pattern)) {
618 return false;
619 }
620 if (!CheckPattern(real_pattern, pair.first)) {
621 return false;
622 }
623 }
624 auto weight_hidden_index = graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1];
625 auto weight_hidden_shape = graph_->allTensors[weight_hidden_index]->dims;
626 if (weight_hidden_shape.size() != C2NUM || weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) {
627 return false;
628 }
629 return CheckClosure(real_node_map_);
630 }
631
CreateCustomGruCell()632 STATUS CreateCustomGruCell() {
633 std::vector<uint32_t> inputs;
634 inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[0]); // x
635 inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[1]); // weight_input
636 inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1]); // weight_hidden
637 inputs.push_back(graph_->nodes[real_node_map_[kAdd3]]->inputIndex[1]); // bias_input
638 inputs.push_back(graph_->nodes[real_node_map_[kAdd5]]->inputIndex[1]); // bias_hidden
639 inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[0]); // init_h
640 auto outputs = graph_->nodes[real_node_map_[kAdd0]]->outputIndex;
641 auto attrs = CreateCustom();
642 MS_CHECK_TRUE_RET(attrs != nullptr, RET_NULL_PTR);
643 auto prim_t = std::make_unique<schema::PrimitiveT>();
644 MS_CHECK_TRUE_MSG(prim_t != nullptr, RET_ERROR, "Create PrimitiveT failed.");
645 prim_t->value.type = schema::PrimitiveType_Custom;
646 prim_t->value.value = attrs.release();
647 auto custom_gru = std::make_unique<schema::CNodeT>();
648 MS_CHECK_TRUE_MSG(custom_gru != nullptr, RET_ERROR, "Create Custom-Gru failed.");
649 custom_gru->name = graph_->nodes[real_node_map_[kAdd0]]->name;
650 custom_gru->inputIndex = inputs;
651 custom_gru->outputIndex = outputs;
652 custom_gru->primitive = std::move(prim_t);
653 link_info_manager_->Replace(real_node_map_[kAdd0], std::move(custom_gru));
654 std::set<uint32_t> delete_nodes;
655 for (uint32_t i = kAdd1; i <= kMatmul1; ++i) {
656 delete_nodes.insert(real_node_map_[i]);
657 }
658 link_info_manager_->AddDeleteNodes(delete_nodes);
659 return RET_OK;
660 }
661
662 std::map<std::pair<int, uint32_t>, NodeInfo> pattern_;
663 std::map<uint32_t, uint32_t> real_node_map_;
664 schema::MetaGraphT *graph_{nullptr};
665 std::shared_ptr<LinkInfoManager> link_info_manager_{nullptr};
666 };
667
Run(schema::MetaGraphT * graph)668 STATUS GruFusionPass::Run(schema::MetaGraphT *graph) {
669 #ifndef ENABLE_ARM64
670 return RET_OK;
671 #endif
672 if (graph == nullptr) {
673 MS_LOG(ERROR) << "graph is a nullptr.";
674 return RET_NULL_PTR;
675 }
676 if (graph->subGraph.size() != 1) {
677 return RET_OK;
678 }
679 if (FuseToGruCell(graph) != RET_OK) {
680 return RET_ERROR;
681 }
682 return FuseGruCell(graph);
683 }
684
FuseToGruCell(schema::MetaGraphT * graph)685 STATUS GruFusionPass::FuseToGruCell(schema::MetaGraphT *graph) {
686 GruCellFusion gru_cell_fusion{};
687 if (gru_cell_fusion.Run(graph) != RET_OK) {
688 MS_LOG(ERROR) << "Fuse GruCell failed.";
689 return RET_ERROR;
690 }
691 return RET_OK;
692 }
693
FuseGruCell(schema::MetaGraphT * graph)694 STATUS GruFusionPass::FuseGruCell(schema::MetaGraphT *graph) {
695 link_info_manager_ = std::make_shared<LinkInfoManager>(graph);
696 for (uint32_t i = 0; i < static_cast<uint32_t>(graph->nodes.size()); ++i) {
697 if (!CheckStack(graph, i)) {
698 continue;
699 }
700 std::vector<uint32_t> strided_slices;
701 std::vector<uint32_t> squeezes;
702 std::vector<uint32_t> gru_cells;
703 if (!MatchPatten(graph, i, &strided_slices, &squeezes, &gru_cells)) {
704 continue;
705 }
706 if (CreateGru(graph, i, strided_slices, squeezes, gru_cells) != RET_OK) {
707 MS_LOG(ERROR) << "Fuse GruCell failed.";
708 return RET_ERROR;
709 }
710 }
711 link_info_manager_->UpdateMetaGraph();
712 link_info_manager_ = nullptr;
713 return RET_OK;
714 }
715
MatchPatten(schema::MetaGraphT * graph,uint32_t stack_index,std::vector<uint32_t> * strided_slices,std::vector<uint32_t> * squeezes,std::vector<uint32_t> * gru_cells)716 bool GruFusionPass::MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector<uint32_t> *strided_slices,
717 std::vector<uint32_t> *squeezes, std::vector<uint32_t> *gru_cells) {
718 auto &link_infos = link_info_manager_->GetLinkInfos();
719 auto &stack_node = graph->nodes[stack_index];
720 int batch_point = 0;
721 auto CommonCheck = [&link_infos](uint32_t tensor_index) {
722 if (link_infos.find(tensor_index) == link_infos.end()) {
723 return std::make_pair(false, 0);
724 }
725 const auto &in_node_info = link_infos.at(tensor_index).first;
726 if (in_node_info.size() != 1 && in_node_info.front().in_indexes.size() != 1) {
727 return std::make_pair(false, 0);
728 }
729 auto node_index = link_infos.at(tensor_index).second.node_index;
730 if (node_index < 0) {
731 return std::make_pair(false, 0);
732 }
733 return std::make_pair(true, node_index);
734 };
735 for (auto tensor_index : stack_node->inputIndex) {
736 auto check_info = CommonCheck(tensor_index);
737 if (!check_info.first || !CheckGruCell(graph, check_info.second)) {
738 return false;
739 }
740 gru_cells->push_back(check_info.second);
741 auto &gru_cell_node = graph->nodes[check_info.second];
742 check_info = CommonCheck(gru_cell_node->inputIndex.front());
743 if (!check_info.first || !CheckSqueeze(graph, check_info.second)) {
744 return false;
745 }
746 squeezes->push_back(check_info.second);
747 auto &squeeze_node = graph->nodes[check_info.second];
748 check_info = CommonCheck(squeeze_node->inputIndex.front());
749 if (!check_info.first || !CheckStridedSlice(graph, check_info.second, batch_point)) {
750 return false;
751 }
752 strided_slices->push_back(check_info.second);
753 ++batch_point;
754 }
755 if (strided_slices->empty()) {
756 return false;
757 }
758 uint32_t input_index = graph->nodes[strided_slices->front()]->inputIndex.front();
759 if (std::any_of(strided_slices->begin(), strided_slices->end(), [input_index, graph](uint32_t strided_slice) {
760 return graph->nodes[strided_slice]->inputIndex.front() != input_index;
761 })) {
762 return false;
763 }
764 auto in_shape = graph->allTensors[input_index]->dims;
765 if (in_shape.empty() || in_shape.front() != batch_point) {
766 return false;
767 }
768 return CheckGruCellConnection(graph, *gru_cells);
769 }
770
CheckGruCellConnection(schema::MetaGraphT * graph,const std::vector<uint32_t> & gru_cells)771 bool GruFusionPass::CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector<uint32_t> &gru_cells) {
772 auto &first_node = graph->nodes[gru_cells.front()];
773 if (first_node->inputIndex.size() != C6NUM) {
774 return false;
775 }
776 auto init_h = first_node->outputIndex.front();
777 for (size_t i = 1; i < gru_cells.size(); ++i) {
778 auto &node = graph->nodes[gru_cells[i]];
779 if (node->inputIndex.size() != first_node->inputIndex.size()) {
780 return false;
781 }
782 for (size_t j = 1; j < C5NUM; ++j) {
783 if (node->inputIndex[j] != first_node->inputIndex[j]) {
784 return false;
785 }
786 }
787 if (node->inputIndex[C5NUM] != init_h) {
788 return false;
789 }
790 init_h = node->outputIndex.front();
791 }
792 return true;
793 }
794
CreateGru(schema::MetaGraphT * graph,uint32_t stack_index,const std::vector<uint32_t> & strided_slices,const std::vector<uint32_t> & squeezes,const std::vector<uint32_t> & gru_cells)795 STATUS GruFusionPass::CreateGru(schema::MetaGraphT *graph, uint32_t stack_index,
796 const std::vector<uint32_t> &strided_slices, const std::vector<uint32_t> &squeezes,
797 const std::vector<uint32_t> &gru_cells) {
798 auto &gru_cell_node = graph->nodes[gru_cells.front()];
799 gru_cell_node->inputIndex[0] = graph->nodes[strided_slices.front()]->inputIndex[0];
800 gru_cell_node->outputIndex[0] = graph->nodes[stack_index]->outputIndex[0];
801 std::set<uint32_t> delete_node{stack_index};
802 (void)delete_node.insert(strided_slices.begin(), strided_slices.end());
803 (void)delete_node.insert(squeezes.begin(), squeezes.end());
804 (void)delete_node.insert(gru_cells.begin() + 1, gru_cells.end());
805 link_info_manager_->AddDeleteNodes(delete_node);
806 return RET_OK;
807 }
808 } // namespace lite
809 } // namespace mindspore
810