1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/communicator/http_message_handler.h"
18
19 #include <event2/event.h>
20 #include <event2/buffer.h>
21 #include <event2/bufferevent.h>
22 #include <event2/bufferevent_compat.h>
23 #include <event2/http.h>
24 #include <event2/http_compat.h>
25 #include <event2/http_struct.h>
26 #include <event2/listener.h>
27 #include <event2/util.h>
28
29 #include <fcntl.h>
30 #include <unistd.h>
31 #include <cstdio>
32 #include <cstdlib>
33 #include <cstring>
34 #include <string>
35 #include <functional>
36
37 namespace mindspore {
38 namespace ps {
39 namespace core {
InitHttpMessage()40 void HttpMessageHandler::InitHttpMessage() {
41 MS_EXCEPTION_IF_NULL(event_request_);
42 event_uri_ = evhttp_request_get_evhttp_uri(event_request_);
43 MS_EXCEPTION_IF_NULL(event_uri_);
44
45 const char *query = evhttp_uri_get_query(event_uri_);
46 if (query != nullptr) {
47 MS_LOG(WARNING) << "The query is:" << query;
48 int result = evhttp_parse_query_str(query, &path_params_);
49 if (result < 0) {
50 MS_LOG(ERROR) << "Http parse query:" << query << " failed.";
51 }
52 }
53
54 head_params_ = evhttp_request_get_input_headers(event_request_);
55 resp_headers_ = evhttp_request_get_output_headers(event_request_);
56 resp_buf_ = evhttp_request_get_output_buffer(event_request_);
57 MS_EXCEPTION_IF_NULL(head_params_);
58 MS_EXCEPTION_IF_NULL(resp_headers_);
59 MS_EXCEPTION_IF_NULL(resp_buf_);
60 }
61
GetHeadParam(const std::string & key) const62 std::string HttpMessageHandler::GetHeadParam(const std::string &key) const {
63 MS_EXCEPTION_IF_NULL(head_params_);
64 const char *val = evhttp_find_header(head_params_, key.c_str());
65 MS_EXCEPTION_IF_NULL(val);
66 return std::string(val);
67 }
68
GetPathParam(const std::string & key) const69 std::string HttpMessageHandler::GetPathParam(const std::string &key) const {
70 const char *val = evhttp_find_header(&path_params_, key.c_str());
71 MS_EXCEPTION_IF_NULL(val);
72 return std::string(val);
73 }
74
ParsePostParam()75 void HttpMessageHandler::ParsePostParam() {
76 MS_EXCEPTION_IF_NULL(event_request_);
77 size_t len = evbuffer_get_length(event_request_->input_buffer);
78 if (len == 0) {
79 MS_LOG(EXCEPTION) << "The post parameter size is: " << len;
80 }
81 post_param_parsed_ = true;
82 const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1));
83 MS_EXCEPTION_IF_NULL(post_message);
84 post_message_ = std::make_unique<std::string>(post_message, len);
85 MS_EXCEPTION_IF_NULL(post_message_);
86 int ret = evhttp_parse_query_str(post_message_->c_str(), &post_params_);
87 if (ret == -1) {
88 MS_LOG(EXCEPTION) << "Parse post parameter failed!";
89 }
90 }
91
ParsePostMessageToJson()92 RequestProcessResult HttpMessageHandler::ParsePostMessageToJson() {
93 MS_EXCEPTION_IF_NULL(event_request_);
94 RequestProcessResult result(RequestProcessResultCode::kSuccess);
95 std::string message;
96
97 size_t len = evbuffer_get_length(event_request_->input_buffer);
98 if (len == 0) {
99 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message size is invalid.");
100 return result;
101 } else if (len > kMaxMessageSize) {
102 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message is bigger than 100mb.");
103 return result;
104 } else {
105 message.resize(len);
106 auto buffer = evbuffer_pullup(event_request_->input_buffer, -1);
107 if (buffer == nullptr) {
108 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Get http post message failed.");
109 return result;
110 }
111 size_t dest_size = len;
112 size_t src_size = len;
113 if (memcpy_s(message.data(), dest_size, buffer, src_size) != EOK) {
114 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Copy message failed.");
115 return result;
116 }
117
118 try {
119 request_message_ = nlohmann::json::parse(message);
120 } catch (nlohmann::json::exception &e) {
121 std::string illegal_exception = e.what();
122 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Illegal JSON format:" + illegal_exception);
123 return result;
124 }
125 }
126 return result;
127 }
128
GetPostParam(const std::string & key)129 std::string HttpMessageHandler::GetPostParam(const std::string &key) {
130 if (!post_param_parsed_) {
131 ParsePostParam();
132 }
133
134 const char *val = evhttp_find_header(&post_params_, key.c_str());
135 MS_EXCEPTION_IF_NULL(val);
136 return std::string(val);
137 }
138
GetRequestUri() const139 std::string HttpMessageHandler::GetRequestUri() const {
140 MS_EXCEPTION_IF_NULL(event_request_);
141 const char *uri = evhttp_request_get_uri(event_request_);
142 MS_EXCEPTION_IF_NULL(uri);
143 return std::string(uri);
144 }
145
GetRequestHost()146 std::string HttpMessageHandler::GetRequestHost() {
147 MS_EXCEPTION_IF_NULL(event_request_);
148 const char *host = evhttp_request_get_host(event_request_);
149 MS_EXCEPTION_IF_NULL(host);
150 return std::string(host);
151 }
152
GetHostByUri() const153 const char *HttpMessageHandler::GetHostByUri() const {
154 MS_EXCEPTION_IF_NULL(event_uri_);
155 const char *host = evhttp_uri_get_host(event_uri_);
156 MS_EXCEPTION_IF_NULL(host);
157 return host;
158 }
159
GetUriPort() const160 int HttpMessageHandler::GetUriPort() const {
161 MS_EXCEPTION_IF_NULL(event_uri_);
162 int port = evhttp_uri_get_port(event_uri_);
163 if (port < 0) {
164 MS_LOG(EXCEPTION) << "The port:" << port << " should not be less than 0!";
165 }
166 return port;
167 }
168
GetUriPath() const169 std::string HttpMessageHandler::GetUriPath() const {
170 MS_EXCEPTION_IF_NULL(event_uri_);
171 const char *path = evhttp_uri_get_path(event_uri_);
172 MS_EXCEPTION_IF_NULL(path);
173 return std::string(path);
174 }
175
GetRequestPath()176 std::string HttpMessageHandler::GetRequestPath() {
177 MS_EXCEPTION_IF_NULL(event_uri_);
178 const char *path = evhttp_uri_get_path(event_uri_);
179 if (path == nullptr || strlen(path) == 0) {
180 path = "/";
181 }
182 std::string path_res(path);
183 const char *query = evhttp_uri_get_query(event_uri_);
184 if (query != nullptr) {
185 path_res.append("?");
186 path_res.append(query);
187 }
188 return path_res;
189 }
190
GetUriQuery() const191 std::string HttpMessageHandler::GetUriQuery() const {
192 MS_EXCEPTION_IF_NULL(event_uri_);
193 const char *query = evhttp_uri_get_query(event_uri_);
194 MS_EXCEPTION_IF_NULL(query);
195 return std::string(query);
196 }
197
GetUriFragment() const198 std::string HttpMessageHandler::GetUriFragment() const {
199 MS_EXCEPTION_IF_NULL(event_uri_);
200 const char *fragment = evhttp_uri_get_fragment(event_uri_);
201 MS_EXCEPTION_IF_NULL(fragment);
202 return std::string(fragment);
203 }
204
GetPostMsg(unsigned char ** buffer)205 uint64_t HttpMessageHandler::GetPostMsg(unsigned char **buffer) {
206 MS_EXCEPTION_IF_NULL(event_request_);
207 MS_EXCEPTION_IF_NULL(buffer);
208
209 size_t len = evbuffer_get_length(event_request_->input_buffer);
210 if (len == 0) {
211 MS_LOG(EXCEPTION) << "The post message is empty!";
212 }
213 *buffer = evbuffer_pullup(event_request_->input_buffer, -1);
214 MS_EXCEPTION_IF_NULL(*buffer);
215 return len;
216 }
217
AddRespHeadParam(const std::string & key,const std::string & val)218 void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) {
219 MS_EXCEPTION_IF_NULL(resp_headers_);
220 if (evhttp_add_header(resp_headers_, key.c_str(), val.c_str()) != 0) {
221 MS_LOG(EXCEPTION) << "Add parameter of response header failed.";
222 }
223 }
224
AddRespHeaders(const HttpHeaders & headers)225 void HttpMessageHandler::AddRespHeaders(const HttpHeaders &headers) {
226 for (auto iter = headers.begin(); iter != headers.end(); ++iter) {
227 auto list = iter->second;
228 for (auto iterator_val = list.begin(); iterator_val != list.end(); ++iterator_val) {
229 AddRespHeadParam(iter->first, *iterator_val);
230 }
231 }
232 }
233
AddRespString(const std::string & str)234 void HttpMessageHandler::AddRespString(const std::string &str) {
235 MS_EXCEPTION_IF_NULL(resp_buf_);
236 if (evbuffer_add_printf(resp_buf_, "%s", str.c_str()) == -1) {
237 MS_LOG(EXCEPTION) << "Add string to response body failed.";
238 }
239 }
240
SetRespCode(int code)241 void HttpMessageHandler::SetRespCode(int code) { resp_code_ = code; }
242
SendResponse()243 void HttpMessageHandler::SendResponse() {
244 MS_EXCEPTION_IF_NULL(event_request_);
245 MS_EXCEPTION_IF_NULL(resp_buf_);
246 evhttp_send_reply(event_request_, resp_code_, "Client", resp_buf_);
247 }
248
QuickResponse(int code,const unsigned char * body,size_t len)249 void HttpMessageHandler::QuickResponse(int code, const unsigned char *body, size_t len) {
250 MS_EXCEPTION_IF_NULL(event_request_);
251 MS_EXCEPTION_IF_NULL(body);
252 MS_EXCEPTION_IF_NULL(resp_buf_);
253 if (evbuffer_add(resp_buf_, body, len) == -1) {
254 MS_LOG(EXCEPTION) << "Add body to response body failed.";
255 }
256 evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
257 }
258
SimpleResponse(int code,const HttpHeaders & headers,const std::string & body)259 void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, const std::string &body) {
260 MS_EXCEPTION_IF_NULL(event_request_);
261 MS_EXCEPTION_IF_NULL(resp_buf_);
262 AddRespHeaders(headers);
263 AddRespString(body);
264 evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
265 }
266
ErrorResponse(int code,const RequestProcessResult & result)267 void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) {
268 nlohmann::json error_json = {{"error_message", result.StatusMessage()}};
269 std::string out_error = error_json.dump();
270 AddRespString(out_error);
271 SetRespCode(code);
272 SendResponse();
273 }
274
RespError(int nCode,const std::string & message)275 void HttpMessageHandler::RespError(int nCode, const std::string &message) {
276 MS_EXCEPTION_IF_NULL(event_request_);
277 if (message.empty()) {
278 evhttp_send_error(event_request_, nCode, nullptr);
279 } else {
280 evhttp_send_error(event_request_, nCode, message.c_str());
281 }
282 }
283
ReceiveMessage(const void * buffer,size_t num)284 void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
285 MS_EXCEPTION_IF_NULL(buffer);
286 MS_EXCEPTION_IF_NULL(body_);
287 size_t dest_size = num;
288 size_t src_size = num;
289 int ret = memcpy_s(body_->data() + offset_, dest_size, buffer, src_size);
290 if (ret != 0) {
291 MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
292 }
293 offset_ += num;
294 }
295
set_content_len(const uint64_t & len)296 void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
297
content_len() const298 uint64_t HttpMessageHandler::content_len() const { return content_len_; }
299
http_base() const300 const event_base *HttpMessageHandler::http_base() const { return event_base_; }
301
set_http_base(const struct event_base * base)302 void HttpMessageHandler::set_http_base(const struct event_base *base) {
303 MS_EXCEPTION_IF_NULL(base);
304 event_base_ = const_cast<event_base *>(base);
305 }
306
set_request(const struct evhttp_request * req)307 void HttpMessageHandler::set_request(const struct evhttp_request *req) {
308 MS_EXCEPTION_IF_NULL(req);
309 event_request_ = const_cast<evhttp_request *>(req);
310 }
311
request() const312 const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; }
313
InitBodySize()314 void HttpMessageHandler::InitBodySize() {
315 MS_EXCEPTION_IF_NULL(body_);
316 body_->resize(content_len());
317 }
318
body()319 std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
320
set_body(const std::shared_ptr<std::vector<char>> & body)321 void HttpMessageHandler::set_body(const std::shared_ptr<std::vector<char>> &body) { body_ = body; }
322
request_message() const323 nlohmann::json HttpMessageHandler::request_message() const { return request_message_; }
324
ParseValueFromKey(const std::string & key,int32_t * const value)325 RequestProcessResult HttpMessageHandler::ParseValueFromKey(const std::string &key, int32_t *const value) {
326 MS_EXCEPTION_IF_NULL(value);
327 RequestProcessResult result(RequestProcessResultCode::kSuccess);
328 if (!request_message_.contains(key)) {
329 std::string message = "The json is not contain the key:" + key;
330 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
331 return result;
332 }
333
334 int32_t res = request_message_.at(key);
335 if (res < 0) {
336 std::string message = "The value should not be less than 0.";
337 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
338 return result;
339 }
340
341 if (res > 0 && key == kWorkerNum) {
342 std::string message = "The Worker does not currently support scale out.";
343 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
344 return result;
345 }
346
347 *value = res;
348 return result;
349 }
350
ParseNodeIdsFromKey(const std::string & key,std::vector<std::string> * const value)351 RequestProcessResult HttpMessageHandler::ParseNodeIdsFromKey(const std::string &key,
352 std::vector<std::string> *const value) {
353 MS_EXCEPTION_IF_NULL(value);
354 RequestProcessResult result(RequestProcessResultCode::kSuccess);
355 if (!request_message_.contains(key)) {
356 std::string message = "The json is not contain the key:" + key;
357 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
358 return result;
359 }
360 auto res = request_message_.at(key).get<std::vector<std::string>>();
361 for (const auto &val : res) {
362 MS_LOG(INFO) << "The node id is:" << val;
363 (*value).push_back(val);
364 }
365 return result;
366 }
367 } // namespace core
368 } // namespace ps
369 } // namespace mindspore
370