• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore.nn as nn
19from mindspore.ops import composite as C
20from mindspore.nn import Momentum
21from mindspore import context, Tensor
22from mindspore.common.api import ms_function
23
24grad_all = C.GradOperation(get_all=True)
25
26
27class CellBprop(nn.Cell):
28    def __init__(self):
29        super(CellBprop, self).__init__()
30
31    def construct(self, x, y):
32        return 2 * x * x + y * y
33
34    @ms_function
35    def bprop(self, x, y, out, dout):
36        return dout, 2 * y
37
38
39def test_cell_bprop_grad():
40    input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
41    input_y = Tensor(np.random.randn(2, 2).astype(np.float32))
42    context.set_context(mode=context.PYNATIVE_MODE)
43    net = CellBprop()
44    with pytest.raises(RuntimeError):
45        grad_all(net)(input_x, input_y)
46
47
48class ConvNet(nn.Cell):
49    def __init__(self):
50        super(ConvNet, self).__init__()
51        self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
52
53    def construct(self, x):
54        out = self.conv(x)
55        return out
56
57
58class MomentumWithMsFunc(nn.Cell):
59    def __init__(self, net):
60        super(MomentumWithMsFunc, self).__init__()
61        self.net = net
62        self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9)
63
64    @ms_function
65    def construct(self, grads):
66        ret = self.optimizer(grads)
67        return ret
68
69
70def test_ms_func_decorate_forward():
71    context.set_context(mode=context.PYNATIVE_MODE)
72    input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32))
73    net = ConvNet()
74    grad_out = grad_all(net)(input_x)
75    opt = MomentumWithMsFunc(net)
76    opt(grad_out)
77