1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test network turn on mix_precision.""" 15 16import os 17import re 18import pytest 19import numpy as np 20from mindspore.common import dtype 21from mindspore import nn 22from mindspore import ops 23from mindspore import amp 24from mindspore import Tensor 25from mindspore import context 26from mindspore.train.loss_scale_manager import FixedLossScaleManager 27from mindspore.train.model import Model 28from utils import FakeData 29from utils import allclose_nparray 30from utils import FakeDataInitMode 31from utils import find_newest_validateir_file 32from utils import clean_all_ir_files 33from tests.security_utils import security_off_wrap 34 35def read_validateir_file(path_folder): 36 filename = find_newest_validateir_file(path_folder) 37 with open(os.path.join(filename), 'r') as f: 38 contend = f.read() 39 return contend 40 41 42class Net(nn.Cell): 43 def __init__(self, in_c, out_c): 44 super().__init__() 45 self.relu = nn.ReLU() 46 self.bn1 = nn.BatchNorm2d(num_features=in_c, 47 gamma_init='ones', 48 beta_init='zeros', 49 moving_mean_init='zeros', 50 moving_var_init='ones') 51 self.bn2 = nn.BatchNorm2d(num_features=out_c, 52 gamma_init='ones', 53 beta_init='zeros', 54 moving_mean_init='zeros', 55 moving_var_init='ones') 56 self.conv = nn.Conv2d(in_channels=in_c, 57 out_channels=out_c, 58 kernel_size=3, 59 stride=1, 60 has_bias=True, 61 pad_mode='same', 62 weight_init='ones', 63 bias_init='ones') 64 self.mean = ops.ReduceMean(keep_dims=False) 65 66 def construct(self, x): 67 x = self.relu(x) 68 x = self.bn1(x) 69 x = self.conv(x) 70 x = self.bn2(x) 71 x = self.relu(x) 72 x = self.mean(x, (2, 3)) 73 return x 74 75 76@pytest.mark.level1 77@pytest.mark.platform_arm_ascend_training 78@pytest.mark.platform_x86_ascend_training 79@pytest.mark.platform_x86_gpu_training 80@pytest.mark.env_onecard 81def test_sit_auto_mix_precision_train_o3(): 82 input_data = np.random.randn(32, 3, 224, 224).astype(np.float64) 83 label_data = np.random.randn(32, 10).astype(np.float32) 84 # graph mode 85 context.set_context(mode=context.GRAPH_MODE) 86 net = Net(3, 10) 87 opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009, weight_decay=0.001, 88 loss_scale=0.0001) 89 loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 90 train_network = amp.build_train_network(net, opt, loss, level="O3", 91 loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)) 92 out = train_network(Tensor(input_data), Tensor(label_data)) 93 94 # pynative mode 95 context.set_context(mode=context.PYNATIVE_MODE) 96 net_pynative = Net(3, 10) 97 opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009, 98 weight_decay=0.001, 99 loss_scale=0.0001) 100 loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 101 train_network_pynative = amp.build_train_network(net_pynative, opt_pynative, loss_pynative, level="O3", 102 loss_scale_manager=FixedLossScaleManager( 103 drop_overflow_update=False)) 104 out_pynative = train_network_pynative(Tensor(input_data), Tensor(label_data)) 105 assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001) 106 107 108@pytest.mark.level1 109@pytest.mark.platform_arm_ascend_training 110@pytest.mark.platform_x86_ascend_training 111@pytest.mark.env_onecard 112@security_off_wrap 113def test_sit_auto_mix_precision_model_o0(): 114 input_data = np.random.randn(32, 3, 224, 224).astype(np.float32) 115 dataset1 = FakeData(size=32, 116 batch_size=32, 117 image_size=(3, 224, 224), 118 num_classes=10, 119 fakedata_mode=FakeDataInitMode.OnesInit) 120 dataset1.set_label_data_type(np.float16) 121 # graph mode 122 context.set_context(mode=context.GRAPH_MODE) 123 context.set_context(save_graphs=True, save_graphs_path='./test_amp_o0') 124 net = Net(3, 10) 125 net.to_float(dtype.float16) 126 opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009) 127 loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 128 model = Model(net, loss, opt, amp_level="O0") 129 model.train(1, dataset1, dataset_sink_mode=False) 130 contend = read_validateir_file('./test_amp_o0/') 131 castnum = re.findall(r"Cast\(", contend) 132 assert len(castnum) == 5 133 clean_all_ir_files('./test_amp_o0') 134 model.predict(Tensor(input_data)) 135 contend = read_validateir_file('./test_amp_o0/') 136 castnum = re.findall(r"Cast\(", contend) 137 assert len(castnum) == 11 138 clean_all_ir_files('./test_amp_o0/') 139 140 141@pytest.mark.level0 142@pytest.mark.platform_arm_ascend_training 143@pytest.mark.platform_x86_ascend_training 144@pytest.mark.platform_x86_gpu_training 145@pytest.mark.env_onecard 146@security_off_wrap 147def test_sit_auto_mix_precision_model_o2(): 148 input_data = np.random.randn(32, 3, 224, 224).astype(np.float32) 149 dataset1 = FakeData(size=32, 150 batch_size=32, 151 image_size=(3, 224, 224), 152 num_classes=10, 153 fakedata_mode=FakeDataInitMode.OnesInit) 154 dataset2 = FakeData(size=32, 155 batch_size=32, 156 image_size=(3, 224, 224), 157 num_classes=10, 158 fakedata_mode=FakeDataInitMode.OnesInit) 159 # graph mode 160 context.set_context(mode=context.GRAPH_MODE) 161 context.set_context(save_graphs=True, save_graphs_path='./test_amp_o2') 162 net = Net(3, 10) 163 opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009) 164 loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 165 model = Model(net, loss, opt, amp_level="O2") 166 model.train(1, dataset1, dataset_sink_mode=False) 167 contend = read_validateir_file('./test_amp_o2/') 168 castnum = re.findall(r"Cast\(", contend) 169 assert len(castnum) == 14 170 clean_all_ir_files('./test_amp_o2/') 171 out_graph = model.predict(Tensor(input_data)) 172 173 # pynative mode 174 context.set_context(mode=context.PYNATIVE_MODE) 175 net_pynative = Net(3, 10) 176 opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009) 177 loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 178 model_pynative = Model(net_pynative, loss_pynative, opt_pynative, amp_level="O2") 179 model_pynative.train(1, dataset2, dataset_sink_mode=False) 180 out_pynative = model_pynative.predict(Tensor(input_data)) 181 allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001) 182