• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test network turn on mix_precision."""
15
16import os
17import re
18import pytest
19import numpy as np
20from mindspore.common import dtype
21from mindspore import nn
22from mindspore import ops
23from mindspore import amp
24from mindspore import Tensor
25from mindspore import context
26from mindspore.train.loss_scale_manager import FixedLossScaleManager
27from mindspore.train.model import Model
28from utils import FakeData
29from utils import allclose_nparray
30from utils import FakeDataInitMode
31from utils import find_newest_validateir_file
32from utils import clean_all_ir_files
33from tests.security_utils import security_off_wrap
34
35def read_validateir_file(path_folder):
36    filename = find_newest_validateir_file(path_folder)
37    with open(os.path.join(filename), 'r') as f:
38        contend = f.read()
39    return contend
40
41
42class Net(nn.Cell):
43    def __init__(self, in_c, out_c):
44        super().__init__()
45        self.relu = nn.ReLU()
46        self.bn1 = nn.BatchNorm2d(num_features=in_c,
47                                  gamma_init='ones',
48                                  beta_init='zeros',
49                                  moving_mean_init='zeros',
50                                  moving_var_init='ones')
51        self.bn2 = nn.BatchNorm2d(num_features=out_c,
52                                  gamma_init='ones',
53                                  beta_init='zeros',
54                                  moving_mean_init='zeros',
55                                  moving_var_init='ones')
56        self.conv = nn.Conv2d(in_channels=in_c,
57                              out_channels=out_c,
58                              kernel_size=3,
59                              stride=1,
60                              has_bias=True,
61                              pad_mode='same',
62                              weight_init='ones',
63                              bias_init='ones')
64        self.mean = ops.ReduceMean(keep_dims=False)
65
66    def construct(self, x):
67        x = self.relu(x)
68        x = self.bn1(x)
69        x = self.conv(x)
70        x = self.bn2(x)
71        x = self.relu(x)
72        x = self.mean(x, (2, 3))
73        return x
74
75
76@pytest.mark.level1
77@pytest.mark.platform_arm_ascend_training
78@pytest.mark.platform_x86_ascend_training
79@pytest.mark.platform_x86_gpu_training
80@pytest.mark.env_onecard
81def test_sit_auto_mix_precision_train_o3():
82    input_data = np.random.randn(32, 3, 224, 224).astype(np.float64)
83    label_data = np.random.randn(32, 10).astype(np.float32)
84    # graph mode
85    context.set_context(mode=context.GRAPH_MODE)
86    net = Net(3, 10)
87    opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009, weight_decay=0.001,
88                      loss_scale=0.0001)
89    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
90    train_network = amp.build_train_network(net, opt, loss, level="O3",
91                                            loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False))
92    out = train_network(Tensor(input_data), Tensor(label_data))
93
94    # pynative mode
95    context.set_context(mode=context.PYNATIVE_MODE)
96    net_pynative = Net(3, 10)
97    opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009,
98                               weight_decay=0.001,
99                               loss_scale=0.0001)
100    loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
101    train_network_pynative = amp.build_train_network(net_pynative, opt_pynative, loss_pynative, level="O3",
102                                                     loss_scale_manager=FixedLossScaleManager(
103                                                         drop_overflow_update=False))
104    out_pynative = train_network_pynative(Tensor(input_data), Tensor(label_data))
105    assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
106
107
108@pytest.mark.level1
109@pytest.mark.platform_arm_ascend_training
110@pytest.mark.platform_x86_ascend_training
111@pytest.mark.env_onecard
112@security_off_wrap
113def test_sit_auto_mix_precision_model_o0():
114    input_data = np.random.randn(32, 3, 224, 224).astype(np.float32)
115    dataset1 = FakeData(size=32,
116                        batch_size=32,
117                        image_size=(3, 224, 224),
118                        num_classes=10,
119                        fakedata_mode=FakeDataInitMode.OnesInit)
120    dataset1.set_label_data_type(np.float16)
121    # graph mode
122    context.set_context(mode=context.GRAPH_MODE)
123    context.set_context(save_graphs=True, save_graphs_path='./test_amp_o0')
124    net = Net(3, 10)
125    net.to_float(dtype.float16)
126    opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009)
127    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
128    model = Model(net, loss, opt, amp_level="O0")
129    model.train(1, dataset1, dataset_sink_mode=False)
130    contend = read_validateir_file('./test_amp_o0/')
131    castnum = re.findall(r"Cast\(", contend)
132    assert len(castnum) == 5
133    clean_all_ir_files('./test_amp_o0')
134    model.predict(Tensor(input_data))
135    contend = read_validateir_file('./test_amp_o0/')
136    castnum = re.findall(r"Cast\(", contend)
137    assert len(castnum) == 11
138    clean_all_ir_files('./test_amp_o0/')
139
140
141@pytest.mark.level0
142@pytest.mark.platform_arm_ascend_training
143@pytest.mark.platform_x86_ascend_training
144@pytest.mark.platform_x86_gpu_training
145@pytest.mark.env_onecard
146@security_off_wrap
147def test_sit_auto_mix_precision_model_o2():
148    input_data = np.random.randn(32, 3, 224, 224).astype(np.float32)
149    dataset1 = FakeData(size=32,
150                        batch_size=32,
151                        image_size=(3, 224, 224),
152                        num_classes=10,
153                        fakedata_mode=FakeDataInitMode.OnesInit)
154    dataset2 = FakeData(size=32,
155                        batch_size=32,
156                        image_size=(3, 224, 224),
157                        num_classes=10,
158                        fakedata_mode=FakeDataInitMode.OnesInit)
159    # graph mode
160    context.set_context(mode=context.GRAPH_MODE)
161    context.set_context(save_graphs=True, save_graphs_path='./test_amp_o2')
162    net = Net(3, 10)
163    opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009)
164    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
165    model = Model(net, loss, opt, amp_level="O2")
166    model.train(1, dataset1, dataset_sink_mode=False)
167    contend = read_validateir_file('./test_amp_o2/')
168    castnum = re.findall(r"Cast\(", contend)
169    assert len(castnum) == 14
170    clean_all_ir_files('./test_amp_o2/')
171    out_graph = model.predict(Tensor(input_data))
172
173    # pynative mode
174    context.set_context(mode=context.PYNATIVE_MODE)
175    net_pynative = Net(3, 10)
176    opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009)
177    loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
178    model_pynative = Model(net_pynative, loss_pynative, opt_pynative, amp_level="O2")
179    model_pynative.train(1, dataset2, dataset_sink_mode=False)
180    out_pynative = model_pynative.predict(Tensor(input_data))
181    allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
182