• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2008 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 // Written in NSPR style to also be suitable for adding to the NSS demo suite
5 
6 /* memio is a simple NSPR I/O layer that lets you decouple NSS from
7  * the real network.  It's rather like openssl's memory bio,
8  * and is useful when your app absolutely, positively doesn't
9  * want to let NSS do its own networking.
10  */
11 
12 #include <stdlib.h>
13 #include <string.h>
14 
15 #include <prerror.h>
16 #include <prinit.h>
17 #include <prlog.h>
18 
19 #include "nss_memio.h"
20 
21 /*--------------- private memio types -----------------------*/
22 
23 /*----------------------------------------------------------------------
24  Simple private circular buffer class.  Size cannot be changed once allocated.
25 ----------------------------------------------------------------------*/
26 
27 struct memio_buffer {
28     int head;     /* where to take next byte out of buf */
29     int tail;     /* where to put next byte into buf */
30     int bufsize;  /* number of bytes allocated to buf */
31     /* TODO(port): error handling is pessimistic right now.
32      * Once an error is set, the socket is considered broken
33      * (PR_WOULD_BLOCK_ERROR not included).
34      */
35     PRErrorCode last_err;
36     char *buf;
37 };
38 
39 
40 /* The 'secret' field of a PRFileDesc created by memio_CreateIOLayer points
41  * to one of these.
42  * In the public header, we use struct memio_Private as a typesafe alias
43  * for this.  This causes a few ugly typecasts in the private file, but
44  * seems safer.
45  */
46 struct PRFilePrivate {
47     /* read requests are satisfied from this buffer */
48     struct memio_buffer readbuf;
49 
50     /* write requests are satisfied from this buffer */
51     struct memio_buffer writebuf;
52 
53     /* SSL needs to know socket peer's name */
54     PRNetAddr peername;
55 
56     /* if set, empty I/O returns EOF instead of EWOULDBLOCK */
57     int eof;
58 };
59 
60 /*--------------- private memio_buffer functions ---------------------*/
61 
62 /* Forward declarations.  */
63 
64 /* Allocate a memio_buffer of given size. */
65 static void memio_buffer_new(struct memio_buffer *mb, int size);
66 
67 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
68 static void memio_buffer_destroy(struct memio_buffer *mb);
69 
70 /* How many bytes can be read out of the buffer without wrapping */
71 static int memio_buffer_used_contiguous(const struct memio_buffer *mb);
72 
73 /* How many bytes can be written into the buffer without wrapping */
74 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb);
75 
76 /* Write n bytes into the buffer.  Returns number of bytes written. */
77 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n);
78 
79 /* Read n bytes from the buffer.  Returns number of bytes read. */
80 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n);
81 
82 /* Allocate a memio_buffer of given size. */
memio_buffer_new(struct memio_buffer * mb,int size)83 static void memio_buffer_new(struct memio_buffer *mb, int size)
84 {
85     mb->head = 0;
86     mb->tail = 0;
87     mb->bufsize = size;
88     mb->buf = malloc(size);
89 }
90 
91 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
memio_buffer_destroy(struct memio_buffer * mb)92 static void memio_buffer_destroy(struct memio_buffer *mb)
93 {
94     free(mb->buf);
95     mb->buf = NULL;
96     mb->head = 0;
97     mb->tail = 0;
98 }
99 
100 /* How many bytes can be read out of the buffer without wrapping */
memio_buffer_used_contiguous(const struct memio_buffer * mb)101 static int memio_buffer_used_contiguous(const struct memio_buffer *mb)
102 {
103     return (((mb->tail >= mb->head) ? mb->tail : mb->bufsize) - mb->head);
104 }
105 
106 /* How many bytes can be written into the buffer without wrapping */
memio_buffer_unused_contiguous(const struct memio_buffer * mb)107 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb)
108 {
109     if (mb->head > mb->tail) return mb->head - mb->tail - 1;
110     return mb->bufsize - mb->tail - (mb->head == 0);
111 }
112 
113 /* Write n bytes into the buffer.  Returns number of bytes written. */
memio_buffer_put(struct memio_buffer * mb,const char * buf,int n)114 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n)
115 {
116     int len;
117     int transferred = 0;
118 
119     /* Handle part before wrap */
120     len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
121     if (len > 0) {
122         /* Buffer not full */
123         memcpy(&mb->buf[mb->tail], buf, len);
124         mb->tail += len;
125         if (mb->tail == mb->bufsize)
126             mb->tail = 0;
127         n -= len;
128         buf += len;
129         transferred += len;
130 
131         /* Handle part after wrap */
132         len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
133         if (len > 0) {
134             /* Output buffer still not full, input buffer still not empty */
135             memcpy(&mb->buf[mb->tail], buf, len);
136             mb->tail += len;
137             if (mb->tail == mb->bufsize)
138                 mb->tail = 0;
139                 transferred += len;
140         }
141     }
142 
143     return transferred;
144 }
145 
146 
147 /* Read n bytes from the buffer.  Returns number of bytes read. */
memio_buffer_get(struct memio_buffer * mb,char * buf,int n)148 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n)
149 {
150     int len;
151     int transferred = 0;
152 
153     /* Handle part before wrap */
154     len = PR_MIN(n, memio_buffer_used_contiguous(mb));
155     if (len) {
156         memcpy(buf, &mb->buf[mb->head], len);
157         mb->head += len;
158         if (mb->head == mb->bufsize)
159             mb->head = 0;
160         n -= len;
161         buf += len;
162         transferred += len;
163 
164         /* Handle part after wrap */
165         len = PR_MIN(n, memio_buffer_used_contiguous(mb));
166         if (len) {
167         memcpy(buf, &mb->buf[mb->head], len);
168         mb->head += len;
169             if (mb->head == mb->bufsize)
170                 mb->head = 0;
171                 transferred += len;
172         }
173     }
174 
175     return transferred;
176 }
177 
178 /*--------------- private memio functions -----------------------*/
179 
memio_Close(PRFileDesc * fd)180 static PRStatus PR_CALLBACK memio_Close(PRFileDesc *fd)
181 {
182     struct PRFilePrivate *secret = fd->secret;
183     memio_buffer_destroy(&secret->readbuf);
184     memio_buffer_destroy(&secret->writebuf);
185     free(secret);
186     fd->dtor(fd);
187     return PR_SUCCESS;
188 }
189 
memio_Shutdown(PRFileDesc * fd,PRIntn how)190 static PRStatus PR_CALLBACK memio_Shutdown(PRFileDesc *fd, PRIntn how)
191 {
192     /* TODO: pass shutdown status to app somehow */
193     return PR_SUCCESS;
194 }
195 
196 /* If there was a network error in the past taking bytes
197  * out of the buffer, return it to the next call that
198  * tries to read from an empty buffer.
199  */
memio_Recv(PRFileDesc * fd,void * buf,PRInt32 len,PRIntn flags,PRIntervalTime timeout)200 static int PR_CALLBACK memio_Recv(PRFileDesc *fd, void *buf, PRInt32 len,
201                                   PRIntn flags, PRIntervalTime timeout)
202 {
203     struct PRFilePrivate *secret;
204     struct memio_buffer *mb;
205     int rv;
206 
207     if (flags) {
208         PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
209         return -1;
210     }
211 
212     secret = fd->secret;
213     mb = &secret->readbuf;
214     PR_ASSERT(mb->bufsize);
215     rv = memio_buffer_get(mb, buf, len);
216     if (rv == 0 && !secret->eof) {
217         if (mb->last_err)
218             PR_SetError(mb->last_err, 0);
219         else
220             PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
221         return -1;
222     }
223 
224     return rv;
225 }
226 
memio_Read(PRFileDesc * fd,void * buf,PRInt32 len)227 static int PR_CALLBACK memio_Read(PRFileDesc *fd, void *buf, PRInt32 len)
228 {
229     /* pull bytes from buffer */
230     return memio_Recv(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
231 }
232 
memio_Send(PRFileDesc * fd,const void * buf,PRInt32 len,PRIntn flags,PRIntervalTime timeout)233 static int PR_CALLBACK memio_Send(PRFileDesc *fd, const void *buf, PRInt32 len,
234                                   PRIntn flags, PRIntervalTime timeout)
235 {
236     struct PRFilePrivate *secret;
237     struct memio_buffer *mb;
238     int rv;
239 
240     secret = fd->secret;
241     mb = &secret->writebuf;
242     PR_ASSERT(mb->bufsize);
243 
244     if (mb->last_err) {
245         PR_SetError(mb->last_err, 0);
246         return -1;
247     }
248     rv = memio_buffer_put(mb, buf, len);
249     if (rv == 0) {
250         PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
251         return -1;
252     }
253     return rv;
254 }
255 
memio_Write(PRFileDesc * fd,const void * buf,PRInt32 len)256 static int PR_CALLBACK memio_Write(PRFileDesc *fd, const void *buf, PRInt32 len)
257 {
258     /* append bytes to buffer */
259     return memio_Send(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
260 }
261 
memio_GetPeerName(PRFileDesc * fd,PRNetAddr * addr)262 static PRStatus PR_CALLBACK memio_GetPeerName(PRFileDesc *fd, PRNetAddr *addr)
263 {
264     /* TODO: fail if memio_SetPeerName has not been called */
265     struct PRFilePrivate *secret = fd->secret;
266     *addr = secret->peername;
267     return PR_SUCCESS;
268 }
269 
memio_GetSocketOption(PRFileDesc * fd,PRSocketOptionData * data)270 static PRStatus memio_GetSocketOption(PRFileDesc *fd, PRSocketOptionData *data)
271 {
272     /*
273      * Even in the original version for real tcp sockets,
274      * PR_SockOpt_Nonblocking is a special case that does not
275      * translate to a getsockopt() call
276      */
277     if (PR_SockOpt_Nonblocking == data->option) {
278         data->value.non_blocking = PR_TRUE;
279         return PR_SUCCESS;
280     }
281     PR_SetError(PR_OPERATION_NOT_SUPPORTED_ERROR, 0);
282     return PR_FAILURE;
283 }
284 
285 /*--------------- private memio data -----------------------*/
286 
287 /*
288  * Implement just the bare minimum number of methods needed to make ssl happy.
289  *
290  * Oddly, PR_Recv calls ssl_Recv calls ssl_SocketIsBlocking calls
291  * PR_GetSocketOption, so we have to provide an implementation of
292  * PR_GetSocketOption that just says "I'm nonblocking".
293  */
294 
295 static struct PRIOMethods  memio_layer_methods = {
296     PR_DESC_LAYERED,
297     memio_Close,
298     memio_Read,
299     memio_Write,
300     NULL,
301     NULL,
302     NULL,
303     NULL,
304     NULL,
305     NULL,
306     NULL,
307     NULL,
308     NULL,
309     NULL,
310     NULL,
311     NULL,
312     memio_Shutdown,
313     memio_Recv,
314     memio_Send,
315     NULL,
316     NULL,
317     NULL,
318     NULL,
319     NULL,
320     NULL,
321     memio_GetPeerName,
322     NULL,
323     NULL,
324     memio_GetSocketOption,
325     NULL,
326     NULL,
327     NULL,
328     NULL,
329     NULL,
330     NULL,
331     NULL,
332 };
333 
334 static PRDescIdentity memio_identity = PR_INVALID_IO_LAYER;
335 
memio_InitializeLayerName(void)336 static PRStatus memio_InitializeLayerName(void)
337 {
338     memio_identity = PR_GetUniqueIdentity("memio");
339     return PR_SUCCESS;
340 }
341 
342 /*--------------- public memio functions -----------------------*/
343 
memio_CreateIOLayer(int bufsize)344 PRFileDesc *memio_CreateIOLayer(int bufsize)
345 {
346     PRFileDesc *fd;
347     struct PRFilePrivate *secret;
348     static PRCallOnceType once;
349 
350     PR_CallOnce(&once, memio_InitializeLayerName);
351 
352     fd = PR_CreateIOLayerStub(memio_identity, &memio_layer_methods);
353     secret = malloc(sizeof(struct PRFilePrivate));
354     memset(secret, 0, sizeof(*secret));
355 
356     memio_buffer_new(&secret->readbuf, bufsize);
357     memio_buffer_new(&secret->writebuf, bufsize);
358     fd->secret = secret;
359     return fd;
360 }
361 
memio_SetPeerName(PRFileDesc * fd,const PRNetAddr * peername)362 void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername)
363 {
364     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
365     struct PRFilePrivate *secret =  memiofd->secret;
366     secret->peername = *peername;
367 }
368 
memio_GetSecret(PRFileDesc * fd)369 memio_Private *memio_GetSecret(PRFileDesc *fd)
370 {
371     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
372     struct PRFilePrivate *secret =  memiofd->secret;
373     return (memio_Private *)secret;
374 }
375 
memio_GetReadParams(memio_Private * secret,char ** buf)376 int memio_GetReadParams(memio_Private *secret, char **buf)
377 {
378     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
379     PR_ASSERT(mb->bufsize);
380 
381     *buf = &mb->buf[mb->tail];
382     return memio_buffer_unused_contiguous(mb);
383 }
384 
memio_PutReadResult(memio_Private * secret,int bytes_read)385 void memio_PutReadResult(memio_Private *secret, int bytes_read)
386 {
387     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
388     PR_ASSERT(mb->bufsize);
389 
390     if (bytes_read > 0) {
391         mb->tail += bytes_read;
392         if (mb->tail == mb->bufsize)
393             mb->tail = 0;
394     } else if (bytes_read == 0) {
395         /* Record EOF condition and report to caller when buffer runs dry */
396         ((PRFilePrivate *)secret)->eof = PR_TRUE;
397     } else /* if (bytes_read < 0) */ {
398         mb->last_err = bytes_read;
399     }
400 }
401 
memio_GetWriteParams(memio_Private * secret,const char ** buf)402 int memio_GetWriteParams(memio_Private *secret, const char **buf)
403 {
404     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
405     PR_ASSERT(mb->bufsize);
406 
407     *buf = &mb->buf[mb->head];
408     return memio_buffer_used_contiguous(mb);
409 }
410 
memio_PutWriteResult(memio_Private * secret,int bytes_written)411 void memio_PutWriteResult(memio_Private *secret, int bytes_written)
412 {
413     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
414     PR_ASSERT(mb->bufsize);
415 
416     if (bytes_written > 0) {
417         mb->head += bytes_written;
418         if (mb->head == mb->bufsize)
419             mb->head = 0;
420     } else if (bytes_written < 0) {
421         mb->last_err = bytes_written;
422     }
423 }
424 
425 /*--------------- private memio_buffer self-test -----------------*/
426 
427 /* Even a trivial unit test is very helpful when doing circular buffers. */
428 /*#define TRIVIAL_SELF_TEST*/
429 #ifdef TRIVIAL_SELF_TEST
430 #include <stdio.h>
431 
432 #define TEST_BUFLEN 7
433 
434 #define CHECKEQ(a, b) { \
435     if ((a) != (b)) { \
436         printf("%d != %d, Test failed line %d\n", a, b, __LINE__); \
437         exit(1); \
438     } \
439 }
440 
main()441 int main()
442 {
443     struct memio_buffer mb;
444     char buf[100];
445     int i;
446 
447     memio_buffer_new(&mb, TEST_BUFLEN);
448 
449     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1);
450     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
451 
452     CHECKEQ(memio_buffer_put(&mb, "howdy", 5), 5);
453 
454     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
455     CHECKEQ(memio_buffer_used_contiguous(&mb), 5);
456 
457     CHECKEQ(memio_buffer_put(&mb, "!", 1), 1);
458 
459     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
460     CHECKEQ(memio_buffer_used_contiguous(&mb), 6);
461 
462     CHECKEQ(memio_buffer_get(&mb, buf, 6), 6);
463     CHECKEQ(memcmp(buf, "howdy!", 6), 0);
464 
465     CHECKEQ(memio_buffer_unused_contiguous(&mb), 1);
466     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
467 
468     CHECKEQ(memio_buffer_put(&mb, "01234", 5), 5);
469 
470     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
471     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
472 
473     CHECKEQ(memio_buffer_put(&mb, "5", 1), 1);
474 
475     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
476     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
477 
478     /* TODO: add more cases */
479 
480     printf("Test passed\n");
481     exit(0);
482 }
483 
484 #endif
485