• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import argparse
17import ast
18import numpy as np
19
20import mindspore.context as context
21import mindspore.nn as nn
22from mindspore import Tensor
23from mindspore.common import dtype as mstype
24from mindspore.nn import WithLossCell
25from src.cell_wrapper import TrainOneStepCellWithServerCommunicator
26from src.model import LeNet5, PushMetrics
27# from src.adam import AdamWeightDecayOp
28
29parser = argparse.ArgumentParser(description="test_hybrid_train_lenet")
30parser.add_argument("--device_target", type=str, default="CPU")
31parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
32parser.add_argument("--ms_role", type=str, default="MS_WORKER")
33parser.add_argument("--worker_num", type=int, default=0)
34parser.add_argument("--server_num", type=int, default=1)
35parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
36parser.add_argument("--scheduler_port", type=int, default=8113)
37parser.add_argument("--fl_server_port", type=int, default=6666)
38parser.add_argument("--start_fl_job_threshold", type=int, default=1)
39parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
40parser.add_argument("--update_model_ratio", type=float, default=1.0)
41parser.add_argument("--update_model_time_window", type=int, default=3000)
42parser.add_argument("--fl_name", type=str, default="Lenet")
43parser.add_argument("--fl_iteration_num", type=int, default=25)
44parser.add_argument("--client_epoch_num", type=int, default=20)
45parser.add_argument("--client_batch_size", type=int, default=32)
46parser.add_argument("--client_learning_rate", type=float, default=0.1)
47parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
48parser.add_argument("--scheduler_manage_port", type=int, default=11202)
49parser.add_argument("--config_file_path", type=str, default="")
50parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
51# parameters for encrypt_type='DP_ENCRYPT'
52parser.add_argument("--dp_eps", type=float, default=50.0)
53parser.add_argument("--dp_delta", type=float, default=0.01)  # 1/worker_num
54parser.add_argument("--dp_norm_clip", type=float, default=1.0)
55# parameters for encrypt_type='PW_ENCRYPT'
56parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
57parser.add_argument("--cipher_time_window", type=int, default=300000)
58parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
59parser.add_argument("--client_password", type=str, default="")
60parser.add_argument("--server_password", type=str, default="")
61parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
62
63args, _ = parser.parse_known_args()
64device_target = args.device_target
65server_mode = args.server_mode
66ms_role = args.ms_role
67worker_num = args.worker_num
68server_num = args.server_num
69scheduler_ip = args.scheduler_ip
70scheduler_port = args.scheduler_port
71fl_server_port = args.fl_server_port
72start_fl_job_threshold = args.start_fl_job_threshold
73start_fl_job_time_window = args.start_fl_job_time_window
74update_model_ratio = args.update_model_ratio
75update_model_time_window = args.update_model_time_window
76fl_name = args.fl_name
77fl_iteration_num = args.fl_iteration_num
78client_epoch_num = args.client_epoch_num
79client_batch_size = args.client_batch_size
80client_learning_rate = args.client_learning_rate
81worker_step_num_per_iteration = args.worker_step_num_per_iteration
82scheduler_manage_port = args.scheduler_manage_port
83config_file_path = args.config_file_path
84encrypt_type = args.encrypt_type
85share_secrets_ratio = args.share_secrets_ratio
86cipher_time_window = args.cipher_time_window
87reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
88dp_eps = args.dp_eps
89dp_delta = args.dp_delta
90dp_norm_clip = args.dp_norm_clip
91client_password = args.client_password
92server_password = args.server_password
93enable_ssl = args.enable_ssl
94
95ctx = {
96    "enable_fl": True,
97    "server_mode": server_mode,
98    "ms_role": ms_role,
99    "worker_num": worker_num,
100    "server_num": server_num,
101    "scheduler_ip": scheduler_ip,
102    "scheduler_port": scheduler_port,
103    "fl_server_port": fl_server_port,
104    "start_fl_job_threshold": start_fl_job_threshold,
105    "start_fl_job_time_window": start_fl_job_time_window,
106    "update_model_ratio": update_model_ratio,
107    "update_model_time_window": update_model_time_window,
108    "fl_name": fl_name,
109    "fl_iteration_num": fl_iteration_num,
110    "client_epoch_num": client_epoch_num,
111    "client_batch_size": client_batch_size,
112    "client_learning_rate": client_learning_rate,
113    "worker_step_num_per_iteration": worker_step_num_per_iteration,
114    "scheduler_manage_port": scheduler_manage_port,
115    "config_file_path": config_file_path,
116    "share_secrets_ratio": share_secrets_ratio,
117    "cipher_time_window": cipher_time_window,
118    "reconstruct_secrets_threshold": reconstruct_secrets_threshold,
119    "dp_eps": dp_eps,
120    "dp_delta": dp_delta,
121    "dp_norm_clip": dp_norm_clip,
122    "encrypt_type": encrypt_type,
123    "client_password": client_password,
124    "server_password": server_password,
125    "enable_ssl": enable_ssl
126}
127
128context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
129context.set_fl_context(**ctx)
130
131if __name__ == "__main__":
132    epoch = 50000
133    np.random.seed(0)
134    network = LeNet5(62)
135    push_metrics = PushMetrics()
136    if context.get_fl_context("ms_role") == "MS_WORKER":
137        # Please do not freeze layers if you want to both get and overwrite these layers to servers, which is meaningless.
138        network.conv1.weight.requires_grad = False
139        network.conv2.weight.requires_grad = False
140        # Get weights before running backbone.
141        network.conv1.set_param_fl(pull_from_server=True)
142        network.conv2.set_param_fl(pull_from_server=True)
143
144        # Overwrite weights after running optimizers.
145        network.fc1.set_param_fl(push_to_server=True)
146        network.fc2.set_param_fl(push_to_server=True)
147        network.fc3.set_param_fl(push_to_server=True)
148
149    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
150    net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
151    # net_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1)
152    net_with_criterion = WithLossCell(network, criterion)
153    # train_network = TrainOneStepCell(net_with_criterion, net_opt)
154    train_network = TrainOneStepCellWithServerCommunicator(net_with_criterion, net_opt)
155    train_network.set_train()
156    losses = []
157
158    for i in range(epoch):
159        data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
160        label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
161        loss = train_network(data, label).asnumpy()
162        if context.get_fl_context("ms_role") == "MS_WORKER":
163            if (i + 1) % worker_step_num_per_iteration == 0:
164                push_metrics(Tensor(loss, mstype.float32), Tensor(loss, mstype.float32))
165        losses.append(loss)
166    print(losses)
167