• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import numpy as np
18import mindspore.nn as nn
19import mindspore.ops as ops
20from mindspore import context, Tensor, Parameter
21from mindspore.nn import TrainOneStepCell
22from mindspore.nn.optim import Momentum
23from mindspore.ops.composite import GradOperation
24from mindspore.common import ParameterTuple
25
26context.set_context(mode=context.GRAPH_MODE)
27
28
29class _Grad(nn.Cell):
30    def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
31        super().__init__()
32        self.network = network
33        self.grad = grad
34        self.sens_param = self.grad.sens_param
35        self.wrt_params = wrt_params
36        self.real_inputs_count = real_inputs_count
37        if self.wrt_params:
38            self.params = ParameterTuple(self.network.trainable_params())
39
40    def construct(self, *inputs):
41        if self.real_inputs_count is None or self.sens_param is False:
42            if self.wrt_params:
43                return self.grad(self.network, self.params)(*inputs)
44            return self.grad(self.network)(*inputs)
45
46        real_inputs = inputs[:self.real_inputs_count]
47        sense_param_inputs = inputs[self.real_inputs_count:]
48        if self.wrt_params:
49            return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
50        return self.grad(self.network)(*real_inputs, sense_param_inputs)
51
52
53class GradOfFirstInput(_Grad):
54    """
55    get grad of first input
56    """
57
58    def __init__(self, network, sens_param=True, real_inputs_count=None):
59        super().__init__(grad=GradOperation(sens_param=sens_param),
60                         network=network, real_inputs_count=real_inputs_count)
61
62
63class Net(nn.Cell):
64    def __init__(self):
65        super(Net, self).__init__()
66        self.mul = ops.Mul()
67        self.add = ops.TensorAdd()
68        weight_np = np.array([2]).astype(np.float32)
69        bias_np = np.array([1]).astype(np.float32)
70        self.weight = Parameter(Tensor(weight_np),
71                                name='weight', requires_grad=True)
72        self.bias = Parameter(Tensor(bias_np),
73                              name="bias", requires_grad=True)
74
75    def construct(self, x):
76        xw = self.mul(x, self.weight)
77        output = self.add(xw, self.bias)
78        return output
79
80
81class WithLossCellLocal(nn.Cell):
82    def __init__(self, grad, loss):
83        super(WithLossCellLocal, self).__init__(auto_prefix=False)
84        self.grad = grad
85        self.loss = loss
86
87    def construct(self, data, label):
88        out = self.grad(data)
89        return self.loss(out, label)
90
91
92@pytest.mark.level0
93@pytest.mark.platform_arm_ascend_training
94@pytest.mark.platform_x86_ascend_training
95@pytest.mark.platform_x86_gpu_training
96@pytest.mark.platform_x86_cpu_training
97@pytest.mark.env_onecard
98def test_high_grad_train():
99    x_pure = np.random.randint(-10, 100, 32)
100    x_train = x_pure.astype(np.float32)
101    y_noise = 3 * x_pure + 2 + np.random.randn(32) / 10
102    y_train = y_noise.astype(np.float32)
103    net = Net()
104    grad_net = GradOfFirstInput(net, sens_param=False)
105    epoch = 2
106    momentum = 0.0
107    learning_rate = 0.001
108    optimizer = Momentum(filter(lambda x: x.requires_grad,
109                                grad_net.get_parameters()), learning_rate, momentum)
110    criterion = nn.loss.MSELoss()
111    net_with_criterion = WithLossCellLocal(grad_net, criterion)
112    train_network = TrainOneStepCell(net_with_criterion, optimizer)
113    train_network.set_train()
114    for i in range(epoch):
115        train_network(Tensor([x_train[i]]), Tensor([y_train[i]]))
116