1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import argparse 17import time 18import datetime 19import random 20import sys 21import requests 22import flatbuffers 23import numpy as np 24from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode, 25 RequestUpdateModel, ResponseUpdateModel, 26 FeatureMap, RequestGetModel, ResponseGetModel) 27 28parser = argparse.ArgumentParser() 29parser.add_argument("--pid", type=int, default=0) 30parser.add_argument("--http_ip", type=str, default="10.113.216.106") 31parser.add_argument("--http_port", type=int, default=6666) 32parser.add_argument("--use_elb", type=bool, default=False) 33parser.add_argument("--server_num", type=int, default=1) 34 35args, _ = parser.parse_known_args() 36pid = args.pid 37http_ip = args.http_ip 38http_port = args.http_port 39use_elb = args.use_elb 40server_num = args.server_num 41 42str_fl_id = 'fl_lenet_' + str(pid) 43 44server_not_available_rsp = ["The cluster is in safemode.", 45 "The server's training job is disabled or finished."] 46 47def generate_port(): 48 if not use_elb: 49 return http_port 50 port = random.randint(0, 100000) % server_num + http_port 51 return port 52 53 54def build_start_fl_job(): 55 start_fl_job_builder = flatbuffers.Builder(1024) 56 57 fl_name = start_fl_job_builder.CreateString('fl_test_job') 58 fl_id = start_fl_job_builder.CreateString(str_fl_id) 59 data_size = 32 60 timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18') 61 62 RequestFLJob.RequestFLJobStart(start_fl_job_builder) 63 RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name) 64 RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id) 65 RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size) 66 RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp) 67 fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder) 68 69 start_fl_job_builder.Finish(fl_job_req) 70 buf = start_fl_job_builder.Output() 71 return buf 72 73def build_feature_map(builder, names, lengths): 74 if len(names) != len(lengths): 75 return None 76 feature_maps = [] 77 np_data = [] 78 for j, _ in enumerate(names): 79 name = names[j] 80 length = lengths[j] 81 weight_full_name = builder.CreateString(name) 82 FeatureMap.FeatureMapStartDataVector(builder, length) 83 weight = np.random.rand(length) * 32 84 np_data.append(weight) 85 for idx in range(length - 1, -1, -1): 86 builder.PrependFloat32(weight[idx]) 87 data = builder.EndVector(length) 88 FeatureMap.FeatureMapStart(builder) 89 FeatureMap.FeatureMapAddData(builder, data) 90 FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name) 91 feature_map = FeatureMap.FeatureMapEnd(builder) 92 feature_maps.append(feature_map) 93 return feature_maps, np_data 94 95def build_update_model(iteration): 96 builder_update_model = flatbuffers.Builder(1) 97 fl_name = builder_update_model.CreateString('fl_test_job') 98 fl_id = builder_update_model.CreateString(str_fl_id) 99 timestamp = builder_update_model.CreateString('2020/11/16/19/18') 100 101 feature_maps, np_data = build_feature_map(builder_update_model, 102 ["conv1.weight", "conv2.weight", "fc1.weight", 103 "fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"], 104 [450, 2400, 48000, 10080, 5208, 120, 84, 62]) 105 106 RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1) 107 for single_feature_map in feature_maps: 108 builder_update_model.PrependUOffsetTRelative(single_feature_map) 109 feature_map = builder_update_model.EndVector(len(feature_maps)) 110 111 RequestUpdateModel.RequestUpdateModelStart(builder_update_model) 112 RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name) 113 RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id) 114 RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration) 115 RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map) 116 RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp) 117 req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model) 118 builder_update_model.Finish(req_update_model) 119 buf = builder_update_model.Output() 120 return buf, np_data 121 122def build_get_model(iteration): 123 builder_get_model = flatbuffers.Builder(1) 124 fl_name = builder_get_model.CreateString('fl_test_job') 125 timestamp = builder_get_model.CreateString('2020/12/16/19/18') 126 127 RequestGetModel.RequestGetModelStart(builder_get_model) 128 RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name) 129 RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration) 130 RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp) 131 req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model) 132 builder_get_model.Finish(req_get_model) 133 buf = builder_get_model.Output() 134 return buf 135 136def datetime_to_timestamp(datetime_obj): 137 local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0 138 return local_timestamp 139 140weight_to_idx = { 141 "conv1.weight": 0, 142 "conv2.weight": 1, 143 "fc1.weight": 2, 144 "fc2.weight": 3, 145 "fc3.weight": 4, 146 "fc1.bias": 5, 147 "fc2.bias": 6, 148 "fc3.bias": 7 149} 150 151session = requests.Session() 152current_iteration = 1 153np.random.seed(0) 154 155def start_fl_job(): 156 start_fl_job_result = {} 157 iteration = 0 158 url = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob' 159 print("Start fl job url is ", url) 160 161 x = session.post(url, data=build_start_fl_job()) 162 if x.text in server_not_available_rsp: 163 start_fl_job_result['reason'] = "Restart iteration." 164 start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500 165 print("Start fl job when safemode.") 166 return start_fl_job_result, iteration 167 168 rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) 169 iteration = rsp_fl_job.Iteration() 170 if rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED: 171 if rsp_fl_job.Retcode() == ResponseCode.ResponseCode.OutOfTime: 172 start_fl_job_result['reason'] = "Restart iteration." 173 start_fl_job_result['next_ts'] = int(rsp_fl_job.NextReqTime().decode('utf-8')) 174 print("Start fl job out of time. Next request at ", 175 start_fl_job_result['next_ts'], "reason:", rsp_fl_job.Reason()) 176 else: 177 print("Start fl job failed, return code is ", rsp_fl_job.Retcode()) 178 sys.exit() 179 else: 180 start_fl_job_result['reason'] = "Success" 181 start_fl_job_result['next_ts'] = 0 182 return start_fl_job_result, iteration 183 184def update_model(iteration): 185 update_model_result = {} 186 187 url = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel' 188 print("Update model url:", url, ", iteration:", iteration) 189 update_model_buf, update_model_np_data = build_update_model(iteration) 190 x = session.post(url, data=update_model_buf) 191 if x.text in server_not_available_rsp: 192 update_model_result['reason'] = "Restart iteration." 193 update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500 194 print("Update model when safemode.") 195 return update_model_result, update_model_np_data 196 197 rsp_update_model = ResponseUpdateModel.ResponseUpdateModel.GetRootAsResponseUpdateModel(x.content, 0) 198 if rsp_update_model.Retcode() != ResponseCode.ResponseCode.SUCCEED: 199 if rsp_update_model.Retcode() == ResponseCode.ResponseCode.OutOfTime: 200 update_model_result['reason'] = "Restart iteration." 201 update_model_result['next_ts'] = int(rsp_update_model.NextReqTime().decode('utf-8')) 202 print("Update model out of time. Next request at ", 203 update_model_result['next_ts'], "reason:", rsp_update_model.Reason()) 204 else: 205 print("Update model failed, return code is ", rsp_update_model.Retcode()) 206 sys.exit() 207 else: 208 update_model_result['reason'] = "Success" 209 update_model_result['next_ts'] = 0 210 return update_model_result, update_model_np_data 211 212def get_model(iteration, update_model_data): 213 get_model_result = {} 214 215 url = "http://" + http_ip + ":" + str(generate_port()) + '/getModel' 216 print("Get model url:", url, ", iteration:", iteration) 217 218 while True: 219 x = session.post(url, data=build_get_model(iteration)) 220 if x.text in server_not_available_rsp: 221 print("Get model when safemode.") 222 time.sleep(0.5) 223 continue 224 225 rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0) 226 ret_code = rsp_get_model.Retcode() 227 if ret_code == ResponseCode.ResponseCode.SUCCEED: 228 break 229 elif ret_code == ResponseCode.ResponseCode.SucNotReady: 230 time.sleep(0.5) 231 continue 232 else: 233 print("Get model failed, return code is ", rsp_get_model.Retcode()) 234 sys.exit() 235 236 for i in range(0, 1): 237 print(rsp_get_model.FeatureMap(i).WeightFullname()) 238 origin = update_model_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]] 239 after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32 240 print("Before update model", args.pid, origin[0:10]) 241 print("After get model", args.pid, after[0:10]) 242 sys.stdout.flush() 243 244 get_model_result['reason'] = "Success" 245 get_model_result['next_ts'] = 0 246 return get_model_result 247 248 249while True: 250 result, current_iteration = start_fl_job() 251 sys.stdout.flush() 252 if result['reason'] == "Restart iteration.": 253 current_ts = datetime_to_timestamp(datetime.datetime.now()) 254 duration = result['next_ts'] - current_ts 255 if duration >= 0: 256 time.sleep(duration / 1000) 257 continue 258 259 result, update_data = update_model(current_iteration) 260 sys.stdout.flush() 261 if result['reason'] == "Restart iteration.": 262 current_ts = datetime_to_timestamp(datetime.datetime.now()) 263 duration = result['next_ts'] - current_ts 264 if duration >= 0: 265 time.sleep(duration / 1000) 266 continue 267 268 result = get_model(current_iteration, update_data) 269 sys.stdout.flush() 270 if result['reason'] == "Restart iteration.": 271 current_ts = datetime_to_timestamp(datetime.datetime.now()) 272 duration = result['next_ts'] - current_ts 273 if duration >= 0: 274 time.sleep(duration / 1000) 275 continue 276 277 print("") 278 sys.stdout.flush() 279