• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/communicator/http_message_handler.h"
18 
19 #include <event2/event.h>
20 #include <event2/buffer.h>
21 #include <event2/bufferevent.h>
22 #include <event2/bufferevent_compat.h>
23 #include <event2/http.h>
24 #include <event2/http_compat.h>
25 #include <event2/http_struct.h>
26 #include <event2/listener.h>
27 #include <event2/util.h>
28 
29 #include <fcntl.h>
30 #include <unistd.h>
31 #include <cstdio>
32 #include <cstdlib>
33 #include <cstring>
34 #include <string>
35 #include <functional>
36 
37 namespace mindspore {
38 namespace ps {
39 namespace core {
InitHttpMessage()40 void HttpMessageHandler::InitHttpMessage() {
41   MS_EXCEPTION_IF_NULL(event_request_);
42   event_uri_ = evhttp_request_get_evhttp_uri(event_request_);
43   MS_EXCEPTION_IF_NULL(event_uri_);
44 
45   const char *query = evhttp_uri_get_query(event_uri_);
46   if (query != nullptr) {
47     MS_LOG(WARNING) << "The query is:" << query;
48     int result = evhttp_parse_query_str(query, &path_params_);
49     if (result < 0) {
50       MS_LOG(ERROR) << "Http parse query:" << query << " failed.";
51     }
52   }
53 
54   head_params_ = evhttp_request_get_input_headers(event_request_);
55   resp_headers_ = evhttp_request_get_output_headers(event_request_);
56   resp_buf_ = evhttp_request_get_output_buffer(event_request_);
57   MS_EXCEPTION_IF_NULL(head_params_);
58   MS_EXCEPTION_IF_NULL(resp_headers_);
59   MS_EXCEPTION_IF_NULL(resp_buf_);
60 }
61 
GetHeadParam(const std::string & key) const62 std::string HttpMessageHandler::GetHeadParam(const std::string &key) const {
63   MS_EXCEPTION_IF_NULL(head_params_);
64   const char *val = evhttp_find_header(head_params_, key.c_str());
65   MS_EXCEPTION_IF_NULL(val);
66   return std::string(val);
67 }
68 
GetPathParam(const std::string & key) const69 std::string HttpMessageHandler::GetPathParam(const std::string &key) const {
70   const char *val = evhttp_find_header(&path_params_, key.c_str());
71   MS_EXCEPTION_IF_NULL(val);
72   return std::string(val);
73 }
74 
ParsePostParam()75 void HttpMessageHandler::ParsePostParam() {
76   MS_EXCEPTION_IF_NULL(event_request_);
77   size_t len = evbuffer_get_length(event_request_->input_buffer);
78   if (len == 0) {
79     MS_LOG(EXCEPTION) << "The post parameter size is: " << len;
80   }
81   post_param_parsed_ = true;
82   const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1));
83   MS_EXCEPTION_IF_NULL(post_message);
84   post_message_ = std::make_unique<std::string>(post_message, len);
85   MS_EXCEPTION_IF_NULL(post_message_);
86   int ret = evhttp_parse_query_str(post_message_->c_str(), &post_params_);
87   if (ret == -1) {
88     MS_LOG(EXCEPTION) << "Parse post parameter failed!";
89   }
90 }
91 
ParsePostMessageToJson()92 RequestProcessResult HttpMessageHandler::ParsePostMessageToJson() {
93   MS_EXCEPTION_IF_NULL(event_request_);
94   RequestProcessResult result(RequestProcessResultCode::kSuccess);
95   std::string message;
96 
97   size_t len = evbuffer_get_length(event_request_->input_buffer);
98   if (len == 0) {
99     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message size is invalid.");
100     return result;
101   } else if (len > kMaxMessageSize) {
102     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message is bigger than 100mb.");
103     return result;
104   } else {
105     message.resize(len);
106     auto buffer = evbuffer_pullup(event_request_->input_buffer, -1);
107     if (buffer == nullptr) {
108       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Get http post message failed.");
109       return result;
110     }
111     size_t dest_size = len;
112     size_t src_size = len;
113     if (memcpy_s(message.data(), dest_size, buffer, src_size) != EOK) {
114       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Copy message failed.");
115       return result;
116     }
117 
118     try {
119       request_message_ = nlohmann::json::parse(message);
120     } catch (nlohmann::json::exception &e) {
121       std::string illegal_exception = e.what();
122       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Illegal JSON format:" + illegal_exception);
123       return result;
124     }
125   }
126   return result;
127 }
128 
GetPostParam(const std::string & key)129 std::string HttpMessageHandler::GetPostParam(const std::string &key) {
130   if (!post_param_parsed_) {
131     ParsePostParam();
132   }
133 
134   const char *val = evhttp_find_header(&post_params_, key.c_str());
135   MS_EXCEPTION_IF_NULL(val);
136   return std::string(val);
137 }
138 
GetRequestUri() const139 std::string HttpMessageHandler::GetRequestUri() const {
140   MS_EXCEPTION_IF_NULL(event_request_);
141   const char *uri = evhttp_request_get_uri(event_request_);
142   MS_EXCEPTION_IF_NULL(uri);
143   return std::string(uri);
144 }
145 
GetRequestHost()146 std::string HttpMessageHandler::GetRequestHost() {
147   MS_EXCEPTION_IF_NULL(event_request_);
148   const char *host = evhttp_request_get_host(event_request_);
149   MS_EXCEPTION_IF_NULL(host);
150   return std::string(host);
151 }
152 
GetHostByUri() const153 const char *HttpMessageHandler::GetHostByUri() const {
154   MS_EXCEPTION_IF_NULL(event_uri_);
155   const char *host = evhttp_uri_get_host(event_uri_);
156   MS_EXCEPTION_IF_NULL(host);
157   return host;
158 }
159 
GetUriPort() const160 int HttpMessageHandler::GetUriPort() const {
161   MS_EXCEPTION_IF_NULL(event_uri_);
162   int port = evhttp_uri_get_port(event_uri_);
163   if (port < 0) {
164     MS_LOG(EXCEPTION) << "The port:" << port << " should not be less than 0!";
165   }
166   return port;
167 }
168 
GetUriPath() const169 std::string HttpMessageHandler::GetUriPath() const {
170   MS_EXCEPTION_IF_NULL(event_uri_);
171   const char *path = evhttp_uri_get_path(event_uri_);
172   MS_EXCEPTION_IF_NULL(path);
173   return std::string(path);
174 }
175 
GetRequestPath()176 std::string HttpMessageHandler::GetRequestPath() {
177   MS_EXCEPTION_IF_NULL(event_uri_);
178   const char *path = evhttp_uri_get_path(event_uri_);
179   if (path == nullptr || strlen(path) == 0) {
180     path = "/";
181   }
182   std::string path_res(path);
183   const char *query = evhttp_uri_get_query(event_uri_);
184   if (query != nullptr) {
185     path_res.append("?");
186     path_res.append(query);
187   }
188   return path_res;
189 }
190 
GetUriQuery() const191 std::string HttpMessageHandler::GetUriQuery() const {
192   MS_EXCEPTION_IF_NULL(event_uri_);
193   const char *query = evhttp_uri_get_query(event_uri_);
194   MS_EXCEPTION_IF_NULL(query);
195   return std::string(query);
196 }
197 
GetUriFragment() const198 std::string HttpMessageHandler::GetUriFragment() const {
199   MS_EXCEPTION_IF_NULL(event_uri_);
200   const char *fragment = evhttp_uri_get_fragment(event_uri_);
201   MS_EXCEPTION_IF_NULL(fragment);
202   return std::string(fragment);
203 }
204 
GetPostMsg(size_t * len,uint8_t ** buffer)205 bool HttpMessageHandler::GetPostMsg(size_t *len, uint8_t **buffer) {
206   MS_EXCEPTION_IF_NULL(event_request_);
207   if (len == nullptr || buffer == nullptr) {
208     MS_LOG(ERROR) << "Input parameter len or buffer cannot be nullptr";
209     return false;
210   }
211   *len = evbuffer_get_length(event_request_->input_buffer);
212   const size_t max_http_bytes_len = UINT32_MAX;  // 4GB
213   if (*len == 0 || *len > max_http_bytes_len) {
214     MS_LOG(ERROR) << "The post message length " << *len << " is invalid!";
215     return false;
216   }
217   *buffer = evbuffer_pullup(event_request_->input_buffer, -1);
218   if (*buffer == nullptr) {
219     MS_LOG(ERROR) << "Failed to pull post message buffer!";
220     return false;
221   }
222   return true;
223 }
224 
AddRespHeadParam(const std::string & key,const std::string & val)225 void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) {
226   MS_EXCEPTION_IF_NULL(resp_headers_);
227   if (evhttp_add_header(resp_headers_, key.c_str(), val.c_str()) != 0) {
228     MS_LOG(EXCEPTION) << "Add parameter of response header failed.";
229   }
230 }
231 
AddRespHeaders(const HttpHeaders & headers)232 void HttpMessageHandler::AddRespHeaders(const HttpHeaders &headers) {
233   for (auto iter = headers.begin(); iter != headers.end(); ++iter) {
234     auto list = iter->second;
235     for (auto iterator_val = list.begin(); iterator_val != list.end(); ++iterator_val) {
236       AddRespHeadParam(iter->first, *iterator_val);
237     }
238   }
239 }
240 
AddRespString(const std::string & str)241 void HttpMessageHandler::AddRespString(const std::string &str) {
242   MS_EXCEPTION_IF_NULL(resp_buf_);
243   if (evbuffer_add_printf(resp_buf_, "%s", str.c_str()) == -1) {
244     MS_LOG(EXCEPTION) << "Add string to response body failed.";
245   }
246 }
247 
SetRespCode(int code)248 void HttpMessageHandler::SetRespCode(int code) { resp_code_ = code; }
249 
SendResponse()250 void HttpMessageHandler::SendResponse() {
251   MS_EXCEPTION_IF_NULL(event_request_);
252   MS_EXCEPTION_IF_NULL(resp_buf_);
253   evhttp_send_reply(event_request_, resp_code_, "Client", resp_buf_);
254 }
255 
QuickResponse(int code,const void * body,size_t len)256 void HttpMessageHandler::QuickResponse(int code, const void *body, size_t len) {
257   MS_EXCEPTION_IF_NULL(event_request_);
258   MS_EXCEPTION_IF_NULL(body);
259   MS_EXCEPTION_IF_NULL(resp_buf_);
260   auto ret = evbuffer_add(resp_buf_, body, len);
261   if (ret == -1) {
262     MS_LOG(WARNING) << "Add body to response body failed.";
263     return;
264   }
265   evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
266 }
267 
QuickResponseInference(int code,const void * body,size_t len,evbuffer_ref_cleanup_cb cb)268 void HttpMessageHandler::QuickResponseInference(int code, const void *body, size_t len, evbuffer_ref_cleanup_cb cb) {
269   MS_EXCEPTION_IF_NULL(event_request_);
270   MS_EXCEPTION_IF_NULL(body);
271   MS_EXCEPTION_IF_NULL(resp_buf_);
272   auto ret = evbuffer_add_reference(resp_buf_, body, len, cb, nullptr);
273   if (ret == -1) {  // -1 if an error occurred
274     MS_LOG(WARNING) << "Add body to response body failed.";
275     if (cb != nullptr) {
276       cb(body, len, nullptr);
277     }
278     return;
279   }
280   evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
281 }
282 
SimpleResponse(int code,const HttpHeaders & headers,const std::string & body)283 void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, const std::string &body) {
284   MS_EXCEPTION_IF_NULL(event_request_);
285   MS_EXCEPTION_IF_NULL(resp_buf_);
286   AddRespHeaders(headers);
287   AddRespString(body);
288   evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
289 }
290 
ErrorResponse(int code,const RequestProcessResult & result)291 void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) {
292   nlohmann::json error_json = {{"error_message", result.StatusMessage()}, {"code", kErrorCode}};
293   std::string out_error = error_json.dump();
294   AddRespString(out_error);
295   SetRespCode(code);
296   SendResponse();
297 }
298 
RespError(int nCode,const std::string & message)299 void HttpMessageHandler::RespError(int nCode, const std::string &message) {
300   MS_EXCEPTION_IF_NULL(event_request_);
301   if (message.empty()) {
302     evhttp_send_error(event_request_, nCode, nullptr);
303   } else {
304     evhttp_send_error(event_request_, nCode, message.c_str());
305   }
306 }
307 
ReceiveMessage(const void * buffer,size_t num)308 void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
309   MS_EXCEPTION_IF_NULL(buffer);
310   MS_EXCEPTION_IF_NULL(body_);
311   size_t dest_size = num;
312   size_t src_size = num;
313   int ret = memcpy_s(body_->data() + offset_, dest_size, buffer, src_size);
314   if (ret != EOK) {
315     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
316   }
317   offset_ += num;
318 }
319 
set_content_len(const uint64_t & len)320 void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
321 
content_len() const322 uint64_t HttpMessageHandler::content_len() const { return content_len_; }
323 
http_base() const324 const event_base *HttpMessageHandler::http_base() const { return event_base_; }
325 
set_http_base(const struct event_base * base)326 void HttpMessageHandler::set_http_base(const struct event_base *base) {
327   MS_EXCEPTION_IF_NULL(base);
328   event_base_ = const_cast<event_base *>(base);
329 }
330 
set_request(const struct evhttp_request * req)331 void HttpMessageHandler::set_request(const struct evhttp_request *req) {
332   MS_EXCEPTION_IF_NULL(req);
333   event_request_ = const_cast<evhttp_request *>(req);
334 }
335 
request() const336 const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; }
337 
InitBodySize()338 void HttpMessageHandler::InitBodySize() {
339   MS_EXCEPTION_IF_NULL(body_);
340   body_->resize(content_len());
341 }
342 
body()343 std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
344 
set_body(const std::shared_ptr<std::vector<char>> & body)345 void HttpMessageHandler::set_body(const std::shared_ptr<std::vector<char>> &body) { body_ = body; }
346 
request_message() const347 nlohmann::json HttpMessageHandler::request_message() const { return request_message_; }
348 
ParseValueFromKey(const std::string & key,uint32_t * const value)349 RequestProcessResult HttpMessageHandler::ParseValueFromKey(const std::string &key, uint32_t *const value) {
350   MS_EXCEPTION_IF_NULL(value);
351   RequestProcessResult result(RequestProcessResultCode::kSuccess);
352   if (!request_message_.contains(key)) {
353     std::string message = "The json is not contain the key:" + key;
354     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
355     return result;
356   }
357 
358   int32_t res = IntToUint(request_message_.at(key));
359   if (res < 0) {
360     std::string message = "The value should not be less than 0.";
361     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
362     return result;
363   }
364 
365   if (res > 0 && key == kWorkerNum) {
366     std::string message = "The Worker does not currently support scale out.";
367     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
368     return result;
369   }
370 
371   *value = res;
372   return result;
373 }
374 
ParseNodeIdsFromKey(const std::string & key,std::vector<std::string> * const value)375 RequestProcessResult HttpMessageHandler::ParseNodeIdsFromKey(const std::string &key,
376                                                              std::vector<std::string> *const value) {
377   MS_EXCEPTION_IF_NULL(value);
378   RequestProcessResult result(RequestProcessResultCode::kSuccess);
379   if (!request_message_.contains(key)) {
380     std::string message = "The json is not contain the key:" + key;
381     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
382     return result;
383   }
384   auto res = request_message_.at(key).get<std::vector<std::string>>();
385   for (const auto &val : res) {
386     MS_LOG(INFO) << "The node id is:" << val;
387     (*value).push_back(val);
388   }
389   return result;
390 }
391 }  // namespace core
392 }  // namespace ps
393 }  // namespace mindspore
394