1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/communicator/http_message_handler.h"
18
19 #include <event2/event.h>
20 #include <event2/buffer.h>
21 #include <event2/bufferevent.h>
22 #include <event2/bufferevent_compat.h>
23 #include <event2/http.h>
24 #include <event2/http_compat.h>
25 #include <event2/http_struct.h>
26 #include <event2/listener.h>
27 #include <event2/util.h>
28
29 #include <fcntl.h>
30 #include <unistd.h>
31 #include <cstdio>
32 #include <cstdlib>
33 #include <cstring>
34 #include <string>
35 #include <functional>
36
37 namespace mindspore {
38 namespace ps {
39 namespace core {
InitHttpMessage()40 void HttpMessageHandler::InitHttpMessage() {
41 MS_EXCEPTION_IF_NULL(event_request_);
42 event_uri_ = evhttp_request_get_evhttp_uri(event_request_);
43 MS_EXCEPTION_IF_NULL(event_uri_);
44
45 const char *query = evhttp_uri_get_query(event_uri_);
46 if (query != nullptr) {
47 MS_LOG(WARNING) << "The query is:" << query;
48 int result = evhttp_parse_query_str(query, &path_params_);
49 if (result < 0) {
50 MS_LOG(ERROR) << "Http parse query:" << query << " failed.";
51 }
52 }
53
54 head_params_ = evhttp_request_get_input_headers(event_request_);
55 resp_headers_ = evhttp_request_get_output_headers(event_request_);
56 resp_buf_ = evhttp_request_get_output_buffer(event_request_);
57 MS_EXCEPTION_IF_NULL(head_params_);
58 MS_EXCEPTION_IF_NULL(resp_headers_);
59 MS_EXCEPTION_IF_NULL(resp_buf_);
60 }
61
GetHeadParam(const std::string & key) const62 std::string HttpMessageHandler::GetHeadParam(const std::string &key) const {
63 MS_EXCEPTION_IF_NULL(head_params_);
64 const char *val = evhttp_find_header(head_params_, key.c_str());
65 MS_EXCEPTION_IF_NULL(val);
66 return std::string(val);
67 }
68
GetPathParam(const std::string & key) const69 std::string HttpMessageHandler::GetPathParam(const std::string &key) const {
70 const char *val = evhttp_find_header(&path_params_, key.c_str());
71 MS_EXCEPTION_IF_NULL(val);
72 return std::string(val);
73 }
74
ParsePostParam()75 void HttpMessageHandler::ParsePostParam() {
76 MS_EXCEPTION_IF_NULL(event_request_);
77 size_t len = evbuffer_get_length(event_request_->input_buffer);
78 if (len == 0) {
79 MS_LOG(EXCEPTION) << "The post parameter size is: " << len;
80 }
81 post_param_parsed_ = true;
82 const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1));
83 MS_EXCEPTION_IF_NULL(post_message);
84 post_message_ = std::make_unique<std::string>(post_message, len);
85 MS_EXCEPTION_IF_NULL(post_message_);
86 int ret = evhttp_parse_query_str(post_message_->c_str(), &post_params_);
87 if (ret == -1) {
88 MS_LOG(EXCEPTION) << "Parse post parameter failed!";
89 }
90 }
91
ParsePostMessageToJson()92 RequestProcessResult HttpMessageHandler::ParsePostMessageToJson() {
93 MS_EXCEPTION_IF_NULL(event_request_);
94 RequestProcessResult result(RequestProcessResultCode::kSuccess);
95 std::string message;
96
97 size_t len = evbuffer_get_length(event_request_->input_buffer);
98 if (len == 0) {
99 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message size is invalid.");
100 return result;
101 } else if (len > kMaxMessageSize) {
102 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message is bigger than 100mb.");
103 return result;
104 } else {
105 message.resize(len);
106 auto buffer = evbuffer_pullup(event_request_->input_buffer, -1);
107 if (buffer == nullptr) {
108 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Get http post message failed.");
109 return result;
110 }
111 size_t dest_size = len;
112 size_t src_size = len;
113 if (memcpy_s(message.data(), dest_size, buffer, src_size) != EOK) {
114 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Copy message failed.");
115 return result;
116 }
117
118 try {
119 request_message_ = nlohmann::json::parse(message);
120 } catch (nlohmann::json::exception &e) {
121 std::string illegal_exception = e.what();
122 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Illegal JSON format:" + illegal_exception);
123 return result;
124 }
125 }
126 return result;
127 }
128
GetPostParam(const std::string & key)129 std::string HttpMessageHandler::GetPostParam(const std::string &key) {
130 if (!post_param_parsed_) {
131 ParsePostParam();
132 }
133
134 const char *val = evhttp_find_header(&post_params_, key.c_str());
135 MS_EXCEPTION_IF_NULL(val);
136 return std::string(val);
137 }
138
GetRequestUri() const139 std::string HttpMessageHandler::GetRequestUri() const {
140 MS_EXCEPTION_IF_NULL(event_request_);
141 const char *uri = evhttp_request_get_uri(event_request_);
142 MS_EXCEPTION_IF_NULL(uri);
143 return std::string(uri);
144 }
145
GetRequestHost()146 std::string HttpMessageHandler::GetRequestHost() {
147 MS_EXCEPTION_IF_NULL(event_request_);
148 const char *host = evhttp_request_get_host(event_request_);
149 MS_EXCEPTION_IF_NULL(host);
150 return std::string(host);
151 }
152
GetHostByUri() const153 const char *HttpMessageHandler::GetHostByUri() const {
154 MS_EXCEPTION_IF_NULL(event_uri_);
155 const char *host = evhttp_uri_get_host(event_uri_);
156 MS_EXCEPTION_IF_NULL(host);
157 return host;
158 }
159
GetUriPort() const160 int HttpMessageHandler::GetUriPort() const {
161 MS_EXCEPTION_IF_NULL(event_uri_);
162 int port = evhttp_uri_get_port(event_uri_);
163 if (port < 0) {
164 MS_LOG(EXCEPTION) << "The port:" << port << " should not be less than 0!";
165 }
166 return port;
167 }
168
GetUriPath() const169 std::string HttpMessageHandler::GetUriPath() const {
170 MS_EXCEPTION_IF_NULL(event_uri_);
171 const char *path = evhttp_uri_get_path(event_uri_);
172 MS_EXCEPTION_IF_NULL(path);
173 return std::string(path);
174 }
175
GetRequestPath()176 std::string HttpMessageHandler::GetRequestPath() {
177 MS_EXCEPTION_IF_NULL(event_uri_);
178 const char *path = evhttp_uri_get_path(event_uri_);
179 if (path == nullptr || strlen(path) == 0) {
180 path = "/";
181 }
182 std::string path_res(path);
183 const char *query = evhttp_uri_get_query(event_uri_);
184 if (query != nullptr) {
185 path_res.append("?");
186 path_res.append(query);
187 }
188 return path_res;
189 }
190
GetUriQuery() const191 std::string HttpMessageHandler::GetUriQuery() const {
192 MS_EXCEPTION_IF_NULL(event_uri_);
193 const char *query = evhttp_uri_get_query(event_uri_);
194 MS_EXCEPTION_IF_NULL(query);
195 return std::string(query);
196 }
197
GetUriFragment() const198 std::string HttpMessageHandler::GetUriFragment() const {
199 MS_EXCEPTION_IF_NULL(event_uri_);
200 const char *fragment = evhttp_uri_get_fragment(event_uri_);
201 MS_EXCEPTION_IF_NULL(fragment);
202 return std::string(fragment);
203 }
204
GetPostMsg(size_t * len,uint8_t ** buffer)205 bool HttpMessageHandler::GetPostMsg(size_t *len, uint8_t **buffer) {
206 MS_EXCEPTION_IF_NULL(event_request_);
207 if (len == nullptr || buffer == nullptr) {
208 MS_LOG(ERROR) << "Input parameter len or buffer cannot be nullptr";
209 return false;
210 }
211 *len = evbuffer_get_length(event_request_->input_buffer);
212 const size_t max_http_bytes_len = UINT32_MAX; // 4GB
213 if (*len == 0 || *len > max_http_bytes_len) {
214 MS_LOG(ERROR) << "The post message length " << *len << " is invalid!";
215 return false;
216 }
217 *buffer = evbuffer_pullup(event_request_->input_buffer, -1);
218 if (*buffer == nullptr) {
219 MS_LOG(ERROR) << "Failed to pull post message buffer!";
220 return false;
221 }
222 return true;
223 }
224
AddRespHeadParam(const std::string & key,const std::string & val)225 void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) {
226 MS_EXCEPTION_IF_NULL(resp_headers_);
227 if (evhttp_add_header(resp_headers_, key.c_str(), val.c_str()) != 0) {
228 MS_LOG(EXCEPTION) << "Add parameter of response header failed.";
229 }
230 }
231
AddRespHeaders(const HttpHeaders & headers)232 void HttpMessageHandler::AddRespHeaders(const HttpHeaders &headers) {
233 for (auto iter = headers.begin(); iter != headers.end(); ++iter) {
234 auto list = iter->second;
235 for (auto iterator_val = list.begin(); iterator_val != list.end(); ++iterator_val) {
236 AddRespHeadParam(iter->first, *iterator_val);
237 }
238 }
239 }
240
AddRespString(const std::string & str)241 void HttpMessageHandler::AddRespString(const std::string &str) {
242 MS_EXCEPTION_IF_NULL(resp_buf_);
243 if (evbuffer_add_printf(resp_buf_, "%s", str.c_str()) == -1) {
244 MS_LOG(EXCEPTION) << "Add string to response body failed.";
245 }
246 }
247
SetRespCode(int code)248 void HttpMessageHandler::SetRespCode(int code) { resp_code_ = code; }
249
SendResponse()250 void HttpMessageHandler::SendResponse() {
251 MS_EXCEPTION_IF_NULL(event_request_);
252 MS_EXCEPTION_IF_NULL(resp_buf_);
253 evhttp_send_reply(event_request_, resp_code_, "Client", resp_buf_);
254 }
255
QuickResponse(int code,const void * body,size_t len)256 void HttpMessageHandler::QuickResponse(int code, const void *body, size_t len) {
257 MS_EXCEPTION_IF_NULL(event_request_);
258 MS_EXCEPTION_IF_NULL(body);
259 MS_EXCEPTION_IF_NULL(resp_buf_);
260 auto ret = evbuffer_add(resp_buf_, body, len);
261 if (ret == -1) {
262 MS_LOG(WARNING) << "Add body to response body failed.";
263 return;
264 }
265 evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
266 }
267
QuickResponseInference(int code,const void * body,size_t len,evbuffer_ref_cleanup_cb cb)268 void HttpMessageHandler::QuickResponseInference(int code, const void *body, size_t len, evbuffer_ref_cleanup_cb cb) {
269 MS_EXCEPTION_IF_NULL(event_request_);
270 MS_EXCEPTION_IF_NULL(body);
271 MS_EXCEPTION_IF_NULL(resp_buf_);
272 auto ret = evbuffer_add_reference(resp_buf_, body, len, cb, nullptr);
273 if (ret == -1) { // -1 if an error occurred
274 MS_LOG(WARNING) << "Add body to response body failed.";
275 if (cb != nullptr) {
276 cb(body, len, nullptr);
277 }
278 return;
279 }
280 evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
281 }
282
SimpleResponse(int code,const HttpHeaders & headers,const std::string & body)283 void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, const std::string &body) {
284 MS_EXCEPTION_IF_NULL(event_request_);
285 MS_EXCEPTION_IF_NULL(resp_buf_);
286 AddRespHeaders(headers);
287 AddRespString(body);
288 evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
289 }
290
ErrorResponse(int code,const RequestProcessResult & result)291 void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) {
292 nlohmann::json error_json = {{"error_message", result.StatusMessage()}, {"code", kErrorCode}};
293 std::string out_error = error_json.dump();
294 AddRespString(out_error);
295 SetRespCode(code);
296 SendResponse();
297 }
298
RespError(int nCode,const std::string & message)299 void HttpMessageHandler::RespError(int nCode, const std::string &message) {
300 MS_EXCEPTION_IF_NULL(event_request_);
301 if (message.empty()) {
302 evhttp_send_error(event_request_, nCode, nullptr);
303 } else {
304 evhttp_send_error(event_request_, nCode, message.c_str());
305 }
306 }
307
ReceiveMessage(const void * buffer,size_t num)308 void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
309 MS_EXCEPTION_IF_NULL(buffer);
310 MS_EXCEPTION_IF_NULL(body_);
311 size_t dest_size = num;
312 size_t src_size = num;
313 int ret = memcpy_s(body_->data() + offset_, dest_size, buffer, src_size);
314 if (ret != EOK) {
315 MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
316 }
317 offset_ += num;
318 }
319
set_content_len(const uint64_t & len)320 void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
321
content_len() const322 uint64_t HttpMessageHandler::content_len() const { return content_len_; }
323
http_base() const324 const event_base *HttpMessageHandler::http_base() const { return event_base_; }
325
set_http_base(const struct event_base * base)326 void HttpMessageHandler::set_http_base(const struct event_base *base) {
327 MS_EXCEPTION_IF_NULL(base);
328 event_base_ = const_cast<event_base *>(base);
329 }
330
set_request(const struct evhttp_request * req)331 void HttpMessageHandler::set_request(const struct evhttp_request *req) {
332 MS_EXCEPTION_IF_NULL(req);
333 event_request_ = const_cast<evhttp_request *>(req);
334 }
335
request() const336 const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; }
337
InitBodySize()338 void HttpMessageHandler::InitBodySize() {
339 MS_EXCEPTION_IF_NULL(body_);
340 body_->resize(content_len());
341 }
342
body()343 std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
344
set_body(const std::shared_ptr<std::vector<char>> & body)345 void HttpMessageHandler::set_body(const std::shared_ptr<std::vector<char>> &body) { body_ = body; }
346
request_message() const347 nlohmann::json HttpMessageHandler::request_message() const { return request_message_; }
348
ParseValueFromKey(const std::string & key,uint32_t * const value)349 RequestProcessResult HttpMessageHandler::ParseValueFromKey(const std::string &key, uint32_t *const value) {
350 MS_EXCEPTION_IF_NULL(value);
351 RequestProcessResult result(RequestProcessResultCode::kSuccess);
352 if (!request_message_.contains(key)) {
353 std::string message = "The json is not contain the key:" + key;
354 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
355 return result;
356 }
357
358 int32_t res = IntToUint(request_message_.at(key));
359 if (res < 0) {
360 std::string message = "The value should not be less than 0.";
361 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
362 return result;
363 }
364
365 if (res > 0 && key == kWorkerNum) {
366 std::string message = "The Worker does not currently support scale out.";
367 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
368 return result;
369 }
370
371 *value = res;
372 return result;
373 }
374
ParseNodeIdsFromKey(const std::string & key,std::vector<std::string> * const value)375 RequestProcessResult HttpMessageHandler::ParseNodeIdsFromKey(const std::string &key,
376 std::vector<std::string> *const value) {
377 MS_EXCEPTION_IF_NULL(value);
378 RequestProcessResult result(RequestProcessResultCode::kSuccess);
379 if (!request_message_.contains(key)) {
380 std::string message = "The json is not contain the key:" + key;
381 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
382 return result;
383 }
384 auto res = request_message_.at(key).get<std::vector<std::string>>();
385 for (const auto &val : res) {
386 MS_LOG(INFO) << "The node id is:" << val;
387 (*value).push_back(val);
388 }
389 return result;
390 }
391 } // namespace core
392 } // namespace ps
393 } // namespace mindspore
394