1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ 19 20 #include <memory> 21 22 #include "ir/pattern_matcher.h" 23 #include "frontend/optimizer/irpass.h" 24 #include "frontend/optimizer/optimizer.h" 25 26 namespace mindspore { 27 namespace opt { 28 namespace irpass { 29 namespace internal { 30 class GetRefValueTransform { 31 public: GetRefValueTransform()32 GetRefValueTransform() {} 33 ~GetRefValueTransform() = default; 34 operator()35 AnfNodePtr operator()(const AnfNodePtr &node) { 36 CNodePtr cnode = node->cast<CNodePtr>(); 37 auto inputs = cnode->inputs(); 38 auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>(); 39 if (fg != nullptr && fg->recursive()) { 40 MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString(); 41 return node; 42 } 43 auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue")); 44 auto output = new_fg->output(); 45 new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output})); 46 inputs[0] = NewValueNode(new_fg); 47 auto ret_node = cnode->func_graph()->NewCNode(inputs); 48 return ret_node; 49 } 50 }; 51 } // namespace internal 52 53 // {prim::kPrimMakeRef, X, Y, Z} -> Y 54 class MakeRefEliminater : public OptimizerCaller { 55 public: operator()56 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 57 PatternNode<AnfNodePtr> x, y, z; 58 MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y); 59 return nullptr; 60 } 61 }; 62 63 // {prim::kPrimGetRefValue, Parameter} -> Parameter 64 class GetRefParamEliminater : public OptimizerCaller { 65 public: operator()66 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 67 PatternNode<AnfNodePtr> x; 68 MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); 69 return nullptr; 70 } 71 }; 72 73 // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X 74 // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y 75 // {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f} 76 class GetMakeRefEliminater : public OptimizerCaller { 77 public: operator()78 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 79 PatternNode<AnfNodePtr> x, y, z; 80 MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); 81 MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); 82 MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node)); 83 internal::GetRefValueTransform trans; 84 auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr { 85 auto rep = trans(x.GetNode(node)); 86 if (rep != nullptr) { 87 return rep; 88 } 89 return nullptr; 90 }; 91 MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node)); 92 return nullptr; 93 } 94 }; 95 96 // IsValueNode<RefKey> 97 class ReplaceRefkeyByParam : public OptimizerCaller { 98 public: operator()99 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 100 auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr { 101 auto refkey = GetValueNode<RefKeyPtr>(node); 102 auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource()); 103 MS_EXCEPTION_IF_NULL(resource); 104 105 auto top_graph = resource->func_graph(); 106 MS_EXCEPTION_IF_NULL(top_graph); 107 108 for (const auto &tnode : top_graph->parameters()) { 109 auto para = tnode->cast<ParameterPtr>(); 110 if (para != nullptr && para->name() == refkey->tag()) { 111 return para; 112 } 113 } 114 return nullptr; 115 }; 116 PatternNode<AnfNodePtr> x; 117 MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node)); 118 return nullptr; 119 } 120 }; 121 } // namespace irpass 122 } // namespace opt 123 } // namespace mindspore 124 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ 125