• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <alloca.h>
2 #include <errno.h>
3 #include <pthread.h>
4 #include <signal.h>
5 #include <string.h>
6 #include <arpa/inet.h>
7 #include <sys/socket.h>
8 #include <sys/types.h>
9 
10 #define LOG_TAG "SocketClient"
11 #include <cutils/log.h>
12 
13 #include <sysutils/SocketClient.h>
14 
SocketClient(int socket,bool owned)15 SocketClient::SocketClient(int socket, bool owned) {
16     init(socket, owned, false);
17 }
18 
SocketClient(int socket,bool owned,bool useCmdNum)19 SocketClient::SocketClient(int socket, bool owned, bool useCmdNum) {
20     init(socket, owned, useCmdNum);
21 }
22 
init(int socket,bool owned,bool useCmdNum)23 void SocketClient::init(int socket, bool owned, bool useCmdNum) {
24     mSocket = socket;
25     mSocketOwned = owned;
26     mUseCmdNum = useCmdNum;
27     pthread_mutex_init(&mWriteMutex, NULL);
28     pthread_mutex_init(&mRefCountMutex, NULL);
29     mPid = -1;
30     mUid = -1;
31     mGid = -1;
32     mRefCount = 1;
33     mCmdNum = 0;
34 
35     struct ucred creds;
36     socklen_t szCreds = sizeof(creds);
37     memset(&creds, 0, szCreds);
38 
39     int err = getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
40     if (err == 0) {
41         mPid = creds.pid;
42         mUid = creds.uid;
43         mGid = creds.gid;
44     }
45 }
46 
~SocketClient()47 SocketClient::~SocketClient() {
48     if (mSocketOwned) {
49         close(mSocket);
50     }
51 }
52 
sendMsg(int code,const char * msg,bool addErrno)53 int SocketClient::sendMsg(int code, const char *msg, bool addErrno) {
54     return sendMsg(code, msg, addErrno, mUseCmdNum);
55 }
56 
sendMsg(int code,const char * msg,bool addErrno,bool useCmdNum)57 int SocketClient::sendMsg(int code, const char *msg, bool addErrno, bool useCmdNum) {
58     char *buf;
59     int ret = 0;
60 
61     if (addErrno) {
62         if (useCmdNum) {
63             ret = asprintf(&buf, "%d %d %s (%s)", code, getCmdNum(), msg, strerror(errno));
64         } else {
65             ret = asprintf(&buf, "%d %s (%s)", code, msg, strerror(errno));
66         }
67     } else {
68         if (useCmdNum) {
69             ret = asprintf(&buf, "%d %d %s", code, getCmdNum(), msg);
70         } else {
71             ret = asprintf(&buf, "%d %s", code, msg);
72         }
73     }
74     // Send the zero-terminated message
75     if (ret != -1) {
76         ret = sendMsg(buf);
77         free(buf);
78     }
79     return ret;
80 }
81 
82 // send 3-digit code, null, binary-length, binary data
sendBinaryMsg(int code,const void * data,int len)83 int SocketClient::sendBinaryMsg(int code, const void *data, int len) {
84 
85     // 4 bytes for the code & null + 4 bytes for the len
86     char buf[8];
87     // Write the code
88     snprintf(buf, 4, "%.3d", code);
89     // Write the len
90     uint32_t tmp = htonl(len);
91     memcpy(buf + 4, &tmp, sizeof(uint32_t));
92 
93     struct iovec vec[2];
94     vec[0].iov_base = (void *) buf;
95     vec[0].iov_len = sizeof(buf);
96     vec[1].iov_base = (void *) data;
97     vec[1].iov_len = len;
98 
99     pthread_mutex_lock(&mWriteMutex);
100     int result = sendDataLockedv(vec, (len > 0) ? 2 : 1);
101     pthread_mutex_unlock(&mWriteMutex);
102 
103     return result;
104 }
105 
106 // Sends the code (c-string null-terminated).
sendCode(int code)107 int SocketClient::sendCode(int code) {
108     char buf[4];
109     snprintf(buf, sizeof(buf), "%.3d", code);
110     return sendData(buf, sizeof(buf));
111 }
112 
quoteArg(const char * arg)113 char *SocketClient::quoteArg(const char *arg) {
114     int len = strlen(arg);
115     char *result = (char *)malloc(len * 2 + 3);
116     char *current = result;
117     const char *end = arg + len;
118     char *oldresult;
119 
120     if(result == NULL) {
121         SLOGW("malloc error (%s)", strerror(errno));
122         return NULL;
123     }
124 
125     *(current++) = '"';
126     while (arg < end) {
127         switch (*arg) {
128         case '\\':
129         case '"':
130             *(current++) = '\\'; // fallthrough
131         default:
132             *(current++) = *(arg++);
133         }
134     }
135     *(current++) = '"';
136     *(current++) = '\0';
137     oldresult = result; // save pointer in case realloc fails
138     result = (char *)realloc(result, current-result);
139     return result ? result : oldresult;
140 }
141 
142 
sendMsg(const char * msg)143 int SocketClient::sendMsg(const char *msg) {
144     // Send the message including null character
145     if (sendData(msg, strlen(msg) + 1) != 0) {
146         SLOGW("Unable to send msg '%s'", msg);
147         return -1;
148     }
149     return 0;
150 }
151 
sendData(const void * data,int len)152 int SocketClient::sendData(const void *data, int len) {
153     struct iovec vec[1];
154     vec[0].iov_base = (void *) data;
155     vec[0].iov_len = len;
156 
157     pthread_mutex_lock(&mWriteMutex);
158     int rc = sendDataLockedv(vec, 1);
159     pthread_mutex_unlock(&mWriteMutex);
160 
161     return rc;
162 }
163 
sendDatav(struct iovec * iov,int iovcnt)164 int SocketClient::sendDatav(struct iovec *iov, int iovcnt) {
165     pthread_mutex_lock(&mWriteMutex);
166     int rc = sendDataLockedv(iov, iovcnt);
167     pthread_mutex_unlock(&mWriteMutex);
168 
169     return rc;
170 }
171 
sendDataLockedv(struct iovec * iov,int iovcnt)172 int SocketClient::sendDataLockedv(struct iovec *iov, int iovcnt) {
173 
174     if (mSocket < 0) {
175         errno = EHOSTUNREACH;
176         return -1;
177     }
178 
179     if (iovcnt <= 0) {
180         return 0;
181     }
182 
183     int ret = 0;
184     int e = 0; // SLOGW and sigaction are not inert regarding errno
185     int current = 0;
186 
187     struct sigaction new_action, old_action;
188     memset(&new_action, 0, sizeof(new_action));
189     new_action.sa_handler = SIG_IGN;
190     sigaction(SIGPIPE, &new_action, &old_action);
191 
192     for (;;) {
193         ssize_t rc = TEMP_FAILURE_RETRY(
194             writev(mSocket, iov + current, iovcnt - current));
195 
196         if (rc > 0) {
197             size_t written = rc;
198             while ((current < iovcnt) && (written >= iov[current].iov_len)) {
199                 written -= iov[current].iov_len;
200                 current++;
201             }
202             if (current == iovcnt) {
203                 break;
204             }
205             iov[current].iov_base = (char *)iov[current].iov_base + written;
206             iov[current].iov_len -= written;
207             continue;
208         }
209 
210         if (rc == 0) {
211             e = EIO;
212             SLOGW("0 length write :(");
213         } else {
214             e = errno;
215             SLOGW("write error (%s)", strerror(e));
216         }
217         ret = -1;
218         break;
219     }
220 
221     sigaction(SIGPIPE, &old_action, &new_action);
222 
223     errno = e;
224     return ret;
225 }
226 
incRef()227 void SocketClient::incRef() {
228     pthread_mutex_lock(&mRefCountMutex);
229     mRefCount++;
230     pthread_mutex_unlock(&mRefCountMutex);
231 }
232 
decRef()233 bool SocketClient::decRef() {
234     bool deleteSelf = false;
235     pthread_mutex_lock(&mRefCountMutex);
236     mRefCount--;
237     if (mRefCount == 0) {
238         deleteSelf = true;
239     } else if (mRefCount < 0) {
240         SLOGE("SocketClient refcount went negative!");
241     }
242     pthread_mutex_unlock(&mRefCountMutex);
243     if (deleteSelf) {
244         delete this;
245     }
246     return deleteSelf;
247 }
248