• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <string>
21 #include <unordered_set>
22 #include <functional>
23 #include <set>
24 #include <vector>
25 #include <unordered_map>
26 
27 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "ir/anf.h"
30 #include "utils/context/graph_kernel_flags.h"
31 
32 namespace mindspore {
33 namespace opt {
34 // operator which follows commutative rules
35 static std::unordered_set<std::string> commutative_ops{"Add", "Mul"};
36 
37 class PatternNode;
38 using PatternNodePtr = std::shared_ptr<PatternNode>;
39 using PatternNodePtrList = std::vector<PatternNodePtr>;
40 
41 class PatternNode {
42  public:
PatternNode(const std::string & op)43   explicit PatternNode(const std::string &op) : op_(op) {}
44   ~PatternNode() = default;
op() const45   std::string op() const { return op_; }
inputs() const46   std::vector<PatternNodePtr> inputs() const { return inputs_; }
AddInput(const PatternNodePtr & input)47   void AddInput(const PatternNodePtr &input) { inputs_.push_back(input); }
48 
49  private:
50   std::string op_ = "";  // ex. "Add","const1","A","0.5" (any op, const or parameter)
51   std::vector<PatternNodePtr> inputs_;
52 };
53 
54 using ParaMap = std::unordered_map<char, graphkernel::NodePtr>;
55 using ConstMap = std::unordered_map<std::string, graphkernel::NodePtr>;
56 
57 /* This class works to store a kind of pattern tree; it needs a string expression to construct;
58  Ex."Pow(Exp(A),B)=Exp(Mul(A,B))"
59  then the left tree is
60           A                             A    B
61            \                             \  /
62             Exp    B                       Mul
63              \   /                           \
64  left tree:   Pow         right tree:         Exp
65  lhs_root_ is Pow ;lhs_root_ is Exp */
66 class PatternTree {
67  public:
68   // pattern_str->ex."Pow(Exp(A),B)=Exp(Mul(A,B))"
PatternTree(const std::string & pattern_str)69   explicit PatternTree(const std::string &pattern_str) { BuildTree(pattern_str); }
70   virtual ~PatternTree() = default;
71 
lhs_root()72   PatternNodePtr lhs_root() { return lhs_root_; }
rhs_root()73   PatternNodePtr rhs_root() { return rhs_root_; }
GetRootOp() const74   std::string GetRootOp() const { return lhs_root_ == nullptr ? "" : lhs_root_->op(); }
75   // build tree with expression string
76   PatternNodePtr BuildTree(const std::string &pattern_str);
77   // traverse pattern tree, return order is topological order
78   void DfsTraverse(const std::shared_ptr<PatternNodePtrList> &res, const PatternNodePtr &cur) const;
79   // leverage pattern tree node and lite node's mapping relation to build lite node graph from pattern tree's right
80   // side
81   graphkernel::NodePtr AlterGraph(const std::shared_ptr<ParaMap> &para_to_ref,
82                                   const std::shared_ptr<ConstMap> &const_to_ref,
83                                   const graphkernel::NodePtr &origin_root);
84   // invoke DfsMatchGraph
85   graphkernel::NodePtrList MatchGraph(const graphkernel::NodePtr &root, const std::shared_ptr<ParaMap> &para_to_ref,
86                                       const std::shared_ptr<ConstMap> &const_to_ref);
87 
88  protected:
89   // set attributes for certain pattern node if needed;
SetAttributes(const graphkernel::NodePtr &)90   virtual std::unordered_map<PatternNodePtr, graphkernel::DAttrs> SetAttributes(const graphkernel::NodePtr &) {
91     auto right_pattern = std::make_shared<PatternNodePtrList>();
92     DfsTraverse(right_pattern, rhs_root_);
93     std::unordered_map<PatternNodePtr, graphkernel::DAttrs> attrs_map;
94     for (auto &i : (*right_pattern)) {
95       attrs_map[i] = {};
96     }
97     return attrs_map;
98   }
99   // check attributes meet requirements for certain pattern node if needed;
CheckAttributes(const graphkernel::NodePtr & origin_root) const100   virtual bool CheckAttributes(const graphkernel::NodePtr &origin_root) const { return true; }
101 
102  private:
103   PatternNodePtr lhs_root_ = nullptr;  // left side's root
104   PatternNodePtr rhs_root_ = nullptr;  // right side's root
105 };
106 
CutStr(const string & s,size_t start_pos=0,size_t len=std::string::npos)107 std::string CutStr(const string &s, size_t start_pos = 0, size_t len = std::string::npos) {
108   std::string new_str = "";
109   if (start_pos >= s.length()) {
110     MS_LOG(EXCEPTION) << "Cut is illegal.";
111   }
112   for (size_t i = 0; i < len; i++) {
113     if (start_pos + i >= s.length()) break;
114     new_str += s[start_pos + i];
115   }
116   return new_str;
117 }
118 
StartWith(const std::string & s,const std::string & prefix)119 bool StartWith(const std::string &s, const std::string &prefix) {
120   if (s.length() < prefix.length()) return false;
121   return s.find(prefix) == 0;
122 }
123 
124 // build pattern tree ;left side's root is lhs_root_ ; right side's root is rhs_root_
BuildTree(const std::string & pattern_str)125 PatternNodePtr PatternTree::BuildTree(const std::string &pattern_str) {
126   size_t pos = pattern_str.find("=");
127   if (pos != std::string::npos) {
128     auto left_expression = CutStr(pattern_str, 0, pos);
129     lhs_root_ = BuildTree(left_expression);
130     auto right_expression = CutStr(pattern_str, pos + 1);
131     rhs_root_ = BuildTree(right_expression);
132   } else {
133     size_t p_start = pattern_str.find("(");
134     if (p_start != std::string::npos) {
135       size_t p_end = pattern_str.rfind(")");
136       auto op_name = CutStr(pattern_str, 0, p_start);
137       auto op_inputs = CutStr(pattern_str, p_start + 1, p_end - p_start - 1);
138       PatternNodePtr cur_node = std::make_shared<PatternNode>(op_name);
139       int tmp = 0;
140       size_t comma = 0;
141       while (comma < op_inputs.length()) {
142         if (op_inputs[comma] == '(') {
143           tmp++;
144         }
145         if (op_inputs[comma] == ')') {
146           tmp--;
147         }
148         if (op_inputs[comma] == ',' && tmp == 0) {
149           auto first_half = CutStr(op_inputs, 0, comma);
150           cur_node->AddInput(BuildTree(first_half));
151           auto second_half = CutStr(op_inputs, comma + 1);
152           op_inputs = second_half;
153           comma = 0;
154         } else {
155           comma++;
156         }
157       }
158       cur_node->AddInput(BuildTree(op_inputs));
159       return cur_node;
160     } else {
161       return std::make_shared<PatternNode>(pattern_str);
162     }
163   }
164   return nullptr;
165 }
166 
PatternNodeType(const std::string & n)167 graphkernel::NType PatternNodeType(const std::string &n) {
168   // return (Primitive, Parameter or Value)
169   if (n.length() > 0 && '0' <= n[n.length() - 1] && n[n.length() - 1] <= '9') {
170     return graphkernel::NType::Value;
171   } else if (n.length() == 1 && 'A' <= n[0] && n[0] <= 'Z') {
172     return graphkernel::NType::Parameter;
173   } else {
174     return graphkernel::NType::Primitive;
175   }
176 }
177 
CleanStr(const std::string & s)178 std::string CleanStr(const std::string &s) {
179   std::string res = "";
180   std::for_each(s.begin(), s.end(), [&res](const char &c) {
181     if (c != '[' && c != ']' && c != ' ') {
182       res += c;
183     }
184   });
185   return res;
186 }
187 
CheckCurNode(const graphkernel::NodePtr & tmp_node,const std::string & tmp_pattern_op,const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref)188 bool CheckCurNode(const graphkernel::NodePtr &tmp_node, const std::string &tmp_pattern_op,
189                   const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref) {
190   // put lite graph node's mapping to pattern node into "para_to_ref" and "const_to_ref"
191   switch (PatternNodeType(tmp_pattern_op)) {
192     case graphkernel::NType::Parameter: {
193       if (para_to_ref->find(tmp_pattern_op[0]) != para_to_ref->end()) {
194         if ((*para_to_ref)[tmp_pattern_op[0]] != tmp_node) {
195           return false;
196         }
197       } else {
198         (*para_to_ref)[tmp_pattern_op[0]] = tmp_node;
199       }
200       break;
201     }
202     case graphkernel::NType::Value: {
203       if (tmp_node->NodeType() != graphkernel::NType::Value) {
204         return false;
205       }
206       auto node_value_str = std::static_pointer_cast<graphkernel::ConstTensorNode>(tmp_node)->ToString();
207       double node_value = std::stod(CleanStr(node_value_str));
208       if (StartWith(tmp_pattern_op, "const")) {
209         if (const_to_ref->find(tmp_pattern_op) != const_to_ref->end()) {
210           auto pattern_value_str =
211             std::static_pointer_cast<graphkernel::ConstTensorNode>((*const_to_ref)[tmp_pattern_op])->ToString();
212           double pattern_value = std::stod(CleanStr(pattern_value_str));
213           if (pattern_value != node_value) return false;
214         } else {
215           (*const_to_ref)[tmp_pattern_op] = tmp_node;
216         }
217       } else {
218         double pattern_value = std::stod(tmp_pattern_op);
219         if (pattern_value != node_value) {
220           return false;
221         }
222       }
223       break;
224     }
225     case graphkernel::NType::Primitive: {
226       if (tmp_node->NodeType() != graphkernel::NType::Primitive ||
227           std::static_pointer_cast<graphkernel::PrimOp>(tmp_node)->op() != tmp_pattern_op) {
228         return false;
229       }
230       break;
231     }
232     default:
233       break;
234   }
235   return true;
236 }
237 
238 // recursion for thr match of lite node graph and pattern tree's left side, store the mapping of pattern tree node to
239 // lite graph
DfsMatchGraph(const graphkernel::NodePtr & tmp_node,const PatternNodePtr & tmp_pattern,const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref,const std::shared_ptr<graphkernel::NodePtrList> & res)240 bool DfsMatchGraph(const graphkernel::NodePtr &tmp_node, const PatternNodePtr &tmp_pattern,
241                    const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref,
242                    const std::shared_ptr<graphkernel::NodePtrList> &res) {
243   std::string tmp_pattern_op = tmp_pattern->op();
244   if (!CheckCurNode(tmp_node, tmp_pattern_op, para_to_ref, const_to_ref)) {
245     return false;
246   }
247   std::vector<PatternNodePtr> tmp_pattern_inputs = tmp_pattern->inputs();
248   auto tmp_node_inputs = tmp_node->inputs();
249   // check if a node meets requiremnet,and DFS check its inputs
250   if (tmp_pattern_inputs.size() != 0 && tmp_node_inputs.size() != tmp_pattern_inputs.size()) {
251     return false;
252   }
253   if (PatternNodeType(tmp_pattern_op) == graphkernel::NType::Primitive) {
254     // exchange inputs for the node who meets commutative rules
255     if (commutative_ops.find(tmp_pattern_op) != commutative_ops.end()) {
256       ParaMap para_to_ref_copy = *para_to_ref;
257       ConstMap const_to_ref_copy = *const_to_ref;
258       bool first_match = DfsMatchGraph(tmp_node_inputs[0], tmp_pattern_inputs[0], para_to_ref, const_to_ref, res) &&
259                          DfsMatchGraph(tmp_node_inputs[1], tmp_pattern_inputs[1], para_to_ref, const_to_ref, res);
260       if (!first_match) {
261         res->clear();
262         para_to_ref->clear();
263         const_to_ref->clear();
264         for (auto &i : para_to_ref_copy) {
265           (*para_to_ref)[i.first] = i.second;
266         }
267         for (auto &i : const_to_ref_copy) {
268           (*const_to_ref)[i.first] = i.second;
269         }
270         bool second_match = DfsMatchGraph(tmp_node_inputs[0], tmp_pattern_inputs[1], para_to_ref, const_to_ref, res) &&
271                             DfsMatchGraph(tmp_node_inputs[1], tmp_pattern_inputs[0], para_to_ref, const_to_ref, res);
272         if (!second_match) {
273           return false;
274         }
275       }
276     } else {
277       for (size_t i = 0; i < tmp_pattern_inputs.size(); i++) {
278         if (!DfsMatchGraph(tmp_node_inputs[i], tmp_pattern_inputs[i], para_to_ref, const_to_ref, res)) {
279           return false;
280         }
281       }
282     }
283     res->push_back(tmp_node);
284   }
285   return true;
286 }
287 
288 // traverse pattern tree and return topological order
DfsTraverse(const std::shared_ptr<PatternNodePtrList> & res,const PatternNodePtr & cur) const289 void PatternTree::DfsTraverse(const std::shared_ptr<PatternNodePtrList> &res, const PatternNodePtr &cur) const {
290   if (cur == nullptr) {
291     return;
292   }
293   for (auto &p : cur->inputs()) {
294     if (PatternNodeType(p->op()) == graphkernel::NType::Primitive) {
295       DfsTraverse(res, p);
296     }
297   }
298   res->push_back(cur);
299 }
300 
301 // invoke DfsMatchGraph
MatchGraph(const graphkernel::NodePtr & root,const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref)302 graphkernel::NodePtrList PatternTree::MatchGraph(const graphkernel::NodePtr &root,
303                                                  const std::shared_ptr<ParaMap> &para_to_ref,
304                                                  const std::shared_ptr<ConstMap> &const_to_ref) {
305   auto res = std::make_shared<graphkernel::NodePtrList>();
306   if (!DfsMatchGraph(root, lhs_root_, para_to_ref, const_to_ref, res)) {
307     return {};
308   }
309   if (CheckAttributes(root)) {
310     return *res;
311   }
312   return {};
313 }
314 
315 // leverage pattern tree node and lite node's mapping relation to build new lite node graph from pattern tree's right
316 // side
AlterGraph(const std::shared_ptr<ParaMap> & para_to_ref,const std::shared_ptr<ConstMap> & const_to_ref,const graphkernel::NodePtr & origin_root)317 graphkernel::NodePtr PatternTree::AlterGraph(const std::shared_ptr<ParaMap> &para_to_ref,
318                                              const std::shared_ptr<ConstMap> &const_to_ref,
319                                              const graphkernel::NodePtr &origin_root) {
320   auto res = std::make_shared<PatternNodePtrList>();
321   DfsTraverse(res, rhs_root_);
322   auto all_attrs = SetAttributes(origin_root);
323   graphkernel::LiteGraph::GraphBuilder gb("");
324   std::unordered_map<PatternNodePtr, graphkernel::NodePtr> pattern_to_ref;
325   for (auto &n : (*res)) {
326     if (PatternNodeType(n->op()) != graphkernel::NType::Primitive) continue;
327     graphkernel::NodePtrList inputs;
328     for (auto &i : n->inputs()) {
329       if (PatternNodeType(i->op()) == graphkernel::NType::Primitive) {
330         inputs.push_back(pattern_to_ref[i]);
331       } else if (PatternNodeType(i->op()) == graphkernel::NType::Parameter) {
332         inputs.push_back((*para_to_ref)[i->op()[0]]);
333       } else {
334         if (StartWith(i->op(), "const")) {
335           inputs.push_back((*const_to_ref)[i->op()]);
336         } else {
337           tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(std::stof(i->op())));
338           inputs.push_back(gb.Value(data));
339         }
340       }
341     }
342     auto p = gb.Emit(n->op(), inputs, all_attrs[n]);
343     pattern_to_ref[n] = p;
344   }
345   auto &alter_graph = gb.Get()->ops();
346   if (alter_graph.empty()) {
347     if (PatternNodeType(rhs_root_->op()) == graphkernel::NType::Parameter) {
348       return (*para_to_ref)[rhs_root_->op()[0]];
349     } else {
350       if (StartWith(rhs_root_->op(), "const")) {
351         return (*const_to_ref)[rhs_root_->op()];
352       } else {
353         tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(std::stof(rhs_root_->op())));
354         return gb.Value(data);
355       }
356     }
357   }
358   return alter_graph.back();
359 }
360 
361 // Reduce(Reduce(A)) = Reduce(A)
362 class ExtraReduce1PatternTree : public PatternTree {
363  public:
ExtraReduce1PatternTree(const std::string & pattern_str)364   explicit ExtraReduce1PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
365   ~ExtraReduce1PatternTree() = default;
366 
367  protected:
CheckAttributes(const graphkernel::NodePtr & origin_root) const368   bool CheckAttributes(const graphkernel::NodePtr &origin_root) const override {
369     return (GetValue<bool>((origin_root->inputs()[0])->attrs().find("keep_dims")->second) ==
370             GetValue<bool>(origin_root->attrs().find("keep_dims")->second));
371   }
SetAttributes(const graphkernel::NodePtr & origin_root)372   std::unordered_map<PatternNodePtr, graphkernel::DAttrs> SetAttributes(
373     const graphkernel::NodePtr &origin_root) override {
374     auto attrs_map = PatternTree::SetAttributes(origin_root);
375     std::vector<int64_t> axis;
376     std::set<int64_t> axis_set;
377     auto first_reduce = origin_root->inputs()[0];
378     bool keep_dims = GetValue<bool>(origin_root->attrs().find("keep_dims")->second);
379     if (keep_dims) {
380       for (auto &i : GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second)) {
381         axis_set.insert(i);
382       }
383       for (auto &i : GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second)) {
384         axis_set.insert(i);
385       }
386     } else {
387       auto first_axis = GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second);
388       auto second_axis = GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second);
389       std::set<int64_t> st(first_axis.begin(), first_axis.end());
390       std::unordered_map<int64_t, int64_t> mp;
391       int64_t shift = 0;
392       for (int64_t n = 0; n < SizeToLong(first_reduce->inputs()[0]->shape.size()); n++) {
393         if (st.find(n) != st.end()) {
394           shift++;
395         } else {
396           mp[n - shift] = n;
397         }
398       }
399       std::for_each(first_axis.begin(), first_axis.end(), [&axis_set](auto &i) { axis_set.insert(i); });
400       std::for_each(second_axis.begin(), second_axis.end(), [&axis_set, &mp](auto &i) { axis_set.insert(mp[i]); });
401     }
402     std::copy(axis_set.begin(), axis_set.end(), std::back_inserter(axis));
403     attrs_map[this->rhs_root()] = {{"keep_dims", MakeValue(keep_dims)}, {"axis", MakeValue(axis)}};
404     return attrs_map;
405   }
406 };
407 
408 // "ReduceSum(Neg(A))=Neg(ReduceSum(A))"
409 class ExtraReduce2PatternTree : public PatternTree {
410  public:
ExtraReduce2PatternTree(const std::string & pattern_str)411   explicit ExtraReduce2PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
412   ~ExtraReduce2PatternTree() = default;
413 
414  protected:
SetAttributes(const graphkernel::NodePtr & origin_root)415   std::unordered_map<PatternNodePtr, graphkernel::DAttrs> SetAttributes(
416     const graphkernel::NodePtr &origin_root) override {
417     auto attrs_map = PatternTree::SetAttributes(origin_root);
418     bool keep_dims = GetValue<bool>(origin_root->attrs().find("keep_dims")->second);
419     auto axis = GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second);
420     attrs_map[this->rhs_root()->inputs()[0]] = {{"keep_dims", MakeValue(keep_dims)}, {"axis", MakeValue(axis)}};
421     return attrs_map;
422   }
423 };
424 
425 /*       A
426         /
427        Neg
428        /  \
429     Neg     Mul
430  Here we cannot transform Neg(Neg(A)) to A because Neg(A) is a input of Mul. OutsideRely is responsible for checking
431  this case.
432  */
OutsideRely(const graphkernel::NodePtrList & nodes,const graphkernel::NodePtr & root)433 bool OutsideRely(const graphkernel::NodePtrList &nodes, const graphkernel::NodePtr &root) {
434   std::unordered_set<graphkernel::Node *> nodes_can_simplify;
435   std::for_each(nodes.begin(), nodes.end(), [&nodes_can_simplify](auto n) { nodes_can_simplify.insert(n.get()); });
436   for (auto &n : nodes) {
437     if (n == root) {
438       continue;
439     }
440     for (auto &usr : n->users()) {
441       if (nodes_can_simplify.find(usr.first) == nodes_can_simplify.end()) {
442         return true;
443       }
444     }
445   }
446   return false;
447 }
448 
449 struct Expression {
450   size_t id;
451   std::string math_expr;
452   std::function<PatternTreePtr(const std::string &)> func;
453 };
454 
455 #define EXPR_PATTERN(cls) [](const std::string &expr) -> PatternTreePtr { return std::make_shared<cls>(expr); }
456 
457 static std::vector<Expression> expressions = {
458   // add
459   {1, "Add(A,0)=A", EXPR_PATTERN(PatternTree)},
460   {2, "Add(Mul(A,C),Mul(A,B))=Mul(A,Add(B,C))", EXPR_PATTERN(PatternTree)},
461   {3, "Add(Add(A,const1),const2)=Add(A,Add(const1,const2))", EXPR_PATTERN(PatternTree)},
462   {4, "Add(A,Neg(A))=0", EXPR_PATTERN(PatternTree)},
463   {5, "Add(Add(A,B),Neg(A))=B", EXPR_PATTERN(PatternTree)},
464   {6, "Add(Add(A,B),Add(Neg(A),C))=Add(B,C)", EXPR_PATTERN(PatternTree)},
465   // sub
466   {7, "Sub(A,0)=A", EXPR_PATTERN(PatternTree)},
467   {8, "Sub(A,const1)=Add(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
468   {9, "Sub(Mul(A,C),Mul(A,B))=Mul(A,Sub(B,C))", EXPR_PATTERN(PatternTree)},
469   {10, "Sub(Mul(A,C),Mul(B,C))=Mul(Sub(A,B),C)", EXPR_PATTERN(PatternTree)},
470   // log
471   {11, "Log(Exp(A))=A", EXPR_PATTERN(PatternTree)},
472   {12, "Log(Pow(A,B))=Mul(B,Log(Abs(A)))", EXPR_PATTERN(PatternTree)},
473   {13, "Log(Sqrt(A))=Mul(0.5,Log(A))", EXPR_PATTERN(PatternTree)},
474   {14, "Log(Rsqrt(A))=Mul(-0.5,Log(A))", EXPR_PATTERN(PatternTree)},
475   // pow
476   {15, "Pow(A,1)=A", EXPR_PATTERN(PatternTree)},
477   {16, "Pow(Exp(A),B)=Exp(Mul(A,B))", EXPR_PATTERN(PatternTree)},
478   {17, "Pow(A,2)=Mul(A,A)", EXPR_PATTERN(PatternTree)},
479   {18, "Pow(A,-1)=Reciprocal(A)", EXPR_PATTERN(PatternTree)},
480   // sqrt
481   {19, "Sqrt(Mul(A,A))=Abs(A)", EXPR_PATTERN(PatternTree)},
482   {20, "Rsqrt(Pow(A,-2))=Abs(A)", EXPR_PATTERN(PatternTree)},
483   {21, "Rsqrt(RealDiv(1,A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
484   {22, "Rsqrt(Reciprocal(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
485   // select
486   {23, "Select(A,B,B)=B", EXPR_PATTERN(PatternTree)},
487   // Neg
488   {24, "Neg(Neg(A))=A", EXPR_PATTERN(PatternTree)},
489   // mul
490   {25, "Mul(Mul(A,const1),Mul(B,const2))=Mul(Mul(A,B),Mul(const1,const2))", EXPR_PATTERN(PatternTree)},
491   {26, "Mul(Mul(A,const1),const2)=Mul(A,Mul(const1,const2))", EXPR_PATTERN(PatternTree)},
492   {27, "Mul(Exp(A),Exp(B))=Exp(Add(A,B))", EXPR_PATTERN(PatternTree)},
493   {28, "Mul(Mul(Exp(A),C),Exp(B))=Mul(Exp(Add(A,B)),C)", EXPR_PATTERN(PatternTree)},
494   {29, "Mul(Mul(Exp(A),C),Mul(Exp(B),D))=Mul(Exp(Add(A,B)),Mul(C,D))", EXPR_PATTERN(PatternTree)},
495   {30, "Mul(Sqrt(A),Sqrt(A))=A", EXPR_PATTERN(PatternTree)},
496   {31, "Mul(Mul(A,Sqrt(B)),Mul(C,Sqrt(B)))=Mul(Mul(A,B),C)", EXPR_PATTERN(PatternTree)},
497   {32, "Mul(Mul(A,Sqrt(B)),Sqrt(B))=Mul(A,B)", EXPR_PATTERN(PatternTree)},
498   {33, "Mul(Sqrt(A),Sqrt(B))=Sqrt(Mul(A,B))", EXPR_PATTERN(PatternTree)},
499   {34, "Mul(Rsqrt(A),Rsqrt(A))=Reciprocal(A)", EXPR_PATTERN(PatternTree)},
500   {35, "Mul(Mul(A,Rsqrt(B)),Rsqrt(B))=RealDiv(A,B)", EXPR_PATTERN(PatternTree)},
501   {36, "Mul(Mul(A,Rsqrt(B)),Mul(C,Rsqrt(B)))=RealDiv(Mul(A,C),B)", EXPR_PATTERN(PatternTree)},
502   {37, "Mul(Rsqrt(A),Rsqrt(B))=Rsqrt(Mul(A,B))", EXPR_PATTERN(PatternTree)},
503   {38, "Mul(A,Rsqrt(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
504   {39, "Mul(Abs(A),Abs(B))=Abs(Mul(A,B))", EXPR_PATTERN(PatternTree)},
505   {40, "Mul(Mul(Abs(A),C),Abs(B))=Mul(Abs(Mul(A,B)),C)", EXPR_PATTERN(PatternTree)},
506   {41, "Mul(Mul(Abs(A),C),Mul(Abs(B),D))=Mul(Abs(Mul(A,B)),Mul(C,D))", EXPR_PATTERN(PatternTree)},
507   {42, "Mul(Neg(A),const1)=Mul(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
508   // realdiv
509   {43, "RealDiv(A,1)=A", EXPR_PATTERN(PatternTree)},
510   {44, "RealDiv(Exp(A),Exp(B))=Exp(Sub(A,B))", EXPR_PATTERN(PatternTree)},
511   {45, "RealDiv(A,Exp(B))=Mul(A,Exp(Neg(B)))", EXPR_PATTERN(PatternTree)},
512   {46, "RealDiv(A,Pow(B,const1))=Mul(A,Pow(B,Neg(const1)))", EXPR_PATTERN(PatternTree)},
513   {47, "RealDiv(A,Sqrt(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
514   {48, "RealDiv(A,Sqrt(B))=Mul(A,Rsqrt(B))", EXPR_PATTERN(PatternTree)},
515   {49, "RealDiv(A,Rsqrt(B))=Mul(A,Sqrt(B))", EXPR_PATTERN(PatternTree)},
516   {50, "RealDiv(A,const1)=Mul(A,Reciprocal(const1))", EXPR_PATTERN(PatternTree)},
517   {51, "RealDiv(RealDiv(A,B),RealDiv(C,D))=RealDiv(Mul(A,D),Mul(B,C))", EXPR_PATTERN(PatternTree)},
518   {52, "RealDiv(Neg(A),const1)=RealDiv(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
519   {53, "RealDiv(RealDiv(A,B),C)=RealDiv(A,Mul(B,C))", EXPR_PATTERN(PatternTree)},
520   {54, "RealDiv(A,RealDiv(B,C))=RealDiv(Mul(A,C),B)", EXPR_PATTERN(PatternTree)},
521   // reduce1
522   {55, "ReduceSum(ReduceSum(A))=ReduceSum(A)", EXPR_PATTERN(ExtraReduce1PatternTree)},
523   {56, "ReduceMin(ReduceMin(A))=ReduceMin(A)", EXPR_PATTERN(ExtraReduce1PatternTree)},
524   {57, "ReduceMax(ReduceMax(A))=ReduceMax(A)", EXPR_PATTERN(ExtraReduce1PatternTree)},
525   // reduce2
526   {58, "ReduceSum(Neg(A))=Neg(ReduceSum(A))", EXPR_PATTERN(ExtraReduce2PatternTree)},
527   {59, "ReduceSum(RealDiv(A,const1))=RealDiv(ReduceSum(A),const1)", EXPR_PATTERN(ExtraReduce2PatternTree)},
528   {60, "ReduceSum(Mul(A,const1))=Mul(ReduceSum(A),const1)", EXPR_PATTERN(ExtraReduce2PatternTree)},
529   {61, "CReal(Complex(A,B))=A", EXPR_PATTERN(PatternTree)},
530   {62, "CImag(Complex(A,B))=B", EXPR_PATTERN(PatternTree)},
531 };
532 
GetExpressions()533 std::unordered_map<std::string, std::vector<PatternTreePtr>> GetExpressions() {
534   const auto &flags = context::GraphKernelFlags::GetInstance();
535   std::unordered_map<std::string, std::vector<PatternTreePtr>> expression_map;
536   std::unordered_set<std::string> enable_ids{flags.enable_simplify_exprs_only.begin(),
537                                              flags.enable_simplify_exprs_only.end()};
538   std::unordered_set<std::string> disable_ids{flags.disable_simplify_exprs.begin(), flags.disable_simplify_exprs.end()};
539   for (auto &e : expressions) {
540     if (!enable_ids.empty()) {
541       if (enable_ids.count(std::to_string(e.id)) == 0) continue;
542     } else {
543       if (disable_ids.count(std::to_string(e.id)) > 0) continue;
544     }
545     PatternTreePtr pt = e.func(e.math_expr);
546     expression_map[pt->GetRootOp()].push_back(pt);
547   }
548   return expression_map;
549 }
550 
551 // arithmetic simplify
DoArithmeticTrans(const graphkernel::LiteGraphPtr & litegraph)552 bool ArithmeticSimplify::DoArithmeticTrans(const graphkernel::LiteGraphPtr &litegraph) {
553   auto ops_list = litegraph->ops();
554   bool changed = false;
555   graphkernel::NodePtrList matched_nodes;
556   auto para_to_ref = std::make_shared<ParaMap>();    // A(B,C ...)->Node* mapping
557   auto const_to_ref = std::make_shared<ConstMap>();  // const->Node* mapping
558   PatternTreePtr cur_pattern;
559   auto iter = ops_list.rbegin();
560   while (iter != ops_list.rend()) {
561     bool can_simplify = false;
562     auto this_op = std::static_pointer_cast<graphkernel::PrimOp>(*iter)->op();
563     if (expressions_map_.find(this_op) != expressions_map_.end()) {
564       for (auto p : expressions_map_[this_op]) {
565         cur_pattern = p;
566         if (!para_to_ref->empty()) {
567           para_to_ref->clear();
568         }
569         if (!const_to_ref->empty()) {
570           const_to_ref->clear();
571         }
572         // match a pattern;if return is empty,then fails to match
573         matched_nodes = p->MatchGraph(*iter, para_to_ref, const_to_ref);
574         if (!matched_nodes.empty()) {
575           auto right_root_type = PatternNodeType(p->rhs_root()->op());
576           if (right_root_type == graphkernel::NType::Primitive && OutsideRely(matched_nodes, *iter)) {
577             continue;
578           }
579           // if no outside rely,then this is a successful match
580           can_simplify = true;
581           // get the new node to replace
582           graphkernel::NodePtr alter_graph_node = cur_pattern->AlterGraph(para_to_ref, const_to_ref, *iter);
583           (*iter)->ReplaceWith(alter_graph_node);
584           ops_list = litegraph->GetOrderedNodes();
585           iter = ops_list.rbegin();
586           changed = true;
587           break;
588         }
589       }
590     }
591     if (!can_simplify) {
592       ++iter;
593     }
594   }
595   return changed;
596 }
597 
598 // constant fold
DoConstantFold(const graphkernel::LiteGraphPtr & litegraph)599 bool ArithmeticSimplify::DoConstantFold(const graphkernel::LiteGraphPtr &litegraph) {
600   auto ops_list = litegraph->GetOrderedNodes();
601   bool changed = false;
602   auto iter = ops_list.begin();
603   while (iter != ops_list.end()) {
604     auto this_op = std::static_pointer_cast<graphkernel::PrimOp>(*iter);
605     auto value = this_op->InferValue(this_op->inputs(), this_op->attrs(), this_op->op());
606     if (value != nullptr) {
607       (*iter)->ReplaceWith(value);
608       ops_list = litegraph->GetOrderedNodes();
609       iter = ops_list.begin();
610       changed = true;
611     } else {
612       ++iter;
613     }
614   }
615   return changed;
616 }
617 
ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr & litegraph)618 void ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr &litegraph) {
619   auto &outputs = litegraph->GetOutputs();
620   for (size_t i = 0; i < outputs.size(); i++) {
621     if (outputs[i]->NodeType() == graphkernel::NType::Value) {
622       graphkernel::LiteGraph::GraphBuilder gb;
623       std::vector<int64_t> new_shape = {1};
624       auto op_ptr = gb.Emit("BroadcastTo", {outputs[i]}, {{"shape", MakeValue(new_shape)}});
625       litegraph->output()->SetInput(i, op_ptr);
626     } else if (outputs[i]->NodeType() == graphkernel::NType::Parameter) {
627       graphkernel::LiteGraph::GraphBuilder gb;
628       auto op_ptr = gb.Emit("Reshape", {outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
629       litegraph->output()->SetInput(i, op_ptr);
630     }
631   }
632   return;
633 }
634 
Run(const FuncGraphPtr & func_graph)635 bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
636   auto mng = func_graph->manager();
637   bool do_simplify = false;
638   expressions_map_ = GetExpressions();
639   for (auto node : func_graph->GetOrderedCnodes()) {
640     if (AnfAlgo::IsGraphKernel(node)) {
641       auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
642       graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph);
643       bool find_pattern = true;
644       bool change_anf_graph = false;
645       while (find_pattern) {
646         find_pattern = false;
647         find_pattern = DoArithmeticTrans(lg) || find_pattern;
648         find_pattern = DoConstantFold(lg) || find_pattern;
649         change_anf_graph = change_anf_graph || find_pattern;
650       }
651       if (!change_anf_graph) continue;
652       ReorganizeEmptyGraph(lg);
653       AnfNodePtrList outputs;
654       auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
655       new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
656       auto cnode = node->cast<CNodePtr>();
657       AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
658       EliminateRedundantParameters(new_funcgraph, &inputs);
659       auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
660       SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
661       mng->Replace(node, new_node);
662       mng->AddFuncGraph(new_funcgraph);
663       do_simplify = true;
664     }
665   }
666   return do_simplify;
667 }
668 }  // namespace opt
669 }  // namespace mindspore
670