1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16network config setting, will be used in dataset.py, run_pretrain.py 17""" 18from easydict import EasyDict as edict 19import mindspore.common.dtype as mstype 20from .bert_model import BertConfig 21cfg = edict({ 22 'bert_network': 'base', 23 'loss_scale_value': 65536, 24 'scale_factor': 2, 25 'scale_window': 1000, 26 'optimizer': 'Lamb', 27 'AdamWeightDecay': edict({ 28 'learning_rate': 3e-5, 29 'end_learning_rate': 1e-10, 30 'power': 5.0, 31 'weight_decay': 1e-5, 32 'eps': 1e-6, 33 'warmup_steps': 10000, 34 }), 35 'Lamb': edict({ 36 'learning_rate': 3e-5, 37 'end_learning_rate': 1e-10, 38 'power': 10.0, 39 'warmup_steps': 10000, 40 'weight_decay': 0.01, 41 'eps': 1e-6, 42 }), 43 'Momentum': edict({ 44 'learning_rate': 2e-5, 45 'momentum': 0.9, 46 }), 47}) 48 49''' 50Including two kinds of network: \ 51base: Goole BERT-base(the base version of BERT model). 52large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ 53 Functional Relative Posetional Encoding as an effective positional encoding scheme). 54''' 55if cfg.bert_network == 'base': 56 bert_net_cfg = BertConfig( 57 batch_size=32, 58 seq_length=128, 59 vocab_size=21128, 60 hidden_size=768, 61 num_hidden_layers=12, 62 num_attention_heads=12, 63 intermediate_size=3072, 64 hidden_act="gelu", 65 hidden_dropout_prob=0.1, 66 attention_probs_dropout_prob=0.1, 67 max_position_embeddings=512, 68 type_vocab_size=2, 69 initializer_range=0.02, 70 use_relative_positions=False, 71 input_mask_from_dataset=True, 72 token_type_ids_from_dataset=True, 73 dtype=mstype.float32, 74 compute_type=mstype.float16 75 ) 76if cfg.bert_network == 'nezha': 77 bert_net_cfg = BertConfig( 78 batch_size=32, 79 seq_length=128, 80 vocab_size=21128, 81 hidden_size=1024, 82 num_hidden_layers=24, 83 num_attention_heads=16, 84 intermediate_size=4096, 85 hidden_act="gelu", 86 hidden_dropout_prob=0.1, 87 attention_probs_dropout_prob=0.1, 88 max_position_embeddings=512, 89 type_vocab_size=2, 90 initializer_range=0.02, 91 use_relative_positions=True, 92 input_mask_from_dataset=True, 93 token_type_ids_from_dataset=True, 94 dtype=mstype.float32, 95 compute_type=mstype.float16 96 ) 97if cfg.bert_network == 'large': 98 bert_net_cfg = BertConfig( 99 batch_size=16, 100 seq_length=512, 101 vocab_size=30522, 102 hidden_size=1024, 103 num_hidden_layers=24, 104 num_attention_heads=16, 105 intermediate_size=4096, 106 hidden_act="gelu", 107 hidden_dropout_prob=0.1, 108 attention_probs_dropout_prob=0.1, 109 max_position_embeddings=512, 110 type_vocab_size=2, 111 initializer_range=0.02, 112 use_relative_positions=False, 113 input_mask_from_dataset=True, 114 token_type_ids_from_dataset=True, 115 dtype=mstype.float32, 116 compute_type=mstype.float16, 117 enable_fused_layernorm=True 118 ) 119