1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ 19 20 #include <vector> 21 #include <algorithm> 22 23 #include "frontend/optimizer/irpass.h" 24 #include "mindspore/core/ops/framework_ops.h" 25 #include "frontend/optimizer/optimizer.h" 26 #include "frontend/optimizer/anf_visitor.h" 27 #include "frontend/operator/ops.h" 28 #include "utils/compile_config.h" 29 30 namespace mindspore { 31 namespace opt { 32 namespace irpass { 33 // {prim::kPrimPartial, func_graph, ...} 34 class PartialDeferInline : public AnfVisitor { 35 public: operator()36 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 37 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1"); 38 if (!enable_pre_lift) { 39 return nullptr; 40 } 41 auto cnode = node->cast<CNodePtr>(); 42 auto real_func = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(1)->abstract()); 43 if (real_func != nullptr) { 44 *(real_func->func_graph()->indirect()) = true; 45 } 46 return nullptr; 47 } 48 }; 49 50 // {prim::kPrimSwitch, cond, true_branch, false_branch} 51 class SwitchDeferInline : public AnfVisitor { 52 public: operator()53 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 54 auto cnode = node->cast<CNodePtr>(); 55 auto true_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(2)->abstract()); 56 if (true_abstract != nullptr) { 57 *(true_abstract->func_graph()->indirect()) = true; 58 } 59 auto false_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(3)->abstract()); 60 if (false_abstract != nullptr) { 61 *(false_abstract->func_graph()->indirect()) = true; 62 } 63 return nullptr; 64 } 65 }; 66 67 // {prim::kPrimSwitchLayer, Index, layers} 68 class SwitchLayerDeferInline : public AnfVisitor { 69 public: operator()70 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 71 auto cnode = node->cast<CNodePtr>(); 72 auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->input(2)->abstract()); 73 if (tuple == nullptr) { 74 return nullptr; 75 } 76 for (auto elem : tuple->elements()) { 77 auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem); 78 if (abstract != nullptr) { 79 *(abstract->func_graph()->indirect()) = true; 80 } 81 } 82 return nullptr; 83 } 84 }; 85 } // namespace irpass 86 } // namespace opt 87 } // namespace mindspore 88 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ 89