1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17import numpy as np 18import mindspore.nn as nn 19import mindspore.ops as ops 20from mindspore import context, Tensor, Parameter 21from mindspore.nn import TrainOneStepCell 22from mindspore.nn.optim import Momentum 23from mindspore.ops.composite import GradOperation 24from mindspore.common import ParameterTuple 25 26context.set_context(mode=context.GRAPH_MODE) 27 28 29class _Grad(nn.Cell): 30 def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): 31 super().__init__() 32 self.network = network 33 self.grad = grad 34 self.sens_param = self.grad.sens_param 35 self.wrt_params = wrt_params 36 self.real_inputs_count = real_inputs_count 37 if self.wrt_params: 38 self.params = ParameterTuple(self.network.trainable_params()) 39 40 def construct(self, *inputs): 41 if self.real_inputs_count is None or self.sens_param is False: 42 if self.wrt_params: 43 return self.grad(self.network, self.params)(*inputs) 44 return self.grad(self.network)(*inputs) 45 46 real_inputs = inputs[:self.real_inputs_count] 47 sense_param_inputs = inputs[self.real_inputs_count:] 48 if self.wrt_params: 49 return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) 50 return self.grad(self.network)(*real_inputs, sense_param_inputs) 51 52 53class GradOfFirstInput(_Grad): 54 """ 55 get grad of first input 56 """ 57 58 def __init__(self, network, sens_param=True, real_inputs_count=None): 59 super().__init__(grad=GradOperation(sens_param=sens_param), 60 network=network, real_inputs_count=real_inputs_count) 61 62 63class Net(nn.Cell): 64 def __init__(self): 65 super(Net, self).__init__() 66 self.mul = ops.Mul() 67 self.add = ops.TensorAdd() 68 weight_np = np.array([2]).astype(np.float32) 69 bias_np = np.array([1]).astype(np.float32) 70 self.weight = Parameter(Tensor(weight_np), 71 name='weight', requires_grad=True) 72 self.bias = Parameter(Tensor(bias_np), 73 name="bias", requires_grad=True) 74 75 def construct(self, x): 76 xw = self.mul(x, self.weight) 77 output = self.add(xw, self.bias) 78 return output 79 80 81class WithLossCellLocal(nn.Cell): 82 def __init__(self, grad, loss): 83 super(WithLossCellLocal, self).__init__(auto_prefix=False) 84 self.grad = grad 85 self.loss = loss 86 87 def construct(self, data, label): 88 out = self.grad(data) 89 return self.loss(out, label) 90 91 92@pytest.mark.level0 93@pytest.mark.platform_arm_ascend_training 94@pytest.mark.platform_x86_ascend_training 95@pytest.mark.platform_x86_gpu_training 96@pytest.mark.platform_x86_cpu_training 97@pytest.mark.env_onecard 98def test_high_grad_train(): 99 x_pure = np.random.randint(-10, 100, 32) 100 x_train = x_pure.astype(np.float32) 101 y_noise = 3 * x_pure + 2 + np.random.randn(32) / 10 102 y_train = y_noise.astype(np.float32) 103 net = Net() 104 grad_net = GradOfFirstInput(net, sens_param=False) 105 epoch = 2 106 momentum = 0.0 107 learning_rate = 0.001 108 optimizer = Momentum(filter(lambda x: x.requires_grad, 109 grad_net.get_parameters()), learning_rate, momentum) 110 criterion = nn.loss.MSELoss() 111 net_with_criterion = WithLossCellLocal(grad_net, criterion) 112 train_network = TrainOneStepCell(net_with_criterion, optimizer) 113 train_network.set_train() 114 for i in range(epoch): 115 train_network(Tensor([x_train[i]]), Tensor([y_train[i]])) 116