• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import argparse
17import time
18import datetime
19import random
20import sys
21import requests
22import flatbuffers
23import numpy as np
24from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
25                              RequestUpdateModel, ResponseUpdateModel,
26                              FeatureMap, RequestGetModel, ResponseGetModel)
27
28parser = argparse.ArgumentParser()
29parser.add_argument("--pid", type=int, default=0)
30parser.add_argument("--http_ip", type=str, default="10.113.216.106")
31parser.add_argument("--http_port", type=int, default=6666)
32parser.add_argument("--use_elb", type=bool, default=False)
33parser.add_argument("--server_num", type=int, default=1)
34
35args, _ = parser.parse_known_args()
36pid = args.pid
37http_ip = args.http_ip
38http_port = args.http_port
39use_elb = args.use_elb
40server_num = args.server_num
41
42str_fl_id = 'fl_lenet_' + str(pid)
43
44server_not_available_rsp = ["The cluster is in safemode.",
45                            "The server's training job is disabled or finished."]
46
47def generate_port():
48    if not use_elb:
49        return http_port
50    port = random.randint(0, 100000) % server_num + http_port
51    return port
52
53
54def build_start_fl_job():
55    start_fl_job_builder = flatbuffers.Builder(1024)
56
57    fl_name = start_fl_job_builder.CreateString('fl_test_job')
58    fl_id = start_fl_job_builder.CreateString(str_fl_id)
59    data_size = 32
60    timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
61
62    RequestFLJob.RequestFLJobStart(start_fl_job_builder)
63    RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
64    RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
65    RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
66    RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
67    fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
68
69    start_fl_job_builder.Finish(fl_job_req)
70    buf = start_fl_job_builder.Output()
71    return buf
72
73def build_feature_map(builder, names, lengths):
74    if len(names) != len(lengths):
75        return None
76    feature_maps = []
77    np_data = []
78    for j, _ in enumerate(names):
79        name = names[j]
80        length = lengths[j]
81        weight_full_name = builder.CreateString(name)
82        FeatureMap.FeatureMapStartDataVector(builder, length)
83        weight = np.random.rand(length) * 32
84        np_data.append(weight)
85        for idx in range(length - 1, -1, -1):
86            builder.PrependFloat32(weight[idx])
87        data = builder.EndVector(length)
88        FeatureMap.FeatureMapStart(builder)
89        FeatureMap.FeatureMapAddData(builder, data)
90        FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
91        feature_map = FeatureMap.FeatureMapEnd(builder)
92        feature_maps.append(feature_map)
93    return feature_maps, np_data
94
95def build_update_model(iteration):
96    builder_update_model = flatbuffers.Builder(1)
97    fl_name = builder_update_model.CreateString('fl_test_job')
98    fl_id = builder_update_model.CreateString(str_fl_id)
99    timestamp = builder_update_model.CreateString('2020/11/16/19/18')
100
101    feature_maps, np_data = build_feature_map(builder_update_model,
102                                              ["conv1.weight", "conv2.weight", "fc1.weight",
103                                               "fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
104                                              [450, 2400, 48000, 10080, 5208, 120, 84, 62])
105
106    RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
107    for single_feature_map in feature_maps:
108        builder_update_model.PrependUOffsetTRelative(single_feature_map)
109    feature_map = builder_update_model.EndVector(len(feature_maps))
110
111    RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
112    RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
113    RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
114    RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
115    RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
116    RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
117    req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
118    builder_update_model.Finish(req_update_model)
119    buf = builder_update_model.Output()
120    return buf, np_data
121
122def build_get_model(iteration):
123    builder_get_model = flatbuffers.Builder(1)
124    fl_name = builder_get_model.CreateString('fl_test_job')
125    timestamp = builder_get_model.CreateString('2020/12/16/19/18')
126
127    RequestGetModel.RequestGetModelStart(builder_get_model)
128    RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
129    RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
130    RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
131    req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
132    builder_get_model.Finish(req_get_model)
133    buf = builder_get_model.Output()
134    return buf
135
136def datetime_to_timestamp(datetime_obj):
137    local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
138    return local_timestamp
139
140weight_to_idx = {
141    "conv1.weight": 0,
142    "conv2.weight": 1,
143    "fc1.weight": 2,
144    "fc2.weight": 3,
145    "fc3.weight": 4,
146    "fc1.bias": 5,
147    "fc2.bias": 6,
148    "fc3.bias": 7
149}
150
151session = requests.Session()
152current_iteration = 1
153np.random.seed(0)
154
155def start_fl_job():
156    start_fl_job_result = {}
157    iteration = 0
158    url = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
159    print("Start fl job url is ", url)
160
161    x = session.post(url, data=build_start_fl_job())
162    if x.text in server_not_available_rsp:
163        start_fl_job_result['reason'] = "Restart iteration."
164        start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
165        print("Start fl job when safemode.")
166        return start_fl_job_result, iteration
167
168    rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
169    iteration = rsp_fl_job.Iteration()
170    if rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
171        if rsp_fl_job.Retcode() == ResponseCode.ResponseCode.OutOfTime:
172            start_fl_job_result['reason'] = "Restart iteration."
173            start_fl_job_result['next_ts'] = int(rsp_fl_job.NextReqTime().decode('utf-8'))
174            print("Start fl job out of time. Next request at ",
175                  start_fl_job_result['next_ts'], "reason:", rsp_fl_job.Reason())
176        else:
177            print("Start fl job failed, return code is ", rsp_fl_job.Retcode())
178            sys.exit()
179    else:
180        start_fl_job_result['reason'] = "Success"
181        start_fl_job_result['next_ts'] = 0
182    return start_fl_job_result, iteration
183
184def update_model(iteration):
185    update_model_result = {}
186
187    url = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
188    print("Update model url:", url, ", iteration:", iteration)
189    update_model_buf, update_model_np_data = build_update_model(iteration)
190    x = session.post(url, data=update_model_buf)
191    if x.text in server_not_available_rsp:
192        update_model_result['reason'] = "Restart iteration."
193        update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
194        print("Update model when safemode.")
195        return update_model_result, update_model_np_data
196
197    rsp_update_model = ResponseUpdateModel.ResponseUpdateModel.GetRootAsResponseUpdateModel(x.content, 0)
198    if rsp_update_model.Retcode() != ResponseCode.ResponseCode.SUCCEED:
199        if rsp_update_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
200            update_model_result['reason'] = "Restart iteration."
201            update_model_result['next_ts'] = int(rsp_update_model.NextReqTime().decode('utf-8'))
202            print("Update model out of time. Next request at ",
203                  update_model_result['next_ts'], "reason:", rsp_update_model.Reason())
204        else:
205            print("Update model failed, return code is ", rsp_update_model.Retcode())
206            sys.exit()
207    else:
208        update_model_result['reason'] = "Success"
209        update_model_result['next_ts'] = 0
210    return update_model_result, update_model_np_data
211
212def get_model(iteration, update_model_data):
213    get_model_result = {}
214
215    url = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
216    print("Get model url:", url, ", iteration:", iteration)
217
218    while True:
219        x = session.post(url, data=build_get_model(iteration))
220        if x.text in server_not_available_rsp:
221            print("Get model when safemode.")
222            time.sleep(0.5)
223            continue
224
225        rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
226        ret_code = rsp_get_model.Retcode()
227        if ret_code == ResponseCode.ResponseCode.SUCCEED:
228            break
229        elif ret_code == ResponseCode.ResponseCode.SucNotReady:
230            time.sleep(0.5)
231            continue
232        else:
233            print("Get model failed, return code is ", rsp_get_model.Retcode())
234            sys.exit()
235
236    for i in range(0, 1):
237        print(rsp_get_model.FeatureMap(i).WeightFullname())
238        origin = update_model_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
239        after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
240        print("Before update model", args.pid, origin[0:10])
241        print("After get model", args.pid, after[0:10])
242        sys.stdout.flush()
243
244    get_model_result['reason'] = "Success"
245    get_model_result['next_ts'] = 0
246    return get_model_result
247
248
249while True:
250    result, current_iteration = start_fl_job()
251    sys.stdout.flush()
252    if result['reason'] == "Restart iteration.":
253        current_ts = datetime_to_timestamp(datetime.datetime.now())
254        duration = result['next_ts'] - current_ts
255        if duration >= 0:
256            time.sleep(duration / 1000)
257        continue
258
259    result, update_data = update_model(current_iteration)
260    sys.stdout.flush()
261    if result['reason'] == "Restart iteration.":
262        current_ts = datetime_to_timestamp(datetime.datetime.now())
263        duration = result['next_ts'] - current_ts
264        if duration >= 0:
265            time.sleep(duration / 1000)
266        continue
267
268    result = get_model(current_iteration, update_data)
269    sys.stdout.flush()
270    if result['reason'] == "Restart iteration.":
271        current_ts = datetime_to_timestamp(datetime.datetime.now())
272        duration = result['next_ts'] - current_ts
273        if duration >= 0:
274            time.sleep(duration / 1000)
275        continue
276
277    print("")
278    sys.stdout.flush()
279