• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17import mindspore.context as context
18import mindspore.nn as nn
19from mindspore import Tensor
20import mindspore.ops as P
21from mindspore.nn.optim import Momentum
22from mindspore.common import ParameterTuple
23
24
25class GradofParams(nn.Cell):
26    def __init__(self, net, sens=False):
27        super().__init__()
28        self.grad = P.GradOperation(get_all=False, get_by_list=True, sens_param=sens)
29        self.net = net
30        self.params = ParameterTuple(self.net.trainable_params())
31
32    def construct(self, *x):
33        out = self.grad(self.net, self.params)(*x)
34        return out
35
36@pytest.mark.level1
37@pytest.mark.platform_arm_ascend_training
38@pytest.mark.platform_x86_ascend_training
39@pytest.mark.platform_x86_gpu_training
40@pytest.mark.env_onecard
41def test_pynative_temporary_cell_variables():
42    context.set_context(mode=context.PYNATIVE_MODE)
43
44    class Net(nn.Cell):
45        def __init__(self):
46            super().__init__()
47            self.add = P.Add()
48            self.conv = nn.Conv2d(1, 1, 3, weight_init='ones', pad_mode='pad')
49            self.relu = nn.ReLU()
50
51        def construct(self, x):
52            x = self.conv(x)
53            x = self.relu(x)
54            x = self.add(x, x)
55            return x
56
57    class TempCellNet(nn.Cell):
58        def __init__(self):
59            super().__init__()
60            self.add = P.Add()
61            self.conv = nn.Conv2d(1, 1, 3, weight_init='ones', pad_mode='pad')
62
63        def construct(self, x):
64            x = self.conv(x)
65            x = nn.ReLU()(x)
66            x = self.add(x, x)
67            return x
68
69    input_data = Tensor(np.random.randn(1, 1, 224, 224).astype(np.float32))
70    # The first net run
71    net = Net()
72    backnet = GradofParams(net)
73    optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
74    grad_first = backnet(input_data)
75    optimizer(grad_first)
76    grad_second = backnet(input_data)
77    # The second net run
78    compare_net = TempCellNet()
79    compare_backnet = GradofParams(compare_net)
80    compare_optimizer = Momentum(filter(lambda x: x.requires_grad, compare_net.get_parameters()), 0.1, 0.9)
81    compare_grad_first = compare_backnet(input_data)
82    compare_optimizer(compare_grad_first)
83    compare_grad_second = compare_backnet(input_data)
84    # compare result
85    assert np.allclose(grad_first[0].asnumpy(), compare_grad_first[0].asnumpy(), 0.01, 0.01)
86    assert np.allclose(grad_second[0].asnumpy(), compare_grad_second[0].asnumpy(), 0.01, 0.01)
87