1 /*
2 * lws-minimal-secure-streams-smd
3 *
4 * Written in 2010-2021 by Andy Green <andy@warmcat.com>
5 *
6 * This file is made available under the Creative Commons CC0 1.0
7 * Universal Public Domain Dedication.
8 *
9 *
10 * This demonstrates a minimal http client using secure streams to access the
11 * SMD api. This file is only built when LWS_SS_USE_SSPC defined.
12 *
13 * This is an alternative test implementation selected by --multi at runtime,
14 * it's in its own file to stop muddying up the main test sources. It's only
15 * available when built with SSPC / produces -client executable.
16 *
17 * We will fork several times, the original thread and the forks hook up to
18 * the proxy with smd SS, each fork waits a second for everyone to have joined,
19 * and then each fork (NOT the original process) sends a bunch of user messages
20 * that all the forks should receive, having been distributed by SMD and the
21 * ss proxy.
22 *
23 * The participants check they received all the messages expected from everyone
24 * and then send a final message indicating success and exits. The original
25 * fork is watching for these to arrive before the timeout, if so it's a PASS.
26 */
27
28 #include <libwebsockets.h>
29 #include <string.h>
30 #include <signal.h>
31
32 static int bad = 1, interrupted;
33
34 /* number of forks */
35 #define FORKS 4
36 /* number of messages each will send, eg, 4 forks 64 message == 256 messages */
37 #define MSGCOUNT 64
38
39 typedef struct myss {
40 struct lws_ss_handle *ss;
41 void *opaque_data;
42 /* ... application specific state ... */
43 uint64_t seen_mask[FORKS];
44 int seen_msgs[FORKS];
45 lws_sorted_usec_list_t sul;
46 int count;
47 char seen_all;
48 char send_seen_all;
49 char starting;
50 } myss_t;
51
52
53 /* secure streams payload interface */
54
55 static lws_ss_state_return_t
multi_myss_rx(void * userobj,const uint8_t * buf,size_t len,int flags)56 multi_myss_rx(void *userobj, const uint8_t *buf, size_t len, int flags)
57 {
58 myss_t *m = (myss_t *)userobj;
59 const char *p;
60 int fk, t, n;
61 size_t al;
62
63 /* ignore our and other forks announcing their result */
64
65 if (lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
66 return LWSSSSRET_OK;
67
68 /*
69 * otherwise once we saw the expected messages, any other messages
70 * coming in this class are wrong
71 */
72
73 if (m->seen_all) {
74 lwsl_err("%s: unexpected extra messages\n", __func__);
75 return LWSSSSRET_DESTROY_ME;
76 }
77
78 p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
79 if (!p)
80 return LWSSSSRET_DESTROY_ME;
81 fk = atoi(p);
82 if (fk < 1 || fk > FORKS)
83 return LWSSSSRET_DESTROY_ME;
84
85 p = lws_json_simple_find((const char *)buf, len, "\"test\":", &al);
86 if (!p)
87 return LWSSSSRET_DESTROY_ME;
88 t = atoi(p);
89
90 if (t < 0 || t >= MSGCOUNT)
91 return LWSSSSRET_DESTROY_ME;
92
93 m->seen_mask[fk - 1] |= 1ull << t;
94 m->seen_msgs[fk - 1]++; /* keep an eye on dupes */
95
96 /* Have we seen a full set of messages from everyone? */
97
98 for (n = 0; n < FORKS; n++) {
99 if (m->seen_msgs[n] != (int)MSGCOUNT)
100 return LWSSSSRET_OK;
101 if (m->seen_mask[n] != 0xffffffffffffffffull)
102 return LWSSSSRET_OK;
103 }
104
105 /*
106 * Oh... so we have finished collecting messages
107 */
108
109 lwsl_user("%s: test thread %d: %s received all messages\n", __func__,
110 (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
111 lws_ss_tag(m->ss));
112 m->seen_all = m->send_seen_all = 1;
113
114 /*
115 * Prepare to inform the original process we saw everything
116 * from everyone OK
117 */
118
119 lws_ss_request_tx(m->ss);
120
121 return LWSSSSRET_OK;
122 }
123
124 static void
sul_multi_tx_periodic_cb(lws_sorted_usec_list_t * sul)125 sul_multi_tx_periodic_cb(lws_sorted_usec_list_t *sul)
126 {
127 myss_t *m = lws_container_of(sul, myss_t, sul);
128
129 if (!m->send_seen_all && m->seen_all) {
130 lws_ss_destroy(&m->ss);
131 return;
132 }
133
134 m->starting = 1;
135 if (m->count < MSGCOUNT || m->send_seen_all)
136 lws_ss_request_tx(m->ss);
137 }
138
139 static lws_ss_state_return_t
multi_myss_tx(void * userobj,lws_ss_tx_ordinal_t ord,uint8_t * buf,size_t * len,int * flags)140 multi_myss_tx(void *userobj, lws_ss_tx_ordinal_t ord, uint8_t *buf, size_t *len,
141 int *flags)
142 {
143 myss_t *m = (myss_t *)userobj;
144
145 /*
146 * We want to send exactly MSGCOUNT user class smd messages
147 */
148
149 if (!m->starting || (m->count == MSGCOUNT && !m->send_seen_all))
150 return LWSSSSRET_TX_DONT_SEND;
151
152 // lwsl_notice("%s: sending SS smd\n", __func__);
153
154 lws_ser_wu64be(buf, 1 << LWSSMDCL_USER_BASE_BITNUM);
155 lws_ser_wu64be(buf + 8, 0); /* valgrind notices uninitialized if left */
156
157 if (m->send_seen_all) {
158 *len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
159 lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
160 "{\"class\":\"user\",\"fork\": %d,\"seen_all\":true}",
161 (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
162
163 m->send_seen_all = 0;
164 lwsl_info("%s: test thread %d: sent summary message\n", __func__,
165 (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
166 } else
167 *len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
168 lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
169 "{\"class\":\"user\",\"fork\": %d,\"test\":%u}",
170 (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
171 m->count++);
172
173 *flags = LWSSS_FLAG_SOM | LWSSS_FLAG_EOM;
174
175 lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
176 sul_multi_tx_periodic_cb, 25 * LWS_US_PER_MS);
177
178 return LWSSSSRET_OK;
179 }
180
181 static lws_ss_state_return_t
multi_myss_state(void * userobj,void * h_src,lws_ss_constate_t state,lws_ss_tx_ordinal_t ack)182 multi_myss_state(void *userobj, void *h_src, lws_ss_constate_t state,
183 lws_ss_tx_ordinal_t ack)
184 {
185 myss_t *m = (myss_t *)userobj;
186 int n;
187
188 lwsl_notice("%s: %s: %s (%d), ord 0x%x\n", __func__, lws_ss_tag(m->ss),
189 lws_ss_state_name((int)state), state, (unsigned int)ack);
190
191 switch (state) {
192 case LWSSSCS_DESTROYING:
193 lws_sul_cancel(&m->sul);
194 interrupted = 1;
195 return 0;
196
197 case LWSSSCS_CONNECTED:
198 lwsl_notice("%s: CONNECTED: test fork %d\n", __func__,
199 (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
200 /*
201 * Because in this test everybody is watching and counting
202 * everybody else's messages from different forks, we have to
203 * hold off starting sending for 2s so all forks can join the
204 * proxy first and not miss anything
205 */
206 lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
207 sul_multi_tx_periodic_cb, 2 * LWS_US_PER_SEC);
208 m->starting = 0;
209 return 0;
210 case LWSSSCS_DISCONNECTED:
211 for (n = 0; n < FORKS; n++)
212 lwsl_notice("%s: testfork %d: peer %d: seen_msg = %d, "
213 "seen make = 0x%llx\n", __func__,
214 (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
215 n, m->seen_msgs[n],
216 (unsigned long long)m->seen_mask[n]);
217 break;
218 default:
219 break;
220 }
221
222 return 0;
223 }
224
225 static const lws_ss_info_t ssi_multi_lws_smd = {
226 .handle_offset = offsetof(myss_t, ss),
227 .opaque_user_data_offset = offsetof(myss_t, opaque_data),
228 .rx = multi_myss_rx,
229 .tx = multi_myss_tx,
230 .state = multi_myss_state,
231 .user_alloc = sizeof(myss_t),
232 .streamtype = LWS_SMD_STREAMTYPENAME,
233 .manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
234 };
235
236 static lws_ss_state_return_t
multi_myss_rx_monitor(void * userobj,const uint8_t * buf,size_t len,int flags)237 multi_myss_rx_monitor(void *userobj, const uint8_t *buf, size_t len, int flags)
238 {
239 myss_t *m = (myss_t *)userobj;
240 const char *p;
241 size_t al;
242 int fk, n;
243
244 /* ignore our and other forks announcing their result */
245
246 if (!lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
247 return LWSSSSRET_OK;
248
249 p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
250 if (!p)
251 return LWSSSSRET_DESTROY_ME;
252 fk = atoi(p);
253 if (fk < 1 || fk > FORKS)
254 return LWSSSSRET_DESTROY_ME;
255
256 if (m->seen_msgs[fk - 1])
257 /* expected only once ... dupe */
258 return LWSSSSRET_DESTROY_ME;
259
260 m->seen_msgs[fk - 1] = 1;
261
262 for (n = 0; n < FORKS; n++)
263 if (!m->seen_msgs[n])
264 return LWSSSSRET_OK;
265
266 /* the test has succeeded */
267
268 bad = 0;
269 interrupted = 1;
270
271 return LWSSSSRET_OK;
272 }
273
274 static const lws_ss_info_t ssi_multi_lws_smd_monitor = {
275 .handle_offset = offsetof(myss_t, ss),
276 .opaque_user_data_offset = offsetof(myss_t, opaque_data),
277 .rx = multi_myss_rx_monitor,
278 // .state = multi_myss_state_monitor,
279 .user_alloc = sizeof(myss_t),
280 .streamtype = LWS_SMD_STREAMTYPENAME,
281 .manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
282 };
283
284 /* for comparison, this is a non-SS lws_smd participant */
285
286 static int
direct_smd_cb(void * opaque,lws_smd_class_t _class,lws_usec_t timestamp,void * buf,size_t len)287 direct_smd_cb(void *opaque, lws_smd_class_t _class, lws_usec_t timestamp,
288 void *buf, size_t len)
289 {
290 struct lws_context **pctx = (struct lws_context **)opaque;
291
292 if (_class != LWSSMDCL_SYSTEM_STATE)
293 return 0;
294
295 if (!lws_json_simple_strcmp(buf, len, "\"state\":", "OPERATIONAL")) {
296
297 /*
298 * Create the SSPC link to lws_smd... notice in ssi_lws_smd
299 * above, we tell this link to use the user class filter.
300 *
301 * If context->user is zero, we are the original process
302 * monitoring the progress of the others, otherwise we are
303 * 1 .. FORKS and producing / checking the smd messages
304 */
305
306 lwsl_info("%s: starting ss for test fork %d\n", __func__,
307 (int)(intptr_t)lws_context_user(*pctx));
308
309 if (lws_ss_create(*pctx, 0, lws_context_user(*pctx) ?
310 &ssi_multi_lws_smd /* forked process send / check */:
311 &ssi_multi_lws_smd_monitor /* original monitors */,
312 NULL, NULL, NULL, NULL)) {
313 lwsl_err("%s: failed to create secure stream\n",
314 __func__);
315
316 return -1;
317 }
318 }
319
320 return 0;
321 }
322
323
324 static void
sul_timeout_cb(lws_sorted_usec_list_t * sul)325 sul_timeout_cb(lws_sorted_usec_list_t *sul)
326 {
327 interrupted = 1;
328 }
329
330 int
smd_ss_multi_test(int argc,const char ** argv)331 smd_ss_multi_test(int argc, const char **argv)
332 {
333 struct lws_context_creation_info info;
334 lws_sorted_usec_list_t sul_timeout;
335 struct lws_context *context;
336 pid_t pid;
337 int n;
338
339 lwsl_user("LWS Secure Streams SMD MULTI test client [-d<verb>]\n");
340
341 for (n = 0; n < FORKS; n++) {
342 pid = fork();
343 if (!pid) /* forked child */ {
344 break;
345 }
346 lwsl_notice("%s: forked test process %u\n", __func__, pid);
347 }
348
349 if (n == FORKS)
350 /* the original process */
351 n = -1; /* so original ends up with context.user as 0 below */
352
353 memset(&info, 0, sizeof info);
354 memset(&sul_timeout, 0, sizeof sul_timeout);
355
356 lws_cmdline_option_handle_builtin(argc, argv, &info);
357
358 {
359 const char *p;
360
361 /* connect to ssproxy via UDS by default, else via
362 * tcp connection to this port */
363 if ((p = lws_cmdline_option(argc, argv, "-p")))
364 info.ss_proxy_port = (uint16_t)atoi(p);
365
366 /* UDS "proxy.ss.lws" in abstract namespace, else this socket
367 * path; when -p given this can specify the network interface
368 * to bind to */
369 if ((p = lws_cmdline_option(argc, argv, "-i")))
370 info.ss_proxy_bind = p;
371
372 /* if -p given, -a specifies the proxy address to connect to */
373 if ((p = lws_cmdline_option(argc, argv, "-a")))
374 info.ss_proxy_address = p;
375 }
376
377 info.fd_limit_per_thread = 1 + 6 + 1;
378 info.port = CONTEXT_PORT_NO_LISTEN;
379 info.protocols = lws_sspc_protocols;
380 info.options = LWS_SERVER_OPTION_EXPLICIT_VHOSTS |
381 LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
382
383 info.early_smd_cb = direct_smd_cb;
384 info.early_smd_class_filter = 0xffffffff;
385 info.early_smd_opaque = &context;
386
387 info.user = (void *)(intptr_t)(n + 1);
388
389 /* create the context */
390
391 context = lws_create_context(&info);
392 if (!context) {
393 lwsl_err("lws init failed\n");
394 return 1;
395 }
396
397 if (!lws_create_vhost(context, &info)) {
398 lwsl_err("%s: failed to create default vhost\n", __func__);
399 goto bail;
400 }
401
402 /* set up the test timeout */
403
404 lws_sul_schedule(context, 0, &sul_timeout, sul_timeout_cb,
405 10 * LWS_US_PER_SEC);
406
407 /* the event loop */
408
409 while (lws_service(context, 0) >= 0 && !interrupted)
410 ;
411
412 bail:
413 lws_context_destroy(context);
414
415 if (n == -1)
416 lwsl_user("%s: finished %s\n", __func__, bad ? "FAIL" : "PASS");
417
418 return bad;
419 }
420