• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore.nn as nn
19from mindspore import Tensor, Parameter
20from mindspore import context
21from mindspore.common import dtype as mstype
22from mindspore.nn.optim import Momentum
23from mindspore.nn.wrap.cell_wrapper import WithLossCell
24from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell
25from mindspore.ops import functional as F
26from mindspore.ops import operations as P
27from mindspore.ops._grad.grad_base import bprop_getters
28from mindspore.ops._grad.grad_math_ops import binop_grad_common
29from mindspore.ops._utils import get_broadcast_shape
30from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
31from mindspore.train.loss_scale_manager import DynamicLossScaleManager
32
33context.set_context(mode=context.GRAPH_MODE)
34
35
36class MockNeg(PrimitiveWithInfer):
37    @prim_attr_register
38    def __init__(self):
39        """init MockNeg"""
40        self.init_prim_io_names(inputs=['x'], outputs=['y'])
41
42    def infer_shape(self, input_x):
43        return input_x
44
45    def infer_dtype(self, input_x):
46        raise TypeError("InferError")
47        # return input_x
48
49
50class MockSub(PrimitiveWithInfer):
51    @prim_attr_register
52    def __init__(self):
53        """init MockSub"""
54        self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
55
56    def infer_shape(self, x_shape, y_shape):
57        return get_broadcast_shape(x_shape, y_shape)
58
59    def infer_dtype(self, x_dtype, y_dtype):
60        return x_dtype
61
62
63@bprop_getters.register(MockSub)
64def get_bprop_mock_sub(self):
65    """Grad definition for `MockSub` operation."""
66    neg_func = MockNeg()
67
68    def bprop(x, y, out, dout):
69        return binop_grad_common(x, y, dout, neg_func(dout))
70
71    return bprop
72
73
74class Net(nn.Cell):
75    def __init__(self, in_features, out_features):
76        super(Net, self).__init__()
77        self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
78        self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
79        self.matmul = P.MatMul()
80        self.add = P.Add()
81
82    def construct(self, input_):
83        output = self.add(self.matmul(input_, self.weight), self.bias)
84        return output
85
86
87class NetFP16(nn.Cell):
88    def __init__(self, in_features, out_features):
89        super(NetFP16, self).__init__()
90        self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
91        self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
92        self.matmul = P.MatMul()
93        self.add = P.Add()
94        self.cast = P.Cast()
95
96    def construct(self, input_):
97        output = self.cast(
98            self.add(self.matmul(self.cast(input_, mstype.float16), self.cast(self.weight, mstype.float16)),
99                     self.cast(self.bias, mstype.float16)), mstype.float32)
100        return output
101
102
103def get_axis(x):
104    shape = F.shape(x)
105    length = F.tuple_len(shape)
106    perm = F.make_range(0, length)
107    return perm
108
109
110class MSELoss(nn.Cell):
111    def __init__(self):
112        super(MSELoss, self).__init__()
113        self.reduce_sum = P.ReduceSum()
114        self.square = P.Square()
115        self.reduce_mean = P.ReduceMean()
116        self.sub = MockSub()
117
118    def construct(self, data, label):
119        diff = self.sub(data, label)
120        return self.reduce_mean(self.square(diff), get_axis(diff))
121
122
123class NegCell(nn.Cell):
124    def __init__(self):
125        super(NegCell, self).__init__()
126        self.neg = MockNeg()
127
128    def construct(self, x):
129        return self.neg(x)
130
131
132class Net3(nn.Cell):
133    def __init__(self):
134        super().__init__()
135        self.tuple = (NegCell(), nn.ReLU())
136
137    def construct(self, x):
138        for op in self.tuple:
139            x = op(x)
140        return x
141
142
143def test_op_forward_infererror():
144    input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
145    input_me = Tensor(input_np)
146    net = Net3()
147    with pytest.raises(TypeError):
148        net(input_me)
149
150
151class SequenceNet(nn.Cell):
152    def __init__(self):
153        super().__init__()
154        self.seq = nn.SequentialCell([nn.AvgPool2d(3, 1), nn.ReLU(), nn.Flatten()])
155
156    def construct(self, x):
157        x = self.seq(x) + bbb
158        return x
159
160
161def test_sequential_resolve_error():
162    input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
163    input_me = Tensor(input_np)
164    net = SequenceNet()
165    with pytest.raises(NameError):
166        net(input_me)
167
168
169def test_compile_grad_error():
170    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
171    label = Tensor(np.zeros([16, 16]).astype(np.float32))
172    lr = Tensor(np.ones([1], np.float32) * 0.1)
173    net = NetFP16(16, 16)
174    loss = MSELoss()
175    optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
176
177    net_with_loss = WithLossCell(net, loss)
178    scale_manager = DynamicLossScaleManager()
179    update_cell = scale_manager.get_update_cell()
180    train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell)
181    train_network.set_train()
182    with pytest.raises(TypeError) as e:
183        train_network(inputs, label)
184        print(e)
185