1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import copy 17from mindspore.common.initializer import initializer 18 19 20def average_weights(para_list): 21 global_parameter = {} 22 length = len(para_list) 23 for para in para_list: 24 for name in para: 25 if name in global_parameter: 26 global_parameter[name] += para[name] / length 27 else: 28 global_parameter[name] = para[name] / length 29 return global_parameter 30 31 32def save_params(network, param_dict=None): 33 if param_dict is None: 34 return {param.name: copy.deepcopy(param) for param in network.trainable_params() 35 if 'learning_rate' not in param.name and 'adam' not in param.name} 36 for param in network.trainable_params(): 37 if param.name in param_dict: 38 param_dict[param.name] = copy.deepcopy(param) 39 return None 40 41 42def restore_params(network, param_dict, init_adam=True): 43 for param in network.trainable_params(): 44 if 'learning_rate' in param.name: 45 continue 46 param.init_data() 47 if init_adam: 48 if 'adam' in param.name: 49 param.set_data(initializer('zeros', shape=param.shape, dtype=param.dtype)) 50 elif param.name in param_dict: 51 param.set_data(param_dict[param.name]) 52 else: 53 if param.name in param_dict: 54 param.set_data(param_dict[param.name]) 55 56 57def get_worker_upload_list(): 58 return [ 59 'albert.encoder.embedding_hidden_mapping_in.weight', 60 'albert.encoder.embedding_hidden_mapping_in.bias', 61 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight', 62 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias', 63 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight', 64 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias', 65 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight', 66 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias', 67 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight', 68 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias', 69 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma', 70 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta', 71 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight', 72 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias', 73 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight', 74 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias', 75 'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma', 76 'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta', 77 'albert.pooler.weight', 78 'albert.pooler.bias', 79 'classifier.weight', 80 'classifier.bias'] 81 82def upload_to_server(network, worker_upload_list): 83 for param in network.trainable_params(): 84 if param.name in worker_upload_list: 85 param.set_param_fl(push_to_server=True) 86 87def get_worker_download_list(): 88 return [ 89 'albert.encoder.embedding_hidden_mapping_in.weight', 90 'albert.encoder.embedding_hidden_mapping_in.bias', 91 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight', 92 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias', 93 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight', 94 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias', 95 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight', 96 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias', 97 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight', 98 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias', 99 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma', 100 'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta', 101 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight', 102 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias', 103 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight', 104 'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias', 105 'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma', 106 'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta' 107 ] 108 109def download_from_server(network, worker_download_list): 110 for param in network.trainable_params(): 111 if param.name in worker_download_list: 112 param.set_param_fl(pull_from_server=True) 113 114def get_freeze_list(): 115 return [ 116 'albert.word_embeddings.embedding_table', 117 'albert.embedding_postprocessor.embedding_table', 118 'albert.embedding_postprocessor.full_position_embeddings', 119 'albert.embedding_postprocessor.layernorm.gamma', 120 'albert.embedding_postprocessor.layernorm.beta' 121 ] 122 123def freeze(network, freeze_list): 124 for param in network.trainable_params(): 125 if param.name in freeze_list: 126 param.requires_grad = False 127