• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
17 
18 #include <memory>
19 #include <vector>
20 #include <string>
21 
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "ir/primitive.h"
24 #include "utils/utils.h"
25 #include "backend/optimizer/common/helper.h"
26 
27 namespace mindspore {
28 namespace opt {
29 constexpr size_t kInputIndex = 1;
30 
IsScalar(const BaseRef & n)31 bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
32   if (utils::isa<AnfNodePtr>(n)) {
33     AnfNodePtr in = utils::cast<AnfNodePtr>(n);
34     MS_EXCEPTION_IF_NULL(in);
35     auto shape_ptr = in->Shape();
36     MS_EXCEPTION_IF_NULL(shape_ptr);
37     auto shape = shape_ptr->cast<abstract::ShapePtr>();
38     MS_EXCEPTION_IF_NULL(shape);
39     if (shape->shape().size() != 0) {
40       return false;
41     }
42     auto dtype = in->Type();
43     MS_EXCEPTION_IF_NULL(dtype);
44     if (dtype->type_id() != kObjectTypeTensorType) {
45       return false;
46     }
47     auto type_ptr = dyn_cast<TensorType>(dtype);
48     MS_EXCEPTION_IF_NULL(type_ptr);
49     auto element = type_ptr->element();
50     MS_EXCEPTION_IF_NULL(element);
51     auto element_type = element->type_id();
52     if (element_type != kNumberTypeFloat32) {
53       return false;
54     }
55     return true;
56   }
57   return false;
58 }
59 
IsCast(const BaseRef & n)60 bool ApplyMomentumWeightDecayScaleFusion::IsCast(const BaseRef &n) {
61   if (utils::isa<AnfNodePtr>(n)) {
62     AnfNodePtr in = utils::cast<AnfNodePtr>(n);
63     MS_EXCEPTION_IF_NULL(in);
64     if (IsPrimitiveCNode(in, prim::kPrimCast) ||
65         (IsPrimitiveCNode(in, prim::kPrimDepend) &&
66          IsPrimitiveCNode(in->cast<CNodePtr>()->input(kInputIndex), prim::kPrimCast))) {
67       return true;
68     }
69   }
70   return false;
71 }
72 
GetCastInput(const AnfNodePtr & node)73 AnfNodePtr GetCastInput(const AnfNodePtr &node) {
74   if (IsPrimitiveCNode(node, prim::kPrimCast)) {
75     return node->cast<CNodePtr>()->input(kInputIndex);
76   }
77   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
78     auto cast_node = node->cast<CNodePtr>()->input(kInputIndex);
79     if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
80       return cast_node->cast<CNodePtr>()->input(kInputIndex);
81     }
82   }
83   return nullptr;
84 }
85 
DefinePattern() const86 const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
87   VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
88   VectorRef weight =
89     VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), cast_gradient_});
90   VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_});
91   VectorRef apply_momentum =
92     VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_});
93   return apply_momentum;
94 }
95 
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr & equiv) const96 const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
97                                                               const EquivPtr &equiv) const {
98   MS_EXCEPTION_IF_NULL(graph);
99   MS_EXCEPTION_IF_NULL(node);
100   MS_EXCEPTION_IF_NULL(equiv);
101   auto weight_decay = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
102   auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]);
103   auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
104   auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
105   auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
106   auto cast_gradient = utils::cast<AnfNodePtr>((*equiv)[cast_gradient_]);
107   auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
108   auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
109 
110   MS_EXCEPTION_IF_NULL(weight_decay);
111   MS_EXCEPTION_IF_NULL(scale);
112   MS_EXCEPTION_IF_NULL(variable);
113   MS_EXCEPTION_IF_NULL(accumulation);
114   MS_EXCEPTION_IF_NULL(learning_rate);
115   MS_EXCEPTION_IF_NULL(cast_gradient);
116   MS_EXCEPTION_IF_NULL(momentum);
117   MS_EXCEPTION_IF_NULL(monad_state);
118 
119   auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum);
120   MS_EXCEPTION_IF_NULL(prim);
121   auto gradient = GetCastInput(cast_gradient);
122   MS_EXCEPTION_IF_NULL(gradient);
123   std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale,    variable,   accumulation,
124                                     learning_rate,      gradient,     momentum, monad_state};
125   auto replace_node = graph->NewCNode(inputs);
126   MS_EXCEPTION_IF_NULL(replace_node);
127   auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
128   auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
129   AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get());
130   replace_node->set_scope(node->scope());
131   return replace_node;
132 }
133 }  // namespace opt
134 }  // namespace mindspore
135