1 /*
2 * Copyright (c) 2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #define LOG printf
16 #define ERROR printf
17 #include "proxy_server.h"
18 #include "securec.h"
19 #include "sys/time.h"
20 #include <algorithm>
21 #include <arpa/inet.h>
22 #include <cerrno>
23 #include <csignal>
24 #include <cstring>
25 #include <fcntl.h>
26 #include <iostream>
27 #include <netdb.h>
28 #include <poll.h>
29 #include <random>
30 #include <sstream>
31 #include <sys/socket.h>
32 #include <thread>
33 #include <unistd.h>
34 #include <vector>
35
36 #define PRINT_RED_FMT_LN(fmt, ...) printf("\033[31m" fmt "\n\033[0m", ##__VA_ARGS__)
37 #define NUM_4 4
38 #define NUM_10 10
39 #define TIME_OUT 30
40 #define PORT_8080 8080
41 #define PORT_80 80
42 #define PORT_1080 1080
43 #define PORT_443 443
44 using namespace OHOS::NetManagerStandard;
45
46 std::map<std::string, std::string> ProxyServer::pacScripts;
47 std::string ProxyServer::proxServerTargetUrl;
48 int32_t ProxyServer::proxServerPort;
49
ProxyServer(int32_t port,int32_t numThreads)50 ProxyServer::ProxyServer(int32_t port, int32_t numThreads)
51 : port_(port), serverSocket_(-1), numThreads_(numThreads), running_(false)
52 {
53 if (numThreads_ <= 0) {
54 numThreads_ = std::thread::hardware_concurrency();
55 if (numThreads_ <= 0) {
56 numThreads_ = NUM_4;
57 }
58 }
59 pacScripts = {
60 {LOCAL_PROXY_9000,
61 "function FindProxyForURL(url, host) {\n"
62 " return \"PROXY 127.0.0.1:9000\";\n"
63 "}"},
64 {LOCAL_PROXY_9001,
65 "function FindProxyForURL(url, host) {\n"
66 " return \"PROXY 127.0.0.1:9001\";\n"
67 "}"},
68 {ALL_DIRECT,
69 "function FindProxyForURL(url, host) {\n"
70 " return \"PROXY 127.0.0.1:9000;PROXY 127.0.0.1:9001; DIRECT\";\n"
71 "}"},
72 };
73 if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
74 perror("signal");
75 }
76 ResetStats();
77 }
78
~ProxyServer()79 ProxyServer::~ProxyServer()
80 {
81 Stop();
82 }
83
SetFindPacProxyFunction(std::function<std::string (std::string,std::string)> pac)84 void ProxyServer::SetFindPacProxyFunction(std::function<std::string(std::string, std::string)> pac)
85 {
86 pacFunction_ = pac;
87 }
88
Start()89 bool ProxyServer::Start()
90 {
91 if (running_) {
92 PRINT_RED_FMT_LN("server is runing \n");
93 return false;
94 }
95 serverSocket_ = socket(AF_INET, SOCK_STREAM, 0);
96 if (serverSocket_ < 0) {
97 PRINT_RED_FMT_LN("create socket fail \n");
98 return false;
99 }
100 int32_t opt = 1;
101 if (setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
102 PRINT_RED_FMT_LN("SO_REUSEADDR fail \n");
103 close(serverSocket_);
104 serverSocket_ = -1;
105 return false;
106 }
107 int32_t flags = fcntl(serverSocket_, F_GETFL, 0);
108 if (flags < 0 || fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK) < 0) {
109 printf("O_NONBLOCK fail \n");
110 close(serverSocket_);
111 serverSocket_ = -1;
112 return false;
113 }
114 struct sockaddr_in serverAddr;
115 serverAddr.sin_family = AF_INET;
116 serverAddr.sin_addr.s_addr = INADDR_ANY;
117 serverAddr.sin_port = htons(port_);
118 if (bind(serverSocket_, reinterpret_cast<const sockaddr *>(&serverAddr), sizeof(serverAddr)) < 0) {
119 close(serverSocket_);
120 serverSocket_ = -1;
121 printf("bind port %d fail \n", port_);
122 return false;
123 }
124 if (listen(serverSocket_, backlog) < 0) {
125 close(serverSocket_);
126 serverSocket_ = -1;
127 printf("listen port %d fail \n", port_);
128 return false;
129 }
130 LOG("run localserver on port:%d threads: %d \n", port_, numThreads_);
131 running_ = true;
132 ResetStats();
133 for (int32_t i = 0; i < numThreads_; i++) {
134 workers_.push_back(std::thread(&ProxyServer::WorkerThread, this));
135 }
136 acceptThread_ = std::thread(&ProxyServer::AcceptLoop, this);
137 return true;
138 }
139
Stop()140 void ProxyServer::Stop()
141 {
142 if (!running_) {
143 return;
144 }
145 running_ = false;
146 queueCondition_.notify_all();
147 if (acceptThread_.joinable()) {
148 acceptThread_.join();
149 }
150 for (auto &worker : workers_) {
151 if (worker.joinable()) {
152 worker.join();
153 }
154 }
155 workers_.clear();
156 {
157 std::lock_guard<std::mutex> lock(queueMutex_);
158 while (!taskQueue_.empty()) {
159 close(taskQueue_.front().clientSocket);
160 taskQueue_.pop();
161 }
162 }
163 if (serverSocket_ >= 0) {
164 close(serverSocket_);
165 serverSocket_ = -1;
166 }
167 }
168
IsRunning() const169 bool ProxyServer::IsRunning() const
170 {
171 return running_;
172 }
173
GetStats()174 std::shared_ptr<Stats> ProxyServer::GetStats()
175 {
176 std::lock_guard<std::mutex> lock(statsMutex_);
177 return stats_;
178 }
179
ResetStats()180 void ProxyServer::ResetStats()
181 {
182 std::lock_guard<std::mutex> lock(statsMutex_);
183 stats_ = std::make_shared<Stats>();
184 stats_->startTime = std::chrono::steady_clock::now();
185 }
186
GetThroughput() const187 double ProxyServer::GetThroughput() const
188 {
189 std::lock_guard<std::mutex> lock(statsMutex_);
190 auto now = std::chrono::steady_clock::now();
191 auto duration = std::chrono::duration_cast<std::chrono::seconds>(now - stats_->startTime).count();
192 if (duration <= 0) {
193 return 0.0;
194 }
195 uint64_t totalBytes = stats_->bytesReceived + stats_->bytesSent;
196 return static_cast<double>(totalBytes) / duration;
197 }
198
GetRequestMethod(const std::string & header)199 std::string ProxyServer::GetRequestMethod(const std::string &header)
200 {
201 size_t spacePos = header.find(' ');
202 if (spacePos == std::string::npos) {
203 return "";
204 }
205 return header.substr(0, spacePos);
206 }
207
ParseConnectRequest(const std::string & header,std::string & host,int32_t & port)208 bool ProxyServer::ParseConnectRequest(const std::string &header, std::string &host, int32_t &port)
209 {
210 size_t methodEnd = header.find(' ');
211 if (methodEnd == std::string::npos) {
212 return false;
213 }
214 size_t hostStart = methodEnd + 1;
215 size_t hostEnd = header.find(' ', hostStart);
216 if (hostEnd == std::string::npos) {
217 return false;
218 }
219 std::string hostPort = header.substr(hostStart, hostEnd - hostStart);
220 size_t colonPos = hostPort.find(':');
221 if (colonPos != std::string::npos) {
222 host = hostPort.substr(0, colonPos);
223 port = std::stoi(hostPort.substr(colonPos + 1));
224 } else {
225 host = hostPort;
226 port = PORT_443;
227 }
228 return true;
229 }
230
ParseHttpRequest(const std::string & header,std::string & host,int32_t & port)231 bool ProxyServer::ParseHttpRequest(const std::string &header, std::string &host, int32_t &port)
232 {
233 size_t hostPos = header.find("Host: ");
234 if (hostPos == std::string::npos) {
235 return false;
236 }
237 size_t hostEnd = header.find("\r\n", hostPos);
238 if (hostEnd == std::string::npos) {
239 return false;
240 }
241 std::string hostLine = header.substr(hostPos + 6, hostEnd - hostPos - 6);
242 size_t colonPos = hostLine.find(':');
243 if (colonPos != std::string::npos) {
244 host = hostLine.substr(0, colonPos);
245 port = std::stoi(hostLine.substr(colonPos + 1));
246 } else {
247 host = hostLine;
248 port = PORT_80;
249 }
250 return true;
251 }
252
HandlePollError(int32_t ret,int32_t errnoVal)253 static bool HandlePollError(int32_t ret, int32_t errnoVal)
254 {
255 if (ret < 0) {
256 if (errnoVal == EINTR) {
257 return false;
258 }
259 std::cerr << "Poll失败: " << strerror(errnoVal) << std::endl;
260 return true;
261 }
262 return false;
263 }
264
CheckPollHupOrErr(const struct pollfd * fds)265 static bool CheckPollHupOrErr(const struct pollfd *fds)
266 {
267 return (fds[0].revents & (POLLHUP | POLLERR)) || (fds[1].revents & (POLLHUP | POLLERR));
268 }
269
TransferData(int32_t srcFd,int32_t dstFd,char * buffer,size_t bufferSize,std::shared_ptr<Stats> stats)270 static bool TransferData(int32_t srcFd, int32_t dstFd, char *buffer, size_t bufferSize, std::shared_ptr<Stats> stats)
271 {
272 int32_t n = recv(srcFd, buffer, bufferSize, 0);
273 if (n <= 0) {
274 return true;
275 }
276 stats->bytesReceived += n;
277 if (send(dstFd, buffer, n, 0) <= 0) {
278 return true;
279 }
280 stats->bytesSent += n;
281 return false;
282 }
283
TunnelData(int32_t client,int32_t server)284 void ProxyServer::TunnelData(int32_t client, int32_t server)
285 {
286 struct pollfd fds[2];
287 fds[0].fd = client;
288 fds[0].events = POLLIN;
289 fds[1].fd = server;
290 fds[1].events = POLLIN;
291 char buffer[bufferSize];
292 bool clientClosed = false;
293 bool serverClosed = false;
294 while (!clientClosed && !serverClosed && running_) {
295 int32_t ret = poll(fds, 2, 1000);
296 if (HandlePollError(ret, errno)) {
297 break;
298 }
299 if (ret == 0) {
300 continue;
301 }
302 if (CheckPollHupOrErr(fds)) {
303 break;
304 }
305 if (fds[0].revents & POLLIN) {
306 clientClosed = TransferData(client, server, buffer, bufferSize, stats_);
307 if (clientClosed) {
308 continue;
309 }
310 }
311 if (fds[1].revents & POLLIN) {
312 serverClosed = TransferData(server, client, buffer, bufferSize, stats_);
313 }
314 }
315 }
ReceiveResponseHeader(int32_t socket)316 std::string ProxyServer::ReceiveResponseHeader(int32_t socket)
317 {
318 std::string header;
319 char buffer[bufferSize];
320
321 int bytesRead;
322 while ((bytesRead = recv(socket, buffer, sizeof(buffer) - 1, 0)) > 0) {
323 buffer[bytesRead] = '\0';
324 header.append(buffer);
325 if (header.find("\r\n\r\n") != std::string::npos) {
326 break;
327 }
328 if (header.size() > maxHeaderSize) {
329 break;
330 }
331 }
332 return header;
333 }
334
IsPortAvailable(int32_t port)335 bool ProxyServer::IsPortAvailable(int32_t port)
336 {
337 int sockfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
338 if (sockfd < 0) {
339 return false;
340 }
341
342 int optval = 1;
343 if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&optval), sizeof(optval)) < 0) {
344 close(sockfd);
345 return false;
346 }
347
348 struct sockaddr_in serverAddr;
349 serverAddr.sin_family = AF_INET;
350 serverAddr.sin_port = htons(static_cast<uint16_t>(port));
351 serverAddr.sin_addr.s_addr = inet_addr("127.0.0.1");
352
353 int bindResult = bind(sockfd, reinterpret_cast<struct sockaddr *>(&serverAddr), sizeof(serverAddr));
354 close(sockfd);
355 return (bindResult == 0);
356 }
357
FindAvailablePort(int32_t startPort,int32_t endPort)358 int ProxyServer::FindAvailablePort(int32_t startPort, int32_t endPort)
359 {
360 std::vector<int> portsToTry;
361 for (int port = startPort; port <= endPort; ++port) {
362 portsToTry.push_back(port);
363 }
364
365 unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
366 std::shuffle(portsToTry.begin(), portsToTry.end(), std::default_random_engine(seed));
367
368 for (int port : portsToTry) {
369 if (IsPortAvailable(port)) {
370 return port;
371 }
372 }
373
374 return -1;
375 }
376
ConnectToServer(const std::string & host,int port)377 int ProxyServer::ConnectToServer(const std::string &host, int port)
378 {
379 int serverSocket = socket(AF_INET, SOCK_STREAM, 0);
380 if (serverSocket < 0) {
381 std::cerr << "无法创建socket" << std::endl;
382 return -1;
383 }
384
385 struct timeval timeout;
386 timeout.tv_sec = NUM_10;
387 timeout.tv_usec = 0;
388
389 if (setsockopt(serverSocket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0 ||
390 setsockopt(serverSocket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
391 std::cerr << "设置socket超时选项失败" << std::endl;
392 close(serverSocket);
393 return -1;
394 }
395
396 struct hostent *server = gethostbyname(host.c_str());
397 if (server == nullptr) {
398 std::cerr << "无法解析主机: " << host << std::endl;
399 close(serverSocket);
400 return -1;
401 }
402
403 struct sockaddr_in serverAddr;
404 serverAddr.sin_family = AF_INET;
405 serverAddr.sin_port = htons(port);
406 memcpy_s(&serverAddr.sin_addr.s_addr, sizeof(serverAddr.sin_addr.s_addr), server->h_addr, server->h_length);
407
408 if (connect(serverSocket, reinterpret_cast<const sockaddr *>(&serverAddr), sizeof(serverAddr)) < 0) {
409 std::cerr << "无法连接到目标服务器 " << host << ":" << port << std::endl;
410 close(serverSocket);
411 return -1;
412 }
413
414 return serverSocket;
415 }
416
ConnectViaUpstreamProxy(const std::string & targetHost,int targetPort,const std::string & originalRequest,std::string proxyHost,int proxyPort)417 int ProxyServer::ConnectViaUpstreamProxy(const std::string &targetHost, int targetPort,
418 const std::string &originalRequest, std::string proxyHost, int proxyPort)
419 {
420 int proxySocket = ConnectToServer(proxyHost, proxyPort);
421 if (proxySocket < 0) {
422 ERROR("connect upstream proxy fail %s:%d\n", proxyHost.c_str(), proxyPort);
423 return -1;
424 }
425 printf("ConnectToServer %s : %d \n", proxyHost.c_str(), proxyPort);
426 printf("originalRequest ########## \n %s\n###########\n", originalRequest.c_str());
427 if (send(proxySocket, originalRequest.c_str(), originalRequest.length(), 0) < 0) {
428 close(proxySocket);
429 return -1;
430 }
431
432 stats_->bytesSent += originalRequest.length();
433
434 return proxySocket;
435 }
436
ConnectViaUpstreamProxyHttps(const std::string & targetHost,int targetPort,std::string proxyHost,int proxyPort)437 int ProxyServer::ConnectViaUpstreamProxyHttps(const std::string &targetHost, int targetPort, std::string proxyHost,
438 int proxyPort)
439 {
440 int proxySocket = ConnectToServer(proxyHost, proxyPort);
441 if (proxySocket < 0) {
442 ERROR("connect upproxy fail %s:%d fail", proxyHost.c_str(), proxyPort);
443 return -1;
444 }
445 std::ostringstream connectRequest;
446 connectRequest << "CONNECT " << targetHost << ":" << targetPort << " HTTP/1.1\r\n"
447 << "Host: " << targetHost << ":" << targetPort << "\r\n"
448 << "Proxy-Connection: Keep-Alive\r\n"
449 << "\r\n";
450 std::string requestStr = connectRequest.str();
451 if (send(proxySocket, requestStr.c_str(), requestStr.length(), 0) < 0) {
452 ERROR("send CONNECT to upstream proxy fail \n");
453 close(proxySocket);
454 return -1;
455 }
456 stats_->bytesSent += requestStr.length();
457 std::string response = ReceiveResponseHeader(proxySocket);
458 if (response.empty() || response.find("HTTP/1.1 200") == std::string::npos) {
459 std::cerr << "上游代理拒绝CONNECT请求: " << response << std::endl;
460 close(proxySocket);
461 return -1;
462 }
463 stats_->bytesReceived += response.length();
464 return proxySocket;
465 }
466
GetRequestUrl(const std::string & header)467 std::string ProxyServer::GetRequestUrl(const std::string &header)
468 {
469 std::string method = GetRequestMethod(header);
470 std::string url;
471 if (method == "CONNECT") {
472 size_t methodEnd = header.find(' ');
473 if (methodEnd == std::string::npos) {
474 return "";
475 }
476 size_t hostStart = methodEnd + 1;
477 size_t hostEnd = header.find(' ', hostStart);
478 if (hostEnd == std::string::npos) {
479 return "";
480 }
481 std::string hostPort = header.substr(hostStart, hostEnd - hostStart);
482 url = "https://" + hostPort;
483 } else {
484 size_t methodEnd = header.find(' ');
485 if (methodEnd == std::string::npos) {
486 return "";
487 }
488 size_t pathStart = methodEnd + 1;
489 size_t pathEnd = header.find(' ', pathStart);
490 if (pathEnd == std::string::npos) {
491 return "";
492 }
493 std::string path = header.substr(pathStart, pathEnd - pathStart);
494 if (path.find("://") != std::string::npos) {
495 return path;
496 }
497 std::string host;
498 int port = PORT_80;
499 if (!ParseHttpRequest(header, host, port)) {
500 return path;
501 }
502 url = "http://";
503 url += host;
504 if (port != PORT_80) {
505 url += ":" + std::to_string(port);
506 }
507 if (!path.empty() && path[0] != '/') {
508 url += "/";
509 }
510 url += path;
511 }
512
513 return url;
514 }
515
ParseProxyInfo(std::string url,std::string host,std::string & proxyType,std::string & proxyHost,int32_t & proxyPort)516 bool ProxyServer::ParseProxyInfo(std::string url, std::string host, std::string &proxyType, std::string &proxyHost,
517 int32_t &proxyPort)
518 {
519 std::string pacScirpt;
520 if (pacFunction_) {
521 pacScirpt = pacFunction_(url, host);
522 }
523 if (pacScirpt.empty()) {
524 return false;
525 }
526 return ParsePacResult(pacScirpt, proxyType, proxyHost, proxyPort);
527 }
528
ParsePacResult(const std::string & pacResult,std::string & proxyType,std::string & proxyHost,int32_t & proxyPort)529 bool ProxyServer::ParsePacResult(const std::string &pacResult, std::string &proxyType, std::string &proxyHost,
530 int32_t &proxyPort)
531 {
532 proxyType = "";
533 proxyHost = "";
534 proxyPort = 0;
535 if (pacResult.empty()) {
536 return false;
537 }
538 std::istringstream stream(pacResult);
539 std::string rule;
540 while (std::getline(stream, rule, ';')) {
541 rule.erase(0, rule.find_first_not_of(" \t"));
542 rule.erase(rule.find_last_not_of(" \t") + 1);
543 if (rule.empty()) {
544 continue;
545 }
546 size_t spacePos = rule.find(' ');
547 if (spacePos == std::string::npos) {
548 proxyType = rule;
549 if (proxyType == "DIRECT") {
550 return true;
551 }
552 continue;
553 }
554 proxyType = rule.substr(0, spacePos);
555 size_t hostStart = rule.find_first_not_of(" \t", spacePos);
556 if (hostStart == std::string::npos) {
557 continue;
558 }
559 std::string hostPort = rule.substr(hostStart);
560 size_t colonPos = hostPort.find(':');
561 if (colonPos == std::string::npos) {
562 proxyHost = hostPort;
563 if (proxyType == "PROXY" || proxyType == "HTTP") {
564 proxyPort = PORT_8080;
565 } else if (proxyType == "SOCKS" || proxyType == "SOCKS5") {
566 proxyPort = PORT_1080;
567 } else if (proxyType == "SOCKS4") {
568 proxyPort = PORT_1080;
569 } else {
570 proxyPort = PORT_8080;
571 }
572 } else {
573 proxyHost = hostPort.substr(0, colonPos);
574 proxyPort = std::stoi(hostPort.substr(colonPos + 1));
575 }
576 return true;
577 }
578 return false;
579 }
580
AddHttpHeader(const std::string & httpMessage,const std::string & headerName,const std::string & headerValue)581 std::string AddHttpHeader(const std::string &httpMessage, const std::string &headerName, const std::string &headerValue)
582 {
583 size_t headersEnd = httpMessage.find("\r\n\r\n");
584 if (headersEnd == std::string::npos) {
585 headersEnd = httpMessage.length();
586 }
587 std::string newHeader = headerName + ": " + headerValue + "\r\n";
588 return httpMessage.substr(0, headersEnd) + newHeader + httpMessage.substr(headersEnd);
589 }
590
HandleClient(int32_t clientSocket)591 void ProxyServer::HandleClient(int32_t clientSocket)
592 {
593 stats_->activeConnections++;
594
595 // 设置超时
596 if (!SetSocketTimeout(clientSocket)) {
597 CleanupConnection(clientSocket);
598 return;
599 }
600
601 // 读取请求头
602 std::string requestHeader;
603 if (!ReadRequestHeader(clientSocket, requestHeader)) {
604 CleanupConnection(clientSocket);
605 return;
606 }
607
608 printf("\033[33mproxy server read client data %.32s \n\033[0m", requestHeader.c_str());
609
610 std::string method = GetRequestMethod(requestHeader);
611 std::string url = GetRequestUrl(requestHeader);
612
613 if (method == "CONNECT") {
614 HandleConnectRequest(clientSocket, requestHeader, url);
615 } else {
616 HandleHttpRequest(clientSocket, requestHeader, url);
617 }
618 }
619
SetSocketTimeout(int32_t socket)620 bool ProxyServer::SetSocketTimeout(int32_t socket)
621 {
622 struct timeval timeout;
623 timeout.tv_sec = TIME_OUT;
624 timeout.tv_usec = 0;
625
626 if (setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
627 return false;
628 }
629 return true;
630 }
631
ReadRequestHeader(int32_t clientSocket,std::string & requestHeader)632 bool ProxyServer::ReadRequestHeader(int32_t clientSocket, std::string &requestHeader)
633 {
634 char buffer[bufferSize];
635 int bytesReceived = 0;
636
637 while ((bytesReceived = recv(clientSocket, buffer, bufferSize - 1, 0)) > 0) {
638 buffer[bytesReceived] = '\0';
639 requestHeader += buffer;
640 stats_->bytesReceived += bytesReceived;
641
642 if (requestHeader.find("\r\n\r\n") != std::string::npos) {
643 break;
644 }
645
646 if (requestHeader.size() > maxHeaderSize) {
647 return false;
648 }
649 }
650
651 return bytesReceived > 0;
652 }
653
CleanupConnection(int32_t clientSocket)654 void ProxyServer::CleanupConnection(int32_t clientSocket)
655 {
656 close(clientSocket);
657 stats_->activeConnections--;
658 }
659
HandleConnectRequest(int clientSocket,const std::string & requestHeader,const std::string & url)660 void ProxyServer::HandleConnectRequest(int clientSocket, const std::string &requestHeader, const std::string &url)
661 {
662 stats_->httpsRequests++;
663 std::string host;
664 int port;
665 if (!ParseConnectRequest(requestHeader, host, port)) {
666 SendErrorResponse(clientSocket, "HTTP/1.1 400 Bad Request\r\n\r\n");
667 return;
668 }
669 proxServerTargetUrl = url;
670 proxServerPort = port_;
671 LOG("local proxy server port:%d url:%s host:%s port:%d \n", port_, url.c_str(), host.c_str(), port);
672 std::string header;
673 int serverSocket = EstablishServerConnection(url, host, port, "HTTPS", header);
674 if (serverSocket < 0) {
675 SendErrorResponse(clientSocket, "HTTP/1.1 502 Bad Gateway\r\n\r\n");
676 return;
677 }
678 const char *response = "HTTP/1.1 200 Connection Established\r\n\r\n";
679 int responseLen = strlen(response);
680 if (send(clientSocket, response, responseLen, 0) < 0) {
681 std::cerr << "发送CONNECT响应失败" << std::endl;
682 close(serverSocket);
683 CleanupConnection(clientSocket);
684 return;
685 }
686 stats_->bytesSent += responseLen;
687 TunnelData(clientSocket, serverSocket);
688 close(serverSocket);
689 CleanupConnection(clientSocket);
690 }
691
HandleHttpRequest(int32_t clientSocket,std::string & requestHeader,const std::string & url)692 void ProxyServer::HandleHttpRequest(int32_t clientSocket, std::string &requestHeader, const std::string &url)
693 {
694 stats_->httpRequests++;
695 std::string host;
696 int port;
697 if (!ParseHttpRequest(requestHeader, host, port)) {
698 std::cerr << "无法解析HTTP请求" << std::endl;
699 SendErrorResponse(clientSocket, "HTTP/1.1 400 Bad Request\r\n\r\n");
700 return;
701 }
702 LOG("\033[33mconnect to localport:%d %s %s \n\033[0m", port_, url.c_str(), host.c_str());
703 int serverSocket = EstablishServerConnection(url, host, port, "HTTP", requestHeader);
704 if (serverSocket < 0) {
705 SendErrorResponse(clientSocket, "HTTP/1.1 502 Bad Gateway\r\n\r\n");
706 return;
707 }
708 ForwardResponseToClient(clientSocket, serverSocket);
709 close(serverSocket);
710 CleanupConnection(clientSocket);
711 }
712
EstablishServerConnection(const std::string & url,const std::string & host,int32_t port,const std::string & requestType,std::string & requestHeader)713 int ProxyServer::EstablishServerConnection(const std::string &url, const std::string &host, int32_t port,
714 const std::string &requestType, std::string &requestHeader)
715 {
716 std::string proxyType;
717 std::string proxyHost;
718 int proxyPort;
719 bool useUpstreamProxy = ParseProxyInfo(url, host, proxyType, proxyHost, proxyPort);
720 int serverSocket = -1;
721 if (useUpstreamProxy && (proxyType == "PROXY" || proxyType == "HTTP")) {
722 if (requestType == "HTTPS") {
723 serverSocket = ConnectViaUpstreamProxyHttps(host, port, proxyHost, proxyPort);
724 } else {
725 serverSocket = ConnectViaUpstreamProxy(host, port, requestHeader, proxyHost, proxyPort);
726 }
727 } else if (requestType == "HTTP") {
728 requestHeader = AddHttpHeader(requestHeader, "Proxy-Port", std::to_string(port_));
729 printf("\033[33mnot proxy info direct send %.32s \n\033[0m", requestHeader.c_str());
730 serverSocket = ConnectToServer(host, port);
731 if (serverSocket >= 0) {
732 if (send(serverSocket, requestHeader.c_str(), requestHeader.length(), 0) < 0) {
733 close(serverSocket);
734 serverSocket = -1;
735 } else {
736 stats_->bytesSent += requestHeader.length();
737 }
738 }
739 } else {
740 serverSocket = ConnectToServer(host, port);
741 }
742 return serverSocket;
743 }
744
SendErrorResponse(int32_t clientSocket,const char * response)745 void ProxyServer::SendErrorResponse(int32_t clientSocket, const char *response)
746 {
747 send(clientSocket, response, strlen(response), 0);
748 CleanupConnection(clientSocket);
749 }
750
ForwardResponseToClient(int32_t clientSocket,int32_t serverSocket)751 void ProxyServer::ForwardResponseToClient(int32_t clientSocket, int32_t serverSocket)
752 {
753 char buffer[bufferSize];
754 int bytesReceived;
755
756 while ((bytesReceived = recv(serverSocket, buffer, bufferSize, 0)) > 0) {
757 stats_->bytesReceived += bytesReceived;
758
759 if (send(clientSocket, buffer, bytesReceived, 0) < 0) {
760 break;
761 }
762
763 stats_->bytesSent += bytesReceived;
764 }
765 }
AddTask(const ClientTask & task)766 void ProxyServer::AddTask(const ClientTask &task)
767 {
768 {
769 std::lock_guard<std::mutex> lock(queueMutex_);
770 taskQueue_.push(task);
771 }
772 queueCondition_.notify_one();
773 }
774
WorkerThread()775 void ProxyServer::WorkerThread()
776 {
777 while (running_) {
778 ClientTask task = {-1, {}};
779
780 {
781 std::unique_lock<std::mutex> lock(queueMutex_);
782 queueCondition_.wait(lock, [this] { return !taskQueue_.empty() || !running_; });
783
784 if (!running_ && taskQueue_.empty()) {
785 break;
786 }
787
788 if (!taskQueue_.empty()) {
789 task = taskQueue_.front();
790 taskQueue_.pop();
791 }
792 }
793
794 if (task.clientSocket >= 0) {
795 HandleClient(task.clientSocket);
796 }
797 }
798 }
799
AcceptLoop()800 void ProxyServer::AcceptLoop()
801 {
802 while (running_) {
803 struct pollfd fd;
804 fd.fd = serverSocket_;
805 fd.events = POLLIN;
806 int ret = poll(&fd, 1, 1000);
807 if (ret < 0) {
808 if (errno == EINTR) {
809 continue;
810 }
811 std::cerr << "Poll失败: " << strerror(errno) << std::endl;
812 break;
813 }
814 if (ret == 0) {
815 continue;
816 }
817 if (!(fd.revents & POLLIN)) {
818 continue;
819 }
820 struct sockaddr_in clientAddr;
821 socklen_t clientAddrLen = sizeof(clientAddr);
822 int clientSocket = accept(serverSocket_, reinterpret_cast<sockaddr *>(&clientAddr), &clientAddrLen);
823 if (clientSocket < 0) {
824 if (errno == EAGAIN || errno == EWOULDBLOCK) {
825 continue;
826 }
827 std::cerr << "接受连接失败: " << strerror(errno) << std::endl;
828 continue;
829 }
830 stats_->totalConnections++;
831 AddTask(ClientTask(clientSocket, clientAddr));
832 }
833 }
834