• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16""" test_fit """
17import sys
18import re
19
20import pytest
21import numpy as np
22
23import mindspore as ms
24from mindspore import Model, nn
25from mindspore.train.callback import LossMonitor
26from mindspore import dataset as ds
27
28
29def get_data(num, w=2.0, b=3.0):
30    for _ in range(num):
31        x = 0
32        y = x * w + b
33        yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
34
35
36def create_dataset(num_data, batch_size=16, repeat_size=1):
37    input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
38    input_data = input_data.batch(batch_size, drop_remainder=True)
39    input_data = input_data.repeat(repeat_size)
40    return input_data
41
42
43def define_model():
44    net = nn.Dense(1, 1, has_bias=False)
45    net_loss = nn.MSELoss()
46    net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
47    return Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={'mse', 'mae'})
48
49
50class Redirect:
51    """
52    Get the content of callbacks.
53    """
54    content = ""
55
56    def write(self, str1):
57        self.content = str1 + self.content
58
59    def flush(self):
60        self.content = ""
61
62
63@pytest.mark.level2
64@pytest.mark.platform_x86_cpu
65@pytest.mark.platform_arm_cpu
66@pytest.mark.platform_x86_gpu_training
67@pytest.mark.platform_arm_ascend_training
68@pytest.mark.platform_x86_ascend_training
69@pytest.mark.env_onecard
70@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
71def test_fit_train_dataset_non_sink_mode(mode):
72    """
73    Feature: `mindspore.train.Model.fit` with train dataset in non-sink mode.
74    Description: test fit with train dataset in non-sink mode.
75    Expectation: run in non-sink mode.
76    """
77    ms.set_context(mode=mode)
78    model = define_model()
79    ds_train = create_dataset(4096, 1024)
80    ds_eval = create_dataset(1024, 512)
81    callbacks = [LossMonitor()]
82    r = Redirect()
83    current = sys.stdout
84    sys.stdout = r
85    model.fit(1, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=False)
86    sys.stdout = current
87    assert re.search("'mse': 9.0", r.content)
88    assert re.search("'mae': 3.0", r.content)
89    r.flush()
90
91
92@pytest.mark.level2
93@pytest.mark.platform_x86_cpu
94@pytest.mark.platform_arm_cpu
95@pytest.mark.platform_x86_gpu_training
96@pytest.mark.platform_arm_ascend_training
97@pytest.mark.platform_x86_ascend_training
98@pytest.mark.env_onecard
99@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
100def test_fit_train_dataset_sink_mode(mode):
101    """
102    Feature: `mindspore.train.Model.fit` with train dataset in sink mode.
103    Description: test fit with train dataset in sink mode.
104    Expectation: run in sink mode.
105    """
106    ms.set_context(mode=mode)
107    model = define_model()
108    ds_train = create_dataset(4096, 1024)
109    ds_eval = create_dataset(1024, 512)
110    callbacks = [LossMonitor()]
111    r = Redirect()
112    current = sys.stdout
113    sys.stdout = r
114    model.fit(1, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=True, sink_size=256)
115    sys.stdout = current
116    assert re.search("'mse': 9.0", r.content)
117    assert re.search("'mae': 3.0", r.content)
118    r.flush()
119
120
121@pytest.mark.level2
122@pytest.mark.platform_x86_cpu
123@pytest.mark.platform_arm_cpu
124@pytest.mark.platform_x86_gpu_training
125@pytest.mark.platform_arm_ascend_training
126@pytest.mark.platform_x86_ascend_training
127@pytest.mark.env_onecard
128@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
129def test_fit_valid_dataset_non_sink_mode(mode):
130    """
131    Feature: `mindspore.train.Model.fit` with valid dataset in non-sink mode.
132    Description: test fit with valid dataset in non-sink mode.
133    Expectation: run in non-sink mode.
134    """
135    ms.set_context(mode=mode)
136    model = define_model()
137    ds_train = create_dataset(4096, 1024)
138    ds_eval = create_dataset(1024, 512)
139    callbacks = [LossMonitor()]
140    r = Redirect()
141    current = sys.stdout
142    sys.stdout = r
143    model.fit(1, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=False)
144    sys.stdout = current
145    assert re.search("'mse': 9.0", r.content)
146    assert re.search("'mae': 3.0", r.content)
147    r.flush()
148
149
150@pytest.mark.level2
151@pytest.mark.platform_x86_cpu
152@pytest.mark.platform_arm_cpu
153@pytest.mark.platform_x86_gpu_training
154@pytest.mark.platform_arm_ascend_training
155@pytest.mark.platform_x86_ascend_training
156@pytest.mark.env_onecard
157@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
158def test_fit_valid_dataset_sink_mode(mode):
159    """
160    Feature: `mindspore.train.Model.fit` with valid dataset in sink mode.
161    Description: test fit with valid dataset in sink mode.
162    Expectation: run in sink mode.
163    """
164    ms.set_context(mode=mode)
165    model = define_model()
166    ds_train = create_dataset(4096, 1024)
167    ds_eval = create_dataset(1024, 512)
168    callbacks = [LossMonitor()]
169    r = Redirect()
170    current = sys.stdout
171    sys.stdout = r
172    model.fit(1, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=True)
173    sys.stdout = current
174    assert re.search("'mse': 9.0", r.content)
175    assert re.search("'mae': 3.0", r.content)
176    r.flush()
177
178
179@pytest.mark.level2
180@pytest.mark.platform_x86_cpu
181@pytest.mark.platform_arm_cpu
182@pytest.mark.platform_x86_gpu_training
183@pytest.mark.platform_arm_ascend_training
184@pytest.mark.platform_x86_ascend_training
185@pytest.mark.env_onecard
186@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
187def test_fit_valid_frequency(mode):
188    """
189    Feature: check `valid_frequency` input  in `mindspore.train.Model.fit`.
190    Description: when `valid_frequency` is integer, list or other types.
191    Expectation: Executed fit valid frequency successfully.
192    """
193    ms.set_context(mode=mode)
194    model = define_model()
195    callbacks = [LossMonitor()]
196    ds_train = create_dataset(4096, 1024)
197    ds_eval = create_dataset(1024, 512)
198    r = Redirect()
199    current = sys.stdout
200    sys.stdout = r
201    model.fit(4, ds_train, ds_eval, valid_frequency=2, callbacks=callbacks)
202    sys.stdout = current
203    assert re.search("Eval result: epoch 4", r.content)
204    assert re.search("'mse': 9.0", r.content)
205    assert re.search("'mae': 3.0", r.content)
206    r.flush()
207