1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" auto mixed precision """ 16import numpy as np 17import pytest 18 19import mindspore.context as context 20from mindspore import Tensor 21from mindspore import amp 22from mindspore import nn 23from mindspore.communication.management import init 24from mindspore.communication._comm_helper import GlobalComm 25from mindspore.context import ParallelMode 26from mindspore.train import Model 27from ....dataset_mock import MindData 28 29 30def setup_module(module): 31 _ = module 32 context.set_context(mode=context.GRAPH_MODE) 33 34 35class Net(nn.Cell): 36 def __init__(self, in_features, out_features): 37 super(Net, self).__init__() 38 self.dense = nn.Dense(in_features, out_features) 39 self.loss = nn.MSELoss() 40 41 def construct(self, input_x, label): 42 output = self.dense(input_x) 43 loss = self.loss(output, label) 44 return loss 45 46 47class NetNoLoss(nn.Cell): 48 def __init__(self, in_features, out_features): 49 super(NetNoLoss, self).__init__() 50 self.dense = nn.Dense(in_features, out_features) 51 52 def construct(self, input_x): 53 return self.dense(input_x) 54 55 56def test_amp_o0(): 57 inputs = Tensor(np.ones([16, 16]).astype(np.float32)) 58 label = Tensor(np.zeros([16, 16]).astype(np.float32)) 59 net = Net(16, 16) 60 61 optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 62 train_network = amp.build_train_network(net, optimizer, level="O0") 63 _ = train_network(inputs, label) 64 65 66def test_amp_o2(): 67 inputs = Tensor(np.ones([16, 16]).astype(np.float32)) 68 label = Tensor(np.zeros([16, 16]).astype(np.float32)) 69 net = Net(16, 16) 70 71 optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 72 train_network = amp.build_train_network(net, optimizer, level="O2") 73 _ = train_network(inputs, label) 74 75 76def test_amp_o2_loss(): 77 inputs = Tensor(np.ones([16, 16]).astype(np.float32)) 78 label = Tensor(np.zeros([16, 16]).astype(np.float32)) 79 net = NetNoLoss(16, 16) 80 loss = nn.MSELoss() 81 optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 82 train_network = amp.build_train_network(net, optimizer, loss, level="O2") 83 _ = train_network(inputs, label) 84 85 86def test_amp_o0_loss(): 87 inputs = Tensor(np.ones([16, 16]).astype(np.float32)) 88 label = Tensor(np.zeros([16, 16]).astype(np.float32)) 89 net = NetNoLoss(16, 16) 90 loss = nn.MSELoss() 91 optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 92 train_network = amp.build_train_network(net, optimizer, loss) 93 _ = train_network(inputs, label) 94 95 96class MindDataSet(MindData): 97 def __init__(self, dataset_types, dataset_shapes): 98 super(MindDataSet, self).__init__(size=2, batch_size=32, 99 np_types=dataset_types, 100 output_shapes=dataset_shapes, 101 input_indexs=(0, 1)) 102 103 def __next__(self): 104 if self._size < self._iter_num: 105 raise StopIteration 106 self._iter_num += 1 107 lst = [] 108 for shape_, type_ in zip(self._output_shapes, self._np_types): 109 lst.append(Tensor(np.ones(shape_).astype(type_))) 110 return tuple(lst) 111 112 113def test_compile_model_train_O0(): 114 dataset_types = (np.float32, np.float32) 115 dataset_shapes = ((16, 16), (16, 16)) 116 117 dataset = MindDataSet(dataset_types, dataset_shapes) 118 119 net = NetNoLoss(16, 16) 120 loss = nn.MSELoss() 121 optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 122 123 model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O0") 124 model.train(2, dataset, dataset_sink_mode=False) 125 with pytest.raises(ValueError): 126 # not actual run, the metrics step will fail, check if compile ok. 127 model.eval(dataset) 128 129 130def test_compile_model_train_O2(): 131 dataset_types = (np.float32, np.float32) 132 dataset_shapes = ((16, 16), (16, 16)) 133 134 dataset = MindDataSet(dataset_types, dataset_shapes) 135 136 net = NetNoLoss(16, 16) 137 loss = nn.MSELoss() 138 optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 139 140 model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") 141 model.train(2, dataset, dataset_sink_mode=False) 142 with pytest.raises(ValueError): 143 # not actual run, the metrics step will fail, check if compile ok. 144 model.eval(dataset) 145 146 147def test_compile_model_train_O2_parallel(): 148 dataset_types = (np.float32, np.float32) 149 dataset_shapes = ((16, 16), (16, 16)) 150 context.set_auto_parallel_context( 151 global_rank=0, device_num=8, 152 gradients_mean=True, parameter_broadcast=True, 153 parallel_mode=ParallelMode.DATA_PARALLEL) 154 155 dataset = MindDataSet(dataset_types, dataset_shapes) 156 157 net = NetNoLoss(16, 16) 158 loss = nn.MSELoss() 159 optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0) 160 GlobalComm.CHECK_ENVS = False 161 init() 162 GlobalComm.CHECK_ENVS = True 163 model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") 164 model.train(2, dataset, dataset_sink_mode=False) 165