• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import copy
17from mindspore.common.initializer import initializer
18
19
20def average_weights(para_list):
21    global_parameter = {}
22    length = len(para_list)
23    for para in para_list:
24        for name in para:
25            if name in global_parameter:
26                global_parameter[name] += para[name] / length
27            else:
28                global_parameter[name] = para[name] / length
29    return global_parameter
30
31
32def save_params(network, param_dict=None):
33    if param_dict is None:
34        return {param.name: copy.deepcopy(param) for param in network.trainable_params()
35                if 'learning_rate' not in param.name and 'adam' not in param.name}
36    for param in network.trainable_params():
37        if param.name in param_dict:
38            param_dict[param.name] = copy.deepcopy(param)
39    return None
40
41
42def restore_params(network, param_dict, init_adam=True):
43    for param in network.trainable_params():
44        if 'learning_rate' in param.name:
45            continue
46        param.init_data()
47        if init_adam:
48            if 'adam' in param.name:
49                param.set_data(initializer('zeros', shape=param.shape, dtype=param.dtype))
50            elif param.name in param_dict:
51                param.set_data(param_dict[param.name])
52        else:
53            if param.name in param_dict:
54                param.set_data(param_dict[param.name])
55
56
57def get_worker_upload_list():
58    return [
59        'albert.encoder.embedding_hidden_mapping_in.weight',
60        'albert.encoder.embedding_hidden_mapping_in.bias',
61        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight',
62        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias',
63        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight',
64        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias',
65        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight',
66        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias',
67        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight',
68        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias',
69        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma',
70        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta',
71        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight',
72        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias',
73        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight',
74        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias',
75        'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma',
76        'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta',
77        'albert.pooler.weight',
78        'albert.pooler.bias',
79        'classifier.weight',
80        'classifier.bias']
81
82def upload_to_server(network, worker_upload_list):
83    for param in network.trainable_params():
84        if param.name in worker_upload_list:
85            param.set_param_fl(push_to_server=True)
86
87def get_worker_download_list():
88    return [
89        'albert.encoder.embedding_hidden_mapping_in.weight',
90        'albert.encoder.embedding_hidden_mapping_in.bias',
91        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight',
92        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias',
93        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight',
94        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias',
95        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight',
96        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias',
97        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight',
98        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias',
99        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma',
100        'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta',
101        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight',
102        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias',
103        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight',
104        'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias',
105        'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma',
106        'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta'
107    ]
108
109def download_from_server(network, worker_download_list):
110    for param in network.trainable_params():
111        if param.name in worker_download_list:
112            param.set_param_fl(pull_from_server=True)
113
114def get_freeze_list():
115    return [
116        'albert.word_embeddings.embedding_table',
117        'albert.embedding_postprocessor.embedding_table',
118        'albert.embedding_postprocessor.full_position_embeddings',
119        'albert.embedding_postprocessor.layernorm.gamma',
120        'albert.embedding_postprocessor.layernorm.beta'
121    ]
122
123def freeze(network, freeze_list):
124    for param in network.trainable_params():
125        if param.name in freeze_list:
126            param.requires_grad = False
127