• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
19 
20 #include <memory>
21 
22 #include "ir/pattern_matcher.h"
23 #include "frontend/optimizer/irpass.h"
24 #include "frontend/optimizer/optimizer.h"
25 
26 namespace mindspore {
27 namespace opt {
28 namespace irpass {
29 namespace internal {
30 class GetRefValueTransform {
31  public:
GetRefValueTransform()32   GetRefValueTransform() {}
33   ~GetRefValueTransform() = default;
34 
operator()35   AnfNodePtr operator()(const AnfNodePtr &node) {
36     CNodePtr cnode = node->cast<CNodePtr>();
37     auto inputs = cnode->inputs();
38     auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>();
39     if (fg != nullptr && fg->recursive()) {
40       MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString();
41       return node;
42     }
43     auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue"));
44     auto output = new_fg->output();
45     new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output}));
46     inputs[0] = NewValueNode(new_fg);
47     auto ret_node = cnode->func_graph()->NewCNode(inputs);
48     return ret_node;
49   }
50 };
51 }  // namespace internal
52 
53 // {prim::kPrimMakeRef, X, Y, Z} -> Y
54 class MakeRefEliminater : public OptimizerCaller {
55  public:
operator()56   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
57     PatternNode<AnfNodePtr> x, y, z;
58     MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
59     return nullptr;
60   }
61 };
62 
63 // {prim::kPrimGetRefValue, Parameter} -> Parameter
64 class GetRefParamEliminater : public OptimizerCaller {
65  public:
operator()66   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
67     PatternNode<AnfNodePtr> x;
68     MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
69     return nullptr;
70   }
71 };
72 
73 // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
74 // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
75 // {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f}
76 class GetMakeRefEliminater : public OptimizerCaller {
77  public:
operator()78   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
79     PatternNode<AnfNodePtr> x, y, z;
80     MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
81     MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
82     MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node));
83     internal::GetRefValueTransform trans;
84     auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr {
85       auto rep = trans(x.GetNode(node));
86       if (rep != nullptr) {
87         return rep;
88       }
89       return nullptr;
90     };
91     MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node));
92     return nullptr;
93   }
94 };
95 
96 // IsValueNode<RefKey>
97 class ReplaceRefkeyByParam : public OptimizerCaller {
98  public:
operator()99   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
100     auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
101       auto refkey = GetValueNode<RefKeyPtr>(node);
102       auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
103       MS_EXCEPTION_IF_NULL(resource);
104 
105       auto top_graph = resource->func_graph();
106       MS_EXCEPTION_IF_NULL(top_graph);
107 
108       for (const auto &tnode : top_graph->parameters()) {
109         auto para = tnode->cast<ParameterPtr>();
110         if (para != nullptr && para->name() == refkey->tag()) {
111           return para;
112         }
113       }
114       return nullptr;
115     };
116     PatternNode<AnfNodePtr> x;
117     MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
118     return nullptr;
119   }
120 };
121 }  // namespace irpass
122 }  // namespace opt
123 }  // namespace mindspore
124 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
125