1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16""" test_fit """ 17import sys 18import re 19 20import pytest 21import numpy as np 22 23import mindspore as ms 24from mindspore import Model, nn 25from mindspore.train.callback import LossMonitor 26from mindspore import dataset as ds 27 28 29def get_data(num, w=2.0, b=3.0): 30 for _ in range(num): 31 x = 0 32 y = x * w + b 33 yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32) 34 35 36def create_dataset(num_data, batch_size=16, repeat_size=1): 37 input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label']) 38 input_data = input_data.batch(batch_size, drop_remainder=True) 39 input_data = input_data.repeat(repeat_size) 40 return input_data 41 42 43def define_model(): 44 net = nn.Dense(1, 1, has_bias=False) 45 net_loss = nn.MSELoss() 46 net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) 47 return Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={'mse', 'mae'}) 48 49 50class Redirect: 51 """ 52 Get the content of callbacks. 53 """ 54 content = "" 55 56 def write(self, str1): 57 self.content = str1 + self.content 58 59 def flush(self): 60 self.content = "" 61 62 63@pytest.mark.level2 64@pytest.mark.platform_x86_cpu 65@pytest.mark.platform_arm_cpu 66@pytest.mark.platform_x86_gpu_training 67@pytest.mark.platform_arm_ascend_training 68@pytest.mark.platform_x86_ascend_training 69@pytest.mark.env_onecard 70@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 71def test_fit_train_dataset_non_sink_mode(mode): 72 """ 73 Feature: `mindspore.train.Model.fit` with train dataset in non-sink mode. 74 Description: test fit with train dataset in non-sink mode. 75 Expectation: run in non-sink mode. 76 """ 77 ms.set_context(mode=mode) 78 model = define_model() 79 ds_train = create_dataset(4096, 1024) 80 ds_eval = create_dataset(1024, 512) 81 callbacks = [LossMonitor()] 82 r = Redirect() 83 current = sys.stdout 84 sys.stdout = r 85 model.fit(1, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=False) 86 sys.stdout = current 87 assert re.search("'mse': 9.0", r.content) 88 assert re.search("'mae': 3.0", r.content) 89 r.flush() 90 91 92@pytest.mark.level2 93@pytest.mark.platform_x86_cpu 94@pytest.mark.platform_arm_cpu 95@pytest.mark.platform_x86_gpu_training 96@pytest.mark.platform_arm_ascend_training 97@pytest.mark.platform_x86_ascend_training 98@pytest.mark.env_onecard 99@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 100def test_fit_train_dataset_sink_mode(mode): 101 """ 102 Feature: `mindspore.train.Model.fit` with train dataset in sink mode. 103 Description: test fit with train dataset in sink mode. 104 Expectation: run in sink mode. 105 """ 106 ms.set_context(mode=mode) 107 model = define_model() 108 ds_train = create_dataset(4096, 1024) 109 ds_eval = create_dataset(1024, 512) 110 callbacks = [LossMonitor()] 111 r = Redirect() 112 current = sys.stdout 113 sys.stdout = r 114 model.fit(1, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=True, sink_size=256) 115 sys.stdout = current 116 assert re.search("'mse': 9.0", r.content) 117 assert re.search("'mae': 3.0", r.content) 118 r.flush() 119 120 121@pytest.mark.level2 122@pytest.mark.platform_x86_cpu 123@pytest.mark.platform_arm_cpu 124@pytest.mark.platform_x86_gpu_training 125@pytest.mark.platform_arm_ascend_training 126@pytest.mark.platform_x86_ascend_training 127@pytest.mark.env_onecard 128@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 129def test_fit_valid_dataset_non_sink_mode(mode): 130 """ 131 Feature: `mindspore.train.Model.fit` with valid dataset in non-sink mode. 132 Description: test fit with valid dataset in non-sink mode. 133 Expectation: run in non-sink mode. 134 """ 135 ms.set_context(mode=mode) 136 model = define_model() 137 ds_train = create_dataset(4096, 1024) 138 ds_eval = create_dataset(1024, 512) 139 callbacks = [LossMonitor()] 140 r = Redirect() 141 current = sys.stdout 142 sys.stdout = r 143 model.fit(1, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=False) 144 sys.stdout = current 145 assert re.search("'mse': 9.0", r.content) 146 assert re.search("'mae': 3.0", r.content) 147 r.flush() 148 149 150@pytest.mark.level2 151@pytest.mark.platform_x86_cpu 152@pytest.mark.platform_arm_cpu 153@pytest.mark.platform_x86_gpu_training 154@pytest.mark.platform_arm_ascend_training 155@pytest.mark.platform_x86_ascend_training 156@pytest.mark.env_onecard 157@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 158def test_fit_valid_dataset_sink_mode(mode): 159 """ 160 Feature: `mindspore.train.Model.fit` with valid dataset in sink mode. 161 Description: test fit with valid dataset in sink mode. 162 Expectation: run in sink mode. 163 """ 164 ms.set_context(mode=mode) 165 model = define_model() 166 ds_train = create_dataset(4096, 1024) 167 ds_eval = create_dataset(1024, 512) 168 callbacks = [LossMonitor()] 169 r = Redirect() 170 current = sys.stdout 171 sys.stdout = r 172 model.fit(1, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=True) 173 sys.stdout = current 174 assert re.search("'mse': 9.0", r.content) 175 assert re.search("'mae': 3.0", r.content) 176 r.flush() 177 178 179@pytest.mark.level2 180@pytest.mark.platform_x86_cpu 181@pytest.mark.platform_arm_cpu 182@pytest.mark.platform_x86_gpu_training 183@pytest.mark.platform_arm_ascend_training 184@pytest.mark.platform_x86_ascend_training 185@pytest.mark.env_onecard 186@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 187def test_fit_valid_frequency(mode): 188 """ 189 Feature: check `valid_frequency` input in `mindspore.train.Model.fit`. 190 Description: when `valid_frequency` is integer, list or other types. 191 Expectation: Executed fit valid frequency successfully. 192 """ 193 ms.set_context(mode=mode) 194 model = define_model() 195 callbacks = [LossMonitor()] 196 ds_train = create_dataset(4096, 1024) 197 ds_eval = create_dataset(1024, 512) 198 r = Redirect() 199 current = sys.stdout 200 sys.stdout = r 201 model.fit(4, ds_train, ds_eval, valid_frequency=2, callbacks=callbacks) 202 sys.stdout = current 203 assert re.search("Eval result: epoch 4", r.content) 204 assert re.search("'mse': 9.0", r.content) 205 assert re.search("'mae': 3.0", r.content) 206 r.flush() 207