1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import argparse 17import ast 18import numpy as np 19 20import mindspore.context as context 21import mindspore.nn as nn 22from mindspore import Tensor 23from mindspore.common import dtype as mstype 24from mindspore.nn import WithLossCell 25from src.cell_wrapper import TrainOneStepCellWithServerCommunicator 26from src.model import LeNet5, PushMetrics 27# from src.adam import AdamWeightDecayOp 28 29parser = argparse.ArgumentParser(description="test_hybrid_train_lenet") 30parser.add_argument("--device_target", type=str, default="CPU") 31parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") 32parser.add_argument("--ms_role", type=str, default="MS_WORKER") 33parser.add_argument("--worker_num", type=int, default=0) 34parser.add_argument("--server_num", type=int, default=1) 35parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") 36parser.add_argument("--scheduler_port", type=int, default=8113) 37parser.add_argument("--fl_server_port", type=int, default=6666) 38parser.add_argument("--start_fl_job_threshold", type=int, default=1) 39parser.add_argument("--start_fl_job_time_window", type=int, default=3000) 40parser.add_argument("--update_model_ratio", type=float, default=1.0) 41parser.add_argument("--update_model_time_window", type=int, default=3000) 42parser.add_argument("--fl_name", type=str, default="Lenet") 43parser.add_argument("--fl_iteration_num", type=int, default=25) 44parser.add_argument("--client_epoch_num", type=int, default=20) 45parser.add_argument("--client_batch_size", type=int, default=32) 46parser.add_argument("--client_learning_rate", type=float, default=0.1) 47parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) 48parser.add_argument("--scheduler_manage_port", type=int, default=11202) 49parser.add_argument("--config_file_path", type=str, default="") 50parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") 51# parameters for encrypt_type='DP_ENCRYPT' 52parser.add_argument("--dp_eps", type=float, default=50.0) 53parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num 54parser.add_argument("--dp_norm_clip", type=float, default=1.0) 55# parameters for encrypt_type='PW_ENCRYPT' 56parser.add_argument("--share_secrets_ratio", type=float, default=1.0) 57parser.add_argument("--cipher_time_window", type=int, default=300000) 58parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) 59parser.add_argument("--client_password", type=str, default="") 60parser.add_argument("--server_password", type=str, default="") 61parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) 62 63args, _ = parser.parse_known_args() 64device_target = args.device_target 65server_mode = args.server_mode 66ms_role = args.ms_role 67worker_num = args.worker_num 68server_num = args.server_num 69scheduler_ip = args.scheduler_ip 70scheduler_port = args.scheduler_port 71fl_server_port = args.fl_server_port 72start_fl_job_threshold = args.start_fl_job_threshold 73start_fl_job_time_window = args.start_fl_job_time_window 74update_model_ratio = args.update_model_ratio 75update_model_time_window = args.update_model_time_window 76fl_name = args.fl_name 77fl_iteration_num = args.fl_iteration_num 78client_epoch_num = args.client_epoch_num 79client_batch_size = args.client_batch_size 80client_learning_rate = args.client_learning_rate 81worker_step_num_per_iteration = args.worker_step_num_per_iteration 82scheduler_manage_port = args.scheduler_manage_port 83config_file_path = args.config_file_path 84encrypt_type = args.encrypt_type 85share_secrets_ratio = args.share_secrets_ratio 86cipher_time_window = args.cipher_time_window 87reconstruct_secrets_threshold = args.reconstruct_secrets_threshold 88dp_eps = args.dp_eps 89dp_delta = args.dp_delta 90dp_norm_clip = args.dp_norm_clip 91client_password = args.client_password 92server_password = args.server_password 93enable_ssl = args.enable_ssl 94 95ctx = { 96 "enable_fl": True, 97 "server_mode": server_mode, 98 "ms_role": ms_role, 99 "worker_num": worker_num, 100 "server_num": server_num, 101 "scheduler_ip": scheduler_ip, 102 "scheduler_port": scheduler_port, 103 "fl_server_port": fl_server_port, 104 "start_fl_job_threshold": start_fl_job_threshold, 105 "start_fl_job_time_window": start_fl_job_time_window, 106 "update_model_ratio": update_model_ratio, 107 "update_model_time_window": update_model_time_window, 108 "fl_name": fl_name, 109 "fl_iteration_num": fl_iteration_num, 110 "client_epoch_num": client_epoch_num, 111 "client_batch_size": client_batch_size, 112 "client_learning_rate": client_learning_rate, 113 "worker_step_num_per_iteration": worker_step_num_per_iteration, 114 "scheduler_manage_port": scheduler_manage_port, 115 "config_file_path": config_file_path, 116 "share_secrets_ratio": share_secrets_ratio, 117 "cipher_time_window": cipher_time_window, 118 "reconstruct_secrets_threshold": reconstruct_secrets_threshold, 119 "dp_eps": dp_eps, 120 "dp_delta": dp_delta, 121 "dp_norm_clip": dp_norm_clip, 122 "encrypt_type": encrypt_type, 123 "client_password": client_password, 124 "server_password": server_password, 125 "enable_ssl": enable_ssl 126} 127 128context.set_context(mode=context.GRAPH_MODE, device_target=device_target) 129context.set_fl_context(**ctx) 130 131if __name__ == "__main__": 132 epoch = 50000 133 np.random.seed(0) 134 network = LeNet5(62) 135 push_metrics = PushMetrics() 136 if context.get_fl_context("ms_role") == "MS_WORKER": 137 # Please do not freeze layers if you want to both get and overwrite these layers to servers, which is meaningless. 138 network.conv1.weight.requires_grad = False 139 network.conv2.weight.requires_grad = False 140 # Get weights before running backbone. 141 network.conv1.set_param_fl(pull_from_server=True) 142 network.conv2.set_param_fl(pull_from_server=True) 143 144 # Overwrite weights after running optimizers. 145 network.fc1.set_param_fl(push_to_server=True) 146 network.fc2.set_param_fl(push_to_server=True) 147 network.fc3.set_param_fl(push_to_server=True) 148 149 criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 150 net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) 151 # net_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1) 152 net_with_criterion = WithLossCell(network, criterion) 153 # train_network = TrainOneStepCell(net_with_criterion, net_opt) 154 train_network = TrainOneStepCellWithServerCommunicator(net_with_criterion, net_opt) 155 train_network.set_train() 156 losses = [] 157 158 for i in range(epoch): 159 data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) 160 label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32)) 161 loss = train_network(data, label).asnumpy() 162 if context.get_fl_context("ms_role") == "MS_WORKER": 163 if (i + 1) % worker_step_num_per_iteration == 0: 164 push_metrics(Tensor(loss, mstype.float32), Tensor(loss, mstype.float32)) 165 losses.append(loss) 166 print(losses) 167