• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/communicator/http_message_handler.h"
18 
19 #include <event2/event.h>
20 #include <event2/buffer.h>
21 #include <event2/bufferevent.h>
22 #include <event2/bufferevent_compat.h>
23 #include <event2/http.h>
24 #include <event2/http_compat.h>
25 #include <event2/http_struct.h>
26 #include <event2/listener.h>
27 #include <event2/util.h>
28 
29 #include <fcntl.h>
30 #include <unistd.h>
31 #include <cstdio>
32 #include <cstdlib>
33 #include <cstring>
34 #include <string>
35 #include <functional>
36 
37 namespace mindspore {
38 namespace ps {
39 namespace core {
InitHttpMessage()40 void HttpMessageHandler::InitHttpMessage() {
41   MS_EXCEPTION_IF_NULL(event_request_);
42   event_uri_ = evhttp_request_get_evhttp_uri(event_request_);
43   MS_EXCEPTION_IF_NULL(event_uri_);
44 
45   const char *query = evhttp_uri_get_query(event_uri_);
46   if (query != nullptr) {
47     MS_LOG(WARNING) << "The query is:" << query;
48     int result = evhttp_parse_query_str(query, &path_params_);
49     if (result < 0) {
50       MS_LOG(ERROR) << "Http parse query:" << query << " failed.";
51     }
52   }
53 
54   head_params_ = evhttp_request_get_input_headers(event_request_);
55   resp_headers_ = evhttp_request_get_output_headers(event_request_);
56   resp_buf_ = evhttp_request_get_output_buffer(event_request_);
57   MS_EXCEPTION_IF_NULL(head_params_);
58   MS_EXCEPTION_IF_NULL(resp_headers_);
59   MS_EXCEPTION_IF_NULL(resp_buf_);
60 }
61 
GetHeadParam(const std::string & key) const62 std::string HttpMessageHandler::GetHeadParam(const std::string &key) const {
63   MS_EXCEPTION_IF_NULL(head_params_);
64   const char *val = evhttp_find_header(head_params_, key.c_str());
65   MS_EXCEPTION_IF_NULL(val);
66   return std::string(val);
67 }
68 
GetPathParam(const std::string & key) const69 std::string HttpMessageHandler::GetPathParam(const std::string &key) const {
70   const char *val = evhttp_find_header(&path_params_, key.c_str());
71   MS_EXCEPTION_IF_NULL(val);
72   return std::string(val);
73 }
74 
ParsePostParam()75 void HttpMessageHandler::ParsePostParam() {
76   MS_EXCEPTION_IF_NULL(event_request_);
77   size_t len = evbuffer_get_length(event_request_->input_buffer);
78   if (len == 0) {
79     MS_LOG(EXCEPTION) << "The post parameter size is: " << len;
80   }
81   post_param_parsed_ = true;
82   const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1));
83   MS_EXCEPTION_IF_NULL(post_message);
84   post_message_ = std::make_unique<std::string>(post_message, len);
85   MS_EXCEPTION_IF_NULL(post_message_);
86   int ret = evhttp_parse_query_str(post_message_->c_str(), &post_params_);
87   if (ret == -1) {
88     MS_LOG(EXCEPTION) << "Parse post parameter failed!";
89   }
90 }
91 
ParsePostMessageToJson()92 RequestProcessResult HttpMessageHandler::ParsePostMessageToJson() {
93   MS_EXCEPTION_IF_NULL(event_request_);
94   RequestProcessResult result(RequestProcessResultCode::kSuccess);
95   std::string message;
96 
97   size_t len = evbuffer_get_length(event_request_->input_buffer);
98   if (len == 0) {
99     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message size is invalid.");
100     return result;
101   } else if (len > kMaxMessageSize) {
102     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message is bigger than 100mb.");
103     return result;
104   } else {
105     message.resize(len);
106     auto buffer = evbuffer_pullup(event_request_->input_buffer, -1);
107     if (buffer == nullptr) {
108       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Get http post message failed.");
109       return result;
110     }
111     size_t dest_size = len;
112     size_t src_size = len;
113     if (memcpy_s(message.data(), dest_size, buffer, src_size) != EOK) {
114       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Copy message failed.");
115       return result;
116     }
117 
118     try {
119       request_message_ = nlohmann::json::parse(message);
120     } catch (nlohmann::json::exception &e) {
121       std::string illegal_exception = e.what();
122       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Illegal JSON format:" + illegal_exception);
123       return result;
124     }
125   }
126   return result;
127 }
128 
GetPostParam(const std::string & key)129 std::string HttpMessageHandler::GetPostParam(const std::string &key) {
130   if (!post_param_parsed_) {
131     ParsePostParam();
132   }
133 
134   const char *val = evhttp_find_header(&post_params_, key.c_str());
135   MS_EXCEPTION_IF_NULL(val);
136   return std::string(val);
137 }
138 
GetRequestUri() const139 std::string HttpMessageHandler::GetRequestUri() const {
140   MS_EXCEPTION_IF_NULL(event_request_);
141   const char *uri = evhttp_request_get_uri(event_request_);
142   MS_EXCEPTION_IF_NULL(uri);
143   return std::string(uri);
144 }
145 
GetRequestHost()146 std::string HttpMessageHandler::GetRequestHost() {
147   MS_EXCEPTION_IF_NULL(event_request_);
148   const char *host = evhttp_request_get_host(event_request_);
149   MS_EXCEPTION_IF_NULL(host);
150   return std::string(host);
151 }
152 
GetHostByUri() const153 const char *HttpMessageHandler::GetHostByUri() const {
154   MS_EXCEPTION_IF_NULL(event_uri_);
155   const char *host = evhttp_uri_get_host(event_uri_);
156   MS_EXCEPTION_IF_NULL(host);
157   return host;
158 }
159 
GetUriPort() const160 int HttpMessageHandler::GetUriPort() const {
161   MS_EXCEPTION_IF_NULL(event_uri_);
162   int port = evhttp_uri_get_port(event_uri_);
163   if (port < 0) {
164     MS_LOG(EXCEPTION) << "The port:" << port << " should not be less than 0!";
165   }
166   return port;
167 }
168 
GetUriPath() const169 std::string HttpMessageHandler::GetUriPath() const {
170   MS_EXCEPTION_IF_NULL(event_uri_);
171   const char *path = evhttp_uri_get_path(event_uri_);
172   MS_EXCEPTION_IF_NULL(path);
173   return std::string(path);
174 }
175 
GetRequestPath()176 std::string HttpMessageHandler::GetRequestPath() {
177   MS_EXCEPTION_IF_NULL(event_uri_);
178   const char *path = evhttp_uri_get_path(event_uri_);
179   if (path == nullptr || strlen(path) == 0) {
180     path = "/";
181   }
182   std::string path_res(path);
183   const char *query = evhttp_uri_get_query(event_uri_);
184   if (query != nullptr) {
185     path_res.append("?");
186     path_res.append(query);
187   }
188   return path_res;
189 }
190 
GetUriQuery() const191 std::string HttpMessageHandler::GetUriQuery() const {
192   MS_EXCEPTION_IF_NULL(event_uri_);
193   const char *query = evhttp_uri_get_query(event_uri_);
194   MS_EXCEPTION_IF_NULL(query);
195   return std::string(query);
196 }
197 
GetUriFragment() const198 std::string HttpMessageHandler::GetUriFragment() const {
199   MS_EXCEPTION_IF_NULL(event_uri_);
200   const char *fragment = evhttp_uri_get_fragment(event_uri_);
201   MS_EXCEPTION_IF_NULL(fragment);
202   return std::string(fragment);
203 }
204 
GetPostMsg(unsigned char ** buffer)205 uint64_t HttpMessageHandler::GetPostMsg(unsigned char **buffer) {
206   MS_EXCEPTION_IF_NULL(event_request_);
207   MS_EXCEPTION_IF_NULL(buffer);
208 
209   size_t len = evbuffer_get_length(event_request_->input_buffer);
210   if (len == 0) {
211     MS_LOG(EXCEPTION) << "The post message is empty!";
212   }
213   *buffer = evbuffer_pullup(event_request_->input_buffer, -1);
214   MS_EXCEPTION_IF_NULL(*buffer);
215   return len;
216 }
217 
AddRespHeadParam(const std::string & key,const std::string & val)218 void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) {
219   MS_EXCEPTION_IF_NULL(resp_headers_);
220   if (evhttp_add_header(resp_headers_, key.c_str(), val.c_str()) != 0) {
221     MS_LOG(EXCEPTION) << "Add parameter of response header failed.";
222   }
223 }
224 
AddRespHeaders(const HttpHeaders & headers)225 void HttpMessageHandler::AddRespHeaders(const HttpHeaders &headers) {
226   for (auto iter = headers.begin(); iter != headers.end(); ++iter) {
227     auto list = iter->second;
228     for (auto iterator_val = list.begin(); iterator_val != list.end(); ++iterator_val) {
229       AddRespHeadParam(iter->first, *iterator_val);
230     }
231   }
232 }
233 
AddRespString(const std::string & str)234 void HttpMessageHandler::AddRespString(const std::string &str) {
235   MS_EXCEPTION_IF_NULL(resp_buf_);
236   if (evbuffer_add_printf(resp_buf_, "%s", str.c_str()) == -1) {
237     MS_LOG(EXCEPTION) << "Add string to response body failed.";
238   }
239 }
240 
SetRespCode(int code)241 void HttpMessageHandler::SetRespCode(int code) { resp_code_ = code; }
242 
SendResponse()243 void HttpMessageHandler::SendResponse() {
244   MS_EXCEPTION_IF_NULL(event_request_);
245   MS_EXCEPTION_IF_NULL(resp_buf_);
246   evhttp_send_reply(event_request_, resp_code_, "Client", resp_buf_);
247 }
248 
QuickResponse(int code,const unsigned char * body,size_t len)249 void HttpMessageHandler::QuickResponse(int code, const unsigned char *body, size_t len) {
250   MS_EXCEPTION_IF_NULL(event_request_);
251   MS_EXCEPTION_IF_NULL(body);
252   MS_EXCEPTION_IF_NULL(resp_buf_);
253   if (evbuffer_add(resp_buf_, body, len) == -1) {
254     MS_LOG(EXCEPTION) << "Add body to response body failed.";
255   }
256   evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
257 }
258 
SimpleResponse(int code,const HttpHeaders & headers,const std::string & body)259 void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, const std::string &body) {
260   MS_EXCEPTION_IF_NULL(event_request_);
261   MS_EXCEPTION_IF_NULL(resp_buf_);
262   AddRespHeaders(headers);
263   AddRespString(body);
264   evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
265 }
266 
ErrorResponse(int code,const RequestProcessResult & result)267 void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) {
268   nlohmann::json error_json = {{"error_message", result.StatusMessage()}};
269   std::string out_error = error_json.dump();
270   AddRespString(out_error);
271   SetRespCode(code);
272   SendResponse();
273 }
274 
RespError(int nCode,const std::string & message)275 void HttpMessageHandler::RespError(int nCode, const std::string &message) {
276   MS_EXCEPTION_IF_NULL(event_request_);
277   if (message.empty()) {
278     evhttp_send_error(event_request_, nCode, nullptr);
279   } else {
280     evhttp_send_error(event_request_, nCode, message.c_str());
281   }
282 }
283 
ReceiveMessage(const void * buffer,size_t num)284 void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
285   MS_EXCEPTION_IF_NULL(buffer);
286   MS_EXCEPTION_IF_NULL(body_);
287   size_t dest_size = num;
288   size_t src_size = num;
289   int ret = memcpy_s(body_->data() + offset_, dest_size, buffer, src_size);
290   if (ret != 0) {
291     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
292   }
293   offset_ += num;
294 }
295 
set_content_len(const uint64_t & len)296 void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
297 
content_len() const298 uint64_t HttpMessageHandler::content_len() const { return content_len_; }
299 
http_base() const300 const event_base *HttpMessageHandler::http_base() const { return event_base_; }
301 
set_http_base(const struct event_base * base)302 void HttpMessageHandler::set_http_base(const struct event_base *base) {
303   MS_EXCEPTION_IF_NULL(base);
304   event_base_ = const_cast<event_base *>(base);
305 }
306 
set_request(const struct evhttp_request * req)307 void HttpMessageHandler::set_request(const struct evhttp_request *req) {
308   MS_EXCEPTION_IF_NULL(req);
309   event_request_ = const_cast<evhttp_request *>(req);
310 }
311 
request() const312 const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; }
313 
InitBodySize()314 void HttpMessageHandler::InitBodySize() {
315   MS_EXCEPTION_IF_NULL(body_);
316   body_->resize(content_len());
317 }
318 
body()319 std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
320 
set_body(const std::shared_ptr<std::vector<char>> & body)321 void HttpMessageHandler::set_body(const std::shared_ptr<std::vector<char>> &body) { body_ = body; }
322 
request_message() const323 nlohmann::json HttpMessageHandler::request_message() const { return request_message_; }
324 
ParseValueFromKey(const std::string & key,int32_t * const value)325 RequestProcessResult HttpMessageHandler::ParseValueFromKey(const std::string &key, int32_t *const value) {
326   MS_EXCEPTION_IF_NULL(value);
327   RequestProcessResult result(RequestProcessResultCode::kSuccess);
328   if (!request_message_.contains(key)) {
329     std::string message = "The json is not contain the key:" + key;
330     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
331     return result;
332   }
333 
334   int32_t res = request_message_.at(key);
335   if (res < 0) {
336     std::string message = "The value should not be less than 0.";
337     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
338     return result;
339   }
340 
341   if (res > 0 && key == kWorkerNum) {
342     std::string message = "The Worker does not currently support scale out.";
343     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
344     return result;
345   }
346 
347   *value = res;
348   return result;
349 }
350 
ParseNodeIdsFromKey(const std::string & key,std::vector<std::string> * const value)351 RequestProcessResult HttpMessageHandler::ParseNodeIdsFromKey(const std::string &key,
352                                                              std::vector<std::string> *const value) {
353   MS_EXCEPTION_IF_NULL(value);
354   RequestProcessResult result(RequestProcessResultCode::kSuccess);
355   if (!request_message_.contains(key)) {
356     std::string message = "The json is not contain the key:" + key;
357     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
358     return result;
359   }
360   auto res = request_message_.at(key).get<std::vector<std::string>>();
361   for (const auto &val : res) {
362     MS_LOG(INFO) << "The node id is:" << val;
363     (*value).push_back(val);
364   }
365   return result;
366 }
367 }  // namespace core
368 }  // namespace ps
369 }  // namespace mindspore
370