1 //! This module handles "arbitration" of ATT packets, to determine whether they
2 //! should be handled by the primary stack or by the "Private GATT" stack
3
4 use std::{collections::HashMap, sync::Mutex};
5
6 use log::{error, info, trace};
7
8 use crate::{
9 do_in_rust_thread,
10 packets::{AttOpcode, OwnedAttView, OwnedPacket},
11 };
12
13 use super::{
14 ffi::{InterceptAction, StoreCallbacksFromRust},
15 ids::{AdvertiserId, ConnectionId, ServerId, TransportIndex},
16 mtu::MtuEvent,
17 opcode_types::{classify_opcode, OperationType},
18 };
19
20 static ARBITER: Mutex<Option<Arbiter>> = Mutex::new(None);
21
22 /// This class is responsible for tracking which connections and advertising we
23 /// own, and using this information to decide what packets should be
24 /// intercepted, and which should be forwarded to the legacy stack.
25 #[derive(Default)]
26 pub struct Arbiter {
27 advertiser_to_server: HashMap<AdvertiserId, ServerId>,
28 transport_to_owned_connection: HashMap<TransportIndex, ConnectionId>,
29 }
30
31 /// Initialize the Arbiter
initialize_arbiter()32 pub fn initialize_arbiter() {
33 *ARBITER.lock().unwrap() = Some(Arbiter::new());
34
35 StoreCallbacksFromRust(
36 on_le_connect,
37 on_le_disconnect,
38 intercept_packet,
39 |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest),
40 |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)),
41 |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)),
42 );
43 }
44
45 /// Acquire the mutex holding the Arbiter and provide a mutable reference to the
46 /// supplied closure
with_arbiter<T>(f: impl FnOnce(&mut Arbiter) -> T) -> T47 pub fn with_arbiter<T>(f: impl FnOnce(&mut Arbiter) -> T) -> T {
48 f(ARBITER.lock().unwrap().as_mut().unwrap())
49 }
50
51 impl Arbiter {
52 /// Constructor
new() -> Self53 pub fn new() -> Self {
54 Arbiter {
55 advertiser_to_server: HashMap::new(),
56 transport_to_owned_connection: HashMap::new(),
57 }
58 }
59
60 /// Link a given GATT server to an LE advertising set, so incoming
61 /// connections to this advertiser will be visible only by the linked
62 /// server
associate_server_with_advertiser( &mut self, server_id: ServerId, advertiser_id: AdvertiserId, )63 pub fn associate_server_with_advertiser(
64 &mut self,
65 server_id: ServerId,
66 advertiser_id: AdvertiserId,
67 ) {
68 info!("associating server {server_id:?} with advertising set {advertiser_id:?}");
69 let old = self.advertiser_to_server.insert(advertiser_id, server_id);
70 if let Some(old) = old {
71 error!("new server {server_id:?} associated with same advertiser {advertiser_id:?}, displacing old server {old:?}");
72 }
73 }
74
75 /// Remove all linked advertising sets from the provided server
clear_server(&mut self, server_id: ServerId)76 pub fn clear_server(&mut self, server_id: ServerId) {
77 info!("clearing advertisers associated with {server_id:?}");
78 self.advertiser_to_server.retain(|_, server| *server != server_id);
79 }
80
81 /// Clear the server associated with this advertiser, if one exists
clear_advertiser(&mut self, advertiser_id: AdvertiserId)82 pub fn clear_advertiser(&mut self, advertiser_id: AdvertiserId) {
83 info!("removing server (if any) associated with advertiser {advertiser_id:?}");
84 self.advertiser_to_server.remove(&advertiser_id);
85 }
86
87 /// Check if this conn_id is currently owned by the Rust stack
is_connection_isolated(&self, conn_id: ConnectionId) -> bool88 pub fn is_connection_isolated(&self, conn_id: ConnectionId) -> bool {
89 self.transport_to_owned_connection.values().any(|owned_conn_id| *owned_conn_id == conn_id)
90 }
91
92 /// Test to see if a buffer contains a valid ATT packet with an opcode we
93 /// are interested in intercepting (those intended for servers that are isolated)
try_parse_att_server_packet( &self, tcb_idx: TransportIndex, packet: Box<[u8]>, ) -> Option<OwnedAttView>94 pub fn try_parse_att_server_packet(
95 &self,
96 tcb_idx: TransportIndex,
97 packet: Box<[u8]>,
98 ) -> Option<OwnedAttView> {
99 if !self.transport_to_owned_connection.contains_key(&tcb_idx) {
100 return None;
101 }
102
103 let att = OwnedAttView::try_parse(packet).ok()?;
104
105 if att.view().get_opcode() == AttOpcode::EXCHANGE_MTU_REQUEST {
106 // special case: this server opcode is handled by legacy stack, and we snoop
107 // on its handling, since the MTU is shared between the client + server
108 return None;
109 }
110
111 match classify_opcode(att.view().get_opcode()) {
112 OperationType::Command | OperationType::Request | OperationType::Confirmation => {
113 Some(att)
114 }
115 _ => None,
116 }
117 }
118
119 /// Check if an incoming connection should be intercepted and, if so, on
120 /// what conn_id
on_le_connect( &mut self, tcb_idx: TransportIndex, advertiser: AdvertiserId, ) -> Option<ConnectionId>121 pub fn on_le_connect(
122 &mut self,
123 tcb_idx: TransportIndex,
124 advertiser: AdvertiserId,
125 ) -> Option<ConnectionId> {
126 info!(
127 "processing incoming connection on transport {tcb_idx:?} to advertiser {advertiser:?}"
128 );
129 let server_id = *self.advertiser_to_server.get(&advertiser)?;
130 info!("connection is isolated to server {server_id:?}");
131
132 let conn_id = ConnectionId::new(tcb_idx, server_id);
133 let old = self.transport_to_owned_connection.insert(tcb_idx, conn_id);
134 if old.is_some() {
135 error!("new server {server_id:?} on transport {tcb_idx:?} displacing existing registered connection {conn_id:?}")
136 }
137 Some(conn_id)
138 }
139
140 /// Handle a disconnection, if any, and return whether the disconnection was registered
on_le_disconnect(&mut self, tcb_idx: TransportIndex) -> bool141 pub fn on_le_disconnect(&mut self, tcb_idx: TransportIndex) -> bool {
142 info!("processing disconnection on transport {tcb_idx:?}");
143 self.transport_to_owned_connection.remove(&tcb_idx).is_some()
144 }
145
146 /// Look up the conn_id for a given tcb_idx, if present
get_conn_id(&self, tcb_idx: TransportIndex) -> Option<ConnectionId>147 pub fn get_conn_id(&self, tcb_idx: TransportIndex) -> Option<ConnectionId> {
148 self.transport_to_owned_connection.get(&tcb_idx).copied()
149 }
150 }
151
on_le_connect(tcb_idx: u8, advertiser: u8)152 fn on_le_connect(tcb_idx: u8, advertiser: u8) {
153 if let Some(conn_id) = with_arbiter(|arbiter| {
154 arbiter.on_le_connect(TransportIndex(tcb_idx), AdvertiserId(advertiser))
155 }) {
156 do_in_rust_thread(move |modules| {
157 if let Err(err) = modules.gatt_module.on_le_connect(conn_id) {
158 error!("{err:?}")
159 }
160 })
161 }
162 }
163
on_le_disconnect(tcb_idx: u8)164 fn on_le_disconnect(tcb_idx: u8) {
165 let tcb_idx = TransportIndex(tcb_idx);
166 if with_arbiter(|arbiter| arbiter.on_le_disconnect(tcb_idx)) {
167 do_in_rust_thread(move |modules| {
168 if let Err(err) = modules.gatt_module.on_le_disconnect(tcb_idx) {
169 error!("{err:?}")
170 }
171 })
172 }
173 }
174
intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction175 fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
176 let tcb_idx = TransportIndex(tcb_idx);
177 if let Some(att) = with_arbiter(|arbiter| {
178 arbiter.try_parse_att_server_packet(tcb_idx, packet.into_boxed_slice())
179 }) {
180 do_in_rust_thread(move |modules| {
181 trace!("pushing packet to GATT");
182 if let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) {
183 bearer.handle_packet(att.view())
184 } else {
185 error!("Bearer for {tcb_idx:?} not found");
186 }
187 });
188 InterceptAction::Drop
189 } else {
190 InterceptAction::Forward
191 }
192 }
193
on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent)194 fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) {
195 if with_arbiter(|arbiter| arbiter.get_conn_id(tcb_idx)).is_some() {
196 do_in_rust_thread(move |modules| {
197 let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) else {
198 error!("Bearer for {tcb_idx:?} not found");
199 return;
200 };
201 if let Err(err) = bearer.handle_mtu_event(event) {
202 error!("{err:?}")
203 }
204 });
205 }
206 }
207
208 #[cfg(test)]
209 mod test {
210 use super::*;
211
212 use crate::{
213 gatt::ids::AttHandle,
214 packets::{
215 AttBuilder, AttExchangeMtuRequestBuilder, AttOpcode, AttReadRequestBuilder,
216 Serializable,
217 },
218 };
219
220 const TCB_IDX: TransportIndex = TransportIndex(1);
221 const ANOTHER_TCB_IDX: TransportIndex = TransportIndex(2);
222 const ADVERTISER_ID: AdvertiserId = AdvertiserId(3);
223 const SERVER_ID: ServerId = ServerId(4);
224
225 const CONN_ID: ConnectionId = ConnectionId::new(TCB_IDX, SERVER_ID);
226
227 const ANOTHER_ADVERTISER_ID: AdvertiserId = AdvertiserId(5);
228
229 #[test]
test_non_isolated_connect()230 fn test_non_isolated_connect() {
231 let mut arbiter = Arbiter::new();
232
233 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
234
235 assert!(conn_id.is_none())
236 }
237
238 #[test]
test_isolated_connect()239 fn test_isolated_connect() {
240 let mut arbiter = Arbiter::new();
241 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
242
243 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
244
245 assert_eq!(conn_id, Some(CONN_ID));
246 }
247
248 #[test]
test_non_isolated_connect_with_isolated_advertiser()249 fn test_non_isolated_connect_with_isolated_advertiser() {
250 let mut arbiter = Arbiter::new();
251 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
252
253 let conn_id = arbiter.on_le_connect(TCB_IDX, ANOTHER_ADVERTISER_ID);
254
255 assert!(conn_id.is_none())
256 }
257
258 #[test]
test_non_isolated_disconnect()259 fn test_non_isolated_disconnect() {
260 let mut arbiter = Arbiter::new();
261 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
262
263 let ok = arbiter.on_le_disconnect(TCB_IDX);
264
265 assert!(!ok)
266 }
267
268 #[test]
test_isolated_disconnect()269 fn test_isolated_disconnect() {
270 let mut arbiter = Arbiter::new();
271 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
272 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
273
274 let ok = arbiter.on_le_disconnect(TCB_IDX);
275
276 assert!(ok)
277 }
278
279 #[test]
test_advertiser_id_reuse()280 fn test_advertiser_id_reuse() {
281 let mut arbiter = Arbiter::new();
282 // start an advertiser associated with the server, then kill the advertiser
283 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
284 arbiter.clear_advertiser(ADVERTISER_ID);
285
286 // a new advertiser appeared with the same ID and got a connection
287 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
288
289 // but we should not be isolated since this is a new advertiser reusing the old
290 // ID
291 assert!(conn_id.is_none())
292 }
293
294 #[test]
test_server_closed()295 fn test_server_closed() {
296 let mut arbiter = Arbiter::new();
297 // start an advertiser associated with the server, then kill the server
298 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
299 arbiter.clear_server(SERVER_ID);
300
301 // then afterwards we get a connection to this advertiser
302 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
303
304 // since the server is gone, we should not capture the connection
305 assert!(conn_id.is_none())
306 }
307
308 #[test]
test_connection_isolated()309 fn test_connection_isolated() {
310 let mut arbiter = Arbiter::new();
311 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
312 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID).unwrap();
313
314 let is_isolated = arbiter.is_connection_isolated(conn_id);
315
316 assert!(is_isolated)
317 }
318
319 #[test]
test_connection_isolated_after_advertiser_stops()320 fn test_connection_isolated_after_advertiser_stops() {
321 let mut arbiter = Arbiter::new();
322 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
323 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID).unwrap();
324 arbiter.clear_advertiser(ADVERTISER_ID);
325
326 let is_isolated = arbiter.is_connection_isolated(conn_id);
327
328 assert!(is_isolated)
329 }
330
331 #[test]
test_connection_isolated_after_server_stops()332 fn test_connection_isolated_after_server_stops() {
333 let mut arbiter = Arbiter::new();
334 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
335 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID).unwrap();
336 arbiter.clear_server(SERVER_ID);
337
338 let is_isolated = arbiter.is_connection_isolated(conn_id);
339
340 assert!(is_isolated)
341 }
342
343 #[test]
test_packet_capture_when_isolated()344 fn test_packet_capture_when_isolated() {
345 let mut arbiter = Arbiter::new();
346 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
347 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
348 let packet = AttBuilder {
349 opcode: AttOpcode::READ_REQUEST,
350 _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
351 };
352
353 let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
354
355 assert!(out.is_some());
356 }
357
358 #[test]
test_packet_bypass_when_isolated()359 fn test_packet_bypass_when_isolated() {
360 let mut arbiter = Arbiter::new();
361 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
362 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
363 let packet = AttBuilder {
364 opcode: AttOpcode::ERROR_RESPONSE,
365 _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
366 };
367
368 let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
369
370 assert!(out.is_none());
371 }
372
373 #[test]
test_mtu_bypass()374 fn test_mtu_bypass() {
375 let mut arbiter = Arbiter::new();
376 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
377 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
378 let packet = AttBuilder {
379 opcode: AttOpcode::EXCHANGE_MTU_REQUEST,
380 _child_: AttExchangeMtuRequestBuilder { mtu: 64 }.into(),
381 };
382
383 let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
384
385 assert!(out.is_none());
386 }
387
388 #[test]
test_packet_bypass_when_not_isolated()389 fn test_packet_bypass_when_not_isolated() {
390 let mut arbiter = Arbiter::new();
391 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
392 arbiter.on_le_connect(TCB_IDX, ANOTHER_ADVERTISER_ID);
393 let packet = AttBuilder {
394 opcode: AttOpcode::READ_REQUEST,
395 _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
396 };
397
398 let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
399
400 assert!(out.is_none());
401 }
402
403 #[test]
test_packet_bypass_when_different_connection()404 fn test_packet_bypass_when_different_connection() {
405 let mut arbiter = Arbiter::new();
406 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
407 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
408 arbiter.on_le_connect(ANOTHER_TCB_IDX, ANOTHER_ADVERTISER_ID);
409 let packet = AttBuilder {
410 opcode: AttOpcode::READ_REQUEST,
411 _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
412 };
413
414 let out =
415 arbiter.try_parse_att_server_packet(ANOTHER_TCB_IDX, packet.to_vec().unwrap().into());
416
417 assert!(out.is_none());
418 }
419
420 #[test]
test_packet_capture_when_isolated_after_advertiser_closes()421 fn test_packet_capture_when_isolated_after_advertiser_closes() {
422 let mut arbiter = Arbiter::new();
423 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
424 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
425 let packet = AttBuilder {
426 opcode: AttOpcode::READ_REQUEST,
427 _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
428 };
429 arbiter.clear_advertiser(ADVERTISER_ID);
430
431 let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
432
433 assert!(out.is_some());
434 }
435
436 #[test]
test_packet_capture_when_isolated_after_server_closes()437 fn test_packet_capture_when_isolated_after_server_closes() {
438 let mut arbiter = Arbiter::new();
439 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
440 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
441 let packet = AttBuilder {
442 opcode: AttOpcode::READ_REQUEST,
443 _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
444 };
445 arbiter.clear_server(SERVER_ID);
446
447 let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
448
449 assert!(out.is_some());
450 }
451
452 #[test]
test_not_isolated_after_disconnection()453 fn test_not_isolated_after_disconnection() {
454 let mut arbiter = Arbiter::new();
455 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
456 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
457
458 arbiter.on_le_disconnect(TCB_IDX);
459 let is_isolated = arbiter.is_connection_isolated(CONN_ID);
460
461 assert!(!is_isolated);
462 }
463
464 #[test]
test_tcb_idx_reuse_after_isolated()465 fn test_tcb_idx_reuse_after_isolated() {
466 let mut arbiter = Arbiter::new();
467 arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
468 arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
469 arbiter.clear_advertiser(ADVERTISER_ID);
470 arbiter.on_le_disconnect(TCB_IDX);
471
472 let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
473
474 assert!(conn_id.is_none());
475 assert!(!arbiter.is_connection_isolated(CONN_ID));
476 }
477 }
478