• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_
19 
20 #include <vector>
21 #include <memory>
22 
23 #include "frontend/optimizer/optimizer.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/math_ops.h"
26 #include "frontend/optimizer/irpass.h"
27 #include "frontend/optimizer/anf_visitor.h"
28 #include "frontend/operator/ops.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace irpass {
33 // {prim::kPrimTupleGetItem, {target_grad, Xs}, C}
34 class MinMaximumGrad : public AnfVisitor {
35  public:
operator()36   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
37     Reset();
38     AnfVisitor::Match(prim::kPrimTupleGetItem, {MinMaximumGrad::IsOriginMaxMinGrad, IsValueNode<Int64Imm>})(node);
39     if (grad_ == nullptr || idx_ < 0 || idx_ > 1 || node->func_graph() == nullptr) {
40       return nullptr;
41     }
42 
43     // check single use
44     auto mng = optimizer->manager();
45     MS_EXCEPTION_IF_NULL(mng);
46     auto &users = mng->node_users();
47     if (users.find(grad_) == users.end() || users[grad_].size() != 1) {
48       return nullptr;
49     }
50 
51     // {target_grad, Xs}
52     auto &inputs = grad_->inputs();
53     auto prim = GetValueNode<PrimitivePtr>(inputs[0]);
54 
55     auto new_prim = std::make_shared<Primitive>(prim->name());
56     new_prim->set_attr("grad_x", MakeValue(true));
57     new_prim->set_attr("grad_y", MakeValue(true));
58 
59     if (idx_ == 0) {
60       new_prim->set_attr("grad_y", MakeValue(false));
61     }
62     if (idx_ == 1) {
63       new_prim->set_attr("grad_x", MakeValue(false));
64     }
65 
66     std::vector<AnfNodePtr> args;
67     args.push_back(NewValueNode(new_prim));
68     (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend());
69 
70     auto fg = node->func_graph();
71     auto new_code = fg->NewCNode(args);
72     if (AnfUtils::GetDumpFlag(grad_)) {
73       AnfUtils::SetDumpFlag(new_code);
74     }
75 
76     return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), new_code, NewValueNode(MakeValue(idx_))});
77   }
78 
Visit(const CNodePtr & cnode)79   void Visit(const CNodePtr &cnode) override { grad_ = cnode; }
80 
Visit(const ValueNodePtr & vnode)81   void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int64_t>(vnode->value()); }
82 
Reset()83   void Reset() {
84     idx_ = -1;
85     grad_ = nullptr;
86   }
87 
88   // Check if node is MinimumGrad() or MaximumGrad()
IsOriginMaxMinGrad(const AnfNodePtr & node)89   static bool IsOriginMaxMinGrad(const AnfNodePtr &node) {
90     if (!IsPrimitiveCNode(node, prim::kPrimMaximumGrad) && !IsPrimitiveCNode(node, prim::kPrimMinimumGrad)) {
91       return false;
92     }
93 
94     auto cnode = node->cast<CNodePtr>();
95     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
96     auto x_v = prim->GetAttr("grad_x");
97     auto y_v = prim->GetAttr("grad_y");
98     if (x_v == nullptr || y_v == nullptr || !x_v->isa<BoolImm>() || !y_v->isa<BoolImm>()) {
99       return false;
100     }
101 
102     bool x = GetValue<bool>(x_v);
103     bool y = GetValue<bool>(y_v);
104     return x && y;
105   }
106 
107  private:
108   int64_t idx_{-1};
109   CNodePtr grad_{nullptr};
110 };
111 }  // namespace irpass
112 }  // namespace opt
113 }  // namespace mindspore
114 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_
115