• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16network config setting, will be used in dataset.py, run_pretrain.py
17"""
18from easydict import EasyDict as edict
19import mindspore.common.dtype as mstype
20from .bert_model import BertConfig
21cfg = edict({
22    'bert_network': 'base',
23    'loss_scale_value': 65536,
24    'scale_factor': 2,
25    'scale_window': 1000,
26    'optimizer': 'Lamb',
27    'AdamWeightDecay': edict({
28        'learning_rate': 3e-5,
29        'end_learning_rate': 1e-10,
30        'power': 5.0,
31        'weight_decay': 1e-5,
32        'eps': 1e-6,
33        'warmup_steps': 10000,
34    }),
35    'Lamb': edict({
36        'learning_rate': 3e-5,
37        'end_learning_rate': 1e-10,
38        'power': 10.0,
39        'warmup_steps': 10000,
40        'weight_decay': 0.01,
41        'eps': 1e-6,
42    }),
43    'Momentum': edict({
44        'learning_rate': 2e-5,
45        'momentum': 0.9,
46    }),
47})
48
49'''
50Including two kinds of network: \
51base: Goole BERT-base(the base version of BERT model).
52large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
53       Functional Relative Posetional Encoding as an effective positional encoding scheme).
54'''
55if cfg.bert_network == 'base':
56    bert_net_cfg = BertConfig(
57        batch_size=32,
58        seq_length=128,
59        vocab_size=21128,
60        hidden_size=768,
61        num_hidden_layers=12,
62        num_attention_heads=12,
63        intermediate_size=3072,
64        hidden_act="gelu",
65        hidden_dropout_prob=0.1,
66        attention_probs_dropout_prob=0.1,
67        max_position_embeddings=512,
68        type_vocab_size=2,
69        initializer_range=0.02,
70        use_relative_positions=False,
71        input_mask_from_dataset=True,
72        token_type_ids_from_dataset=True,
73        dtype=mstype.float32,
74        compute_type=mstype.float16
75    )
76if cfg.bert_network == 'nezha':
77    bert_net_cfg = BertConfig(
78        batch_size=32,
79        seq_length=128,
80        vocab_size=21128,
81        hidden_size=1024,
82        num_hidden_layers=24,
83        num_attention_heads=16,
84        intermediate_size=4096,
85        hidden_act="gelu",
86        hidden_dropout_prob=0.1,
87        attention_probs_dropout_prob=0.1,
88        max_position_embeddings=512,
89        type_vocab_size=2,
90        initializer_range=0.02,
91        use_relative_positions=True,
92        input_mask_from_dataset=True,
93        token_type_ids_from_dataset=True,
94        dtype=mstype.float32,
95        compute_type=mstype.float16
96    )
97if cfg.bert_network == 'large':
98    bert_net_cfg = BertConfig(
99        batch_size=16,
100        seq_length=512,
101        vocab_size=30522,
102        hidden_size=1024,
103        num_hidden_layers=24,
104        num_attention_heads=16,
105        intermediate_size=4096,
106        hidden_act="gelu",
107        hidden_dropout_prob=0.1,
108        attention_probs_dropout_prob=0.1,
109        max_position_embeddings=512,
110        type_vocab_size=2,
111        initializer_range=0.02,
112        use_relative_positions=False,
113        input_mask_from_dataset=True,
114        token_type_ids_from_dataset=True,
115        dtype=mstype.float32,
116        compute_type=mstype.float16,
117        enable_fused_layernorm=True
118    )
119