• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_training """
16import logging
17import numpy as np
18import pytest
19
20import mindspore.nn as nn
21from mindspore import Model, context
22from mindspore import Tensor
23from mindspore.train.callback import Callback
24from mindspore.nn.optim import Momentum
25from ..ut_filter import non_graph_engine
26from ....dataset_mock import MindData
27
28
29class Net(nn.Cell):
30    """ Net definition """
31
32    def __init__(self):
33        super(Net, self).__init__()
34        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
35        self.bn = nn.BatchNorm2d(64)
36        self.relu = nn.ReLU()
37        self.flatten = nn.Flatten()
38        self.fc = nn.Dense(64 * 222 * 222, 3)  # padding=0
39
40    def construct(self, x):
41        x = self.conv(x)
42        x = self.bn(x)
43        x = self.relu(x)
44        x = self.flatten(x)
45        out = self.fc(x)
46        return out
47
48
49class LossNet(nn.Cell):
50    """ LossNet definition """
51
52    def __init__(self):
53        super(LossNet, self).__init__()
54        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
55        self.bn = nn.BatchNorm2d(64)
56        self.relu = nn.ReLU()
57        self.flatten = nn.Flatten()
58        self.fc = nn.Dense(64 * 222 * 222, 3)  # padding=0
59        self.loss = nn.SoftmaxCrossEntropyWithLogits()
60
61    def construct(self, x, y):
62        x = self.conv(x)
63        x = self.bn(x)
64        x = self.relu(x)
65        x = self.flatten(x)
66        x = self.fc(x)
67        out = self.loss(x, y)
68        return out
69
70
71def get_model(metrics=None):
72    """ get_model """
73    net = Net()
74    loss = nn.SoftmaxCrossEntropyWithLogits()
75    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
76    model = Model(net, loss_fn=loss, optimizer=optim, metrics=metrics)
77    return model
78
79
80def get_dataset():
81    """ get_dataset """
82    dataset_types = (np.float32, np.float32)
83    dataset_shapes = ((32, 3, 224, 224), (32, 3))
84
85    dataset = MindData(size=2, batch_size=32,
86                       np_types=dataset_types,
87                       output_shapes=dataset_shapes,
88                       input_indexs=(0, 1))
89    return dataset
90
91
92@non_graph_engine
93def test_single_input():
94    """ test_single_input """
95    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
96    context.set_context(mode=context.GRAPH_MODE)
97    model = Model(Net())
98    out = model.predict(input_data)
99    assert out is not None
100
101
102@non_graph_engine
103def test_multiple_argument():
104    """ test_multiple_argument """
105    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
106    input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32))
107    context.set_context(mode=context.GRAPH_MODE)
108    model = Model(LossNet())
109    out = model.predict(input_data, input_label)
110    assert out is not None
111
112
113def test_train_feed_mode(test_with_simu):
114    """ test_train_feed_mode """
115    dataset = get_dataset()
116    model = get_model()
117    if test_with_simu:
118        return
119    model.train(2, dataset)
120
121
122def test_dataset_sink_mode_args_check():
123    """ test_dataset_sink_mode_args_check """
124    dataset = get_dataset()
125    model = get_model()
126    with pytest.raises(TypeError):
127        model.train(2, dataset, dataset_sink_mode="True")
128
129    with pytest.raises(TypeError):
130        model.train(2, dataset, dataset_sink_mode=1)
131
132
133@non_graph_engine
134def test_eval():
135    """ test_eval """
136    dataset_types = (np.float32, np.float32)
137    dataset_shapes = ((32, 3, 224, 224), (32, 3))
138
139    dataset = MindData(size=2, batch_size=32,
140                       np_types=dataset_types,
141                       output_shapes=dataset_shapes,
142                       input_indexs=(0, 1))
143    net = Net()
144    context.set_context(mode=context.GRAPH_MODE)
145    model = Model(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(), metrics={"loss"})
146    with pytest.raises(ValueError):
147        model.eval(dataset)
148
149    net2 = LossNet()
150    model2 = Model(net2, eval_network=net2, eval_indexes=[0, 1, 2], metrics={"loss"})
151    with pytest.raises(ValueError):
152        model2.eval(dataset)
153
154    _ = LossNet()
155    model3 = Model(net2, eval_network=net2, metrics={"loss"})
156    with pytest.raises(ValueError):
157        model3.eval(dataset)
158
159
160class TestGraphMode:
161    """ TestGraphMode definition """
162
163    def test_train_minddata_graph_mode(self, test_with_simu):
164        """ test_train_minddata_graph_mode """
165        # pylint: disable=unused-argument
166        dataset_types = (np.float32, np.float32)
167        dataset_shapes = ((32, 3, 224, 224), (32, 3))
168
169        dataset = MindData(size=2, batch_size=32,
170                           np_types=dataset_types,
171                           output_shapes=dataset_shapes,
172                           input_indexs=())
173        model = get_model()
174        model.train(1, dataset)
175
176
177class CallbackTest(Callback):
178    """ CallbackTest definition """
179
180    def __init__(self):
181        pass
182
183    def __enter__(self):
184        return self
185
186    def __exit__(self, *err):
187        pass
188
189    def step_end(self, run_context):
190        cb_params = run_context.original_args()
191        print(cb_params.cur_epoch_num, cb_params.cur_step_num)
192
193
194def test_train_callback(test_with_simu):
195    """ test_train_callback """
196    dataset = get_dataset()
197    model = get_model()
198    callback = CallbackTest()
199    if test_with_simu:
200        return
201    model.train(2, dataset, callbacks=callback)
202
203
204log = logging.getLogger("test")
205log.setLevel(level=logging.ERROR)
206
207
208# Test the invalid args and trigger RuntimeError
209def test_model_build_abnormal_string():
210    """ test_model_build_abnormal_string """
211    net = nn.ReLU()
212    context.set_context(mode=context.GRAPH_MODE)
213    model = Model(net)
214    err = False
215    try:
216        model.predict('aaa')
217    except TypeError as e:
218        log.error("Find type error: %r ", e)
219        err = True
220    finally:
221        assert err
222
223
224def test_init_model_error():
225    """ test_init_model_error """
226    net = nn.ReLU()
227    loss = nn.SoftmaxCrossEntropyWithLogits()
228    with pytest.raises(KeyError):
229        Model(net, loss, metrics={"top1"})
230
231    with pytest.raises(ValueError):
232        Model(net, metrics={"top_1_accuracy"})
233
234    with pytest.raises(TypeError):
235        Model(net, metrics={"top5": None})
236
237    with pytest.raises(ValueError):
238        Model(net, eval_network=net, eval_indexes=[], metrics={"top_1_accuracy"})
239
240    with pytest.raises(ValueError):
241        Model(net, eval_network=net, eval_indexes=(1, 2, 3), metrics={"top_1_accuracy"})
242
243    with pytest.raises(TypeError):
244        Model(net, loss, metrics=("top_1_accuracy"))
245
246    with pytest.raises(TypeError):
247        Model(net, loss, metrics=())
248
249    with pytest.raises(TypeError):
250        Model(net, loss, metrics=["top_1_accuracy"])
251
252
253def test_model_eval_error():
254    """ test_model_eval_error """
255    dataset_types = (np.float32, np.float32)
256    dataset_shapes = ((32, 3, 224, 224), (32, 3))
257
258    dataset = MindData(size=2, batch_size=32,
259                       np_types=dataset_types,
260                       output_shapes=dataset_shapes,
261                       input_indexs=())
262
263    net = nn.ReLU()
264    loss = nn.SoftmaxCrossEntropyWithLogits()
265    context.set_context(mode=context.GRAPH_MODE)
266    model_nometrics = Model(net, loss)
267    with pytest.raises(ValueError):
268        model_nometrics.eval(dataset)
269
270    model_metrics_empty = Model(net, loss, metrics={})
271    with pytest.raises(ValueError):
272        model_metrics_empty.eval(dataset)
273