• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #define LOG printf
16 #define ERROR printf
17 #include "proxy_server.h"
18 #include "securec.h"
19 #include "sys/time.h"
20 #include <algorithm>
21 #include <arpa/inet.h>
22 #include <cerrno>
23 #include <csignal>
24 #include <cstring>
25 #include <fcntl.h>
26 #include <iostream>
27 #include <netdb.h>
28 #include <poll.h>
29 #include <random>
30 #include <sstream>
31 #include <sys/socket.h>
32 #include <thread>
33 #include <unistd.h>
34 #include <vector>
35 
36 #define PRINT_RED_FMT_LN(fmt, ...) printf("\033[31m" fmt "\n\033[0m", ##__VA_ARGS__)
37 #define NUM_4 4
38 #define NUM_10 10
39 #define TIME_OUT 30
40 #define PORT_8080 8080
41 #define PORT_80 80
42 #define PORT_1080 1080
43 #define PORT_443 443
44 using namespace OHOS::NetManagerStandard;
45 
46 std::map<std::string, std::string> ProxyServer::pacScripts;
47 std::string ProxyServer::proxServerTargetUrl;
48 int32_t ProxyServer::proxServerPort;
49 
ProxyServer(int32_t port,int32_t numThreads)50 ProxyServer::ProxyServer(int32_t port, int32_t numThreads)
51     : port_(port), serverSocket_(-1), numThreads_(numThreads), running_(false)
52 {
53     if (numThreads_ <= 0) {
54         numThreads_ = std::thread::hardware_concurrency();
55         if (numThreads_ <= 0) {
56             numThreads_ = NUM_4;
57         }
58     }
59     pacScripts = {
60         {LOCAL_PROXY_9000,
61          "function FindProxyForURL(url, host) {\n"
62          "    return \"PROXY 127.0.0.1:9000\";\n"
63          "}"},
64         {LOCAL_PROXY_9001,
65          "function FindProxyForURL(url, host) {\n"
66          "    return \"PROXY 127.0.0.1:9001\";\n"
67          "}"},
68         {ALL_DIRECT,
69          "function FindProxyForURL(url, host) {\n"
70          "    return \"PROXY 127.0.0.1:9000;PROXY 127.0.0.1:9001; DIRECT\";\n"
71          "}"},
72     };
73     if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
74         perror("signal");
75     }
76     ResetStats();
77 }
78 
~ProxyServer()79 ProxyServer::~ProxyServer()
80 {
81     Stop();
82 }
83 
SetFindPacProxyFunction(std::function<std::string (std::string,std::string)> pac)84 void ProxyServer::SetFindPacProxyFunction(std::function<std::string(std::string, std::string)> pac)
85 {
86     pacFunction_ = pac;
87 }
88 
Start()89 bool ProxyServer::Start()
90 {
91     if (running_) {
92         PRINT_RED_FMT_LN("server is runing \n");
93         return false;
94     }
95     serverSocket_ = socket(AF_INET, SOCK_STREAM, 0);
96     if (serverSocket_ < 0) {
97         PRINT_RED_FMT_LN("create socket fail \n");
98         return false;
99     }
100     int32_t opt = 1;
101     if (setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
102         PRINT_RED_FMT_LN("SO_REUSEADDR fail \n");
103         close(serverSocket_);
104         serverSocket_ = -1;
105         return false;
106     }
107     int32_t flags = fcntl(serverSocket_, F_GETFL, 0);
108     if (flags < 0 || fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK) < 0) {
109         printf("O_NONBLOCK fail \n");
110         close(serverSocket_);
111         serverSocket_ = -1;
112         return false;
113     }
114     struct sockaddr_in serverAddr;
115     serverAddr.sin_family = AF_INET;
116     serverAddr.sin_addr.s_addr = INADDR_ANY;
117     serverAddr.sin_port = htons(port_);
118     if (bind(serverSocket_, reinterpret_cast<const sockaddr *>(&serverAddr), sizeof(serverAddr)) < 0) {
119         close(serverSocket_);
120         serverSocket_ = -1;
121         printf("bind port %d fail \n", port_);
122         return false;
123     }
124     if (listen(serverSocket_, backlog) < 0) {
125         close(serverSocket_);
126         serverSocket_ = -1;
127         printf("listen port %d fail \n", port_);
128         return false;
129     }
130     LOG("run localserver on port:%d threads: %d \n", port_, numThreads_);
131     running_ = true;
132     ResetStats();
133     for (int32_t i = 0; i < numThreads_; i++) {
134         workers_.push_back(std::thread(&ProxyServer::WorkerThread, this));
135     }
136     acceptThread_ = std::thread(&ProxyServer::AcceptLoop, this);
137     return true;
138 }
139 
Stop()140 void ProxyServer::Stop()
141 {
142     if (!running_) {
143         return;
144     }
145     running_ = false;
146     queueCondition_.notify_all();
147     if (acceptThread_.joinable()) {
148         acceptThread_.join();
149     }
150     for (auto &worker : workers_) {
151         if (worker.joinable()) {
152             worker.join();
153         }
154     }
155     workers_.clear();
156     {
157         std::lock_guard<std::mutex> lock(queueMutex_);
158         while (!taskQueue_.empty()) {
159             close(taskQueue_.front().clientSocket);
160             taskQueue_.pop();
161         }
162     }
163     if (serverSocket_ >= 0) {
164         close(serverSocket_);
165         serverSocket_ = -1;
166     }
167 }
168 
IsRunning() const169 bool ProxyServer::IsRunning() const
170 {
171     return running_;
172 }
173 
GetStats()174 std::shared_ptr<Stats> ProxyServer::GetStats()
175 {
176     std::lock_guard<std::mutex> lock(statsMutex_);
177     return stats_;
178 }
179 
ResetStats()180 void ProxyServer::ResetStats()
181 {
182     std::lock_guard<std::mutex> lock(statsMutex_);
183     stats_ = std::make_shared<Stats>();
184     stats_->startTime = std::chrono::steady_clock::now();
185 }
186 
GetThroughput() const187 double ProxyServer::GetThroughput() const
188 {
189     std::lock_guard<std::mutex> lock(statsMutex_);
190     auto now = std::chrono::steady_clock::now();
191     auto duration = std::chrono::duration_cast<std::chrono::seconds>(now - stats_->startTime).count();
192     if (duration <= 0) {
193         return 0.0;
194     }
195     uint64_t totalBytes = stats_->bytesReceived + stats_->bytesSent;
196     return static_cast<double>(totalBytes) / duration;
197 }
198 
GetRequestMethod(const std::string & header)199 std::string ProxyServer::GetRequestMethod(const std::string &header)
200 {
201     size_t spacePos = header.find(' ');
202     if (spacePos == std::string::npos) {
203         return "";
204     }
205     return header.substr(0, spacePos);
206 }
207 
ParseConnectRequest(const std::string & header,std::string & host,int32_t & port)208 bool ProxyServer::ParseConnectRequest(const std::string &header, std::string &host, int32_t &port)
209 {
210     size_t methodEnd = header.find(' ');
211     if (methodEnd == std::string::npos) {
212         return false;
213     }
214     size_t hostStart = methodEnd + 1;
215     size_t hostEnd = header.find(' ', hostStart);
216     if (hostEnd == std::string::npos) {
217         return false;
218     }
219     std::string hostPort = header.substr(hostStart, hostEnd - hostStart);
220     size_t colonPos = hostPort.find(':');
221     if (colonPos != std::string::npos) {
222         host = hostPort.substr(0, colonPos);
223         port = std::stoi(hostPort.substr(colonPos + 1));
224     } else {
225         host = hostPort;
226         port = PORT_443;
227     }
228     return true;
229 }
230 
ParseHttpRequest(const std::string & header,std::string & host,int32_t & port)231 bool ProxyServer::ParseHttpRequest(const std::string &header, std::string &host, int32_t &port)
232 {
233     size_t hostPos = header.find("Host: ");
234     if (hostPos == std::string::npos) {
235         return false;
236     }
237     size_t hostEnd = header.find("\r\n", hostPos);
238     if (hostEnd == std::string::npos) {
239         return false;
240     }
241     std::string hostLine = header.substr(hostPos + 6, hostEnd - hostPos - 6);
242     size_t colonPos = hostLine.find(':');
243     if (colonPos != std::string::npos) {
244         host = hostLine.substr(0, colonPos);
245         port = std::stoi(hostLine.substr(colonPos + 1));
246     } else {
247         host = hostLine;
248         port = PORT_80;
249     }
250     return true;
251 }
252 
HandlePollError(int32_t ret,int32_t errnoVal)253 static bool HandlePollError(int32_t ret, int32_t errnoVal)
254 {
255     if (ret < 0) {
256         if (errnoVal == EINTR) {
257             return false;
258         }
259         std::cerr << "Poll失败: " << strerror(errnoVal) << std::endl;
260         return true;
261     }
262     return false;
263 }
264 
CheckPollHupOrErr(const struct pollfd * fds)265 static bool CheckPollHupOrErr(const struct pollfd *fds)
266 {
267     return (fds[0].revents & (POLLHUP | POLLERR)) || (fds[1].revents & (POLLHUP | POLLERR));
268 }
269 
TransferData(int32_t srcFd,int32_t dstFd,char * buffer,size_t bufferSize,std::shared_ptr<Stats> stats)270 static bool TransferData(int32_t srcFd, int32_t dstFd, char *buffer, size_t bufferSize, std::shared_ptr<Stats> stats)
271 {
272     int32_t n = recv(srcFd, buffer, bufferSize, 0);
273     if (n <= 0) {
274         return true;
275     }
276     stats->bytesReceived += n;
277     if (send(dstFd, buffer, n, 0) <= 0) {
278         return true;
279     }
280     stats->bytesSent += n;
281     return false;
282 }
283 
TunnelData(int32_t client,int32_t server)284 void ProxyServer::TunnelData(int32_t client, int32_t server)
285 {
286     struct pollfd fds[2];
287     fds[0].fd = client;
288     fds[0].events = POLLIN;
289     fds[1].fd = server;
290     fds[1].events = POLLIN;
291     char buffer[bufferSize];
292     bool clientClosed = false;
293     bool serverClosed = false;
294     while (!clientClosed && !serverClosed && running_) {
295         int32_t ret = poll(fds, 2, 1000);
296         if (HandlePollError(ret, errno)) {
297             break;
298         }
299         if (ret == 0) {
300             continue;
301         }
302         if (CheckPollHupOrErr(fds)) {
303             break;
304         }
305         if (fds[0].revents & POLLIN) {
306             clientClosed = TransferData(client, server, buffer, bufferSize, stats_);
307             if (clientClosed) {
308                 continue;
309             }
310         }
311         if (fds[1].revents & POLLIN) {
312             serverClosed = TransferData(server, client, buffer, bufferSize, stats_);
313         }
314     }
315 }
ReceiveResponseHeader(int32_t socket)316 std::string ProxyServer::ReceiveResponseHeader(int32_t socket)
317 {
318     std::string header;
319     char buffer[bufferSize];
320 
321     int bytesRead;
322     while ((bytesRead = recv(socket, buffer, sizeof(buffer) - 1, 0)) > 0) {
323         buffer[bytesRead] = '\0';
324         header.append(buffer);
325         if (header.find("\r\n\r\n") != std::string::npos) {
326             break;
327         }
328         if (header.size() > maxHeaderSize) {
329             break;
330         }
331     }
332     return header;
333 }
334 
IsPortAvailable(int32_t port)335 bool ProxyServer::IsPortAvailable(int32_t port)
336 {
337     int sockfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
338     if (sockfd < 0) {
339         return false;
340     }
341 
342     int optval = 1;
343     if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&optval), sizeof(optval)) < 0) {
344         close(sockfd);
345         return false;
346     }
347 
348     struct sockaddr_in serverAddr;
349     serverAddr.sin_family = AF_INET;
350     serverAddr.sin_port = htons(static_cast<uint16_t>(port));
351     serverAddr.sin_addr.s_addr = inet_addr("127.0.0.1");
352 
353     int bindResult = bind(sockfd, reinterpret_cast<struct sockaddr *>(&serverAddr), sizeof(serverAddr));
354     close(sockfd);
355     return (bindResult == 0);
356 }
357 
FindAvailablePort(int32_t startPort,int32_t endPort)358 int ProxyServer::FindAvailablePort(int32_t startPort, int32_t endPort)
359 {
360     std::vector<int> portsToTry;
361     for (int port = startPort; port <= endPort; ++port) {
362         portsToTry.push_back(port);
363     }
364 
365     unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
366     std::shuffle(portsToTry.begin(), portsToTry.end(), std::default_random_engine(seed));
367 
368     for (int port : portsToTry) {
369         if (IsPortAvailable(port)) {
370             return port;
371         }
372     }
373 
374     return -1;
375 }
376 
ConnectToServer(const std::string & host,int port)377 int ProxyServer::ConnectToServer(const std::string &host, int port)
378 {
379     int serverSocket = socket(AF_INET, SOCK_STREAM, 0);
380     if (serverSocket < 0) {
381         std::cerr << "无法创建socket" << std::endl;
382         return -1;
383     }
384 
385     struct timeval timeout;
386     timeout.tv_sec = NUM_10;
387     timeout.tv_usec = 0;
388 
389     if (setsockopt(serverSocket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0 ||
390         setsockopt(serverSocket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
391         std::cerr << "设置socket超时选项失败" << std::endl;
392         close(serverSocket);
393         return -1;
394     }
395 
396     struct hostent *server = gethostbyname(host.c_str());
397     if (server == nullptr) {
398         std::cerr << "无法解析主机: " << host << std::endl;
399         close(serverSocket);
400         return -1;
401     }
402 
403     struct sockaddr_in serverAddr;
404     serverAddr.sin_family = AF_INET;
405     serverAddr.sin_port = htons(port);
406     memcpy_s(&serverAddr.sin_addr.s_addr, sizeof(serverAddr.sin_addr.s_addr), server->h_addr, server->h_length);
407 
408     if (connect(serverSocket, reinterpret_cast<const sockaddr *>(&serverAddr), sizeof(serverAddr)) < 0) {
409         std::cerr << "无法连接到目标服务器 " << host << ":" << port << std::endl;
410         close(serverSocket);
411         return -1;
412     }
413 
414     return serverSocket;
415 }
416 
ConnectViaUpstreamProxy(const std::string & targetHost,int targetPort,const std::string & originalRequest,std::string proxyHost,int proxyPort)417 int ProxyServer::ConnectViaUpstreamProxy(const std::string &targetHost, int targetPort,
418                                          const std::string &originalRequest, std::string proxyHost, int proxyPort)
419 {
420     int proxySocket = ConnectToServer(proxyHost, proxyPort);
421     if (proxySocket < 0) {
422         ERROR("connect upstream proxy fail %s:%d\n", proxyHost.c_str(), proxyPort);
423         return -1;
424     }
425     printf("ConnectToServer %s : %d \n", proxyHost.c_str(), proxyPort);
426     printf("originalRequest ########## \n %s\n###########\n", originalRequest.c_str());
427     if (send(proxySocket, originalRequest.c_str(), originalRequest.length(), 0) < 0) {
428         close(proxySocket);
429         return -1;
430     }
431 
432     stats_->bytesSent += originalRequest.length();
433 
434     return proxySocket;
435 }
436 
ConnectViaUpstreamProxyHttps(const std::string & targetHost,int targetPort,std::string proxyHost,int proxyPort)437 int ProxyServer::ConnectViaUpstreamProxyHttps(const std::string &targetHost, int targetPort, std::string proxyHost,
438                                               int proxyPort)
439 {
440     int proxySocket = ConnectToServer(proxyHost, proxyPort);
441     if (proxySocket < 0) {
442         ERROR("connect upproxy fail %s:%d fail", proxyHost.c_str(), proxyPort);
443         return -1;
444     }
445     std::ostringstream connectRequest;
446     connectRequest << "CONNECT " << targetHost << ":" << targetPort << " HTTP/1.1\r\n"
447                    << "Host: " << targetHost << ":" << targetPort << "\r\n"
448                    << "Proxy-Connection: Keep-Alive\r\n"
449                    << "\r\n";
450     std::string requestStr = connectRequest.str();
451     if (send(proxySocket, requestStr.c_str(), requestStr.length(), 0) < 0) {
452         ERROR("send CONNECT to upstream proxy fail \n");
453         close(proxySocket);
454         return -1;
455     }
456     stats_->bytesSent += requestStr.length();
457     std::string response = ReceiveResponseHeader(proxySocket);
458     if (response.empty() || response.find("HTTP/1.1 200") == std::string::npos) {
459         std::cerr << "上游代理拒绝CONNECT请求: " << response << std::endl;
460         close(proxySocket);
461         return -1;
462     }
463     stats_->bytesReceived += response.length();
464     return proxySocket;
465 }
466 
GetRequestUrl(const std::string & header)467 std::string ProxyServer::GetRequestUrl(const std::string &header)
468 {
469     std::string method = GetRequestMethod(header);
470     std::string url;
471     if (method == "CONNECT") {
472         size_t methodEnd = header.find(' ');
473         if (methodEnd == std::string::npos) {
474             return "";
475         }
476         size_t hostStart = methodEnd + 1;
477         size_t hostEnd = header.find(' ', hostStart);
478         if (hostEnd == std::string::npos) {
479             return "";
480         }
481         std::string hostPort = header.substr(hostStart, hostEnd - hostStart);
482         url = "https://" + hostPort;
483     } else {
484         size_t methodEnd = header.find(' ');
485         if (methodEnd == std::string::npos) {
486             return "";
487         }
488         size_t pathStart = methodEnd + 1;
489         size_t pathEnd = header.find(' ', pathStart);
490         if (pathEnd == std::string::npos) {
491             return "";
492         }
493         std::string path = header.substr(pathStart, pathEnd - pathStart);
494         if (path.find("://") != std::string::npos) {
495             return path;
496         }
497         std::string host;
498         int port = PORT_80;
499         if (!ParseHttpRequest(header, host, port)) {
500             return path;
501         }
502         url = "http://";
503         url += host;
504         if (port != PORT_80) {
505             url += ":" + std::to_string(port);
506         }
507         if (!path.empty() && path[0] != '/') {
508             url += "/";
509         }
510         url += path;
511     }
512 
513     return url;
514 }
515 
ParseProxyInfo(std::string url,std::string host,std::string & proxyType,std::string & proxyHost,int32_t & proxyPort)516 bool ProxyServer::ParseProxyInfo(std::string url, std::string host, std::string &proxyType, std::string &proxyHost,
517                                  int32_t &proxyPort)
518 {
519     std::string pacScirpt;
520     if (pacFunction_) {
521         pacScirpt = pacFunction_(url, host);
522     }
523     if (pacScirpt.empty()) {
524         return false;
525     }
526     return ParsePacResult(pacScirpt, proxyType, proxyHost, proxyPort);
527 }
528 
ParsePacResult(const std::string & pacResult,std::string & proxyType,std::string & proxyHost,int32_t & proxyPort)529 bool ProxyServer::ParsePacResult(const std::string &pacResult, std::string &proxyType, std::string &proxyHost,
530                                  int32_t &proxyPort)
531 {
532     proxyType = "";
533     proxyHost = "";
534     proxyPort = 0;
535     if (pacResult.empty()) {
536         return false;
537     }
538     std::istringstream stream(pacResult);
539     std::string rule;
540     while (std::getline(stream, rule, ';')) {
541         rule.erase(0, rule.find_first_not_of(" \t"));
542         rule.erase(rule.find_last_not_of(" \t") + 1);
543         if (rule.empty()) {
544             continue;
545         }
546         size_t spacePos = rule.find(' ');
547         if (spacePos == std::string::npos) {
548             proxyType = rule;
549             if (proxyType == "DIRECT") {
550                 return true;
551             }
552             continue;
553         }
554         proxyType = rule.substr(0, spacePos);
555         size_t hostStart = rule.find_first_not_of(" \t", spacePos);
556         if (hostStart == std::string::npos) {
557             continue;
558         }
559         std::string hostPort = rule.substr(hostStart);
560         size_t colonPos = hostPort.find(':');
561         if (colonPos == std::string::npos) {
562             proxyHost = hostPort;
563             if (proxyType == "PROXY" || proxyType == "HTTP") {
564                 proxyPort = PORT_8080;
565             } else if (proxyType == "SOCKS" || proxyType == "SOCKS5") {
566                 proxyPort = PORT_1080;
567             } else if (proxyType == "SOCKS4") {
568                 proxyPort = PORT_1080;
569             } else {
570                 proxyPort = PORT_8080;
571             }
572         } else {
573             proxyHost = hostPort.substr(0, colonPos);
574             proxyPort = std::stoi(hostPort.substr(colonPos + 1));
575         }
576         return true;
577     }
578     return false;
579 }
580 
AddHttpHeader(const std::string & httpMessage,const std::string & headerName,const std::string & headerValue)581 std::string AddHttpHeader(const std::string &httpMessage, const std::string &headerName, const std::string &headerValue)
582 {
583     size_t headersEnd = httpMessage.find("\r\n\r\n");
584     if (headersEnd == std::string::npos) {
585         headersEnd = httpMessage.length();
586     }
587     std::string newHeader = headerName + ": " + headerValue + "\r\n";
588     return httpMessage.substr(0, headersEnd) + newHeader + httpMessage.substr(headersEnd);
589 }
590 
HandleClient(int32_t clientSocket)591 void ProxyServer::HandleClient(int32_t clientSocket)
592 {
593     stats_->activeConnections++;
594 
595     // 设置超时
596     if (!SetSocketTimeout(clientSocket)) {
597         CleanupConnection(clientSocket);
598         return;
599     }
600 
601     // 读取请求头
602     std::string requestHeader;
603     if (!ReadRequestHeader(clientSocket, requestHeader)) {
604         CleanupConnection(clientSocket);
605         return;
606     }
607 
608     printf("\033[33mproxy server read client data %.32s \n\033[0m", requestHeader.c_str());
609 
610     std::string method = GetRequestMethod(requestHeader);
611     std::string url = GetRequestUrl(requestHeader);
612 
613     if (method == "CONNECT") {
614         HandleConnectRequest(clientSocket, requestHeader, url);
615     } else {
616         HandleHttpRequest(clientSocket, requestHeader, url);
617     }
618 }
619 
SetSocketTimeout(int32_t socket)620 bool ProxyServer::SetSocketTimeout(int32_t socket)
621 {
622     struct timeval timeout;
623     timeout.tv_sec = TIME_OUT;
624     timeout.tv_usec = 0;
625 
626     if (setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
627         return false;
628     }
629     return true;
630 }
631 
ReadRequestHeader(int32_t clientSocket,std::string & requestHeader)632 bool ProxyServer::ReadRequestHeader(int32_t clientSocket, std::string &requestHeader)
633 {
634     char buffer[bufferSize];
635     int bytesReceived = 0;
636 
637     while ((bytesReceived = recv(clientSocket, buffer, bufferSize - 1, 0)) > 0) {
638         buffer[bytesReceived] = '\0';
639         requestHeader += buffer;
640         stats_->bytesReceived += bytesReceived;
641 
642         if (requestHeader.find("\r\n\r\n") != std::string::npos) {
643             break;
644         }
645 
646         if (requestHeader.size() > maxHeaderSize) {
647             return false;
648         }
649     }
650 
651     return bytesReceived > 0;
652 }
653 
CleanupConnection(int32_t clientSocket)654 void ProxyServer::CleanupConnection(int32_t clientSocket)
655 {
656     close(clientSocket);
657     stats_->activeConnections--;
658 }
659 
HandleConnectRequest(int clientSocket,const std::string & requestHeader,const std::string & url)660 void ProxyServer::HandleConnectRequest(int clientSocket, const std::string &requestHeader, const std::string &url)
661 {
662     stats_->httpsRequests++;
663     std::string host;
664     int port;
665     if (!ParseConnectRequest(requestHeader, host, port)) {
666         SendErrorResponse(clientSocket, "HTTP/1.1 400 Bad Request\r\n\r\n");
667         return;
668     }
669     proxServerTargetUrl = url;
670     proxServerPort = port_;
671     LOG("local proxy server port:%d url:%s host:%s port:%d \n", port_, url.c_str(), host.c_str(), port);
672     std::string header;
673     int serverSocket = EstablishServerConnection(url, host, port, "HTTPS", header);
674     if (serverSocket < 0) {
675         SendErrorResponse(clientSocket, "HTTP/1.1 502 Bad Gateway\r\n\r\n");
676         return;
677     }
678     const char *response = "HTTP/1.1 200 Connection Established\r\n\r\n";
679     int responseLen = strlen(response);
680     if (send(clientSocket, response, responseLen, 0) < 0) {
681         std::cerr << "发送CONNECT响应失败" << std::endl;
682         close(serverSocket);
683         CleanupConnection(clientSocket);
684         return;
685     }
686     stats_->bytesSent += responseLen;
687     TunnelData(clientSocket, serverSocket);
688     close(serverSocket);
689     CleanupConnection(clientSocket);
690 }
691 
HandleHttpRequest(int32_t clientSocket,std::string & requestHeader,const std::string & url)692 void ProxyServer::HandleHttpRequest(int32_t clientSocket, std::string &requestHeader, const std::string &url)
693 {
694     stats_->httpRequests++;
695     std::string host;
696     int port;
697     if (!ParseHttpRequest(requestHeader, host, port)) {
698         std::cerr << "无法解析HTTP请求" << std::endl;
699         SendErrorResponse(clientSocket, "HTTP/1.1 400 Bad Request\r\n\r\n");
700         return;
701     }
702     LOG("\033[33mconnect to localport:%d %s %s \n\033[0m", port_, url.c_str(), host.c_str());
703     int serverSocket = EstablishServerConnection(url, host, port, "HTTP", requestHeader);
704     if (serverSocket < 0) {
705         SendErrorResponse(clientSocket, "HTTP/1.1 502 Bad Gateway\r\n\r\n");
706         return;
707     }
708     ForwardResponseToClient(clientSocket, serverSocket);
709     close(serverSocket);
710     CleanupConnection(clientSocket);
711 }
712 
EstablishServerConnection(const std::string & url,const std::string & host,int32_t port,const std::string & requestType,std::string & requestHeader)713 int ProxyServer::EstablishServerConnection(const std::string &url, const std::string &host, int32_t port,
714                                            const std::string &requestType, std::string &requestHeader)
715 {
716     std::string proxyType;
717     std::string proxyHost;
718     int proxyPort;
719     bool useUpstreamProxy = ParseProxyInfo(url, host, proxyType, proxyHost, proxyPort);
720     int serverSocket = -1;
721     if (useUpstreamProxy && (proxyType == "PROXY" || proxyType == "HTTP")) {
722         if (requestType == "HTTPS") {
723             serverSocket = ConnectViaUpstreamProxyHttps(host, port, proxyHost, proxyPort);
724         } else {
725             serverSocket = ConnectViaUpstreamProxy(host, port, requestHeader, proxyHost, proxyPort);
726         }
727     } else if (requestType == "HTTP") {
728         requestHeader = AddHttpHeader(requestHeader, "Proxy-Port", std::to_string(port_));
729         printf("\033[33mnot proxy info direct send %.32s \n\033[0m", requestHeader.c_str());
730         serverSocket = ConnectToServer(host, port);
731         if (serverSocket >= 0) {
732             if (send(serverSocket, requestHeader.c_str(), requestHeader.length(), 0) < 0) {
733                 close(serverSocket);
734                 serverSocket = -1;
735             } else {
736                 stats_->bytesSent += requestHeader.length();
737             }
738         }
739     } else {
740         serverSocket = ConnectToServer(host, port);
741     }
742     return serverSocket;
743 }
744 
SendErrorResponse(int32_t clientSocket,const char * response)745 void ProxyServer::SendErrorResponse(int32_t clientSocket, const char *response)
746 {
747     send(clientSocket, response, strlen(response), 0);
748     CleanupConnection(clientSocket);
749 }
750 
ForwardResponseToClient(int32_t clientSocket,int32_t serverSocket)751 void ProxyServer::ForwardResponseToClient(int32_t clientSocket, int32_t serverSocket)
752 {
753     char buffer[bufferSize];
754     int bytesReceived;
755 
756     while ((bytesReceived = recv(serverSocket, buffer, bufferSize, 0)) > 0) {
757         stats_->bytesReceived += bytesReceived;
758 
759         if (send(clientSocket, buffer, bytesReceived, 0) < 0) {
760             break;
761         }
762 
763         stats_->bytesSent += bytesReceived;
764     }
765 }
AddTask(const ClientTask & task)766 void ProxyServer::AddTask(const ClientTask &task)
767 {
768     {
769         std::lock_guard<std::mutex> lock(queueMutex_);
770         taskQueue_.push(task);
771     }
772     queueCondition_.notify_one();
773 }
774 
WorkerThread()775 void ProxyServer::WorkerThread()
776 {
777     while (running_) {
778         ClientTask task = {-1, {}};
779 
780         {
781             std::unique_lock<std::mutex> lock(queueMutex_);
782             queueCondition_.wait(lock, [this] { return !taskQueue_.empty() || !running_; });
783 
784             if (!running_ && taskQueue_.empty()) {
785                 break;
786             }
787 
788             if (!taskQueue_.empty()) {
789                 task = taskQueue_.front();
790                 taskQueue_.pop();
791             }
792         }
793 
794         if (task.clientSocket >= 0) {
795             HandleClient(task.clientSocket);
796         }
797     }
798 }
799 
AcceptLoop()800 void ProxyServer::AcceptLoop()
801 {
802     while (running_) {
803         struct pollfd fd;
804         fd.fd = serverSocket_;
805         fd.events = POLLIN;
806         int ret = poll(&fd, 1, 1000);
807         if (ret < 0) {
808             if (errno == EINTR) {
809                 continue;
810             }
811             std::cerr << "Poll失败: " << strerror(errno) << std::endl;
812             break;
813         }
814         if (ret == 0) {
815             continue;
816         }
817         if (!(fd.revents & POLLIN)) {
818             continue;
819         }
820         struct sockaddr_in clientAddr;
821         socklen_t clientAddrLen = sizeof(clientAddr);
822         int clientSocket = accept(serverSocket_, reinterpret_cast<sockaddr *>(&clientAddr), &clientAddrLen);
823         if (clientSocket < 0) {
824             if (errno == EAGAIN || errno == EWOULDBLOCK) {
825                 continue;
826             }
827             std::cerr << "接受连接失败: " << strerror(errno) << std::endl;
828             continue;
829         }
830         stats_->totalConnections++;
831         AddTask(ClientTask(clientSocket, clientAddr));
832     }
833 }
834