1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Bert test.""" 17 18# pylint: disable=missing-docstring, arguments-differ, W0612 19 20import os 21 22import mindspore.common.dtype as mstype 23import mindspore.context as context 24from mindspore import Tensor 25from mindspore.ops import operations as P 26from mindspore.nn.optim import AdamWeightDecay 27from mindspore.train.loss_scale_manager import DynamicLossScaleManager 28from mindspore.nn import learning_rate_schedule as lr_schedules 29from tests.models.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell 30from ...dataset_mock import MindData 31from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph 32 33 34_current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data" 35context.set_context(mode=context.GRAPH_MODE) 36 37 38def get_dataset(batch_size=1): 39 dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) 40 dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), \ 41 (batch_size, 20), (batch_size, 20), (batch_size, 20)) 42 43 dataset = MindData(size=2, batch_size=batch_size, 44 np_types=dataset_types, 45 output_shapes=dataset_shapes, 46 input_indexs=(0, 1)) 47 return dataset 48 49 50def load_test_data(batch_size=1): 51 dataset = get_dataset(batch_size) 52 ret = dataset.next() 53 ret = batch_tuple_tensor(ret, batch_size) 54 return ret 55 56 57def get_config(version='base'): 58 """ 59 get_config definition 60 """ 61 if version == 'base': 62 return BertConfig( 63 seq_length=128, 64 vocab_size=21128, 65 hidden_size=768, 66 num_hidden_layers=12, 67 num_attention_heads=12, 68 intermediate_size=3072, 69 hidden_act="gelu", 70 hidden_dropout_prob=0.1, 71 attention_probs_dropout_prob=0.1, 72 max_position_embeddings=512, 73 type_vocab_size=2, 74 initializer_range=0.02, 75 use_relative_positions=True, 76 dtype=mstype.float32, 77 compute_type=mstype.float32) 78 if version == 'large': 79 return BertConfig( 80 seq_length=128, 81 vocab_size=21128, 82 hidden_size=1024, 83 num_hidden_layers=24, 84 num_attention_heads=16, 85 intermediate_size=4096, 86 hidden_act="gelu", 87 hidden_dropout_prob=0.1, 88 attention_probs_dropout_prob=0.1, 89 max_position_embeddings=512, 90 type_vocab_size=2, 91 initializer_range=0.02, 92 use_relative_positions=True, 93 dtype=mstype.float32, 94 compute_type=mstype.float32) 95 return BertConfig() 96 97 98class BertLearningRate(lr_schedules.LearningRateSchedule): 99 def __init__(self, decay_steps, warmup_steps=100, learning_rate=0.1, end_learning_rate=0.0001, power=1.0): 100 super(BertLearningRate, self).__init__() 101 self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) 102 self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) 103 self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) 104 105 self.greater = P.Greater() 106 self.one = Tensor(np.array([1.0]).astype(np.float32)) 107 self.cast = P.Cast() 108 109 def construct(self, global_step): 110 is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) 111 warmup_lr = self.warmup_lr(global_step) 112 decay_lr = self.decay_lr(global_step) 113 lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr 114 return lr 115 116 117def test_bert_train(): 118 """ 119 the main function 120 """ 121 122 class ModelBert(nn.Cell): 123 """ 124 ModelBert definition 125 """ 126 127 def __init__(self, network, optimizer=None): 128 super(ModelBert, self).__init__() 129 self.optimizer = optimizer 130 self.train_network = BertTrainOneStepCell(network, self.optimizer) 131 self.train_network.set_train() 132 133 def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6): 134 return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6) 135 136 version = os.getenv('VERSION', 'large') 137 batch_size = int(os.getenv('BATCH_SIZE', '1')) 138 inputs = load_test_data(batch_size) 139 140 config = get_config(version=version) 141 netwithloss = BertNetworkWithLoss(config, True) 142 lr = BertLearningRate(10) 143 optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) 144 net = ModelBert(netwithloss, optimizer=optimizer) 145 net.set_train() 146 build_construct_graph(net, *inputs, execute=False) 147 148 149def test_bert_withlossscale_train(): 150 class ModelBert(nn.Cell): 151 def __init__(self, network, optimizer=None): 152 super(ModelBert, self).__init__() 153 self.optimizer = optimizer 154 self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer) 155 self.train_network.set_train() 156 157 def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7): 158 return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) 159 160 version = os.getenv('VERSION', 'base') 161 batch_size = int(os.getenv('BATCH_SIZE', '1')) 162 scaling_sens = Tensor(np.ones([1]).astype(np.float32)) 163 inputs = load_test_data(batch_size) + (scaling_sens,) 164 165 config = get_config(version=version) 166 netwithloss = BertNetworkWithLoss(config, True) 167 lr = BertLearningRate(10) 168 optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) 169 net = ModelBert(netwithloss, optimizer=optimizer) 170 net.set_train() 171 build_construct_graph(net, *inputs, execute=True) 172 173 174def bert_withlossscale_manager_train(): 175 class ModelBert(nn.Cell): 176 def __init__(self, network, optimizer=None): 177 super(ModelBert, self).__init__() 178 self.optimizer = optimizer 179 manager = DynamicLossScaleManager() 180 update_cell = LossScaleUpdateCell(manager) 181 self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer, 182 scale_update_cell=update_cell) 183 self.train_network.set_train() 184 185 def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6): 186 return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6) 187 188 version = os.getenv('VERSION', 'base') 189 batch_size = int(os.getenv('BATCH_SIZE', '1')) 190 inputs = load_test_data(batch_size) 191 192 config = get_config(version=version) 193 netwithloss = BertNetworkWithLoss(config, True) 194 lr = BertLearningRate(10) 195 optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) 196 net = ModelBert(netwithloss, optimizer=optimizer) 197 net.set_train() 198 build_construct_graph(net, *inputs, execute=True) 199 200 201def bert_withlossscale_manager_train_feed(): 202 class ModelBert(nn.Cell): 203 def __init__(self, network, optimizer=None): 204 super(ModelBert, self).__init__() 205 self.optimizer = optimizer 206 manager = DynamicLossScaleManager() 207 update_cell = LossScaleUpdateCell(manager) 208 self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer, 209 scale_update_cell=update_cell) 210 self.train_network.set_train() 211 212 def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7): 213 return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) 214 215 version = os.getenv('VERSION', 'base') 216 batch_size = int(os.getenv('BATCH_SIZE', '1')) 217 scaling_sens = Tensor(np.ones([1]).astype(np.float32)) 218 inputs = load_test_data(batch_size) + (scaling_sens,) 219 220 config = get_config(version=version) 221 netwithloss = BertNetworkWithLoss(config, True) 222 lr = BertLearningRate(10) 223 optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) 224 net = ModelBert(netwithloss, optimizer=optimizer) 225 net.set_train() 226 build_construct_graph(net, *inputs, execute=True) 227