1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test adam """ 16import numpy as np 17import pytest 18 19import mindspore.nn as nn 20from mindspore import Tensor, Parameter 21from mindspore.common.api import _cell_graph_executor 22from mindspore.nn import TrainOneStepCell, WithLossCell 23from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell 24from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb, Momentum 25from mindspore.ops import operations as P 26from mindspore import context 27 28 29class Net(nn.Cell): 30 """Net definition""" 31 def __init__(self): 32 super(Net, self).__init__() 33 self.fc1 = nn.Dense(128, 768, activation='relu') 34 self.fc2 = nn.Dense(128, 768, activation='relu') 35 self.fc3 = nn.Dense(128, 768, activation='relu') 36 self.fc4 = nn.Dense(768, 768, activation='relu') 37 self.relu4 = nn.ReLU() 38 self.relu5 = nn.ReLU() 39 self.transpose = P.Transpose() 40 self.matmul1 = P.MatMul() 41 self.matmul2 = P.MatMul() 42 43 def construct(self, x): 44 q = self.fc1(x) 45 k = self.fc2(x) 46 v = self.fc3(x) 47 k = self.transpose(k, (1, 0)) 48 c = self.relu4(self.matmul1(q, k)) 49 s = self.relu5(self.matmul2(c, v)) 50 s = self.fc4(s) 51 return s 52 53 54class Net2(nn.Cell): 55 """Net definition""" 56 def __init__(self, strategy1, strategy2): 57 super(Net2, self).__init__() 58 self.fc1 = P.MatMul().shard(strategy=strategy1) 59 self.fc2 = P.MatMul().shard(strategy=strategy2) 60 self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") 61 self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2") 62 63 def construct(self, x, y): 64 x = self.fc1(x, self.p1) 65 x = self.fc2(x, self.p2) 66 return x - y 67 68 69class Net3(nn.Cell): 70 """Net definition""" 71 def __init__(self, strategy1, strategy2): 72 super(Net3, self).__init__() 73 self.fc1 = P.MatMul().shard(strategy=strategy1) 74 self.fc2 = P.MatMul().shard(strategy=strategy2) 75 self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") 76 self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False) 77 78 def construct(self, x, y): 79 x = self.fc1(x, self.p1) 80 x = self.fc2(x, self.p2) 81 return x - y 82 83 84def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None): 85 context.set_context(mode=context.GRAPH_MODE) 86 context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True) 87 inputs = Tensor(np.ones([32, 48]).astype(np.float32)) 88 label = Tensor(np.zeros([32, 16]).astype(np.float32)) 89 net = net(strategy1, strategy2) 90 net = _VirtualDatasetCell(net) 91 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 92 train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4) 93 train_network.set_auto_parallel() 94 train_network.set_train() 95 _cell_graph_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True) 96 context.reset_auto_parallel_context() 97 return train_network 98 99 100def test_auto_parallel_momentum_1(): 101 auto_parallel_compile_net("auto_parallel", 8, Net2) 102 103 104def test_auto_parallel_momentum_2(): 105 # data parallel case 106 auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1))) 107 108 109def test_auto_parallel_momentum_3(): 110 # hybrid parallel case 111 # weight1 could not be shard and weight2 is repeated 112 train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) 113 param_dict = train_network.parameter_layout_dict 114 # validate opt_shard_group 115 assert not param_dict["weight1"][5] 116 assert param_dict["weight2"][5].startswith("4") 117 118 119def test_auto_parallel_momentum_4(): 120 # hybrid parallel cases 121 # devices are repeatedly used 122 auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2))) 123 124 125def test_auto_parallel_momentum_5(): 126 # test parallel optimizer filter 127 train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2))) 128 param_dict = train_network.parameter_layout_dict 129 # validate opt_shard_group 130 assert not param_dict["weight1"][5] 131 assert not param_dict["weight2"][5] 132 133 134def test_auto_parallel_momentum_6(): 135 # test not fully use parallel optimizer with optimizer_weight_shard_size 136 # weight1 could not be shard and weight2 is repeated 137 context.set_auto_parallel_context(optimizer_weight_shard_size=2) 138 train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) 139 param_dict = train_network.parameter_layout_dict 140 # validate opt_shard_group 141 assert param_dict["weight1"][5].startswith("2") 142 assert param_dict["weight2"][5].startswith("2") 143 144 145def test_AdamWeightDecay(): 146 """ test_AdamWeightDecay """ 147 context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) 148 inputs = Tensor(np.ones([32, 128]).astype(np.float32)) 149 label = Tensor(np.zeros([32, 768]).astype(np.float32)) 150 net = Net() 151 net.set_train() 152 loss = nn.SoftmaxCrossEntropyWithLogits() 153 optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1) 154 155 net_with_loss = WithLossCell(net, loss) 156 train_network = TrainOneStepCell(net_with_loss, optimizer) 157 _cell_graph_executor.compile(train_network, inputs, label) 158 context.reset_auto_parallel_context() 159 160 161def test_lamb_compile(): 162 """ test_Lamb_compile """ 163 context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) 164 inputs = Tensor(np.ones([32, 128]).astype(np.float32)) 165 label = Tensor(np.zeros([32, 768]).astype(np.float32)) 166 net = Net() 167 net.set_train() 168 loss = nn.SoftmaxCrossEntropyWithLogits() 169 optimizer = Lamb(net.trainable_params(), learning_rate=0.1) 170 171 net_with_loss = WithLossCell(net, loss) 172 train_network = TrainOneStepCell(net_with_loss, optimizer) 173 _cell_graph_executor.compile(train_network, inputs, label) 174 context.reset_auto_parallel_context() 175 176 177def test_lamb_split_fusion(): 178 """ test_Lamb_split_fusion """ 179 context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True, 180 all_reduce_fusion_config=[2, 4, 6, 8]) 181 inputs = Tensor(np.ones([32, 128]).astype(np.float32)) 182 label = Tensor(np.zeros([32, 768]).astype(np.float32)) 183 net = Net() 184 net.set_train() 185 loss = nn.SoftmaxCrossEntropyWithLogits() 186 optimizer = Lamb(net.trainable_params(), learning_rate=0.1) 187 188 net_with_loss = WithLossCell(net, loss) 189 train_network = TrainOneStepCell(net_with_loss, optimizer) 190 _cell_graph_executor.compile(train_network, inputs, label) 191 context.reset_auto_parallel_context() 192 193 194def test_edge_case(): 195 """ test_edge_case """ 196 context.set_auto_parallel_context(enable_parallel_optimizer=True) 197 net = Net() 198 with pytest.raises(RuntimeError): 199 context.set_auto_parallel_context(parallel_mode="stand_alone") 200 Lamb(net.trainable_params(), learning_rate=0.1) 201 with pytest.raises(RuntimeError): 202 context.set_context(device_target="GPU") 203 context.set_auto_parallel_context(parallel_mode="data_parallel") 204 Lamb(net.trainable_params(), learning_rate=0.1) 205 with pytest.raises(RuntimeError): 206 context.set_context(device_target="Ascend") 207 context.set_auto_parallel_context(parallel_mode="data_parallel") 208 Adam(net.trainable_params(), learning_rate=0.1) 209 with pytest.raises(RuntimeError): 210 context.set_auto_parallel_context(device_num=16) 211 Lamb(net.trainable_params(), learning_rate=0.1) 212 context.reset_auto_parallel_context() 213