• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * lws-minimal-secure-streams-smd
3  *
4  * Written in 2010-2021 by Andy Green <andy@warmcat.com>
5  *
6  * This file is made available under the Creative Commons CC0 1.0
7  * Universal Public Domain Dedication.
8  *
9  *
10  * This demonstrates a minimal http client using secure streams to access the
11  * SMD api.  This file is only built when LWS_SS_USE_SSPC defined.
12  *
13  * This is an alternative test implementation selected by --multi at runtime,
14  * it's in its own file to stop muddying up the main test sources.  It's only
15  * available when built with SSPC / produces -client executable.
16  *
17  * We will fork several times, the original thread and the forks hook up to
18  * the proxy with smd SS, each fork waits a second for everyone to have joined,
19  * and then each fork (NOT the original process) sends a bunch of user messages
20  * that all the forks should receive, having been distributed by SMD and the
21  * ss proxy.
22  *
23  * The participants check they received all the messages expected from everyone
24  * and then send a final message indicating success and exits.  The original
25  * fork is watching for these to arrive before the timeout, if so it's a PASS.
26  */
27 
28 #include <libwebsockets.h>
29 #include <string.h>
30 #include <signal.h>
31 
32 static int bad = 1, interrupted;
33 
34 /* number of forks */
35 #define FORKS 4
36 /* number of messages each will send, eg, 4 forks 64 message == 256 messages */
37 #define MSGCOUNT 64
38 
39 typedef struct myss {
40 	struct lws_ss_handle 		*ss;
41 	void				*opaque_data;
42 	/* ... application specific state ... */
43 	uint64_t			seen_mask[FORKS];
44 	int				seen_msgs[FORKS];
45 	lws_sorted_usec_list_t		sul;
46 	int				count;
47 	char				seen_all;
48 	char				send_seen_all;
49 	char				starting;
50 } myss_t;
51 
52 
53 /* secure streams payload interface */
54 
55 static lws_ss_state_return_t
multi_myss_rx(void * userobj,const uint8_t * buf,size_t len,int flags)56 multi_myss_rx(void *userobj, const uint8_t *buf, size_t len, int flags)
57 {
58 	myss_t *m = (myss_t *)userobj;
59 	const char *p;
60 	int fk, t, n;
61 	size_t al;
62 
63 	/* ignore our and other forks announcing their result */
64 
65 	if (lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
66 		return LWSSSSRET_OK;
67 
68 	/*
69 	 * otherwise once we saw the expected messages, any other messages
70 	 * coming in this class are wrong
71 	 */
72 
73 	if (m->seen_all) {
74 		lwsl_err("%s: unexpected extra messages\n", __func__);
75 		return LWSSSSRET_DESTROY_ME;
76 	}
77 
78 	p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
79 	if (!p)
80 		return LWSSSSRET_DESTROY_ME;
81 	fk = atoi(p);
82 	if (fk < 1 || fk > FORKS)
83 		return LWSSSSRET_DESTROY_ME;
84 
85 	p = lws_json_simple_find((const char *)buf, len, "\"test\":", &al);
86 	if (!p)
87 		return LWSSSSRET_DESTROY_ME;
88 	t = atoi(p);
89 
90 	if (t < 0 || t >= MSGCOUNT)
91 		return LWSSSSRET_DESTROY_ME;
92 
93 	m->seen_mask[fk - 1] |= 1ull << t;
94 	m->seen_msgs[fk - 1]++; /* keep an eye on dupes */
95 
96 	/* Have we seen a full set of messages from everyone? */
97 
98 	for (n = 0; n < FORKS; n++) {
99 		if (m->seen_msgs[n] != (int)MSGCOUNT)
100 			return LWSSSSRET_OK;
101 		if (m->seen_mask[n] != 0xffffffffffffffffull)
102 			return LWSSSSRET_OK;
103 	}
104 
105 	/*
106 	 * Oh... so we have finished collecting messages
107 	 */
108 
109 	lwsl_user("%s: test thread %d: %s received all messages\n", __func__,
110 			(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
111 			lws_ss_tag(m->ss));
112 	m->seen_all = m->send_seen_all = 1;
113 
114 	/*
115 	 * Prepare to inform the original process we saw everything
116 	 * from everyone OK
117 	 */
118 
119 	lws_ss_request_tx(m->ss);
120 
121 	return LWSSSSRET_OK;
122 }
123 
124 static void
sul_multi_tx_periodic_cb(lws_sorted_usec_list_t * sul)125 sul_multi_tx_periodic_cb(lws_sorted_usec_list_t *sul)
126 {
127 	myss_t *m = lws_container_of(sul, myss_t, sul);
128 
129 	if (!m->send_seen_all && m->seen_all) {
130 		lws_ss_destroy(&m->ss);
131 		return;
132 	}
133 
134 	m->starting = 1;
135 	if (m->count < MSGCOUNT ||  m->send_seen_all)
136 		lws_ss_request_tx(m->ss);
137 }
138 
139 static lws_ss_state_return_t
multi_myss_tx(void * userobj,lws_ss_tx_ordinal_t ord,uint8_t * buf,size_t * len,int * flags)140 multi_myss_tx(void *userobj, lws_ss_tx_ordinal_t ord, uint8_t *buf, size_t *len,
141 	int *flags)
142 {
143 	myss_t *m = (myss_t *)userobj;
144 
145 	/*
146 	 * We want to send exactly MSGCOUNT user class smd messages
147 	 */
148 
149 	if (!m->starting || (m->count == MSGCOUNT && !m->send_seen_all))
150 		return LWSSSSRET_TX_DONT_SEND;
151 
152 //	lwsl_notice("%s: sending SS smd\n", __func__);
153 
154 	lws_ser_wu64be(buf, 1 << LWSSMDCL_USER_BASE_BITNUM);
155 	lws_ser_wu64be(buf + 8, 0); /* valgrind notices uninitialized if left */
156 
157 	if (m->send_seen_all) {
158 		*len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
159 			lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
160 			     "{\"class\":\"user\",\"fork\": %d,\"seen_all\":true}",
161 			     (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
162 
163 		m->send_seen_all = 0;
164 		lwsl_info("%s: test thread %d: sent summary message\n", __func__,
165 				(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
166 	} else
167 		*len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
168 			lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
169 			     "{\"class\":\"user\",\"fork\": %d,\"test\":%u}",
170 			     (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
171 			     m->count++);
172 
173 	*flags = LWSSS_FLAG_SOM | LWSSS_FLAG_EOM;
174 
175 	lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
176 			sul_multi_tx_periodic_cb, 25 * LWS_US_PER_MS);
177 
178 	return LWSSSSRET_OK;
179 }
180 
181 static lws_ss_state_return_t
multi_myss_state(void * userobj,void * h_src,lws_ss_constate_t state,lws_ss_tx_ordinal_t ack)182 multi_myss_state(void *userobj, void *h_src, lws_ss_constate_t state,
183 	   lws_ss_tx_ordinal_t ack)
184 {
185 	myss_t *m = (myss_t *)userobj;
186 	int n;
187 
188 	lwsl_notice("%s: %s: %s (%d), ord 0x%x\n", __func__, lws_ss_tag(m->ss),
189 		    lws_ss_state_name((int)state), state, (unsigned int)ack);
190 
191 	switch (state) {
192 	case LWSSSCS_DESTROYING:
193 		lws_sul_cancel(&m->sul);
194 		interrupted = 1;
195 		return 0;
196 
197 	case LWSSSCS_CONNECTED:
198 		lwsl_notice("%s: CONNECTED: test fork %d\n", __func__,
199 				(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
200 		/*
201 		 * Because in this test everybody is watching and counting
202 		 * everybody else's messages from different forks, we have to
203 		 * hold off starting sending for 2s so all forks can join the
204 		 * proxy first and not miss anything
205 		 */
206 		lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
207 				sul_multi_tx_periodic_cb, 2 * LWS_US_PER_SEC);
208 		m->starting = 0;
209 		return 0;
210 	case LWSSSCS_DISCONNECTED:
211 		for (n = 0; n < FORKS; n++)
212 			lwsl_notice("%s: testfork %d: peer %d: seen_msg = %d, "
213 				    "seen make = 0x%llx\n", __func__,
214 				    (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
215 				    n, m->seen_msgs[n],
216 				    (unsigned long long)m->seen_mask[n]);
217 		break;
218 	default:
219 		break;
220 	}
221 
222 	return 0;
223 }
224 
225 static const lws_ss_info_t ssi_multi_lws_smd = {
226 	.handle_offset		  = offsetof(myss_t, ss),
227 	.opaque_user_data_offset  = offsetof(myss_t, opaque_data),
228 	.rx			  = multi_myss_rx,
229 	.tx			  = multi_myss_tx,
230 	.state			  = multi_myss_state,
231 	.user_alloc		  = sizeof(myss_t),
232 	.streamtype		  = LWS_SMD_STREAMTYPENAME,
233 	.manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
234 };
235 
236 static lws_ss_state_return_t
multi_myss_rx_monitor(void * userobj,const uint8_t * buf,size_t len,int flags)237 multi_myss_rx_monitor(void *userobj, const uint8_t *buf, size_t len, int flags)
238 {
239 	myss_t *m = (myss_t *)userobj;
240 	const char *p;
241 	size_t al;
242 	int fk, n;
243 
244 	/* ignore our and other forks announcing their result */
245 
246 	if (!lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
247 		return LWSSSSRET_OK;
248 
249 	p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
250 	if (!p)
251 		return LWSSSSRET_DESTROY_ME;
252 	fk = atoi(p);
253 	if (fk < 1 || fk > FORKS)
254 		return LWSSSSRET_DESTROY_ME;
255 
256 	if (m->seen_msgs[fk - 1])
257 		/* expected only once ... dupe */
258 		return LWSSSSRET_DESTROY_ME;
259 
260 	m->seen_msgs[fk - 1] = 1;
261 
262 	for (n = 0; n < FORKS; n++)
263 		if (!m->seen_msgs[n])
264 			return LWSSSSRET_OK;
265 
266 	/* the test has succeeded */
267 
268 	bad = 0;
269 	interrupted = 1;
270 
271 	return LWSSSSRET_OK;
272 }
273 
274 static const lws_ss_info_t ssi_multi_lws_smd_monitor = {
275 	.handle_offset		  = offsetof(myss_t, ss),
276 	.opaque_user_data_offset  = offsetof(myss_t, opaque_data),
277 	.rx			  = multi_myss_rx_monitor,
278 //	.state			  = multi_myss_state_monitor,
279 	.user_alloc		  = sizeof(myss_t),
280 	.streamtype		  = LWS_SMD_STREAMTYPENAME,
281 	.manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
282 };
283 
284 /* for comparison, this is a non-SS lws_smd participant */
285 
286 static int
direct_smd_cb(void * opaque,lws_smd_class_t _class,lws_usec_t timestamp,void * buf,size_t len)287 direct_smd_cb(void *opaque, lws_smd_class_t _class, lws_usec_t timestamp,
288 	      void *buf, size_t len)
289 {
290 	struct lws_context **pctx = (struct lws_context **)opaque;
291 
292 	if (_class != LWSSMDCL_SYSTEM_STATE)
293 		return 0;
294 
295 	if (!lws_json_simple_strcmp(buf, len, "\"state\":", "OPERATIONAL")) {
296 
297 		/*
298 		 * Create the SSPC link to lws_smd... notice in ssi_lws_smd
299 		 * above, we tell this link to use the user class filter.
300 		 *
301 		 * If context->user is zero, we are the original process
302 		 * monitoring the progress of the others, otherwise we are
303 		 * 1 .. FORKS and producing / checking the smd messages
304 		 */
305 
306 		lwsl_info("%s: starting ss for test fork %d\n", __func__,
307 				(int)(intptr_t)lws_context_user(*pctx));
308 
309 		if (lws_ss_create(*pctx, 0, lws_context_user(*pctx) ?
310 				&ssi_multi_lws_smd /* forked process send / check */:
311 				&ssi_multi_lws_smd_monitor /* original monitors */,
312 				NULL, NULL, NULL, NULL)) {
313 			lwsl_err("%s: failed to create secure stream\n",
314 				 __func__);
315 
316 			return -1;
317 		}
318 	}
319 
320 	return 0;
321 }
322 
323 
324 static void
sul_timeout_cb(lws_sorted_usec_list_t * sul)325 sul_timeout_cb(lws_sorted_usec_list_t *sul)
326 {
327 	interrupted = 1;
328 }
329 
330 int
smd_ss_multi_test(int argc,const char ** argv)331 smd_ss_multi_test(int argc, const char **argv)
332 {
333 	struct lws_context_creation_info info;
334 	lws_sorted_usec_list_t sul_timeout;
335 	struct lws_context *context;
336 	pid_t pid;
337 	int n;
338 
339 	lwsl_user("LWS Secure Streams SMD MULTI test client [-d<verb>]\n");
340 
341 	for (n = 0; n < FORKS; n++) {
342 		pid = fork();
343 		if (!pid) /* forked child */ {
344 			break;
345 		}
346 		lwsl_notice("%s: forked test process %u\n", __func__, pid);
347 	}
348 
349 	if (n == FORKS)
350 		/* the original process */
351 		n = -1; /* so original ends up with context.user as 0 below */
352 
353 	memset(&info, 0, sizeof info);
354 	memset(&sul_timeout, 0, sizeof sul_timeout);
355 
356 	lws_cmdline_option_handle_builtin(argc, argv, &info);
357 
358 	{
359 		const char *p;
360 
361 		/* connect to ssproxy via UDS by default, else via
362 		 * tcp connection to this port */
363 		if ((p = lws_cmdline_option(argc, argv, "-p")))
364 			info.ss_proxy_port = (uint16_t)atoi(p);
365 
366 		/* UDS "proxy.ss.lws" in abstract namespace, else this socket
367 		 * path; when -p given this can specify the network interface
368 		 * to bind to */
369 		if ((p = lws_cmdline_option(argc, argv, "-i")))
370 			info.ss_proxy_bind = p;
371 
372 		/* if -p given, -a specifies the proxy address to connect to */
373 		if ((p = lws_cmdline_option(argc, argv, "-a")))
374 			info.ss_proxy_address = p;
375 	}
376 
377 	info.fd_limit_per_thread	= 1 + 6 + 1;
378 	info.port			= CONTEXT_PORT_NO_LISTEN;
379 	info.protocols			= lws_sspc_protocols;
380 	info.options			= LWS_SERVER_OPTION_EXPLICIT_VHOSTS |
381 					  LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
382 
383 	info.early_smd_cb		= direct_smd_cb;
384 	info.early_smd_class_filter	= 0xffffffff;
385 	info.early_smd_opaque		= &context;
386 
387 	info.user			= (void *)(intptr_t)(n + 1);
388 
389 	/* create the context */
390 
391 	context = lws_create_context(&info);
392 	if (!context) {
393 		lwsl_err("lws init failed\n");
394 		return 1;
395 	}
396 
397 	if (!lws_create_vhost(context, &info)) {
398 		lwsl_err("%s: failed to create default vhost\n", __func__);
399 		goto bail;
400 	}
401 
402 	/* set up the test timeout */
403 
404 	lws_sul_schedule(context, 0, &sul_timeout, sul_timeout_cb,
405 			 10 * LWS_US_PER_SEC);
406 
407 	/* the event loop */
408 
409 	while (lws_service(context, 0) >= 0 && !interrupted)
410 		;
411 
412 bail:
413 	lws_context_destroy(context);
414 
415 	if (n == -1)
416 		lwsl_user("%s: finished %s\n", __func__, bad ? "FAIL" : "PASS");
417 
418 	return bad;
419 }
420