• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/arithmetic_simplify.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <string>
21 #include <functional>
22 #include <set>
23 #include <vector>
24 #include <utility>
25 
26 #include "ir/anf.h"
27 #include "utils/hash_map.h"
28 #include "utils/hash_set.h"
29 #include "utils/anf_utils.h"
30 #include "utils/check_convert_utils.h"
31 #include "backend/common/graph_kernel/core/graph_builder.h"
32 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
33 #include "backend/common/graph_kernel/graph_kernel_flags.h"
34 #include "backend/common/graph_kernel/model/node.h"
35 #include "backend/common/graph_kernel/model/op_node.h"
36 #include "backend/common/graph_kernel/model/graph_builder.h"
37 #include "ops/auto_generate/gen_ops_primitive.h"
38 
39 namespace mindspore::graphkernel {
40 // operator which follows commutative rules
41 static mindspore::HashSet<std::string> commutative_ops{"Add", "Mul"};
42 
43 class PatternNode;
44 using PatternNodePtr = std::shared_ptr<PatternNode>;
45 using PatternNodePtrList = std::vector<PatternNodePtr>;
46 
47 class PatternNode {
48  public:
PatternNode(const std::string & op)49   explicit PatternNode(const std::string &op) : op_(op) {}
50   ~PatternNode() = default;
op() const51   std::string op() const { return op_; }
inputs() const52   std::vector<PatternNodePtr> inputs() const { return inputs_; }
AddInput(const PatternNodePtr & input)53   void AddInput(const PatternNodePtr &input) { inputs_.push_back(input); }
54 
55  private:
56   std::string op_ = "";  // ex. "Add","const1","A","0.5" (any op, const or parameter)
57   std::vector<PatternNodePtr> inputs_;
58 };
59 
60 using ParaMap = mindspore::HashMap<char, inner::NodePtr>;
61 using ConstMap = mindspore::HashMap<std::string, inner::NodePtr>;
62 
63 /* This class works to store a kind of pattern tree; it needs a string expression to construct;
64  Ex."Pow(Exp(A),B)=Exp(Mul(A,B))"
65  then the left tree is
66           A                             A    B
67            \                             \  /
68             Exp    B                       Mul
69              \   /                           \
70  left tree:   Pow         right tree:         Exp
71  lhs_root_ is Pow ;lhs_root_ is Exp */
72 class PatternTree {
73  public:
74   // pattern_str->ex."Pow(Exp(A),B)=Exp(Mul(A,B))"
PatternTree(const std::string & pattern_str)75   explicit PatternTree(const std::string &pattern_str) { (void)BuildTree(pattern_str); }
76   virtual ~PatternTree() = default;
77 
lhs_root()78   PatternNodePtr lhs_root() { return lhs_root_; }
rhs_root()79   PatternNodePtr rhs_root() { return rhs_root_; }
GetRootOp() const80   std::string GetRootOp() const { return lhs_root_ == nullptr ? "" : lhs_root_->op(); }
81   // build tree with expression string
82   PatternNodePtr BuildTree(const std::string &pattern_str);
83   // traverse pattern tree, return order is topological order
84   void DfsTraverse(const std::shared_ptr<PatternNodePtrList> &res, const PatternNodePtr &cur) const;
85   // For some patterns, the input parameters may change (e.g.: ReduceSum(ReduceSum(A,B),C)=ReduceSum(A,D)),
86   // in this case we need to compute the new axes(D), and update the parameter map.
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const87   virtual std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
88                                                     const std::shared_ptr<ParaMap> &para_to_ref) const {
89     (void)origin_root;
90     return para_to_ref;
91   }
92   // leverage pattern tree node and lite node's mapping relation to build lite node graph from pattern tree's right
93   // side
94   inner::NodePtr AlterGraph(const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref,
95                             const inner::NodePtr &origin_root);
96   // invoke DfsMatchGraph
97   inner::NodePtrList MatchGraph(const inner::NodePtr &root, const std::shared_ptr<ParaMap> &para_to_ref,
98                                 const std::shared_ptr<ConstMap> &const_to_ref);
99 
100  protected:
101   // set attributes for certain pattern node if needed;
SetAttributes(const inner::NodePtr &)102   virtual mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &) {
103     auto right_pattern = std::make_shared<PatternNodePtrList>();
104     DfsTraverse(right_pattern, rhs_root_);
105     mindspore::HashMap<PatternNodePtr, inner::DAttrs> attrs_map;
106     for (auto &i : (*right_pattern)) {
107       attrs_map[i] = {};
108     }
109     return attrs_map;
110   }
111   // check whether inputs and attributes meet requirements for certain pattern node if needed;
CheckInputsAndAttrs(const inner::NodePtr &) const112   virtual bool CheckInputsAndAttrs(const inner::NodePtr &) const { return true; }
113 
114  private:
115   PatternNodePtr lhs_root_ = nullptr;  // left side's root
116   PatternNodePtr rhs_root_ = nullptr;  // right side's root
117 };
118 
CutStr(const string & s,size_t start_pos=0,size_t len=std::string::npos)119 std::string CutStr(const string &s, size_t start_pos = 0, size_t len = std::string::npos) {
120   std::string new_str = "";
121   if (start_pos >= s.length()) {
122     MS_LOG(EXCEPTION) << "Start index " << start_pos << " is out of range [0, " << s.length() << ") in string: " << s;
123   }
124   for (size_t i = 0; i < len; i++) {
125     if (start_pos + i >= s.length()) {
126       break;
127     }
128     new_str += s[start_pos + i];
129   }
130   return new_str;
131 }
132 
StartWith(const std::string & s,const std::string & prefix)133 bool StartWith(const std::string &s, const std::string &prefix) {
134   if (s.length() < prefix.length()) {
135     return false;
136   }
137   return s.find(prefix) == 0;
138 }
139 
140 // build pattern tree ;left side's root is lhs_root_ ; right side's root is rhs_root_
BuildTree(const std::string & pattern_str)141 PatternNodePtr PatternTree::BuildTree(const std::string &pattern_str) {
142   size_t pos = pattern_str.find("=");
143   if (pos != std::string::npos) {
144     auto left_expression = CutStr(pattern_str, 0, pos);
145     lhs_root_ = BuildTree(left_expression);
146     auto right_expression = CutStr(pattern_str, pos + 1);
147     rhs_root_ = BuildTree(right_expression);
148   } else {
149     size_t p_start = pattern_str.find("(");
150     if (p_start != std::string::npos) {
151       size_t p_end = pattern_str.rfind(")");
152       auto op_name = CutStr(pattern_str, 0, p_start);
153       auto op_inputs = CutStr(pattern_str, p_start + 1, (p_end - p_start) - 1);
154       PatternNodePtr cur_node = std::make_shared<PatternNode>(op_name);
155       int tmp = 0;
156       size_t comma = 0;
157       while (comma < op_inputs.length()) {
158         if (op_inputs[comma] == '(') {
159           tmp++;
160         }
161         if (op_inputs[comma] == ')') {
162           tmp--;
163         }
164         if (op_inputs[comma] == ',' && tmp == 0) {
165           auto first_half = CutStr(op_inputs, 0, comma);
166           cur_node->AddInput(BuildTree(first_half));
167           auto second_half = CutStr(op_inputs, comma + 1);
168           op_inputs = second_half;
169           comma = 0;
170         } else {
171           comma++;
172         }
173       }
174       cur_node->AddInput(BuildTree(op_inputs));
175       return cur_node;
176     } else {
177       return std::make_shared<PatternNode>(pattern_str);
178     }
179   }
180   return nullptr;
181 }
182 
PatternNodeType(const std::string & n)183 inner::NType PatternNodeType(const std::string &n) {
184   // return (Primitive, Parameter or Value)
185   if (n.length() > 0 && n[n.length() - 1] >= '0' && n[n.length() - 1] <= '9') {
186     return inner::NType::Tensor;
187   } else if (n.length() == 1 && n[0] >= 'A' && n[0] <= 'Z') {
188     return inner::NType::Parameter;
189   } else {
190     return inner::NType::Primitive;
191   }
192 }
193 
CleanStr(const std::string & s)194 std::string CleanStr(const std::string &s) {
195   std::string res = "";
196   (void)std::for_each(s.begin(), s.end(), [&res](const char &c) {
197     if (c != '[' && c != ']' && c != ' ') {
198       res += c;
199     }
200   });
201   return res;
202 }
203 
CheckCurNode(const inner::NodePtr & tmp_node,const std::string & tmp_pattern_op,const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref)204 bool CheckCurNode(const inner::NodePtr &tmp_node, const std::string &tmp_pattern_op,
205                   const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref) {
206   // put lite graph node's mapping to pattern node into "para_to_ref" and "const_to_ref"
207   switch (PatternNodeType(tmp_pattern_op)) {
208     case inner::NType::Parameter: {
209       if (para_to_ref->find(tmp_pattern_op[0]) != para_to_ref->end()) {
210         if ((*para_to_ref)[tmp_pattern_op[0]] != tmp_node) {
211           return false;
212         }
213       } else {
214         (*para_to_ref)[tmp_pattern_op[0]] = tmp_node;
215       }
216       break;
217     }
218     case inner::NType::Tensor: {
219       if (tmp_node->NodeType() != inner::NType::Tensor) {
220         return false;
221       }
222       auto node_value_str = std::static_pointer_cast<inner::ConstTensorNode>(tmp_node)->ToString();
223       double node_value = std::stod(CleanStr(node_value_str));
224       if (StartWith(tmp_pattern_op, "const")) {
225         if (const_to_ref->find(tmp_pattern_op) != const_to_ref->end()) {
226           auto pattern_value_str =
227             std::static_pointer_cast<inner::ConstTensorNode>((*const_to_ref)[tmp_pattern_op])->ToString();
228           double pattern_value = std::stod(CleanStr(pattern_value_str));
229           if (pattern_value != node_value) {
230             return false;
231           }
232         } else {
233           (*const_to_ref)[tmp_pattern_op] = tmp_node;
234         }
235       } else {
236         double pattern_value = std::stod(tmp_pattern_op);
237         if (pattern_value != node_value) {
238           return false;
239         }
240       }
241       break;
242     }
243     case inner::NType::Primitive: {
244       if (tmp_node->NodeType() != inner::NType::Primitive ||
245           std::static_pointer_cast<inner::PrimOp>(tmp_node)->op() != tmp_pattern_op) {
246         return false;
247       }
248       break;
249     }
250     default:
251       break;
252   }
253   return true;
254 }
255 
256 // recursion for thr match of lite node graph and pattern tree's left side, store the mapping of pattern tree node to
257 // lite graph
DfsMatchGraph(const inner::NodePtr & tmp_node,const PatternNodePtr & tmp_pattern,const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref,const std::shared_ptr<inner::NodePtrList> & res)258 bool DfsMatchGraph(const inner::NodePtr &tmp_node, const PatternNodePtr &tmp_pattern,
259                    const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref,
260                    const std::shared_ptr<inner::NodePtrList> &res) {
261   std::string tmp_pattern_op = tmp_pattern->op();
262   if (!CheckCurNode(tmp_node, tmp_pattern_op, para_to_ref, const_to_ref)) {
263     return false;
264   }
265   std::vector<PatternNodePtr> tmp_pattern_inputs = tmp_pattern->inputs();
266   auto tmp_node_inputs = tmp_node->inputs();
267   // check if a node meets requiremnet,and DFS check its inputs
268   if (tmp_pattern_inputs.size() != 0 && tmp_node_inputs.size() != tmp_pattern_inputs.size()) {
269     return false;
270   }
271   if (PatternNodeType(tmp_pattern_op) == inner::NType::Primitive) {
272     // exchange inputs for the node who meets commutative rules
273     if (commutative_ops.find(tmp_pattern_op) != commutative_ops.end()) {
274       ParaMap para_to_ref_copy = *para_to_ref;
275       ConstMap const_to_ref_copy = *const_to_ref;
276       bool first_match = DfsMatchGraph(tmp_node_inputs[0], tmp_pattern_inputs[0], para_to_ref, const_to_ref, res) &&
277                          DfsMatchGraph(tmp_node_inputs[1], tmp_pattern_inputs[1], para_to_ref, const_to_ref, res);
278       if (!first_match) {
279         res->clear();
280         para_to_ref->clear();
281         const_to_ref->clear();
282         for (auto &i : para_to_ref_copy) {
283           (*para_to_ref)[i.first] = i.second;
284         }
285         for (auto &i : const_to_ref_copy) {
286           (*const_to_ref)[i.first] = i.second;
287         }
288         bool second_match = DfsMatchGraph(tmp_node_inputs[0], tmp_pattern_inputs[1], para_to_ref, const_to_ref, res) &&
289                             DfsMatchGraph(tmp_node_inputs[1], tmp_pattern_inputs[0], para_to_ref, const_to_ref, res);
290         if (!second_match) {
291           return false;
292         }
293       }
294     } else {
295       for (size_t i = 0; i < tmp_pattern_inputs.size(); i++) {
296         if (!DfsMatchGraph(tmp_node_inputs[i], tmp_pattern_inputs[i], para_to_ref, const_to_ref, res)) {
297           return false;
298         }
299       }
300     }
301     res->push_back(tmp_node);
302   }
303   return true;
304 }
305 
306 // traverse pattern tree and return topological order
DfsTraverse(const std::shared_ptr<PatternNodePtrList> & res,const PatternNodePtr & cur) const307 void PatternTree::DfsTraverse(const std::shared_ptr<PatternNodePtrList> &res, const PatternNodePtr &cur) const {
308   if (cur == nullptr) {
309     return;
310   }
311   for (auto &p : cur->inputs()) {
312     if (PatternNodeType(p->op()) == inner::NType::Primitive) {
313       DfsTraverse(res, p);
314     }
315   }
316   res->push_back(cur);
317 }
318 
319 // invoke DfsMatchGraph
MatchGraph(const inner::NodePtr & root,const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref)320 inner::NodePtrList PatternTree::MatchGraph(const inner::NodePtr &root, const std::shared_ptr<ParaMap> &para_to_ref,
321                                            const std::shared_ptr<ConstMap> &const_to_ref) {
322   auto res = std::make_shared<inner::NodePtrList>();
323   if (!DfsMatchGraph(root, lhs_root_, para_to_ref, const_to_ref, res)) {
324     return {};
325   }
326   if (CheckInputsAndAttrs(root)) {
327     return *res;
328   }
329   return {};
330 }
331 
332 // leverage pattern tree node and lite node's mapping relation to build new lite node graph from pattern tree's right
333 // side
AlterGraph(const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref,const inner::NodePtr & origin_root)334 inner::NodePtr PatternTree::AlterGraph(const std::shared_ptr<ParaMap> &para_to_ref,
335                                        const std::shared_ptr<ConstMap> &const_to_ref,
336                                        const inner::NodePtr &origin_root) {
337   auto res = std::make_shared<PatternNodePtrList>();
338   DfsTraverse(res, rhs_root_);
339   auto all_attrs = SetAttributes(origin_root);
340   inner::GraphBuilder gb("");
341   mindspore::HashMap<PatternNodePtr, inner::NodePtr> pattern_to_ref;
342   for (auto &n : (*res)) {
343     if (PatternNodeType(n->op()) != inner::NType::Primitive) {
344       continue;
345     }
346     inner::NodePtrList inputs;
347     for (auto &i : n->inputs()) {
348       if (PatternNodeType(i->op()) == inner::NType::Primitive) {
349         inputs.push_back(pattern_to_ref[i]);
350       } else if (PatternNodeType(i->op()) == inner::NType::Parameter) {
351         inputs.push_back((*para_to_ref)[i->op()[0]]);
352       } else {
353         if (StartWith(i->op(), "const")) {
354           inputs.push_back((*const_to_ref)[i->op()]);
355         } else {
356           tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(std::stof(i->op())));
357           inputs.push_back(gb.Value(data));
358         }
359       }
360     }
361     auto p = gb.Emit(n->op(), inputs, all_attrs[n]);
362     pattern_to_ref[n] = p;
363   }
364   auto &alter_graph = gb.Get()->ops();
365   if (alter_graph.empty()) {
366     if (PatternNodeType(rhs_root_->op()) == inner::NType::Parameter) {
367       return (*para_to_ref)[rhs_root_->op()[0]];
368     } else {
369       if (StartWith(rhs_root_->op(), "const")) {
370         return (*const_to_ref)[rhs_root_->op()];
371       } else {
372         tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(std::stof(rhs_root_->op())));
373         return gb.Value(data);
374       }
375     }
376   }
377   return alter_graph.back();
378 }
379 
380 // Reduce(Reduce(A,B),C) = Reduce(A,D)
381 class ExtraReduce1PatternTree : public PatternTree {
382  public:
ExtraReduce1PatternTree(const std::string & pattern_str)383   explicit ExtraReduce1PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
384   ~ExtraReduce1PatternTree() override = default;
385 
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const386   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
387                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
388     MS_EXCEPTION_IF_NULL(para_to_ref);
389     auto axes1_tensornode = (*para_to_ref)['B']->As<inner::ConstTensorNode>();
390     MS_EXCEPTION_IF_NULL(axes1_tensornode);
391     auto axes2_tensornode = (*para_to_ref)['C']->As<inner::ConstTensorNode>();
392     MS_EXCEPTION_IF_NULL(axes2_tensornode);
393     auto axes1 = CheckAndConvertUtils::CheckTensorIntValue("axes", axes1_tensornode->data(), "Reduce");
394     auto axes2 = CheckAndConvertUtils::CheckTensorIntValue("axes", axes2_tensornode->data(), "Reduce");
395     bool keep_dims = GetValue<bool>(origin_root->attrs().find("keep_dims")->second);
396     std::vector<int64_t> axes;
397     std::set<int64_t> axis_set;
398     if (keep_dims) {
399       for (auto &i : axes1) {
400         (void)axis_set.insert(i);
401       }
402       for (auto &i : axes2) {
403         (void)axis_set.insert(i);
404       }
405     } else {
406       std::set<int64_t> st(axes1.begin(), axes1.end());
407       mindspore::HashMap<int64_t, int64_t> mp;
408       int64_t shift = 0;
409       auto size = SizeToLong((*para_to_ref)['A']->shape.size());
410       for (int64_t n = 0; n < size; n++) {
411         if (st.find(n) != st.end()) {
412           shift++;
413         } else {
414           mp[n - shift] = n;
415         }
416       }
417       (void)std::for_each(axes1.begin(), axes1.end(), [&axis_set](auto &i) { (void)axis_set.insert(i); });
418       (void)std::for_each(axes2.begin(), axes2.end(), [&axis_set, &mp](auto &i) { (void)axis_set.insert(mp[i]); });
419     }
420     (void)std::copy(axis_set.begin(), axis_set.end(), std::back_inserter(axes));
421     inner::GraphBuilder gb("");
422     auto new_axes_tensornode = gb.Tensor(axes);
423     (*para_to_ref)['D'] = new_axes_tensornode;
424     (void)para_to_ref->erase('B');
425     (void)para_to_ref->erase('C');
426     return para_to_ref;
427   }
428 
429  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const430   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
431     auto first_reduce_shape = origin_root->input(0)->shape;
432     return (GetValue<bool>((origin_root->inputs()[0])->attrs().find("keep_dims")->second) ==
433               GetValue<bool>(origin_root->attrs().find("keep_dims")->second) &&
434             !IsDynamicRank(first_reduce_shape));
435   }
SetAttributes(const inner::NodePtr & origin_root)436   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
437     auto attrs_map = PatternTree::SetAttributes(origin_root);
438     bool keep_dims = GetValue<bool>(origin_root->attrs().find("keep_dims")->second);
439     if (GetRootOp() == prim::kPrimReduceSum->name()) {
440       auto iter = origin_root->attrs().find("skip_mode");
441       if (iter != origin_root->attrs().end()) {
442         bool skip_mode = GetValue<bool>(iter->second);
443         attrs_map[this->rhs_root()] = {{"keep_dims", MakeValue(keep_dims)}, {"skip_mode", MakeValue(skip_mode)}};
444       } else {
445         MS_LOG(EXCEPTION) << origin_root->ToString() << "not found skip_mode attrs.";
446       }
447     } else {
448       attrs_map[this->rhs_root()] = {{"keep_dims", MakeValue(keep_dims)}};
449     }
450     return attrs_map;
451   }
452 };
453 
454 // "ReduceSum(Neg(A),B)=Neg(ReduceSum(A,B))"
455 class ExtraReduce2PatternTree : public PatternTree {
456  public:
ExtraReduce2PatternTree(const std::string & pattern_str)457   explicit ExtraReduce2PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
458   ~ExtraReduce2PatternTree() override = default;
459 
460  protected:
SetAttributes(const inner::NodePtr & origin_root)461   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
462     auto attrs_map = PatternTree::SetAttributes(origin_root);
463     bool keep_dims = GetValue<bool>(origin_root->attrs().find("keep_dims")->second);
464     auto iter = origin_root->attrs().find("skip_mode");
465     if (iter != origin_root->attrs().end()) {
466       bool skip_mode = GetValue<bool>(iter->second);
467       attrs_map[this->rhs_root()->inputs()[0]] = {{"keep_dims", MakeValue(keep_dims)},
468                                                   {"skip_mode", MakeValue(skip_mode)}};
469     } else {
470       MS_LOG(EXCEPTION) << origin_root->ToString() << "not found skip_mode attrs.";
471     }
472     return attrs_map;
473   }
474 };
475 
476 // "ReduceSum(A,B)=ReShape(A,C)"
477 class ReducePatternTree : public PatternTree {
478  public:
ReducePatternTree(const std::string & pattern_str)479   explicit ReducePatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
480   ~ReducePatternTree() override = default;
481 
482  protected:
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const483   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
484                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
485     MS_EXCEPTION_IF_NULL(para_to_ref);
486     inner::GraphBuilder gb("");
487     // Because an empty Tensor cannot be generated, the second input for the reshape function needs to be a Tuple.
488     auto shape_node = gb.Tensor(origin_root->shape);
489     (*para_to_ref)['C'] = shape_node;
490     (void)para_to_ref->erase('B');
491     return para_to_ref;
492   }
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const493   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
494     auto reduce_shape = origin_root->input(0)->shape;
495     if (IsDynamicShape(reduce_shape)) {
496       return false;
497     }
498     if (reduce_shape.empty()) {
499       return true;
500     }
501     auto reduce_axis = origin_root->input(1)->As<inner::ConstTensorNode>();
502     if (reduce_axis == nullptr) {
503       return false;
504     }
505     auto axis = CheckAndConvertUtils::CheckTensorIntValue("axis", reduce_axis->data(), "Reduce");
506     for (auto &i : axis) {
507       if (i < 0) {
508         i += SizeToLong(reduce_shape.size());
509       }
510       if (reduce_shape[i] != 1) {
511         return false;
512       }
513     }
514     return true;
515   }
516 };
517 
518 class CastPatternTree : public PatternTree {
519  public:
CastPatternTree(const std::string & pattern_str)520   explicit CastPatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
521   ~CastPatternTree() = default;
522 
523  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const524   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
525     auto dst_type_id = origin_root->type;
526     auto src_type_id = origin_root->input(0)->type;
527     return dst_type_id == src_type_id;
528   }
529 };
530 
531 // "LayoutTransform(LayoutTransform(A))=A"
532 class LayoutTransform1PatternTree : public PatternTree {
533  public:
LayoutTransform1PatternTree(const std::string & pattern_str)534   explicit LayoutTransform1PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
535   ~LayoutTransform1PatternTree() override = default;
536 
537  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const538   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
539     return (GetValue<string>((origin_root->inputs()[0])->attrs().find("src_format")->second) ==
540             GetValue<string>(origin_root->attrs().find("dst_format")->second));
541   }
542 };
543 
544 // "LayoutTransform(LayoutTransform(A))=LayoutTransform(A)"
545 class LayoutTransform2PatternTree : public PatternTree {
546  public:
LayoutTransform2PatternTree(const std::string & pattern_str)547   explicit LayoutTransform2PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
548   ~LayoutTransform2PatternTree() override = default;
549 
550  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const551   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
552     return (GetValue<string>((origin_root->inputs()[0])->attrs().find("src_format")->second) !=
553             GetValue<string>(origin_root->attrs().find("dst_format")->second));
554   }
SetAttributes(const inner::NodePtr & origin_root)555   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
556     auto attrs_map = PatternTree::SetAttributes(origin_root);
557     attrs_map[this->rhs_root()] = {{"src_format", origin_root->inputs()[0]->attrs().find("src_format")->second},
558                                    {"dst_format", origin_root->attrs().find("dst_format")->second}};
559     return attrs_map;
560   }
561 };
562 
IsRedundantTransposePair(const ShapeVector & perm1,const ShapeVector & perm2)563 bool IsRedundantTransposePair(const ShapeVector &perm1, const ShapeVector &perm2) {
564   auto dim = perm2.size();
565   for (size_t i = 0; i < dim; i++) {
566     auto index = perm2[i] < 0 ? perm2[i] + static_cast<ShapeValueDType>(dim) : perm2[i];
567     MS_EXCEPTION_IF_CHECK_FAIL(static_cast<size_t>(index) < dim, "perm is out of bound");
568     auto axis = perm1[index] < 0 ? perm1[index] + static_cast<ShapeValueDType>(dim) : perm1[index];
569     if (static_cast<size_t>(axis) != i) {
570       return false;
571     }
572   }
573   return true;
574 }
575 // Transpose(Transpose(A,B),C)=A
576 class Transpose1PatternTree : public PatternTree {
577  public:
Transpose1PatternTree(const std::string & pattern_str)578   explicit Transpose1PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
579   ~Transpose1PatternTree() override = default;
580 
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const581   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
582                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
583     MS_EXCEPTION_IF_NULL(para_to_ref);
584     (void)para_to_ref->erase('B');
585     (void)para_to_ref->erase('C');
586     return para_to_ref;
587   }
588 
589  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const590   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
591     auto transpose1_node = origin_root->input(0);
592     MS_EXCEPTION_IF_NULL(transpose1_node);
593     if (transpose1_node->format != origin_root->format) {
594       MS_LOG(DEBUG) << "The input format of the first transpose is different from the output format of the second "
595                        "transpose, can't remove this transpose pair.";
596       return false;
597     }
598     auto input_shape = transpose1_node->input(0)->shape;
599     auto perm2_node = origin_root->input(1);
600     MS_EXCEPTION_IF_NULL(perm2_node);
601     auto perm1_node = transpose1_node->input(1);
602     MS_EXCEPTION_IF_NULL(perm1_node);
603     auto perm2_tensornode = perm2_node->As<inner::ConstTensorNode>();
604     MS_EXCEPTION_IF_NULL(perm2_tensornode);
605     auto perm1_tensornode = perm1_node->As<inner::ConstTensorNode>();
606     MS_EXCEPTION_IF_NULL(perm1_tensornode);
607     auto perm2 = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm2_tensornode->data(), "Transpose");
608     auto perm1 = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm1_tensornode->data(), "Transpose");
609     if (perm1.size() != input_shape.size() || perm2.size() != input_shape.size()) {
610       MS_LOG(DEBUG) << "The length of input shape and perm is not same";
611       return false;
612     }
613     return IsRedundantTransposePair(perm1, perm2);
614   }
615 
SetAttributes(const inner::NodePtr & origin_root)616   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
617     auto attrs_map = PatternTree::SetAttributes(origin_root);
618     attrs_map[this->rhs_root()] = {{"format", MakeValue(origin_root->format)}};
619     return attrs_map;
620   }
621 };
622 
623 // Transpose(A,B)=Reshape(A,C)
624 class Transpose2PatternTree : public PatternTree {
625  public:
Transpose2PatternTree(const std::string & pattern_str)626   explicit Transpose2PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
627   ~Transpose2PatternTree() override = default;
628 
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const629   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
630                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
631     MS_EXCEPTION_IF_NULL(para_to_ref);
632     inner::GraphBuilder gb("");
633     auto out_shape = origin_root->shape;
634     auto out_shape_tensornode = gb.Tensor(out_shape);
635     (*para_to_ref)['C'] = out_shape_tensornode;
636     (void)para_to_ref->erase('B');
637     return para_to_ref;
638   }
639 
640  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const641   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
642     auto input_shape = origin_root->input(0)->shape;
643     if (IsDynamicRank(input_shape)) {
644       MS_LOG(DEBUG) << "Skip dynamic rank case";
645       return false;
646     }
647     auto perm_tensornode = origin_root->input(1)->As<inner::ConstTensorNode>();
648     MS_EXCEPTION_IF_NULL(perm_tensornode);
649     auto perm = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm_tensornode->data(), "Transpose");
650     if (perm.size() != input_shape.size()) {
651       MS_LOG(DEBUG) << "The length of input shape " << input_shape << " and perm " << perm << " is not same";
652       return false;
653     }
654     auto rank = SizeToLong(input_shape.size());
655     // If the axes which have dimension size greater than 1 keep ascending order in permutation, then this transpose can
656     // be replaced by reshape
657     ShapeValueDType prev_non_one_axis = -1;
658     for (size_t i = 0; i < input_shape.size(); ++i) {
659       if (perm[i] < -rank || perm[i] >= rank) {
660         MS_LOG(DEBUG) << "perm[" << i << "] is " << perm[i] << ", which is out of range[-" << rank << ", " << rank
661                       << ")";
662         return false;
663       }
664       perm[i] = perm[i] < 0 ? (perm[i] + rank) : perm[i];
665       if (input_shape[perm[i]] != 1) {
666         if (perm[i] < prev_non_one_axis) {
667           MS_LOG(DEBUG) << "perm[" << i << "] is axis " << perm[i]
668                         << ", which is greater than the previous non-one axis " << prev_non_one_axis
669                         << ", replace failed";
670           return false;
671         }
672         prev_non_one_axis = perm[i];
673       }
674     }
675     return true;
676   }
677 
SetAttributes(const inner::NodePtr & origin_root)678   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
679     auto attrs_map = PatternTree::SetAttributes(origin_root);
680     attrs_map[this->rhs_root()] = {{"format", MakeValue(origin_root->format)}};
681     return attrs_map;
682   }
683 };
684 
685 // Reshape(Reshape(A,B),C)=Reshape(A,C)
686 class ReshapePatternTree : public PatternTree {
687  public:
ReshapePatternTree(const std::string & pattern_str)688   explicit ReshapePatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
689   ~ReshapePatternTree() override = default;
690 
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const691   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
692                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
693     MS_EXCEPTION_IF_NULL(para_to_ref);
694     inner::GraphBuilder gb("");
695     auto out_shape = origin_root->shape;
696     auto out_shape_tensornode = gb.Tensor(out_shape);
697     (*para_to_ref)['C'] = out_shape_tensornode;
698     (void)para_to_ref->erase('B');
699     return para_to_ref;
700   }
701 
702  protected:
SetAttributes(const inner::NodePtr & origin_root)703   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
704     auto attrs_map = PatternTree::SetAttributes(origin_root);
705     attrs_map[this->rhs_root()] = {{"format", MakeValue(origin_root->format)}};
706     return attrs_map;
707   }
708 };
709 
710 // Transpose(Transpose(Reshape(A,B),C),D)=Reshape(A,E), RTT is the abbreviation for Reshape + Transpose + Transpose
711 class RTTPatternTree : public PatternTree {
712  public:
RTTPatternTree(const std::string & pattern_str)713   explicit RTTPatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
714   ~RTTPatternTree() override = default;
715 
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const716   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
717                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
718     MS_EXCEPTION_IF_NULL(para_to_ref);
719     inner::GraphBuilder gb("");
720     auto out_shape = origin_root->shape;
721     auto out_shape_tensornode = gb.Tensor(out_shape);
722     (*para_to_ref)['E'] = out_shape_tensornode;
723     (void)para_to_ref->erase('B');
724     (void)para_to_ref->erase('C');
725     (void)para_to_ref->erase('D');
726     return para_to_ref;
727   }
728 
729  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const730   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
731     auto perm2_node = origin_root->input(1);
732     MS_EXCEPTION_IF_NULL(perm2_node);
733     auto transpose1_node = origin_root->input(0);
734     MS_EXCEPTION_IF_NULL(transpose1_node);
735     auto perm1_node = transpose1_node->input(1);
736     MS_EXCEPTION_IF_NULL(perm1_node);
737     auto perm2_tensornode = perm2_node->As<inner::ConstTensorNode>();
738     MS_EXCEPTION_IF_NULL(perm2_tensornode);
739     auto perm1_tensornode = perm1_node->As<inner::ConstTensorNode>();
740     MS_EXCEPTION_IF_NULL(perm1_tensornode);
741     auto perm2 = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm2_tensornode->data(), "Transpose");
742     auto perm1 = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm1_tensornode->data(), "Transpose");
743     return IsRedundantTransposePair(perm1, perm2);
744   }
SetAttributes(const inner::NodePtr & origin_root)745   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
746     auto attrs_map = PatternTree::SetAttributes(origin_root);
747     attrs_map[this->rhs_root()] = {{"format", MakeValue(origin_root->format)}};
748     return attrs_map;
749   }
750 };
751 
752 // StridedSlice(A,B,C,D)=Reshape(A,E)
753 class StridedSlicePatternTree : public PatternTree {
754  public:
StridedSlicePatternTree(const std::string & pattern_str)755   explicit StridedSlicePatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
756   ~StridedSlicePatternTree() override = default;
757 
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const758   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
759                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
760     MS_EXCEPTION_IF_NULL(para_to_ref);
761     inner::GraphBuilder gb("");
762     auto out_shape = origin_root->shape;
763     auto out_shape_tensornode = gb.Tensor(out_shape);
764     (*para_to_ref)['E'] = out_shape_tensornode;
765     (void)para_to_ref->erase('B');
766     (void)para_to_ref->erase('C');
767     (void)para_to_ref->erase('D');
768     return para_to_ref;
769   }
770 
771  protected:
GetInputVec(const inner::NodePtr & origin_root,size_t input_idx,const std::string & node_name,const std::string & input_name) const772   const ShapeVector GetInputVec(const inner::NodePtr &origin_root, size_t input_idx, const std::string &node_name,
773                                 const std::string &input_name) const {
774     auto input_node = origin_root->input(input_idx);
775     MS_EXCEPTION_IF_NULL(input_node);
776     MS_EXCEPTION_IF_CHECK_FAIL(input_node->NodeType() == inner::NType::Tensor, "input must be a Tensor");
777     auto input_tensornode = input_node->As<inner::ConstTensorNode>();
778     auto input_vec = CheckAndConvertUtils::CheckTensorIntValue(input_name, input_tensornode->data(), node_name);
779     return input_vec;
780   }
781 
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const782   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
783     auto input_node = origin_root->input(0);
784     MS_EXCEPTION_IF_NULL(input_node);
785     auto input_shape = input_node->shape;
786     const ShapeVector &begin_vec = GetInputVec(origin_root, 1, "StridedSlice", "begin");
787     if (std::any_of(begin_vec.begin(), begin_vec.end(), [](ShapeValueDType i) { return i != 0; })) {
788       return false;
789     }
790     const ShapeVector &end_vec = GetInputVec(origin_root, 2, "StridedSlice", "end");
791     for (size_t i = 0; i < end_vec.size(); i++) {
792       if (end_vec[i] != input_shape[i]) {
793         return false;
794       }
795     }
796     const ShapeVector &strides_vec = GetInputVec(origin_root, 3, "StridedSlice", "strideds");
797     if (std::any_of(strides_vec.begin(), strides_vec.end(), [](ShapeValueDType i) { return i != 1; })) {
798       return false;
799     }
800     auto begin_mask = GetValue<int64_t>(origin_root->attrs().find("begin_mask")->second);
801     auto end_mask = GetValue<int64_t>(origin_root->attrs().find("end_mask")->second);
802     auto ellipsis_mask = GetValue<int64_t>(origin_root->attrs().find("ellipsis_mask")->second);
803     if (begin_mask != 0 || end_mask != 0 || ellipsis_mask != 0) {
804       return false;
805     }
806     auto shrink_axis_mask = LongToSize(GetValue<int64_t>(origin_root->attrs().find("shrink_axis_mask")->second));
807     for (size_t i = 0; i < input_shape.size(); i++) {
808       if (((shrink_axis_mask >> i) & 1) != 0 && input_shape[i] != 1) {
809         return false;
810       }
811     }
812     return true;
813   }
814 };
815 
816 // Transpose(Transpose(A,B),C)=Transpose(A,D)
817 class TransposeCombinePatternTree : public PatternTree {
818  public:
TransposeCombinePatternTree(const std::string & pattern_str)819   explicit TransposeCombinePatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
820   ~TransposeCombinePatternTree() override = default;
821 
UpdateParameters(const inner::NodePtr & origin_root,const std::shared_ptr<ParaMap> & para_to_ref) const822   std::shared_ptr<ParaMap> UpdateParameters(const inner::NodePtr &origin_root,
823                                             const std::shared_ptr<ParaMap> &para_to_ref) const override {
824     /* %0 = Transpose(p, (1, 0, 2))
825      * %1 = Transpose(%0, (0, 2, 1))
826      * --->
827      * %0 = Transpose(p, (1, 2, 0))
828      */
829     MS_EXCEPTION_IF_NULL(para_to_ref);
830     auto perm1_node = (*para_to_ref)['B']->As<inner::ConstTensorNode>();
831     MS_EXCEPTION_IF_NULL(perm1_node);
832     auto perm2_node = (*para_to_ref)['C']->As<inner::ConstTensorNode>();
833     MS_EXCEPTION_IF_NULL(perm1_node);
834     auto perm1 = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm1_node->data(), "Transpose");
835     auto perm2 = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm2_node->data(), "Transpose");
836     auto rank = SizeToLong(origin_root->shape.size());
837     (void)std::for_each(perm1.begin(), perm1.end(), [rank](auto &axis) { axis = axis < 0 ? axis + rank : axis; });
838     (void)std::for_each(perm2.begin(), perm2.end(), [rank](auto &axis) { axis = axis < 0 ? axis + rank : axis; });
839     ShapeVector new_perm(perm2.size());
840     for (size_t i = 0; i < perm2.size(); ++i) {
841       new_perm[i] = perm1[LongToSize(perm2[i])];
842     }
843     inner::GraphBuilder gb("");
844     (*para_to_ref)['D'] = gb.Tensor(new_perm);
845     (void)para_to_ref->erase('B');
846     (void)para_to_ref->erase('C');
847     return para_to_ref;
848   }
849 
850  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const851   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
852     auto perm2 = origin_root->input(1);
853     MS_EXCEPTION_IF_NULL(perm2);
854     auto trans1 = origin_root->input(0);
855     MS_EXCEPTION_IF_NULL(trans1);
856     auto perm1 = trans1->input(1);
857     MS_EXCEPTION_IF_NULL(perm1);
858     auto perm2_tensor = perm2->As<inner::ConstTensorNode>();
859     MS_EXCEPTION_IF_NULL(perm2_tensor);
860     auto perm1_tensor = perm1->As<inner::ConstTensorNode>();
861     MS_EXCEPTION_IF_NULL(perm1_tensor);
862     auto perm1_value = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm2_tensor->data(), "Transpose");
863     auto perm2_value = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm1_tensor->data(), "Transpose");
864     auto shape = origin_root->shape;
865     if (perm2_value.size() != shape.size() || perm1_value.size() != shape.size()) {
866       MS_LOG(DEBUG) << "perm1, perm2 and shape have different size. perm1: " << perm2_value << " perm2: " << perm1_value
867                     << " node shape: " << shape;
868       return false;
869     }
870     return true;
871   }
872 
SetAttributes(const inner::NodePtr & origin_root)873   mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
874     auto attrs_map = PatternTree::SetAttributes(origin_root);
875     attrs_map[this->rhs_root()] = {{kAttrDstFormat, MakeValue(origin_root->format)}};
876     return attrs_map;
877   }
878 };
879 
880 class FloatCheckPatternTree : public PatternTree {
881  public:
FloatCheckPatternTree(const std::string & pattern_str)882   explicit FloatCheckPatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
883   ~FloatCheckPatternTree() override = default;
884 
885  protected:
CheckInputsAndAttrs(const inner::NodePtr & origin_root) const886   bool CheckInputsAndAttrs(const inner::NodePtr &origin_root) const override {
887     auto type_id = origin_root->type;
888     return (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 ||
889             type_id == kNumberTypeFloat64);
890   }
891 };
892 
893 /*       A
894         /
895        Neg
896        /  \
897     Neg     Mul
898  Here we cannot transform Neg(Neg(A)) to A because Neg(A) is a input of Mul. OutsideRely is responsible for checking
899  this case.
900  */
OutsideRely(const inner::NodePtrList & nodes,const inner::NodePtr & root)901 bool OutsideRely(const inner::NodePtrList &nodes, const inner::NodePtr &root) {
902   mindspore::HashSet<inner::Node *> nodes_can_simplify;
903   (void)std::for_each(nodes.begin(), nodes.end(),
904                       [&nodes_can_simplify](auto n) { (void)nodes_can_simplify.insert(n.get()); });
905   for (auto &n : nodes) {
906     if (n == root) {
907       continue;
908     }
909     for (auto &usr : n->users()) {
910       if (nodes_can_simplify.find(usr.first) == nodes_can_simplify.end()) {
911         return true;
912       }
913     }
914   }
915   return false;
916 }
917 
918 struct Expression {
919   size_t id;
920   std::string math_expr;
921   std::function<PatternTreePtr(const std::string &)> func;
922 };
923 
924 #define EXPR_PATTERN(cls) [](const std::string &expr) -> PatternTreePtr { return std::make_shared<cls>(expr); }
925 
926 static std::vector<Expression> expressions = {
927   // add
928   {1, "Add(A,0)=A", EXPR_PATTERN(PatternTree)},
929   {2, "Add(Mul(A,C),Mul(A,B))=Mul(A,Add(B,C))", EXPR_PATTERN(PatternTree)},
930   {3, "Add(Add(A,const1),const2)=Add(A,Add(const1,const2))", EXPR_PATTERN(PatternTree)},
931   {4, "Add(A,Neg(A))=0", EXPR_PATTERN(PatternTree)},
932   {5, "Add(Add(A,B),Neg(A))=B", EXPR_PATTERN(PatternTree)},
933   {6, "Add(Add(A,B),Add(Neg(A),C))=Add(B,C)", EXPR_PATTERN(PatternTree)},
934   // sub
935   {7, "Sub(A,0)=A", EXPR_PATTERN(PatternTree)},
936   {8, "Sub(A,const1)=Add(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
937   {9, "Sub(Mul(A,C),Mul(A,B))=Mul(A,Sub(B,C))", EXPR_PATTERN(PatternTree)},
938   {10, "Sub(Mul(A,C),Mul(B,C))=Mul(Sub(A,B),C)", EXPR_PATTERN(PatternTree)},
939   // log
940   {11, "Log(Exp(A))=A", EXPR_PATTERN(PatternTree)},
941   {12, "Log(Pow(A,B))=Mul(B,Log(Abs(A)))", EXPR_PATTERN(PatternTree)},
942   {13, "Log(Sqrt(A))=Mul(0.5,Log(A))", EXPR_PATTERN(PatternTree)},
943   {14, "Log(Rsqrt(A))=Mul(-0.5,Log(A))", EXPR_PATTERN(PatternTree)},
944   // pow
945   {15, "Pow(A,1)=A", EXPR_PATTERN(PatternTree)},
946   {16, "Pow(Exp(A),B)=Exp(Mul(A,B))", EXPR_PATTERN(PatternTree)},
947   {17, "Pow(A,2)=Mul(A,A)", EXPR_PATTERN(PatternTree)},
948   {18, "Pow(A,-1)=Reciprocal(A)", EXPR_PATTERN(PatternTree)},
949   // sqrt
950   {19, "Sqrt(Mul(A,A))=Abs(A)", EXPR_PATTERN(PatternTree)},
951   {20, "Rsqrt(Pow(A,-2))=Abs(A)", EXPR_PATTERN(PatternTree)},
952   {21, "Rsqrt(RealDiv(1,A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
953   {22, "Rsqrt(Reciprocal(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
954   // select
955   {23, "Select(A,B,B)=B", EXPR_PATTERN(PatternTree)},
956   // Neg
957   {24, "Neg(Neg(A))=A", EXPR_PATTERN(PatternTree)},
958   // mul
959   {25, "Mul(Mul(A,const1),Mul(B,const2))=Mul(Mul(A,B),Mul(const1,const2))", EXPR_PATTERN(PatternTree)},
960   {26, "Mul(Mul(A,const1),const2)=Mul(A,Mul(const1,const2))", EXPR_PATTERN(PatternTree)},
961   {27, "Mul(Exp(A),Exp(B))=Exp(Add(A,B))", EXPR_PATTERN(PatternTree)},
962   {28, "Mul(Mul(Exp(A),C),Exp(B))=Mul(Exp(Add(A,B)),C)", EXPR_PATTERN(PatternTree)},
963   {29, "Mul(Mul(Exp(A),C),Mul(Exp(B),D))=Mul(Exp(Add(A,B)),Mul(C,D))", EXPR_PATTERN(PatternTree)},
964   {30, "Mul(Sqrt(A),Sqrt(A))=A", EXPR_PATTERN(PatternTree)},
965   {31, "Mul(Mul(A,Sqrt(B)),Mul(C,Sqrt(B)))=Mul(Mul(A,B),C)", EXPR_PATTERN(PatternTree)},
966   {32, "Mul(Mul(A,Sqrt(B)),Sqrt(B))=Mul(A,B)", EXPR_PATTERN(PatternTree)},
967   {33, "Mul(Sqrt(A),Sqrt(B))=Sqrt(Mul(A,B))", EXPR_PATTERN(PatternTree)},
968   {34, "Mul(Rsqrt(A),Rsqrt(A))=Reciprocal(A)", EXPR_PATTERN(PatternTree)},
969   {35, "Mul(Mul(A,Rsqrt(B)),Rsqrt(B))=RealDiv(A,B)", EXPR_PATTERN(PatternTree)},
970   {36, "Mul(Mul(A,Rsqrt(B)),Mul(C,Rsqrt(B)))=RealDiv(Mul(A,C),B)", EXPR_PATTERN(PatternTree)},
971   {37, "Mul(Rsqrt(A),Rsqrt(B))=Rsqrt(Mul(A,B))", EXPR_PATTERN(PatternTree)},
972   {38, "Mul(A,Rsqrt(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
973   {39, "Mul(Abs(A),Abs(B))=Abs(Mul(A,B))", EXPR_PATTERN(PatternTree)},
974   {40, "Mul(Mul(Abs(A),C),Abs(B))=Mul(Abs(Mul(A,B)),C)", EXPR_PATTERN(PatternTree)},
975   {41, "Mul(Mul(Abs(A),C),Mul(Abs(B),D))=Mul(Abs(Mul(A,B)),Mul(C,D))", EXPR_PATTERN(PatternTree)},
976   {42, "Mul(Neg(A),const1)=Mul(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
977   // realdiv
978   {43, "RealDiv(A,1)=A", EXPR_PATTERN(PatternTree)},
979   {44, "RealDiv(Exp(A),Exp(B))=Exp(Sub(A,B))", EXPR_PATTERN(PatternTree)},
980   {45, "RealDiv(A,Exp(B))=Mul(A,Exp(Neg(B)))", EXPR_PATTERN(PatternTree)},
981   {46, "RealDiv(A,Pow(B,const1))=Mul(A,Pow(B,Neg(const1)))", EXPR_PATTERN(PatternTree)},
982   {47, "RealDiv(A,Sqrt(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
983   {48, "RealDiv(A,Sqrt(B))=Mul(A,Rsqrt(B))", EXPR_PATTERN(PatternTree)},
984   {49, "RealDiv(A,Rsqrt(B))=Mul(A,Sqrt(B))", EXPR_PATTERN(PatternTree)},
985   {50, "RealDiv(A,const1)=Mul(A,Reciprocal(const1))", EXPR_PATTERN(FloatCheckPatternTree)},
986   {51, "RealDiv(RealDiv(A,B),RealDiv(C,D))=RealDiv(Mul(A,D),Mul(B,C))", EXPR_PATTERN(PatternTree)},
987   {52, "RealDiv(Neg(A),const1)=RealDiv(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
988   {53, "RealDiv(RealDiv(A,B),C)=RealDiv(A,Mul(B,C))", EXPR_PATTERN(PatternTree)},
989   {54, "RealDiv(A,RealDiv(B,C))=RealDiv(Mul(A,C),B)", EXPR_PATTERN(PatternTree)},
990   // reduce1, B, C, D are all axes input
991   {55, "ReduceSum(ReduceSum(A,B),C)=ReduceSum(A,D)", EXPR_PATTERN(ExtraReduce1PatternTree)},
992   {56, "ReduceMin(ReduceMin(A,B),C)=ReduceMin(A,D)", EXPR_PATTERN(ExtraReduce1PatternTree)},
993   {57, "ReduceMax(ReduceMax(A,B),C)=ReduceMax(A,D)", EXPR_PATTERN(ExtraReduce1PatternTree)},
994   // reduce2, B is axes input
995   {58, "ReduceSum(Neg(A),B)=Neg(ReduceSum(A,B))", EXPR_PATTERN(ExtraReduce2PatternTree)},
996   {59, "ReduceSum(RealDiv(A,const1),B)=RealDiv(ReduceSum(A,B),const1)", EXPR_PATTERN(ExtraReduce2PatternTree)},
997   {60, "ReduceSum(Mul(A,const1),B)=Mul(ReduceSum(A,B),const1)", EXPR_PATTERN(ExtraReduce2PatternTree)},
998   {61, "CReal(Complex(A,B))=A", EXPR_PATTERN(PatternTree)},
999   {62, "CImag(Complex(A,B))=B", EXPR_PATTERN(PatternTree)},
1000   // lite only
1001   {63, "LayoutTransform(LayoutTransform(A))=A", EXPR_PATTERN(LayoutTransform1PatternTree)},
1002   {64, "LayoutTransform(LayoutTransform(A))=LayoutTransform(A)", EXPR_PATTERN(LayoutTransform2PatternTree)},
1003   // patterns that can be transformed to reshape
1004   {65, "Transpose(Transpose(A,B),C)=A", EXPR_PATTERN(Transpose1PatternTree)},
1005   {66, "Transpose(A,B)=Reshape(A,C)", EXPR_PATTERN(Transpose2PatternTree)},
1006   {67, "Reshape(Reshape(A,B),C)=Reshape(A,C)", EXPR_PATTERN(ReshapePatternTree)},
1007   {68, "Transpose(Transpose(Reshape(A,B),C),D)=Reshape(A,E)", EXPR_PATTERN(RTTPatternTree)},
1008   {69, "StridedSlice(A,B,C,D)=Reshape(A,E)", EXPR_PATTERN(StridedSlicePatternTree)},
1009   // cmp + logical
1010   {70, "LogicalNot(Greater(A,B))=LessEqual(A,B)", EXPR_PATTERN(PatternTree)},
1011   {71, "LogicalNot(LessEqual(A,B))=Greater(A,B)", EXPR_PATTERN(PatternTree)},
1012   {72, "LogicalNot(GreaterEqual(A,B))=Less(A,B)", EXPR_PATTERN(PatternTree)},
1013   {73, "LogicalNot(Less(A,B))=GreaterEqual(A,B)", EXPR_PATTERN(PatternTree)},
1014   {74, "LogicalNot(NotEqual(A,B))=Equal(A,B)", EXPR_PATTERN(PatternTree)},
1015   {75, "LogicalNot(Equal(A,B))=NotEqual(A,B)", EXPR_PATTERN(PatternTree)},
1016   // reduce -> reshape
1017   {76, "ReduceSum(A,B)=Reshape(A,C)", EXPR_PATTERN(ReducePatternTree)},
1018   {77, "ReduceMin(A,B)=Reshape(A,C)", EXPR_PATTERN(ReducePatternTree)},
1019   {78, "ReduceMax(A,B)=Reshape(A,C)", EXPR_PATTERN(ReducePatternTree)},
1020   {79, "Cast(A,B)=A", EXPR_PATTERN(CastPatternTree)},
1021   // transpose
1022   {80, "Transpose(Transpose(A,B),C)=Transpose(A,D)", EXPR_PATTERN(TransposeCombinePatternTree)},
1023 };
1024 
GetExpressions()1025 mindspore::HashMap<std::string, std::vector<PatternTreePtr>> GetExpressions() {
1026   const auto &flags = GraphKernelFlags::GetInstance();
1027   mindspore::HashMap<std::string, std::vector<PatternTreePtr>> expression_map;
1028   mindspore::HashSet<std::string> enable_ids{flags.enable_simplify_exprs_only.begin(),
1029                                              flags.enable_simplify_exprs_only.end()};
1030   mindspore::HashSet<std::string> disable_ids{flags.disable_simplify_exprs.begin(), flags.disable_simplify_exprs.end()};
1031   for (auto &e : expressions) {
1032     if (!enable_ids.empty()) {
1033       if (enable_ids.count(std::to_string(e.id)) == 0) {
1034         continue;
1035       }
1036     } else {
1037       if (disable_ids.count(std::to_string(e.id)) > 0) {
1038         continue;
1039       }
1040     }
1041     PatternTreePtr pt = e.func(e.math_expr);
1042     expression_map[pt->GetRootOp()].push_back(pt);
1043   }
1044   return expression_map;
1045 }
1046 
1047 // arithmetic simplify
DoArithmeticTrans(const inner::LiteGraphPtr & litegraph)1048 bool ArithmeticSimplify::DoArithmeticTrans(const inner::LiteGraphPtr &litegraph) {
1049   auto ops_list = litegraph->ops();
1050   bool changed = false;
1051   inner::NodePtrList matched_nodes;
1052   auto para_to_ref = std::make_shared<ParaMap>();    // A(B,C ...)->Node* mapping
1053   auto const_to_ref = std::make_shared<ConstMap>();  // const->Node* mapping
1054   PatternTreePtr cur_pattern;
1055   auto iter = ops_list.rbegin();
1056   while (iter != ops_list.rend()) {
1057     bool can_simplify = false;
1058     auto this_op = std::static_pointer_cast<inner::PrimOp>(*iter)->op();
1059     if (expressions_map_.find(this_op) != expressions_map_.end()) {
1060       for (auto p : expressions_map_[this_op]) {
1061         cur_pattern = p;
1062         if (!para_to_ref->empty()) {
1063           para_to_ref->clear();
1064         }
1065         if (!const_to_ref->empty()) {
1066           const_to_ref->clear();
1067         }
1068         // match a pattern;if return is empty,then fails to match
1069         matched_nodes = p->MatchGraph(*iter, para_to_ref, const_to_ref);
1070         if (!matched_nodes.empty()) {
1071           auto right_root_type = PatternNodeType(p->rhs_root()->op());
1072           if (right_root_type == inner::NType::Primitive && OutsideRely(matched_nodes, *iter)) {
1073             continue;
1074           }
1075           // if no outside rely,then this is a successful match
1076           can_simplify = true;
1077           para_to_ref = cur_pattern->UpdateParameters(*iter, para_to_ref);
1078           // get the new node to replace
1079           inner::NodePtr alter_graph_node = cur_pattern->AlterGraph(para_to_ref, const_to_ref, *iter);
1080           (*iter)->ReplaceWith(alter_graph_node);
1081           changed = true;
1082           break;
1083         }
1084       }
1085     }
1086     if (!can_simplify) {
1087       ++iter;
1088     } else {
1089       break;
1090     }
1091   }
1092   return changed;
1093 }
1094 
1095 // constant fold
DoConstantFold(const inner::LiteGraphPtr & litegraph)1096 bool ArithmeticSimplify::DoConstantFold(const inner::LiteGraphPtr &litegraph) {
1097   auto ops_list = litegraph->GetOrderedNodes();
1098   bool changed = false;
1099   auto iter = ops_list.begin();
1100   while (iter != ops_list.end()) {
1101     auto this_op = std::static_pointer_cast<inner::PrimOp>(*iter);
1102     auto value = this_op->InferValue(this_op->inputs(), this_op->attrs());
1103     if (value != nullptr) {
1104       (*iter)->ReplaceWith(value);
1105       ops_list = litegraph->GetOrderedNodes();
1106       iter = ops_list.begin();
1107       changed = true;
1108     } else {
1109       ++iter;
1110     }
1111   }
1112   return changed;
1113 }
1114 
ResetOutputs(const inner::LiteGraphPtr & litegraph)1115 bool ResetOutputs(const inner::LiteGraphPtr &litegraph) {
1116   /** If after arithmetic transformation and constant folding, an output of subgraph is just a Tensor or Parameter,
1117    * insert Reshape/BroadcastTo and reset the output to this op.
1118    */
1119   auto &outputs = litegraph->GetOutputs();
1120   for (size_t i = 0; i < outputs.size(); i++) {
1121     MS_EXCEPTION_IF_NULL(outputs[i]);
1122     auto out_shape = outputs[i]->shape;
1123     if (outputs[i]->NodeType() == inner::NType::Tensor) {
1124       if (IsDynamic(out_shape)) {
1125         return false;
1126       }
1127       inner::GraphBuilder gb;
1128       auto output_shape = outputs[i]->As<inner::ConstTensorNode>()->data()->shape();
1129       auto op_ptr = gb.BroadcastTo(outputs[i], output_shape);
1130       litegraph->SetOutput(i, op_ptr);
1131     } else if (outputs[i]->NodeType() == inner::NType::Parameter) {
1132       if (IsDynamicRank(out_shape) ||
1133           std::count_if(out_shape.begin(), out_shape.end(), [](int64_t sh) { return sh < 0; }) > 1) {
1134         return false;
1135       }
1136       inner::GraphBuilder gb;
1137       auto op_ptr = gb.Reshape(outputs[i], out_shape);
1138       litegraph->SetOutput(i, op_ptr);
1139     }
1140   }
1141   return true;
1142 }
1143 
Run(const FuncGraphPtr & func_graph)1144 bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
1145   auto mng = func_graph->manager();
1146   MS_EXCEPTION_IF_NULL(mng);
1147   bool do_simplify = false;
1148   expressions_map_ = GetExpressions();
1149   for (auto node : func_graph->GetOrderedCnodes()) {
1150     if (AnfUtils::IsGraphKernel(node)) {
1151       auto sub_graph = GetCNodeFuncGraph(node);
1152       if (auto type = sub_graph->get_attr("composite_type")) {
1153         if (GetValue<std::string>(type) == "inplace_assign_builder") {
1154           continue;
1155         }
1156       }
1157       auto cnode = node->cast<CNodePtr>();
1158       AnfNodePtrList inputs = cnode->inputs();
1159       inner::LiteGraphPtr lg = GkUtils::AnfGraph2LiteGraph(sub_graph);
1160       bool find_pattern = true;
1161       bool change_anf_graph = false;
1162       try {
1163         MS_LOG_TRY_CATCH_SCOPE;
1164         while (find_pattern) {
1165           find_pattern = false;
1166           find_pattern = DoConstantFold(lg) || find_pattern;
1167           find_pattern = DoArithmeticTrans(lg) || find_pattern;
1168           change_anf_graph = change_anf_graph || find_pattern;
1169         }
1170       } catch (const std::exception &e) {
1171         MS_LOG(INFO) << "During arithmetic simplify for node [" << node->fullname_with_scope()
1172                      << "], an error occurs: " << e.what();
1173         continue;
1174       }
1175       AnfNodePtrList input_nodes{inputs.begin() + 1, inputs.end()};
1176       if (!change_anf_graph) {
1177         continue;
1178       }
1179       if (!ResetOutputs(lg)) {
1180         continue;
1181       }
1182       auto new_funcgraph = GkUtils::LiteGraph2AnfGraph(lg, Callback::Instance());
1183       if (new_funcgraph == nullptr) {
1184         continue;
1185       }
1186       (void)ConvertTensorToParameter(new_funcgraph, &input_nodes);
1187       new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
1188       auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, input_nodes);
1189       (void)mng->Replace(node, new_node);
1190       mng->AddFuncGraph(new_funcgraph);
1191       do_simplify = true;
1192     }
1193   }
1194   return do_simplify;
1195 }
1196 }  // namespace mindspore::graphkernel
1197