• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" auto mixed precision """
16import numpy as np
17import pytest
18
19import mindspore.context as context
20from mindspore import Tensor
21from mindspore import amp
22from mindspore import nn
23from mindspore.communication.management import init
24from mindspore.communication._comm_helper import GlobalComm
25from mindspore.context import ParallelMode
26from mindspore.train import Model
27from ....dataset_mock import MindData
28
29
30def setup_module(module):
31    _ = module
32    context.set_context(mode=context.GRAPH_MODE)
33
34
35class Net(nn.Cell):
36    def __init__(self, in_features, out_features):
37        super(Net, self).__init__()
38        self.dense = nn.Dense(in_features, out_features)
39        self.loss = nn.MSELoss()
40
41    def construct(self, input_x, label):
42        output = self.dense(input_x)
43        loss = self.loss(output, label)
44        return loss
45
46
47class NetNoLoss(nn.Cell):
48    def __init__(self, in_features, out_features):
49        super(NetNoLoss, self).__init__()
50        self.dense = nn.Dense(in_features, out_features)
51
52    def construct(self, input_x):
53        return self.dense(input_x)
54
55
56def test_amp_o0():
57    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
58    label = Tensor(np.zeros([16, 16]).astype(np.float32))
59    net = Net(16, 16)
60
61    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
62    train_network = amp.build_train_network(net, optimizer, level="O0")
63    _ = train_network(inputs, label)
64
65
66def test_amp_o2():
67    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
68    label = Tensor(np.zeros([16, 16]).astype(np.float32))
69    net = Net(16, 16)
70
71    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
72    train_network = amp.build_train_network(net, optimizer, level="O2")
73    _ = train_network(inputs, label)
74
75
76def test_amp_o2_loss():
77    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
78    label = Tensor(np.zeros([16, 16]).astype(np.float32))
79    net = NetNoLoss(16, 16)
80    loss = nn.MSELoss()
81    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
82    train_network = amp.build_train_network(net, optimizer, loss, level="O2")
83    _ = train_network(inputs, label)
84
85
86def test_amp_o0_loss():
87    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
88    label = Tensor(np.zeros([16, 16]).astype(np.float32))
89    net = NetNoLoss(16, 16)
90    loss = nn.MSELoss()
91    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
92    train_network = amp.build_train_network(net, optimizer, loss)
93    _ = train_network(inputs, label)
94
95
96class MindDataSet(MindData):
97    def __init__(self, dataset_types, dataset_shapes):
98        super(MindDataSet, self).__init__(size=2, batch_size=32,
99                                          np_types=dataset_types,
100                                          output_shapes=dataset_shapes,
101                                          input_indexs=(0, 1))
102
103    def __next__(self):
104        if self._size < self._iter_num:
105            raise StopIteration
106        self._iter_num += 1
107        lst = []
108        for shape_, type_ in zip(self._output_shapes, self._np_types):
109            lst.append(Tensor(np.ones(shape_).astype(type_)))
110        return tuple(lst)
111
112
113def test_compile_model_train_O0():
114    dataset_types = (np.float32, np.float32)
115    dataset_shapes = ((16, 16), (16, 16))
116
117    dataset = MindDataSet(dataset_types, dataset_shapes)
118
119    net = NetNoLoss(16, 16)
120    loss = nn.MSELoss()
121    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
122
123    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O0")
124    model.train(2, dataset, dataset_sink_mode=False)
125    with pytest.raises(ValueError):
126        # not actual run, the metrics step will fail, check if compile ok.
127        model.eval(dataset)
128
129
130def test_compile_model_train_O2():
131    dataset_types = (np.float32, np.float32)
132    dataset_shapes = ((16, 16), (16, 16))
133
134    dataset = MindDataSet(dataset_types, dataset_shapes)
135
136    net = NetNoLoss(16, 16)
137    loss = nn.MSELoss()
138    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
139
140    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
141    model.train(2, dataset, dataset_sink_mode=False)
142    with pytest.raises(ValueError):
143        # not actual run, the metrics step will fail, check if compile ok.
144        model.eval(dataset)
145
146
147def test_compile_model_train_O2_parallel():
148    dataset_types = (np.float32, np.float32)
149    dataset_shapes = ((16, 16), (16, 16))
150    context.set_auto_parallel_context(
151        global_rank=0, device_num=8,
152        gradients_mean=True, parameter_broadcast=True,
153        parallel_mode=ParallelMode.DATA_PARALLEL)
154
155    dataset = MindDataSet(dataset_types, dataset_shapes)
156
157    net = NetNoLoss(16, 16)
158    loss = nn.MSELoss()
159    optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
160    GlobalComm.CHECK_ENVS = False
161    init()
162    GlobalComm.CHECK_ENVS = True
163    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
164    model.train(2, dataset, dataset_sink_mode=False)
165