• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/optimizer/fusion/gru_fusion_pass.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "src/common/log_adapter.h"
26 #include "include/errorcode.h"
27 #include "nnacl/op_base.h"
28 
29 namespace mindspore {
30 namespace lite {
31 namespace {
32 constexpr size_t kSplitOutSize = 3;
33 constexpr uint32_t kAdd0 = 0;
34 constexpr uint32_t kAdd1 = 1;
35 constexpr uint32_t kAdd2 = 2;
36 constexpr uint32_t kAdd3 = 3;
37 constexpr uint32_t kAdd4 = 4;
38 constexpr uint32_t kAdd5 = 5;
39 constexpr uint32_t kSub = 6;
40 constexpr uint32_t kMul0 = 7;
41 constexpr uint32_t kMul1 = 8;
42 constexpr uint32_t kTanh = 9;
43 constexpr uint32_t kSigmoid0 = 10;
44 constexpr uint32_t kSigmoid1 = 11;
45 constexpr uint32_t kSplit0 = 12;
46 constexpr uint32_t kSplit1 = 13;
47 constexpr uint32_t kMatmul0 = 14;
48 constexpr uint32_t kMatmul1 = 15;
49 constexpr uint32_t kInputH = 16;
50 constexpr uint32_t kInputI = 17;
51 constexpr auto kCustomGRU = "CustomGRU";
52 
CheckCommon(schema::MetaGraphT * graph,uint32_t node_index,schema::PrimitiveType type,size_t in_nums,size_t out_nums)53 bool CheckCommon(schema::MetaGraphT *graph, uint32_t node_index, schema::PrimitiveType type, size_t in_nums,
54                  size_t out_nums) {
55   if (graph->nodes.size() <= node_index) {
56     return false;
57   }
58   const auto &node = graph->nodes[node_index];
59   if (node == nullptr || node->primitive == nullptr) {
60     return false;
61   }
62   const auto &value = node->primitive->value;
63   if (value.type != type) {
64     return false;
65   }
66   if (value.value == nullptr) {
67     return false;
68   }
69   if ((in_nums > 0 && node->inputIndex.size() != in_nums) || node->outputIndex.size() != out_nums) {
70     return false;
71   }
72   return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
73                      [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; }) &&
74          std::all_of(node->outputIndex.begin(), node->outputIndex.end(),
75                      [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; });
76 }
77 
78 template <schema::PrimitiveType T, typename P>
CheckArithmetic(schema::MetaGraphT * graph,uint32_t node_index)79 bool CheckArithmetic(schema::MetaGraphT *graph, uint32_t node_index) {
80   if (!CheckCommon(graph, node_index, T, kInputSize1, 1)) {
81     return false;
82   }
83   const auto &node = graph->nodes[node_index];
84   const auto &value = node->primitive->value;
85   const auto add_attr = static_cast<const P *>(value.value);
86   if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
87     return false;
88   }
89   auto tensor_indexes = node->inputIndex;
90   (void)tensor_indexes.insert(tensor_indexes.end(), node->outputIndex.begin(), node->outputIndex.end());
91   std::vector<int> shape;
92   for (size_t i = 0; i < tensor_indexes.size(); ++i) {
93     if (i == 0) {
94       shape = graph->allTensors[tensor_indexes[i]]->dims;
95       continue;
96     }
97     if (graph->allTensors[tensor_indexes[i]]->dims != shape) {
98       return false;
99     }
100   }
101   return true;
102 }
103 
104 template <schema::ActivationType T>
CheckActivation(schema::MetaGraphT * graph,uint32_t node_index)105 bool CheckActivation(schema::MetaGraphT *graph, uint32_t node_index) {
106   if (!CheckCommon(graph, node_index, schema::PrimitiveType_Activation, 1, 1)) {
107     return false;
108   }
109   const auto &value = graph->nodes[node_index]->primitive->value;
110   const auto add_attr = static_cast<const schema::ActivationT *>(value.value);
111   if (add_attr->activation_type != T) {
112     return false;
113   }
114   return true;
115 }
116 
CheckBiasAdd(schema::MetaGraphT * graph,uint32_t node_index)117 bool CheckBiasAdd(schema::MetaGraphT *graph, uint32_t node_index) {
118   if (!CheckCommon(graph, node_index, schema::PrimitiveType_AddFusion, kInputSize1, 1) &&
119       !CheckCommon(graph, node_index, schema::PrimitiveType_BiasAdd, kInputSize1, 1)) {
120     return false;
121   }
122   const auto &node = graph->nodes[node_index];
123   const auto &value = node->primitive->value;
124   if (value.type == schema::PrimitiveType_AddFusion) {
125     const auto add_attr = static_cast<const schema::AddFusionT *>(value.value);
126     if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
127       return false;
128     }
129   }
130   auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims;
131   auto in_shape1 = graph->allTensors[node->inputIndex[1]]->dims;
132   if (in_shape1.size() != 1 || in_shape0.empty() || in_shape0.back() != in_shape1.back()) {
133     return false;
134   }
135   return true;
136 }
137 
CheckMatmul(schema::MetaGraphT * graph,uint32_t node_index)138 bool CheckMatmul(schema::MetaGraphT *graph, uint32_t node_index) {
139   if (!CheckCommon(graph, node_index, schema::PrimitiveType_MatMulFusion, kInputSize1, 1)) {
140     return false;
141   }
142   const auto &node = graph->nodes[node_index];
143   const auto &value = node->primitive->value;
144   const auto matmul_attr = static_cast<const schema::MatMulFusionT *>(value.value);
145   if (matmul_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
146     return false;
147   }
148   auto out_shape = graph->allTensors[node->outputIndex.front()]->dims;
149   return out_shape.size() == kInputSize1;
150 }
151 
CheckSplit(schema::MetaGraphT * graph,uint32_t node_index)152 bool CheckSplit(schema::MetaGraphT *graph, uint32_t node_index) {
153   if (!CheckCommon(graph, node_index, schema::PrimitiveType_Split, 1, kSplitOutSize)) {
154     return false;
155   }
156   const auto &node = graph->nodes[node_index];
157   if (node->inputIndex.size() != 1 || node->outputIndex.size() != kSplitOutSize) {
158     return false;
159   }
160   auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims;
161   auto out_shape0 = graph->allTensors[node->outputIndex[0]]->dims;
162   auto out_shape1 = graph->allTensors[node->outputIndex[1]]->dims;
163   auto out_shape2 = graph->allTensors[node->outputIndex[kInputSize1]]->dims;
164   if (out_shape0 != out_shape1 || out_shape0 != out_shape2) {
165     return false;
166   }
167   if (in_shape0.empty() || out_shape0.empty()) {
168     return false;
169   }
170   if (in_shape0.back() != (out_shape0.back() + out_shape1.back() + out_shape2.back())) {
171     return false;
172   }
173   return true;
174 }
175 
CheckStack(schema::MetaGraphT * graph,uint32_t node_index)176 bool CheckStack(schema::MetaGraphT *graph, uint32_t node_index) {
177   if (!CheckCommon(graph, node_index, schema::PrimitiveType_Stack, 0, 1)) {
178     return false;
179   }
180   const auto &node = graph->nodes[node_index];
181   const auto &value = node->primitive->value;
182   const auto stack_attr = static_cast<const schema::StackT *>(value.value);
183   auto out_shape = graph->allTensors[node->outputIndex.front()]->dims;
184   if (out_shape.empty()) {
185     return false;
186   }
187   auto axis = stack_attr->axis;
188   if (axis < 0) {
189     axis += static_cast<int64_t>(out_shape.size());
190   }
191   return axis == 0;
192 }
193 
CheckSqueeze(schema::MetaGraphT * graph,uint32_t node_index)194 bool CheckSqueeze(schema::MetaGraphT *graph, uint32_t node_index) {
195   if (!CheckCommon(graph, node_index, schema::PrimitiveType_Squeeze, 0, 1)) {
196     return false;
197   }
198   const auto &node = graph->nodes[node_index];
199   if (node->inputIndex.size() != 1 && node->inputIndex.size() != kInputSize1) {
200     return false;
201   }
202   int axis = 0;
203   if (node->inputIndex.size() == kInputSize1) {
204     const auto &data = graph->allTensors[node->inputIndex[1]]->data;
205     if (data.size() != sizeof(int)) {
206       return false;
207     }
208     axis = *(reinterpret_cast<const int *>(data.data()));
209   } else {
210     const auto &value = node->primitive->value;
211     const auto squeeze_attr = static_cast<const schema::SqueezeT *>(value.value);
212     if (squeeze_attr->axis.size() != 1) {
213       return false;
214     }
215     axis = squeeze_attr->axis.front();
216   }
217   auto in_shape = graph->allTensors[node->inputIndex[0]]->dims;
218   if (in_shape.empty()) {
219     return false;
220   }
221   if (axis < 0) {
222     axis += static_cast<int>(in_shape.size());
223   }
224   return axis == 0;
225 }
226 
GetStridedSlicePoints(const schema::TensorT * tensor,int64_t mask)227 std::vector<int> GetStridedSlicePoints(const schema::TensorT *tensor, int64_t mask) {
228   if (tensor->data.empty()) {
229     return {};
230   }
231   auto origin_data = reinterpret_cast<const int *>(tensor->data.data());
232   size_t num = tensor->data.size() / sizeof(int);
233   std::vector<int> data;
234   for (size_t i = 0; i < num; ++i) {
235     bool ineffective = (mask & (1 << i));
236     int cur_point = ineffective ? 0 : origin_data[i];
237     data.push_back(cur_point);
238   }
239   return data;
240 }
241 
CheckStridedSlice(schema::MetaGraphT * graph,uint32_t node_index,int batch_position)242 bool CheckStridedSlice(schema::MetaGraphT *graph, uint32_t node_index, int batch_position) {
243   if (!CheckCommon(graph, node_index, schema::PrimitiveType_StridedSlice, C4NUM, 1)) {
244     return false;
245   }
246   const auto &node = graph->nodes[node_index];
247   const auto &step_tensor = graph->allTensors[node->inputIndex.back()];
248   if (!step_tensor->data.empty()) {
249     const auto data = reinterpret_cast<int *>(step_tensor->data.data());
250     auto size = step_tensor->data.size() / sizeof(int);
251     if (std::any_of(data, data + size, [](int val) { return val != 1; })) {
252       return false;
253     }
254   }
255   auto in_shape = graph->allTensors[node->inputIndex.front()]->dims;
256   auto out_shape = graph->allTensors[node->outputIndex.back()]->dims;
257   if (in_shape.size() != out_shape.size() || in_shape.empty()) {
258     return false;
259   }
260   for (size_t i = 1; i < in_shape.size(); ++i) {
261     if (in_shape[i] != out_shape[i]) {
262       return false;
263     }
264   }
265   const auto &value = node->primitive->value;
266   const auto strided_slice_attr = static_cast<const schema::StridedSliceT *>(value.value);
267   if (strided_slice_attr->ellipsis_mask != 0 || strided_slice_attr->new_axis_mask != 0 ||
268       strided_slice_attr->shrink_axis_mask != 0) {
269     return false;
270   }
271   auto begin = GetStridedSlicePoints(graph->allTensors[node->inputIndex[1]].get(), strided_slice_attr->begin_mask);
272   if (begin.empty()) {
273     return false;
274   }
275   return begin.front() == batch_position;
276 }
277 
CheckGruCell(schema::MetaGraphT * graph,uint32_t node_index)278 bool CheckGruCell(schema::MetaGraphT *graph, uint32_t node_index) {
279   if (!CheckCommon(graph, node_index, schema::PrimitiveType_Custom, C6NUM, 1)) {
280     return false;
281   }
282   const auto &node = graph->nodes[node_index];
283   const auto &value = node->primitive->value;
284   const auto gru_attr = static_cast<const schema::CustomT *>(value.value);
285   return gru_attr->type == kCustomGRU;
286 }
287 
CreateCustom()288 std::unique_ptr<schema::CustomT> CreateCustom() {
289   auto ConvertToAttr = [](const std::string &key, const std::vector<uint8_t> &value) {
290     auto attr = std::make_unique<schema::AttributeT>();
291     attr->name = key;
292     attr->data = value;
293     return attr;
294   };
295   auto attrs = std::make_unique<schema::CustomT>();
296   MS_CHECK_TRUE_MSG(attrs != nullptr, nullptr, "Create CustomT failed.");
297   attrs->type = kCustomGRU;
298   std::vector<uint8_t> transpose_a{false};
299   std::vector<uint8_t> transpose_b{true};
300   std::vector<uint8_t> built_in{true};
301 
302   attrs->attr.push_back(ConvertToAttr("transpose_a", transpose_a));
303   attrs->attr.push_back(ConvertToAttr("transpose_b", transpose_b));
304   attrs->attr.push_back(ConvertToAttr("builtin", built_in));
305   return attrs;
306 }
307 
308 struct InNodeInfo {
309   int node_index;
310   std::vector<uint32_t> in_indexes;
311 };
312 
313 struct OutNodeInfo {
314   int node_index;
315   uint32_t out_index;
316 };
317 
318 struct camp {
operator ()mindspore::lite::__anon8248cbd80111::camp319   bool operator()(uint32_t left, uint32_t right) const { return left > right; }
320 };
321 }  // namespace
322 
323 class LinkInfoManager {
324  public:
LinkInfoManager(schema::MetaGraphT * graph)325   explicit LinkInfoManager(schema::MetaGraphT *graph) : graph_{graph} {
326     auto &all_nodes = graph->nodes;
327     for (int node_index = 0; node_index < static_cast<int>(all_nodes.size()); ++node_index) {
328       auto in_indexes = all_nodes[node_index]->inputIndex;
329       for (uint32_t index = 0; index < static_cast<uint32_t>(in_indexes.size()); ++index) {
330         if (link_info_manager_.find(in_indexes[index]) == link_info_manager_.end()) {
331           link_info_manager_[in_indexes[index]] = std::make_pair(std::vector<InNodeInfo>{}, OutNodeInfo{-1, 0});
332         }
333         auto &in_infos = link_info_manager_[in_indexes[index]].first;
334         auto iter = in_infos.begin();
335         for (; iter != in_infos.end(); ++iter) {
336           if (iter->node_index == node_index) {
337             break;
338           }
339         }
340         if (iter != in_infos.end()) {
341           iter->in_indexes.push_back(index);
342         } else {
343           in_infos.push_back({node_index, {index}});
344         }
345       }
346 
347       auto out_indexes = all_nodes[node_index]->outputIndex;
348       for (uint32_t index = 0; index < static_cast<uint32_t>(out_indexes.size()); ++index) {
349         link_info_manager_[out_indexes[index]].second = OutNodeInfo{node_index, index};
350       }
351     }
352   }
353 
GetLinkInfos() const354   const auto &GetLinkInfos() const { return link_info_manager_; }
355 
Replace(uint32_t node_index,std::unique_ptr<CNodeT> node)356   void Replace(uint32_t node_index, std::unique_ptr<CNodeT> node) { graph_->nodes[node_index].swap(node); }
357 
AddDeleteNodes(const std::set<uint32_t> & node_indexes)358   void AddDeleteNodes(const std::set<uint32_t> &node_indexes) {
359     delete_nodes_.insert(node_indexes.begin(), node_indexes.end());
360   }
361 
UpdateMetaGraph()362   void UpdateMetaGraph() {
363     auto &main_graph = graph_->subGraph.front();
364     for (auto node_index : delete_nodes_) {
365       graph_->nodes.erase(graph_->nodes.begin() + node_index);
366     }
367     main_graph->nodeIndices.clear();
368     for (uint32_t index = 0; index < static_cast<uint32_t>(graph_->nodes.size()); ++index) {
369       main_graph->nodeIndices.push_back(index);
370     }
371     std::map<uint32_t, uint32_t> tensor_maps;
372     BuildTensorMap(&tensor_maps);
373     auto UpdateTensorIndex = [&tensor_maps](std::vector<uint32_t> *origin) {
374       auto origin_indexes = *origin;
375       origin->clear();
376       (void)std::transform(origin_indexes.begin(), origin_indexes.end(), std::back_inserter(*origin),
377                            [&tensor_maps](uint32_t origin_index) { return tensor_maps[origin_index]; });
378     };
379     UpdateTensorIndex(&graph_->inputIndex);
380     for (auto &node : graph_->nodes) {
381       UpdateTensorIndex(&node->inputIndex);
382       UpdateTensorIndex(&node->outputIndex);
383     }
384     UpdateTensorIndex(&graph_->outputIndex);
385     main_graph->inputIndices = graph_->inputIndex;
386     main_graph->outputIndices = graph_->outputIndex;
387     main_graph->tensorIndices.clear();
388     for (uint32_t index = 0; index < static_cast<uint32_t>(tensor_maps.size()); ++index) {
389       main_graph->tensorIndices.push_back(index);
390     }
391     std::vector<std::unique_ptr<TensorT>> tensors;
392     graph_->allTensors.swap(tensors);
393     graph_->allTensors.resize(tensor_maps.size());
394     for (auto &tensor_map : tensor_maps) {
395       graph_->allTensors[tensor_map.second].swap(tensors[tensor_map.first]);
396     }
397   }
398 
399  private:
BuildTensorMap(std::map<uint32_t,uint32_t> * tensor_maps)400   void BuildTensorMap(std::map<uint32_t, uint32_t> *tensor_maps) {
401     uint32_t new_index = 0;
402     auto InsertElements = [tensor_maps, &new_index](const std::vector<uint32_t> &indexes) mutable {
403       for (auto index : indexes) {
404         if (tensor_maps->find(index) != tensor_maps->end()) {
405           continue;
406         }
407         (*tensor_maps)[index] = new_index++;
408       }
409     };
410     InsertElements(graph_->inputIndex);
411     for (auto &node : graph_->nodes) {
412       InsertElements(node->inputIndex);
413       InsertElements(node->outputIndex);
414     }
415     InsertElements(graph_->outputIndex);
416   }
417 
418   schema::MetaGraphT *graph_{nullptr};
419   std::set<uint32_t, camp> delete_nodes_;
420   // tensor_index, <in_node_infos, out_node_info>
421   std::map<uint32_t, std::pair<std::vector<InNodeInfo>, OutNodeInfo>> link_info_manager_;
422 };
423 
424 class GruCellFusion {
425  public:
426   GruCellFusion() = default;
427   ~GruCellFusion() = default;
Run(schema::MetaGraphT * graph)428   STATUS Run(schema::MetaGraphT *graph) {
429     MS_ASSERT(graph != nullptr);
430     MS_ASSERT(graph->subGraph.size() == 1);
431     link_info_manager_ = std::make_shared<LinkInfoManager>(graph);
432     graph_ = graph;
433     DefinePattern();
434     for (uint32_t node_index = 0; node_index < static_cast<uint32_t>(graph->nodes.size()); ++node_index) {
435       if (!MatchPattern(node_index)) {
436         continue;
437       }
438       if (CreateCustomGruCell() != RET_OK) {
439         MS_LOG(ERROR) << "Create Custom-Gru failed.";
440         return RET_ERROR;
441       }
442     }
443     link_info_manager_->UpdateMetaGraph();
444     return RET_OK;
445   }
446 
447  private:
448   struct NodeInfo {
449     struct InTensorInfo {
450       bool is_const{false};
451       uint32_t node_index_{0};
452       uint32_t tensor_index_{0};
453     };
454     struct OutTensorInfo {
455       uint32_t node_index_{0};
456       uint32_t tensor_index_{0};
457     };
458     bool (*checker)(schema::MetaGraphT *graph, uint32_t node_index);
459     std::vector<InTensorInfo> in_infos;
460     std::vector<OutTensorInfo> out_infos;
461   };
462 
DefinePattern()463   void DefinePattern() {
464     int match_order = 0;
465     pattern_[{match_order++, kAdd0}] = {
466       CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>, {{false, kTanh, 0}, {false, kMul0, 0}}, {}};
467     pattern_[{match_order++, kTanh}] = {
468       CheckActivation<schema::ActivationType_TANH>, {{false, kAdd1, 0}}, {{kSub, 1}, {kAdd0, 0}}};
469     pattern_[{match_order++, kMul0}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>,
470                                         {{false, kSigmoid0, 0}, {false, kSub, 0}},
471                                         {{kAdd0, 1}}};
472     pattern_[{match_order++, kAdd1}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
473                                         {{false, kSplit0, 2}, {false, kMul1, 0}},
474                                         {{kTanh, 0}}};
475     pattern_[{match_order++, kSub}] = {CheckArithmetic<schema::PrimitiveType_SubFusion, schema::SubFusionT>,
476                                        {{false, kInputH, 0}, {false, kTanh, 0}},
477                                        {{kMul0, 1}}};
478     pattern_[{match_order++, kSigmoid0}] = {
479       CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd2, 0}}, {{kMul0, 0}}};
480     pattern_[{match_order++, kSplit0}] = {CheckSplit, {{false, kAdd3, 0}}, {{kAdd4, 0}, {kAdd2, 0}, {kAdd1, 0}}};
481     pattern_[{match_order++, kMul1}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>,
482                                         {{false, kSigmoid1, 0}, {false, kSplit1, 2}},
483                                         {{kAdd1, 1}}};
484     pattern_[{match_order++, kAdd2}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
485                                         {{false, kSplit0, 1}, {false, kSplit1, 1}},
486                                         {{kSigmoid0, 0}}};
487     pattern_[{match_order++, kSigmoid1}] = {
488       CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd4, 0}}, {{kMul1, 0}}};
489     pattern_[{match_order++, kAdd3}] = {CheckBiasAdd, {{false, kMatmul0, 0}, {true}}, {{kSplit0, 0}}};
490     pattern_[{match_order++, kSplit1}] = {CheckSplit, {{false, kAdd5, 0}}, {{kAdd4, 1}, {kAdd2, 1}, {kMul1, 1}}};
491     pattern_[{match_order++, kAdd4}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
492                                         {{false, kSplit0, 0}, {false, kSplit1, 0}},
493                                         {{kSigmoid1, 0}}};
494     pattern_[{match_order++, kAdd5}] = {CheckBiasAdd, {{false, kMatmul1, 0}, {true}}, {{kSplit1, 0}}};
495     pattern_[{match_order++, kMatmul0}] = {CheckMatmul, {{false, kInputI, 0}, {true}}, {{kAdd3, 0}}};
496     pattern_[{match_order++, kMatmul1}] = {CheckMatmul, {{false, kInputH, 0}, {true}}, {{kAdd5, 0}}};
497   }
498 
FillRealPattern(uint32_t node_index,std::map<uint32_t,NodeInfo> * real_pattern)499   bool FillRealPattern(uint32_t node_index, std::map<uint32_t, NodeInfo> *real_pattern) {
500     const auto &link_infos = link_info_manager_->GetLinkInfos();
501     if (real_pattern->find(node_index) != real_pattern->end()) {
502       return false;
503     }
504     real_pattern->insert({node_index, {nullptr}});
505     auto in_tensor_indexes = graph_->nodes[node_index]->inputIndex;
506     for (auto tensor_index : in_tensor_indexes) {
507       if (link_infos.find(tensor_index) == link_infos.end()) {
508         return false;
509       }
510       const auto &tensor_out_info = link_infos.at(tensor_index).second;
511       if (tensor_out_info.node_index < 0) {
512         real_pattern->at(node_index).in_infos.push_back({true});
513       } else {
514         real_pattern->at(node_index)
515           .in_infos.push_back({false, static_cast<uint32_t>(tensor_out_info.node_index), tensor_out_info.out_index});
516       }
517     }
518     auto out_tensor_indexes = graph_->nodes[node_index]->outputIndex;
519     for (auto tensor_index : out_tensor_indexes) {
520       if (link_infos.find(tensor_index) == link_infos.end()) {
521         return false;
522       }
523       const auto &in_tensor_out_info = link_infos.at(tensor_index).first;
524       for (const auto &in_node_info : in_tensor_out_info) {
525         for (auto index : in_node_info.in_indexes) {
526           real_pattern->at(node_index).out_infos.push_back({static_cast<uint32_t>(in_node_info.node_index), index});
527         }
528       }
529     }
530     return true;
531   }
532 
CheckPattern(const std::map<uint32_t,NodeInfo> & real_pattern,const std::pair<int,uint32_t> & pattern_node_index)533   bool CheckPattern(const std::map<uint32_t, NodeInfo> &real_pattern,
534                     const std::pair<int, uint32_t> &pattern_node_index) {
535     const auto &real_in_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).in_infos;
536     const auto &virtual_in_infos = pattern_.at(pattern_node_index).in_infos;
537     if (real_in_infos.size() != virtual_in_infos.size()) {
538       return false;
539     }
540     for (size_t i = 0; i < virtual_in_infos.size(); ++i) {
541       if (virtual_in_infos[i].is_const) {
542         if (!real_in_infos[i].is_const) {
543           return false;
544         }
545         continue;
546       }
547       if (virtual_in_infos[i].tensor_index_ != real_in_infos[i].tensor_index_) {
548         return false;
549       }
550       if (real_node_map_.find(virtual_in_infos[i].node_index_) == real_node_map_.end()) {
551         real_node_map_.insert({virtual_in_infos[i].node_index_, real_in_infos[i].node_index_});
552       } else if (real_node_map_.at(virtual_in_infos[i].node_index_) != real_in_infos[i].node_index_) {
553         return false;
554       }
555     }
556     const auto &real_out_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).out_infos;
557     const auto &virtual_out_infos = pattern_.at(pattern_node_index).out_infos;
558     if (virtual_out_infos.empty()) {
559       return true;
560     }
561     if (real_out_infos.size() != virtual_out_infos.size()) {
562       return false;
563     }
564     for (size_t i = 0; i < virtual_out_infos.size(); ++i) {
565       if (virtual_out_infos[i].tensor_index_ != real_out_infos[i].tensor_index_) {
566         return false;
567       }
568       if (real_node_map_.find(virtual_out_infos[i].node_index_) == real_node_map_.end()) {
569         real_node_map_.insert({virtual_out_infos[i].node_index_, real_out_infos[i].node_index_});
570       } else if (real_node_map_.at(virtual_out_infos[i].node_index_) != real_out_infos[i].node_index_) {
571         return false;
572       }
573     }
574     return true;
575   }
576 
CheckClosure(const std::map<uint32_t,uint32_t> & node_map)577   bool CheckClosure(const std::map<uint32_t, uint32_t> &node_map) {
578     std::set<uint32_t> real_nodes;
579     (void)std::for_each(node_map.begin(), node_map.end(),
580                         [&real_nodes](std::pair<uint32_t, uint32_t> pair) { real_nodes.insert(pair.second); });
581     if (real_nodes.size() != node_map.size()) {
582       return false;
583     }
584     const auto &link_infos = link_info_manager_->GetLinkInfos();
585     for (uint32_t start = kAdd1; start <= kMatmul1; ++start) {
586       if (node_map.find(start) == node_map.end()) {
587         return false;
588       }
589       const auto &node = graph_->nodes[node_map.at(start)];
590       auto out_tensor_indexes = node->outputIndex;
591       for (auto out_index : out_tensor_indexes) {
592         if (link_infos.find(out_index) == link_infos.end()) {
593           return false;
594         }
595         for (const auto &in_node_info : link_infos.at(out_index).first) {
596           if (real_nodes.find(in_node_info.node_index) == real_nodes.end()) {
597             return false;
598           }
599         }
600       }
601     }
602     return true;
603   }
604 
MatchPattern(uint32_t add_index)605   bool MatchPattern(uint32_t add_index) {
606     real_node_map_.clear();
607     real_node_map_[kAdd0] = add_index;
608     std::map<uint32_t, NodeInfo> real_pattern;
609     for (const auto &pair : pattern_) {
610       if (real_node_map_.find(pair.first.second) == real_node_map_.end()) {
611         return false;
612       }
613       auto node_index = real_node_map_[pair.first.second];
614       if (!pair.second.checker(graph_, node_index)) {
615         return false;
616       }
617       if (!FillRealPattern(node_index, &real_pattern)) {
618         return false;
619       }
620       if (!CheckPattern(real_pattern, pair.first)) {
621         return false;
622       }
623     }
624     auto weight_hidden_index = graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1];
625     auto weight_hidden_shape = graph_->allTensors[weight_hidden_index]->dims;
626     if (weight_hidden_shape.size() != C2NUM || weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) {
627       return false;
628     }
629     return CheckClosure(real_node_map_);
630   }
631 
CreateCustomGruCell()632   STATUS CreateCustomGruCell() {
633     std::vector<uint32_t> inputs;
634     inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[0]);  // x
635     inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[1]);  // weight_input
636     inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1]);  // weight_hidden
637     inputs.push_back(graph_->nodes[real_node_map_[kAdd3]]->inputIndex[1]);     // bias_input
638     inputs.push_back(graph_->nodes[real_node_map_[kAdd5]]->inputIndex[1]);     // bias_hidden
639     inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[0]);  // init_h
640     auto outputs = graph_->nodes[real_node_map_[kAdd0]]->outputIndex;
641     auto attrs = CreateCustom();
642     MS_CHECK_TRUE_RET(attrs != nullptr, RET_NULL_PTR);
643     auto prim_t = std::make_unique<schema::PrimitiveT>();
644     MS_CHECK_TRUE_MSG(prim_t != nullptr, RET_ERROR, "Create PrimitiveT failed.");
645     prim_t->value.type = schema::PrimitiveType_Custom;
646     prim_t->value.value = attrs.release();
647     auto custom_gru = std::make_unique<schema::CNodeT>();
648     MS_CHECK_TRUE_MSG(custom_gru != nullptr, RET_ERROR, "Create Custom-Gru failed.");
649     custom_gru->name = graph_->nodes[real_node_map_[kAdd0]]->name;
650     custom_gru->inputIndex = inputs;
651     custom_gru->outputIndex = outputs;
652     custom_gru->primitive = std::move(prim_t);
653     link_info_manager_->Replace(real_node_map_[kAdd0], std::move(custom_gru));
654     std::set<uint32_t> delete_nodes;
655     for (uint32_t i = kAdd1; i <= kMatmul1; ++i) {
656       delete_nodes.insert(real_node_map_[i]);
657     }
658     link_info_manager_->AddDeleteNodes(delete_nodes);
659     return RET_OK;
660   }
661 
662   std::map<std::pair<int, uint32_t>, NodeInfo> pattern_;
663   std::map<uint32_t, uint32_t> real_node_map_;
664   schema::MetaGraphT *graph_{nullptr};
665   std::shared_ptr<LinkInfoManager> link_info_manager_{nullptr};
666 };
667 
Run(schema::MetaGraphT * graph)668 STATUS GruFusionPass::Run(schema::MetaGraphT *graph) {
669 #ifndef ENABLE_ARM64
670   return RET_OK;
671 #endif
672   if (graph == nullptr) {
673     MS_LOG(ERROR) << "graph is a nullptr.";
674     return RET_NULL_PTR;
675   }
676   if (graph->subGraph.size() != 1) {
677     return RET_OK;
678   }
679   if (FuseToGruCell(graph) != RET_OK) {
680     return RET_ERROR;
681   }
682   return FuseGruCell(graph);
683 }
684 
FuseToGruCell(schema::MetaGraphT * graph)685 STATUS GruFusionPass::FuseToGruCell(schema::MetaGraphT *graph) {
686   GruCellFusion gru_cell_fusion{};
687   if (gru_cell_fusion.Run(graph) != RET_OK) {
688     MS_LOG(ERROR) << "Fuse GruCell failed.";
689     return RET_ERROR;
690   }
691   return RET_OK;
692 }
693 
FuseGruCell(schema::MetaGraphT * graph)694 STATUS GruFusionPass::FuseGruCell(schema::MetaGraphT *graph) {
695   link_info_manager_ = std::make_shared<LinkInfoManager>(graph);
696   for (uint32_t i = 0; i < static_cast<uint32_t>(graph->nodes.size()); ++i) {
697     if (!CheckStack(graph, i)) {
698       continue;
699     }
700     std::vector<uint32_t> strided_slices;
701     std::vector<uint32_t> squeezes;
702     std::vector<uint32_t> gru_cells;
703     if (!MatchPatten(graph, i, &strided_slices, &squeezes, &gru_cells)) {
704       continue;
705     }
706     if (CreateGru(graph, i, strided_slices, squeezes, gru_cells) != RET_OK) {
707       MS_LOG(ERROR) << "Fuse GruCell failed.";
708       return RET_ERROR;
709     }
710   }
711   link_info_manager_->UpdateMetaGraph();
712   link_info_manager_ = nullptr;
713   return RET_OK;
714 }
715 
MatchPatten(schema::MetaGraphT * graph,uint32_t stack_index,std::vector<uint32_t> * strided_slices,std::vector<uint32_t> * squeezes,std::vector<uint32_t> * gru_cells)716 bool GruFusionPass::MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector<uint32_t> *strided_slices,
717                                 std::vector<uint32_t> *squeezes, std::vector<uint32_t> *gru_cells) {
718   auto &link_infos = link_info_manager_->GetLinkInfos();
719   auto &stack_node = graph->nodes[stack_index];
720   int batch_point = 0;
721   auto CommonCheck = [&link_infos](uint32_t tensor_index) {
722     if (link_infos.find(tensor_index) == link_infos.end()) {
723       return std::make_pair(false, 0);
724     }
725     const auto &in_node_info = link_infos.at(tensor_index).first;
726     if (in_node_info.size() != 1 && in_node_info.front().in_indexes.size() != 1) {
727       return std::make_pair(false, 0);
728     }
729     auto node_index = link_infos.at(tensor_index).second.node_index;
730     if (node_index < 0) {
731       return std::make_pair(false, 0);
732     }
733     return std::make_pair(true, node_index);
734   };
735   for (auto tensor_index : stack_node->inputIndex) {
736     auto check_info = CommonCheck(tensor_index);
737     if (!check_info.first || !CheckGruCell(graph, check_info.second)) {
738       return false;
739     }
740     gru_cells->push_back(check_info.second);
741     auto &gru_cell_node = graph->nodes[check_info.second];
742     check_info = CommonCheck(gru_cell_node->inputIndex.front());
743     if (!check_info.first || !CheckSqueeze(graph, check_info.second)) {
744       return false;
745     }
746     squeezes->push_back(check_info.second);
747     auto &squeeze_node = graph->nodes[check_info.second];
748     check_info = CommonCheck(squeeze_node->inputIndex.front());
749     if (!check_info.first || !CheckStridedSlice(graph, check_info.second, batch_point)) {
750       return false;
751     }
752     strided_slices->push_back(check_info.second);
753     ++batch_point;
754   }
755   if (strided_slices->empty()) {
756     return false;
757   }
758   uint32_t input_index = graph->nodes[strided_slices->front()]->inputIndex.front();
759   if (std::any_of(strided_slices->begin(), strided_slices->end(), [input_index, graph](uint32_t strided_slice) {
760         return graph->nodes[strided_slice]->inputIndex.front() != input_index;
761       })) {
762     return false;
763   }
764   auto in_shape = graph->allTensors[input_index]->dims;
765   if (in_shape.empty() || in_shape.front() != batch_point) {
766     return false;
767   }
768   return CheckGruCellConnection(graph, *gru_cells);
769 }
770 
CheckGruCellConnection(schema::MetaGraphT * graph,const std::vector<uint32_t> & gru_cells)771 bool GruFusionPass::CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector<uint32_t> &gru_cells) {
772   auto &first_node = graph->nodes[gru_cells.front()];
773   if (first_node->inputIndex.size() != C6NUM) {
774     return false;
775   }
776   auto init_h = first_node->outputIndex.front();
777   for (size_t i = 1; i < gru_cells.size(); ++i) {
778     auto &node = graph->nodes[gru_cells[i]];
779     if (node->inputIndex.size() != first_node->inputIndex.size()) {
780       return false;
781     }
782     for (size_t j = 1; j < C5NUM; ++j) {
783       if (node->inputIndex[j] != first_node->inputIndex[j]) {
784         return false;
785       }
786     }
787     if (node->inputIndex[C5NUM] != init_h) {
788       return false;
789     }
790     init_h = node->outputIndex.front();
791   }
792   return true;
793 }
794 
CreateGru(schema::MetaGraphT * graph,uint32_t stack_index,const std::vector<uint32_t> & strided_slices,const std::vector<uint32_t> & squeezes,const std::vector<uint32_t> & gru_cells)795 STATUS GruFusionPass::CreateGru(schema::MetaGraphT *graph, uint32_t stack_index,
796                                 const std::vector<uint32_t> &strided_slices, const std::vector<uint32_t> &squeezes,
797                                 const std::vector<uint32_t> &gru_cells) {
798   auto &gru_cell_node = graph->nodes[gru_cells.front()];
799   gru_cell_node->inputIndex[0] = graph->nodes[strided_slices.front()]->inputIndex[0];
800   gru_cell_node->outputIndex[0] = graph->nodes[stack_index]->outputIndex[0];
801   std::set<uint32_t> delete_node{stack_index};
802   (void)delete_node.insert(strided_slices.begin(), strided_slices.end());
803   (void)delete_node.insert(squeezes.begin(), squeezes.end());
804   (void)delete_node.insert(gru_cells.begin() + 1, gru_cells.end());
805   link_info_manager_->AddDeleteNodes(delete_node);
806   return RET_OK;
807 }
808 }  // namespace lite
809 }  // namespace mindspore
810