• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.common.parameter import Parameter
23from mindspore.common.initializer import initializer
24from mindspore.ops import operations as P
25
26context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
27
28
29class NetCenteredRMSProp(nn.Cell):
30    def __init__(self, lr, decay, momentum, epsilon, var, g, mg, rms, mom):
31        super(NetCenteredRMSProp, self).__init__()
32        self.rms_opt = P.ApplyCenteredRMSProp()
33        self.lr = lr
34        self.decay = decay
35        self.momentum = momentum
36        self.epsilon = epsilon
37        self.var = var
38        self.g = g
39        self.mg = mg
40        self.rms = rms
41        self.mom = mom
42
43    def construct(self):
44        return self.rms_opt(self.var, self.mg, self.rms, self.mom, self.g, self.lr, self.decay, self.momentum,
45                            self.epsilon)
46
47
48class NetRMSProp(nn.Cell):
49    def __init__(self, lr, decay, momentum, epsilon, var, g, mg, rms, mom):
50        super(NetRMSProp, self).__init__()
51        self.lr = lr
52        self.decay = decay
53        self.momentum = momentum
54        self.epsilon = epsilon
55        self.var = var
56        self.g = g
57        self.mg = mg
58        self.rms = rms
59        self.mom = mom
60        self.rms_opt = P.ApplyRMSProp()
61
62    def construct(self):
63        return self.rms_opt(self.var, self.rms, self.mom, self.lr, self.g, self.decay, self.momentum, self.epsilon)
64
65
66def rmsprop_numpy(variable, gradients, mean_square, moment,
67                  learning_rate, decay, momentum, epsilon):
68    mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients
69    moment = momentum * moment + learning_rate / np.sqrt(mean_square + epsilon) * gradients
70    variable = variable - moment
71    return variable, gradients, mean_square, moment
72
73
74def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment,
75                        learning_rate, decay, momentum, epsilon):
76    mean_gradients = mean_gradients * decay + (1.0 - decay) * gradients
77    mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients
78    moment = momentum * moment + learning_rate / np.sqrt(
79        mean_square - mean_gradients * mean_gradients + epsilon) * gradients
80    variable = variable - moment
81    return variable, gradients, mean_gradients, mean_square, moment
82
83
84@pytest.mark.level0
85@pytest.mark.platform_cpu
86@pytest.mark.env_onecard
87def test_rmsprop():
88    learning_rate, decay, momentum, epsilon, centered = [0.5, 0.8, 0.9, 1e-3, True]
89
90    variable_np = np.array([1.0, 2.0], dtype=np.float32)
91    gradients_np = np.array([0.1, 0.2], dtype=np.float32)
92    mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32)
93    mean_square_np = np.array([epsilon, epsilon], dtype=np.float32)
94    moment_np = np.array([0.0, 0.0], dtype=np.float32)
95
96    variable = Tensor(variable_np)
97    gradients = Tensor(gradients_np)
98    mean_gradients = Tensor(mean_gradients_np)
99    mean_square = Tensor(mean_square_np)
100    moment = Tensor(moment_np)
101
102    variable_ms = Parameter(initializer(variable, variable.shape), name='var')
103    gradients_ms = Parameter(initializer(gradients, gradients.shape), name='grad')
104    mean_gradients_ms = Parameter(initializer(mean_gradients, mean_gradients.shape), name='mg')
105    mean_square_ms = Parameter(initializer(mean_square, mean_square.shape), name='msr')
106    moment_ms = Parameter(initializer(moment, moment.shape), name='mom')
107
108    if centered:
109        variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \
110            rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
111                                learning_rate, decay, momentum, epsilon)
112        net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms,
113                                 mean_square_ms, moment_ms)
114        _ = net()
115
116    else:
117        variable_np, gradients_np, mean_square_np, moment_np = \
118            rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
119                          learning_rate, decay, momentum, epsilon)
120        net = NetRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms,
121                         mean_square_ms, moment_ms)
122        _ = net()
123
124    error = np.ones(shape=variable_np.shape) * 10e-6
125    diff = variable_ms.asnumpy() - variable_np
126    assert np.all(diff < error)
127
128    error = np.ones(shape=gradients_np.shape) * 10e-6
129    diff = gradients_ms.asnumpy() - gradients_np
130    assert np.all(diff < error)
131
132    error = np.ones(shape=mean_gradients_np.shape) * 10e-6
133    diff = mean_gradients_ms.asnumpy() - mean_gradients_np
134    assert np.all(diff < error)
135
136    error = np.ones(shape=mean_square_np.shape) * 10e-6
137    diff = mean_square_ms.asnumpy() - mean_square_np
138    assert np.all(diff < error)
139
140    error = np.ones(shape=moment_np.shape) * 10e-6
141    diff = moment_ms.asnumpy() - moment_np
142    assert np.all(diff < error)
143
144
145@pytest.mark.level0
146@pytest.mark.platform_cpu
147@pytest.mark.env_onecard
148def test_rmspropcenter():
149    learning_rate, decay, momentum, epsilon, centered = [0.1, 0.3, 0.9, 1.0, False]
150
151    variable_np = np.array([1.0, 2.0], dtype=np.float32)
152    gradients_np = np.array([0.1, 0.2], dtype=np.float32)
153    mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32)
154    mean_square_np = np.array([epsilon, epsilon], dtype=np.float32)
155    moment_np = np.array([0.0, 0.0], dtype=np.float32)
156
157    variable = Tensor(variable_np)
158    gradients = Tensor(gradients_np)
159    mean_gradients = Tensor(mean_gradients_np)
160    mean_square = Tensor(mean_square_np)
161    moment = Tensor(moment_np)
162
163    variable_ms = Parameter(initializer(variable, variable.shape), name='var')
164    gradients_ms = Parameter(initializer(gradients, gradients.shape), name='grad')
165    mean_gradients_ms = Parameter(initializer(mean_gradients, mean_gradients.shape), name='mg')
166    mean_square_ms = Parameter(initializer(mean_square, mean_square.shape), name='msr')
167    moment_ms = Parameter(initializer(moment, moment.shape), name='mom')
168
169    if centered:
170        variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \
171            rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
172                                learning_rate, decay, momentum, epsilon)
173        net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms,
174                                 mean_square_ms, moment_ms)
175        _ = net()
176    else:
177        variable_np, gradients_np, mean_square_np, moment_np = \
178            rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
179                          learning_rate, decay, momentum, epsilon)
180        net = NetRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms,
181                         mean_square_ms, moment_ms)
182        _ = net()
183
184    error = np.ones(shape=variable_np.shape) * 10e-6
185    diff = variable_ms.asnumpy() - variable_np
186    assert np.all(diff < error)
187
188    error = np.ones(shape=gradients_np.shape) * 10e-6
189    diff = gradients_ms.asnumpy() - gradients_np
190    assert np.all(diff < error)
191
192    error = np.ones(shape=mean_gradients_np.shape) * 10e-6
193    diff = mean_gradients_ms.asnumpy() - mean_gradients_np
194    assert np.all(diff < error)
195
196    error = np.ones(shape=mean_square_np.shape) * 10e-6
197    diff = mean_square_ms.asnumpy() - mean_square_np
198    assert np.all(diff < error)
199
200    error = np.ones(shape=moment_np.shape) * 10e-6
201    diff = moment_ms.asnumpy() - moment_np
202    assert np.all(diff < error)
203