1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore.nn as nn 19from mindspore.ops import composite as C 20from mindspore.nn import Momentum 21from mindspore import context, Tensor 22from mindspore.common.api import ms_function 23 24grad_all = C.GradOperation(get_all=True) 25 26 27class CellBprop(nn.Cell): 28 def __init__(self): 29 super(CellBprop, self).__init__() 30 31 def construct(self, x, y): 32 return 2 * x * x + y * y 33 34 @ms_function 35 def bprop(self, x, y, out, dout): 36 return dout, 2 * y 37 38 39def test_cell_bprop_grad(): 40 input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) 41 input_y = Tensor(np.random.randn(2, 2).astype(np.float32)) 42 context.set_context(mode=context.PYNATIVE_MODE) 43 net = CellBprop() 44 with pytest.raises(RuntimeError): 45 grad_all(net)(input_x, input_y) 46 47 48class ConvNet(nn.Cell): 49 def __init__(self): 50 super(ConvNet, self).__init__() 51 self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid") 52 53 def construct(self, x): 54 out = self.conv(x) 55 return out 56 57 58class MomentumWithMsFunc(nn.Cell): 59 def __init__(self, net): 60 super(MomentumWithMsFunc, self).__init__() 61 self.net = net 62 self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9) 63 64 @ms_function 65 def construct(self, grads): 66 ret = self.optimizer(grads) 67 return ret 68 69 70def test_ms_func_decorate_forward(): 71 context.set_context(mode=context.PYNATIVE_MODE) 72 input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32)) 73 net = ConvNet() 74 grad_out = grad_all(net)(input_x) 75 opt = MomentumWithMsFunc(net) 76 opt(grad_out) 77