• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test model train """
16import numpy as np
17
18import mindspore.nn as nn
19from mindspore import Tensor, Parameter, Model
20from mindspore.common.initializer import initializer
21from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
22from mindspore.nn.optim import Momentum
23from mindspore.ops import operations as P
24
25
26# fn is a funcation use i as input
27def lr_gen(fn, epoch_size):
28    for i in range(epoch_size):
29        yield fn(i)
30
31
32def me_train_tensor(net, input_np, label_np, epoch_size=2):
33    """me_train_tensor"""
34    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
35    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr_gen(lambda i: 0.1, epoch_size), 0.9,
36                   0.01, 1024)
37    Model(net, loss, opt)
38    _network = nn.WithLossCell(net, loss)
39    _train_net = nn.TrainOneStepCell(_network, opt)
40    _train_net.set_train()
41    label_np = np.argmax(label_np, axis=-1).astype(np.int32)
42    for epoch in range(0, epoch_size):
43        print(f"epoch %d" % (epoch))
44        _train_net(Tensor(input_np), Tensor(label_np))
45
46
47def test_bias_add(test_with_simu):
48    """test_bias_add"""
49    import mindspore.context as context
50    is_pynative_mode = (context.get_context("mode") == context.PYNATIVE_MODE)
51    # training api is implemented under Graph mode
52    if is_pynative_mode:
53        context.set_context(mode=context.GRAPH_MODE)
54    if test_with_simu:
55        return
56
57    class Net(nn.Cell):
58        """Net definition"""
59
60        def __init__(self,
61                     output_channels,
62                     bias_init='zeros',
63                     ):
64            super(Net, self).__init__()
65            self.biasAdd = P.BiasAdd()
66
67            if isinstance(bias_init, Tensor):
68                if bias_init.ndim != 1 or bias_init.shape[0] != output_channels:
69                    raise ValueError("bias_init shape error")
70
71            self.bias = Parameter(initializer(
72                bias_init, [output_channels]), name="bias")
73
74        def construct(self, input_x):
75            return self.biasAdd(input_x, self.bias)
76
77    bias_init = Tensor(np.ones([3]).astype(np.float32))
78    input_np = np.ones([1, 3, 3, 3], np.float32)
79    label_np = np.ones([1, 3, 3, 3], np.int32) * 2
80    me_train_tensor(Net(3, bias_init=bias_init), input_np, label_np)
81
82
83def test_conv(test_with_simu):
84    """test_conv"""
85    import mindspore.context as context
86    is_pynative_mode = (context.get_context("mode") == context.PYNATIVE_MODE)
87    # training api is implemented under Graph mode
88    if is_pynative_mode:
89        context.set_context(mode=context.GRAPH_MODE)
90    if test_with_simu:
91        return
92
93    class Net(nn.Cell):
94        "Net definition"""
95
96        def __init__(self,
97                     cin,
98                     cout,
99                     kernel_size):
100            super(Net, self).__init__()
101            Tensor(np.ones([6, 3, 3, 3]).astype(np.float32) * 0.01)
102            self.conv = nn.Conv2d(cin,
103                                  cout,
104                                  kernel_size)
105
106        def construct(self, input_x):
107            return self.conv(input_x)
108
109    net = Net(3, 6, (3, 3))
110    input_np = np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01
111    label_np = np.ones([1, 6, 32, 32]).astype(np.int32)
112    me_train_tensor(net, input_np, label_np)
113
114
115def test_net():
116    """test_net"""
117    import mindspore.context as context
118    is_pynative_mode = (context.get_context("mode") == context.PYNATIVE_MODE)
119    # training api is implemented under Graph mode
120    if is_pynative_mode:
121        context.set_context(mode=context.GRAPH_MODE)
122
123    class Net(nn.Cell):
124        """Net definition"""
125
126        def __init__(self):
127            super(Net, self).__init__()
128            Tensor(np.ones([64, 3, 7, 7]).astype(np.float32) * 0.01)
129            self.conv = nn.Conv2d(3, 64, (7, 7), pad_mode="same", stride=2)
130            self.relu = nn.ReLU()
131            self.bn = nn.BatchNorm2d(64)
132            self.mean = P.ReduceMean(keep_dims=True)
133            self.flatten = nn.Flatten()
134            self.dense = nn.Dense(64, 12)
135
136        def construct(self, input_x):
137            output = input_x
138            output = self.conv(output)
139            output = self.bn(output)
140            output = self.relu(output)
141            output = self.mean(output, (-2, -1))
142            output = self.flatten(output)
143            output = self.dense(output)
144            return output
145
146    net = Net()
147    input_np = np.ones([32, 3, 224, 224]).astype(np.float32) * 0.01
148    label_np = np.ones([32, 12]).astype(np.int32)
149    me_train_tensor(net, input_np, label_np)
150
151
152def test_bn():
153    """test_bn"""
154    import mindspore.context as context
155    is_pynative_mode = (context.get_context("mode") == context.PYNATIVE_MODE)
156    # training api is implemented under Graph mode
157    if is_pynative_mode:
158        context.set_context(mode=context.GRAPH_MODE)
159
160    class Net(nn.Cell):
161        """Net definition"""
162
163        def __init__(self, cin, cout):
164            super(Net, self).__init__()
165            self.bn = nn.BatchNorm2d(cin)
166            self.flatten = nn.Flatten()
167            self.dense = nn.Dense(cin, cout)
168
169        def construct(self, input_x):
170            output = input_x
171            output = self.bn(output)
172            output = self.flatten(output)
173            output = self.dense(output)
174            return output
175
176    net = Net(2048, 16)
177    input_np = np.ones([32, 2048, 1, 1]).astype(np.float32) * 0.01
178    label_np = np.ones([32, 16]).astype(np.int32)
179    me_train_tensor(net, input_np, label_np)
180