• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
19 
20 #include <vector>
21 #include <algorithm>
22 
23 #include "frontend/optimizer/irpass.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "frontend/optimizer/optimizer.h"
26 #include "frontend/optimizer/anf_visitor.h"
27 #include "frontend/operator/ops.h"
28 #include "utils/compile_config.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace irpass {
33 // {prim::kPrimPartial, func_graph, ...}
34 class PartialDeferInline : public AnfVisitor {
35  public:
operator()36   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
37     static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
38     if (!enable_pre_lift) {
39       return nullptr;
40     }
41     auto cnode = node->cast<CNodePtr>();
42     auto real_func = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(1)->abstract());
43     if (real_func != nullptr) {
44       *(real_func->func_graph()->indirect()) = true;
45     }
46     return nullptr;
47   }
48 };
49 
50 // {prim::kPrimSwitch, cond, true_branch, false_branch}
51 class SwitchDeferInline : public AnfVisitor {
52  public:
operator()53   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
54     auto cnode = node->cast<CNodePtr>();
55     auto true_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(2)->abstract());
56     if (true_abstract != nullptr) {
57       *(true_abstract->func_graph()->indirect()) = true;
58     }
59     auto false_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(3)->abstract());
60     if (false_abstract != nullptr) {
61       *(false_abstract->func_graph()->indirect()) = true;
62     }
63     return nullptr;
64   }
65 };
66 
67 // {prim::kPrimSwitchLayer, Index, layers}
68 class SwitchLayerDeferInline : public AnfVisitor {
69  public:
operator()70   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
71     auto cnode = node->cast<CNodePtr>();
72     auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->input(2)->abstract());
73     if (tuple == nullptr) {
74       return nullptr;
75     }
76     for (auto elem : tuple->elements()) {
77       auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
78       if (abstract != nullptr) {
79         *(abstract->func_graph()->indirect()) = true;
80       }
81     }
82     return nullptr;
83   }
84 };
85 }  // namespace irpass
86 }  // namespace opt
87 }  // namespace mindspore
88 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
89