1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import argparse 17import ast 18import numpy as np 19 20import mindspore.context as context 21import mindspore.nn as nn 22from mindspore import Tensor 23from mindspore.nn import TrainOneStepCell, WithLossCell 24from src.model import LeNet5 25from src.adam import AdamWeightDecayOp 26 27parser = argparse.ArgumentParser(description="test_fl_lenet") 28parser.add_argument("--device_target", type=str, default="CPU") 29parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") 30parser.add_argument("--ms_role", type=str, default="MS_WORKER") 31parser.add_argument("--worker_num", type=int, default=0) 32parser.add_argument("--server_num", type=int, default=1) 33parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") 34parser.add_argument("--scheduler_port", type=int, default=8113) 35parser.add_argument("--fl_server_port", type=int, default=6666) 36parser.add_argument("--start_fl_job_threshold", type=int, default=1) 37parser.add_argument("--start_fl_job_time_window", type=int, default=3000) 38parser.add_argument("--update_model_ratio", type=float, default=1.0) 39parser.add_argument("--update_model_time_window", type=int, default=3000) 40parser.add_argument("--fl_name", type=str, default="Lenet") 41parser.add_argument("--fl_iteration_num", type=int, default=25) 42parser.add_argument("--client_epoch_num", type=int, default=20) 43parser.add_argument("--client_batch_size", type=int, default=32) 44parser.add_argument("--client_learning_rate", type=float, default=0.1) 45parser.add_argument("--scheduler_manage_port", type=int, default=11202) 46parser.add_argument("--config_file_path", type=str, default="") 47parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") 48# parameters for encrypt_type='DP_ENCRYPT' 49parser.add_argument("--dp_eps", type=float, default=50.0) 50parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num 51parser.add_argument("--dp_norm_clip", type=float, default=1.0) 52# parameters for encrypt_type='PW_ENCRYPT' 53parser.add_argument("--share_secrets_ratio", type=float, default=1.0) 54parser.add_argument("--cipher_time_window", type=int, default=300000) 55parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) 56parser.add_argument("--client_password", type=str, default="") 57parser.add_argument("--server_password", type=str, default="") 58parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) 59 60args, _ = parser.parse_known_args() 61device_target = args.device_target 62server_mode = args.server_mode 63ms_role = args.ms_role 64worker_num = args.worker_num 65server_num = args.server_num 66scheduler_ip = args.scheduler_ip 67scheduler_port = args.scheduler_port 68fl_server_port = args.fl_server_port 69start_fl_job_threshold = args.start_fl_job_threshold 70start_fl_job_time_window = args.start_fl_job_time_window 71update_model_ratio = args.update_model_ratio 72update_model_time_window = args.update_model_time_window 73share_secrets_ratio = args.share_secrets_ratio 74cipher_time_window = args.cipher_time_window 75reconstruct_secrets_threshold = args.reconstruct_secrets_threshold 76fl_name = args.fl_name 77fl_iteration_num = args.fl_iteration_num 78client_epoch_num = args.client_epoch_num 79client_batch_size = args.client_batch_size 80client_learning_rate = args.client_learning_rate 81scheduler_manage_port = args.scheduler_manage_port 82config_file_path = args.config_file_path 83dp_eps = args.dp_eps 84dp_delta = args.dp_delta 85dp_norm_clip = args.dp_norm_clip 86encrypt_type = args.encrypt_type 87client_password = args.client_password 88server_password = args.server_password 89enable_ssl = args.enable_ssl 90 91ctx = { 92 "enable_fl": True, 93 "server_mode": server_mode, 94 "ms_role": ms_role, 95 "worker_num": worker_num, 96 "server_num": server_num, 97 "scheduler_ip": scheduler_ip, 98 "scheduler_port": scheduler_port, 99 "fl_server_port": fl_server_port, 100 "start_fl_job_threshold": start_fl_job_threshold, 101 "start_fl_job_time_window": start_fl_job_time_window, 102 "update_model_ratio": update_model_ratio, 103 "update_model_time_window": update_model_time_window, 104 "share_secrets_ratio": share_secrets_ratio, 105 "cipher_time_window": cipher_time_window, 106 "reconstruct_secrets_threshold": reconstruct_secrets_threshold, 107 "fl_name": fl_name, 108 "fl_iteration_num": fl_iteration_num, 109 "client_epoch_num": client_epoch_num, 110 "client_batch_size": client_batch_size, 111 "client_learning_rate": client_learning_rate, 112 "scheduler_manage_port": scheduler_manage_port, 113 "config_file_path": config_file_path, 114 "dp_eps": dp_eps, 115 "dp_delta": dp_delta, 116 "dp_norm_clip": dp_norm_clip, 117 "encrypt_type": encrypt_type, 118 "client_password": client_password, 119 "server_password": server_password, 120 "enable_ssl": enable_ssl 121} 122 123context.set_context(mode=context.GRAPH_MODE, device_target=device_target) 124context.set_fl_context(**ctx) 125 126if __name__ == "__main__": 127 epoch = 5 128 np.random.seed(0) 129 network = LeNet5(62) 130 criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 131 net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) 132 net_adam_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1) 133 net_with_criterion = WithLossCell(network, criterion) 134 train_network = TrainOneStepCell(net_with_criterion, net_opt) 135 train_network.set_train() 136 losses = [] 137 138 for _ in range(epoch): 139 data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) 140 label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32)) 141 loss = train_network(data, label).asnumpy() 142 losses.append(loss) 143 print(losses) 144