1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
17
18 #include <memory>
19 #include <vector>
20 #include <string>
21
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "ir/primitive.h"
24 #include "utils/utils.h"
25 #include "backend/optimizer/common/helper.h"
26
27 namespace mindspore {
28 namespace opt {
29 constexpr size_t kInputIndex = 1;
30
IsScalar(const BaseRef & n)31 bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
32 if (utils::isa<AnfNodePtr>(n)) {
33 AnfNodePtr in = utils::cast<AnfNodePtr>(n);
34 MS_EXCEPTION_IF_NULL(in);
35 auto shape_ptr = in->Shape();
36 MS_EXCEPTION_IF_NULL(shape_ptr);
37 auto shape = shape_ptr->cast<abstract::ShapePtr>();
38 MS_EXCEPTION_IF_NULL(shape);
39 if (shape->shape().size() != 0) {
40 return false;
41 }
42 auto dtype = in->Type();
43 MS_EXCEPTION_IF_NULL(dtype);
44 if (dtype->type_id() != kObjectTypeTensorType) {
45 return false;
46 }
47 auto type_ptr = dyn_cast<TensorType>(dtype);
48 MS_EXCEPTION_IF_NULL(type_ptr);
49 auto element = type_ptr->element();
50 MS_EXCEPTION_IF_NULL(element);
51 auto element_type = element->type_id();
52 if (element_type != kNumberTypeFloat32) {
53 return false;
54 }
55 return true;
56 }
57 return false;
58 }
59
IsCast(const BaseRef & n)60 bool ApplyMomentumWeightDecayScaleFusion::IsCast(const BaseRef &n) {
61 if (utils::isa<AnfNodePtr>(n)) {
62 AnfNodePtr in = utils::cast<AnfNodePtr>(n);
63 MS_EXCEPTION_IF_NULL(in);
64 if (IsPrimitiveCNode(in, prim::kPrimCast) ||
65 (IsPrimitiveCNode(in, prim::kPrimDepend) &&
66 IsPrimitiveCNode(in->cast<CNodePtr>()->input(kInputIndex), prim::kPrimCast))) {
67 return true;
68 }
69 }
70 return false;
71 }
72
GetCastInput(const AnfNodePtr & node)73 AnfNodePtr GetCastInput(const AnfNodePtr &node) {
74 if (IsPrimitiveCNode(node, prim::kPrimCast)) {
75 return node->cast<CNodePtr>()->input(kInputIndex);
76 }
77 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
78 auto cast_node = node->cast<CNodePtr>()->input(kInputIndex);
79 if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
80 return cast_node->cast<CNodePtr>()->input(kInputIndex);
81 }
82 }
83 return nullptr;
84 }
85
DefinePattern() const86 const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
87 VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
88 VectorRef weight =
89 VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), cast_gradient_});
90 VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_});
91 VectorRef apply_momentum =
92 VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_});
93 return apply_momentum;
94 }
95
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr & equiv) const96 const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
97 const EquivPtr &equiv) const {
98 MS_EXCEPTION_IF_NULL(graph);
99 MS_EXCEPTION_IF_NULL(node);
100 MS_EXCEPTION_IF_NULL(equiv);
101 auto weight_decay = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
102 auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]);
103 auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
104 auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
105 auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
106 auto cast_gradient = utils::cast<AnfNodePtr>((*equiv)[cast_gradient_]);
107 auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
108 auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
109
110 MS_EXCEPTION_IF_NULL(weight_decay);
111 MS_EXCEPTION_IF_NULL(scale);
112 MS_EXCEPTION_IF_NULL(variable);
113 MS_EXCEPTION_IF_NULL(accumulation);
114 MS_EXCEPTION_IF_NULL(learning_rate);
115 MS_EXCEPTION_IF_NULL(cast_gradient);
116 MS_EXCEPTION_IF_NULL(momentum);
117 MS_EXCEPTION_IF_NULL(monad_state);
118
119 auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum);
120 MS_EXCEPTION_IF_NULL(prim);
121 auto gradient = GetCastInput(cast_gradient);
122 MS_EXCEPTION_IF_NULL(gradient);
123 std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, accumulation,
124 learning_rate, gradient, momentum, monad_state};
125 auto replace_node = graph->NewCNode(inputs);
126 MS_EXCEPTION_IF_NULL(replace_node);
127 auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
128 auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
129 AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get());
130 replace_node->set_scope(node->scope());
131 return replace_node;
132 }
133 } // namespace opt
134 } // namespace mindspore
135