1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ 19 20 #include <vector> 21 #include <memory> 22 23 #include "frontend/optimizer/optimizer.h" 24 #include "mindspore/core/ops/sequence_ops.h" 25 #include "mindspore/core/ops/math_ops.h" 26 #include "frontend/optimizer/irpass.h" 27 #include "frontend/optimizer/anf_visitor.h" 28 #include "frontend/operator/ops.h" 29 30 namespace mindspore { 31 namespace opt { 32 namespace irpass { 33 // {prim::kPrimTupleGetItem, {target_grad, Xs}, C} 34 class MinMaximumGrad : public AnfVisitor { 35 public: operator()36 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 37 Reset(); 38 AnfVisitor::Match(prim::kPrimTupleGetItem, {MinMaximumGrad::IsOriginMaxMinGrad, IsValueNode<Int64Imm>})(node); 39 if (grad_ == nullptr || idx_ < 0 || idx_ > 1 || node->func_graph() == nullptr) { 40 return nullptr; 41 } 42 43 // check single use 44 auto mng = optimizer->manager(); 45 MS_EXCEPTION_IF_NULL(mng); 46 auto &users = mng->node_users(); 47 if (users.find(grad_) == users.end() || users[grad_].size() != 1) { 48 return nullptr; 49 } 50 51 // {target_grad, Xs} 52 auto &inputs = grad_->inputs(); 53 auto prim = GetValueNode<PrimitivePtr>(inputs[0]); 54 55 auto new_prim = std::make_shared<Primitive>(prim->name()); 56 new_prim->set_attr("grad_x", MakeValue(true)); 57 new_prim->set_attr("grad_y", MakeValue(true)); 58 59 if (idx_ == 0) { 60 new_prim->set_attr("grad_y", MakeValue(false)); 61 } 62 if (idx_ == 1) { 63 new_prim->set_attr("grad_x", MakeValue(false)); 64 } 65 66 std::vector<AnfNodePtr> args; 67 args.push_back(NewValueNode(new_prim)); 68 (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend()); 69 70 auto fg = node->func_graph(); 71 auto new_code = fg->NewCNode(args); 72 if (AnfUtils::GetDumpFlag(grad_)) { 73 AnfUtils::SetDumpFlag(new_code); 74 } 75 76 return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), new_code, NewValueNode(MakeValue(idx_))}); 77 } 78 Visit(const CNodePtr & cnode)79 void Visit(const CNodePtr &cnode) override { grad_ = cnode; } 80 Visit(const ValueNodePtr & vnode)81 void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int64_t>(vnode->value()); } 82 Reset()83 void Reset() { 84 idx_ = -1; 85 grad_ = nullptr; 86 } 87 88 // Check if node is MinimumGrad() or MaximumGrad() IsOriginMaxMinGrad(const AnfNodePtr & node)89 static bool IsOriginMaxMinGrad(const AnfNodePtr &node) { 90 if (!IsPrimitiveCNode(node, prim::kPrimMaximumGrad) && !IsPrimitiveCNode(node, prim::kPrimMinimumGrad)) { 91 return false; 92 } 93 94 auto cnode = node->cast<CNodePtr>(); 95 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); 96 auto x_v = prim->GetAttr("grad_x"); 97 auto y_v = prim->GetAttr("grad_y"); 98 if (x_v == nullptr || y_v == nullptr || !x_v->isa<BoolImm>() || !y_v->isa<BoolImm>()) { 99 return false; 100 } 101 102 bool x = GetValue<bool>(x_v); 103 bool y = GetValue<bool>(y_v); 104 return x && y; 105 } 106 107 private: 108 int64_t idx_{-1}; 109 CNodePtr grad_{nullptr}; 110 }; 111 } // namespace irpass 112 } // namespace opt 113 } // namespace mindspore 114 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ 115