• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import argparse
17import ast
18import numpy as np
19
20import mindspore.context as context
21import mindspore.nn as nn
22from mindspore import Tensor
23from mindspore.nn import TrainOneStepCell, WithLossCell
24from src.model import LeNet5
25from src.adam import AdamWeightDecayOp
26
27parser = argparse.ArgumentParser(description="test_fl_lenet")
28parser.add_argument("--device_target", type=str, default="CPU")
29parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
30parser.add_argument("--ms_role", type=str, default="MS_WORKER")
31parser.add_argument("--worker_num", type=int, default=0)
32parser.add_argument("--server_num", type=int, default=1)
33parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
34parser.add_argument("--scheduler_port", type=int, default=8113)
35parser.add_argument("--fl_server_port", type=int, default=6666)
36parser.add_argument("--start_fl_job_threshold", type=int, default=1)
37parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
38parser.add_argument("--update_model_ratio", type=float, default=1.0)
39parser.add_argument("--update_model_time_window", type=int, default=3000)
40parser.add_argument("--fl_name", type=str, default="Lenet")
41parser.add_argument("--fl_iteration_num", type=int, default=25)
42parser.add_argument("--client_epoch_num", type=int, default=20)
43parser.add_argument("--client_batch_size", type=int, default=32)
44parser.add_argument("--client_learning_rate", type=float, default=0.1)
45parser.add_argument("--scheduler_manage_port", type=int, default=11202)
46parser.add_argument("--config_file_path", type=str, default="")
47parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
48# parameters for encrypt_type='DP_ENCRYPT'
49parser.add_argument("--dp_eps", type=float, default=50.0)
50parser.add_argument("--dp_delta", type=float, default=0.01)  # 1/worker_num
51parser.add_argument("--dp_norm_clip", type=float, default=1.0)
52# parameters for encrypt_type='PW_ENCRYPT'
53parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
54parser.add_argument("--cipher_time_window", type=int, default=300000)
55parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
56parser.add_argument("--client_password", type=str, default="")
57parser.add_argument("--server_password", type=str, default="")
58parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
59
60args, _ = parser.parse_known_args()
61device_target = args.device_target
62server_mode = args.server_mode
63ms_role = args.ms_role
64worker_num = args.worker_num
65server_num = args.server_num
66scheduler_ip = args.scheduler_ip
67scheduler_port = args.scheduler_port
68fl_server_port = args.fl_server_port
69start_fl_job_threshold = args.start_fl_job_threshold
70start_fl_job_time_window = args.start_fl_job_time_window
71update_model_ratio = args.update_model_ratio
72update_model_time_window = args.update_model_time_window
73share_secrets_ratio = args.share_secrets_ratio
74cipher_time_window = args.cipher_time_window
75reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
76fl_name = args.fl_name
77fl_iteration_num = args.fl_iteration_num
78client_epoch_num = args.client_epoch_num
79client_batch_size = args.client_batch_size
80client_learning_rate = args.client_learning_rate
81scheduler_manage_port = args.scheduler_manage_port
82config_file_path = args.config_file_path
83dp_eps = args.dp_eps
84dp_delta = args.dp_delta
85dp_norm_clip = args.dp_norm_clip
86encrypt_type = args.encrypt_type
87client_password = args.client_password
88server_password = args.server_password
89enable_ssl = args.enable_ssl
90
91ctx = {
92    "enable_fl": True,
93    "server_mode": server_mode,
94    "ms_role": ms_role,
95    "worker_num": worker_num,
96    "server_num": server_num,
97    "scheduler_ip": scheduler_ip,
98    "scheduler_port": scheduler_port,
99    "fl_server_port": fl_server_port,
100    "start_fl_job_threshold": start_fl_job_threshold,
101    "start_fl_job_time_window": start_fl_job_time_window,
102    "update_model_ratio": update_model_ratio,
103    "update_model_time_window": update_model_time_window,
104    "share_secrets_ratio": share_secrets_ratio,
105    "cipher_time_window": cipher_time_window,
106    "reconstruct_secrets_threshold": reconstruct_secrets_threshold,
107    "fl_name": fl_name,
108    "fl_iteration_num": fl_iteration_num,
109    "client_epoch_num": client_epoch_num,
110    "client_batch_size": client_batch_size,
111    "client_learning_rate": client_learning_rate,
112    "scheduler_manage_port": scheduler_manage_port,
113    "config_file_path": config_file_path,
114    "dp_eps": dp_eps,
115    "dp_delta": dp_delta,
116    "dp_norm_clip": dp_norm_clip,
117    "encrypt_type": encrypt_type,
118    "client_password": client_password,
119    "server_password": server_password,
120    "enable_ssl": enable_ssl
121}
122
123context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
124context.set_fl_context(**ctx)
125
126if __name__ == "__main__":
127    epoch = 5
128    np.random.seed(0)
129    network = LeNet5(62)
130    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
131    net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
132    net_adam_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1)
133    net_with_criterion = WithLossCell(network, criterion)
134    train_network = TrainOneStepCell(net_with_criterion, net_opt)
135    train_network.set_train()
136    losses = []
137
138    for _ in range(epoch):
139        data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
140        label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
141        loss = train_network(data, label).asnumpy()
142        losses.append(loss)
143    print(losses)
144