• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! This module handles "arbitration" of ATT packets, to determine whether they
2 //! should be handled by the primary stack or by the "Private GATT" stack
3 
4 use std::{collections::HashMap, sync::Mutex};
5 
6 use log::{error, info, trace};
7 
8 use crate::{
9     do_in_rust_thread,
10     packets::{AttOpcode, OwnedAttView, OwnedPacket},
11 };
12 
13 use super::{
14     ffi::{InterceptAction, StoreCallbacksFromRust},
15     ids::{AdvertiserId, ConnectionId, ServerId, TransportIndex},
16     mtu::MtuEvent,
17     opcode_types::{classify_opcode, OperationType},
18 };
19 
20 static ARBITER: Mutex<Option<Arbiter>> = Mutex::new(None);
21 
22 /// This class is responsible for tracking which connections and advertising we
23 /// own, and using this information to decide what packets should be
24 /// intercepted, and which should be forwarded to the legacy stack.
25 #[derive(Default)]
26 pub struct Arbiter {
27     advertiser_to_server: HashMap<AdvertiserId, ServerId>,
28     transport_to_owned_connection: HashMap<TransportIndex, ConnectionId>,
29 }
30 
31 /// Initialize the Arbiter
initialize_arbiter()32 pub fn initialize_arbiter() {
33     *ARBITER.lock().unwrap() = Some(Arbiter::new());
34 
35     StoreCallbacksFromRust(
36         on_le_connect,
37         on_le_disconnect,
38         intercept_packet,
39         |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest),
40         |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)),
41         |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)),
42     );
43 }
44 
45 /// Acquire the mutex holding the Arbiter and provide a mutable reference to the
46 /// supplied closure
with_arbiter<T>(f: impl FnOnce(&mut Arbiter) -> T) -> T47 pub fn with_arbiter<T>(f: impl FnOnce(&mut Arbiter) -> T) -> T {
48     f(ARBITER.lock().unwrap().as_mut().unwrap())
49 }
50 
51 impl Arbiter {
52     /// Constructor
new() -> Self53     pub fn new() -> Self {
54         Arbiter {
55             advertiser_to_server: HashMap::new(),
56             transport_to_owned_connection: HashMap::new(),
57         }
58     }
59 
60     /// Link a given GATT server to an LE advertising set, so incoming
61     /// connections to this advertiser will be visible only by the linked
62     /// server
associate_server_with_advertiser( &mut self, server_id: ServerId, advertiser_id: AdvertiserId, )63     pub fn associate_server_with_advertiser(
64         &mut self,
65         server_id: ServerId,
66         advertiser_id: AdvertiserId,
67     ) {
68         info!("associating server {server_id:?} with advertising set {advertiser_id:?}");
69         let old = self.advertiser_to_server.insert(advertiser_id, server_id);
70         if let Some(old) = old {
71             error!("new server {server_id:?} associated with same advertiser {advertiser_id:?}, displacing old server {old:?}");
72         }
73     }
74 
75     /// Remove all linked advertising sets from the provided server
clear_server(&mut self, server_id: ServerId)76     pub fn clear_server(&mut self, server_id: ServerId) {
77         info!("clearing advertisers associated with {server_id:?}");
78         self.advertiser_to_server.retain(|_, server| *server != server_id);
79     }
80 
81     /// Clear the server associated with this advertiser, if one exists
clear_advertiser(&mut self, advertiser_id: AdvertiserId)82     pub fn clear_advertiser(&mut self, advertiser_id: AdvertiserId) {
83         info!("removing server (if any) associated with advertiser {advertiser_id:?}");
84         self.advertiser_to_server.remove(&advertiser_id);
85     }
86 
87     /// Check if this conn_id is currently owned by the Rust stack
is_connection_isolated(&self, conn_id: ConnectionId) -> bool88     pub fn is_connection_isolated(&self, conn_id: ConnectionId) -> bool {
89         self.transport_to_owned_connection.values().any(|owned_conn_id| *owned_conn_id == conn_id)
90     }
91 
92     /// Test to see if a buffer contains a valid ATT packet with an opcode we
93     /// are interested in intercepting (those intended for servers that are isolated)
try_parse_att_server_packet( &self, tcb_idx: TransportIndex, packet: Box<[u8]>, ) -> Option<OwnedAttView>94     pub fn try_parse_att_server_packet(
95         &self,
96         tcb_idx: TransportIndex,
97         packet: Box<[u8]>,
98     ) -> Option<OwnedAttView> {
99         if !self.transport_to_owned_connection.contains_key(&tcb_idx) {
100             return None;
101         }
102 
103         let att = OwnedAttView::try_parse(packet).ok()?;
104 
105         if att.view().get_opcode() == AttOpcode::EXCHANGE_MTU_REQUEST {
106             // special case: this server opcode is handled by legacy stack, and we snoop
107             // on its handling, since the MTU is shared between the client + server
108             return None;
109         }
110 
111         match classify_opcode(att.view().get_opcode()) {
112             OperationType::Command | OperationType::Request | OperationType::Confirmation => {
113                 Some(att)
114             }
115             _ => None,
116         }
117     }
118 
119     /// Check if an incoming connection should be intercepted and, if so, on
120     /// what conn_id
on_le_connect( &mut self, tcb_idx: TransportIndex, advertiser: AdvertiserId, ) -> Option<ConnectionId>121     pub fn on_le_connect(
122         &mut self,
123         tcb_idx: TransportIndex,
124         advertiser: AdvertiserId,
125     ) -> Option<ConnectionId> {
126         info!(
127             "processing incoming connection on transport {tcb_idx:?} to advertiser {advertiser:?}"
128         );
129         let server_id = *self.advertiser_to_server.get(&advertiser)?;
130         info!("connection is isolated to server {server_id:?}");
131 
132         let conn_id = ConnectionId::new(tcb_idx, server_id);
133         let old = self.transport_to_owned_connection.insert(tcb_idx, conn_id);
134         if old.is_some() {
135             error!("new server {server_id:?} on transport {tcb_idx:?} displacing existing registered connection {conn_id:?}")
136         }
137         Some(conn_id)
138     }
139 
140     /// Handle a disconnection, if any, and return whether the disconnection was registered
on_le_disconnect(&mut self, tcb_idx: TransportIndex) -> bool141     pub fn on_le_disconnect(&mut self, tcb_idx: TransportIndex) -> bool {
142         info!("processing disconnection on transport {tcb_idx:?}");
143         self.transport_to_owned_connection.remove(&tcb_idx).is_some()
144     }
145 
146     /// Look up the conn_id for a given tcb_idx, if present
get_conn_id(&self, tcb_idx: TransportIndex) -> Option<ConnectionId>147     pub fn get_conn_id(&self, tcb_idx: TransportIndex) -> Option<ConnectionId> {
148         self.transport_to_owned_connection.get(&tcb_idx).copied()
149     }
150 }
151 
on_le_connect(tcb_idx: u8, advertiser: u8)152 fn on_le_connect(tcb_idx: u8, advertiser: u8) {
153     if let Some(conn_id) = with_arbiter(|arbiter| {
154         arbiter.on_le_connect(TransportIndex(tcb_idx), AdvertiserId(advertiser))
155     }) {
156         do_in_rust_thread(move |modules| {
157             if let Err(err) = modules.gatt_module.on_le_connect(conn_id) {
158                 error!("{err:?}")
159             }
160         })
161     }
162 }
163 
on_le_disconnect(tcb_idx: u8)164 fn on_le_disconnect(tcb_idx: u8) {
165     let tcb_idx = TransportIndex(tcb_idx);
166     if with_arbiter(|arbiter| arbiter.on_le_disconnect(tcb_idx)) {
167         do_in_rust_thread(move |modules| {
168             if let Err(err) = modules.gatt_module.on_le_disconnect(tcb_idx) {
169                 error!("{err:?}")
170             }
171         })
172     }
173 }
174 
intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction175 fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
176     let tcb_idx = TransportIndex(tcb_idx);
177     if let Some(att) = with_arbiter(|arbiter| {
178         arbiter.try_parse_att_server_packet(tcb_idx, packet.into_boxed_slice())
179     }) {
180         do_in_rust_thread(move |modules| {
181             trace!("pushing packet to GATT");
182             if let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) {
183                 bearer.handle_packet(att.view())
184             } else {
185                 error!("Bearer for {tcb_idx:?} not found");
186             }
187         });
188         InterceptAction::Drop
189     } else {
190         InterceptAction::Forward
191     }
192 }
193 
on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent)194 fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) {
195     if with_arbiter(|arbiter| arbiter.get_conn_id(tcb_idx)).is_some() {
196         do_in_rust_thread(move |modules| {
197             let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) else {
198                 error!("Bearer for {tcb_idx:?} not found");
199                 return;
200             };
201             if let Err(err) = bearer.handle_mtu_event(event) {
202                 error!("{err:?}")
203             }
204         });
205     }
206 }
207 
208 #[cfg(test)]
209 mod test {
210     use super::*;
211 
212     use crate::{
213         gatt::ids::AttHandle,
214         packets::{
215             AttBuilder, AttExchangeMtuRequestBuilder, AttOpcode, AttReadRequestBuilder,
216             Serializable,
217         },
218     };
219 
220     const TCB_IDX: TransportIndex = TransportIndex(1);
221     const ANOTHER_TCB_IDX: TransportIndex = TransportIndex(2);
222     const ADVERTISER_ID: AdvertiserId = AdvertiserId(3);
223     const SERVER_ID: ServerId = ServerId(4);
224 
225     const CONN_ID: ConnectionId = ConnectionId::new(TCB_IDX, SERVER_ID);
226 
227     const ANOTHER_ADVERTISER_ID: AdvertiserId = AdvertiserId(5);
228 
229     #[test]
test_non_isolated_connect()230     fn test_non_isolated_connect() {
231         let mut arbiter = Arbiter::new();
232 
233         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
234 
235         assert!(conn_id.is_none())
236     }
237 
238     #[test]
test_isolated_connect()239     fn test_isolated_connect() {
240         let mut arbiter = Arbiter::new();
241         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
242 
243         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
244 
245         assert_eq!(conn_id, Some(CONN_ID));
246     }
247 
248     #[test]
test_non_isolated_connect_with_isolated_advertiser()249     fn test_non_isolated_connect_with_isolated_advertiser() {
250         let mut arbiter = Arbiter::new();
251         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
252 
253         let conn_id = arbiter.on_le_connect(TCB_IDX, ANOTHER_ADVERTISER_ID);
254 
255         assert!(conn_id.is_none())
256     }
257 
258     #[test]
test_non_isolated_disconnect()259     fn test_non_isolated_disconnect() {
260         let mut arbiter = Arbiter::new();
261         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
262 
263         let ok = arbiter.on_le_disconnect(TCB_IDX);
264 
265         assert!(!ok)
266     }
267 
268     #[test]
test_isolated_disconnect()269     fn test_isolated_disconnect() {
270         let mut arbiter = Arbiter::new();
271         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
272         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
273 
274         let ok = arbiter.on_le_disconnect(TCB_IDX);
275 
276         assert!(ok)
277     }
278 
279     #[test]
test_advertiser_id_reuse()280     fn test_advertiser_id_reuse() {
281         let mut arbiter = Arbiter::new();
282         // start an advertiser associated with the server, then kill the advertiser
283         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
284         arbiter.clear_advertiser(ADVERTISER_ID);
285 
286         // a new advertiser appeared with the same ID and got a connection
287         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
288 
289         // but we should not be isolated since this is a new advertiser reusing the old
290         // ID
291         assert!(conn_id.is_none())
292     }
293 
294     #[test]
test_server_closed()295     fn test_server_closed() {
296         let mut arbiter = Arbiter::new();
297         // start an advertiser associated with the server, then kill the server
298         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
299         arbiter.clear_server(SERVER_ID);
300 
301         // then afterwards we get a connection to this advertiser
302         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
303 
304         // since the server is gone, we should not capture the connection
305         assert!(conn_id.is_none())
306     }
307 
308     #[test]
test_connection_isolated()309     fn test_connection_isolated() {
310         let mut arbiter = Arbiter::new();
311         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
312         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID).unwrap();
313 
314         let is_isolated = arbiter.is_connection_isolated(conn_id);
315 
316         assert!(is_isolated)
317     }
318 
319     #[test]
test_connection_isolated_after_advertiser_stops()320     fn test_connection_isolated_after_advertiser_stops() {
321         let mut arbiter = Arbiter::new();
322         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
323         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID).unwrap();
324         arbiter.clear_advertiser(ADVERTISER_ID);
325 
326         let is_isolated = arbiter.is_connection_isolated(conn_id);
327 
328         assert!(is_isolated)
329     }
330 
331     #[test]
test_connection_isolated_after_server_stops()332     fn test_connection_isolated_after_server_stops() {
333         let mut arbiter = Arbiter::new();
334         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
335         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID).unwrap();
336         arbiter.clear_server(SERVER_ID);
337 
338         let is_isolated = arbiter.is_connection_isolated(conn_id);
339 
340         assert!(is_isolated)
341     }
342 
343     #[test]
test_packet_capture_when_isolated()344     fn test_packet_capture_when_isolated() {
345         let mut arbiter = Arbiter::new();
346         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
347         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
348         let packet = AttBuilder {
349             opcode: AttOpcode::READ_REQUEST,
350             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
351         };
352 
353         let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
354 
355         assert!(out.is_some());
356     }
357 
358     #[test]
test_packet_bypass_when_isolated()359     fn test_packet_bypass_when_isolated() {
360         let mut arbiter = Arbiter::new();
361         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
362         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
363         let packet = AttBuilder {
364             opcode: AttOpcode::ERROR_RESPONSE,
365             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
366         };
367 
368         let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
369 
370         assert!(out.is_none());
371     }
372 
373     #[test]
test_mtu_bypass()374     fn test_mtu_bypass() {
375         let mut arbiter = Arbiter::new();
376         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
377         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
378         let packet = AttBuilder {
379             opcode: AttOpcode::EXCHANGE_MTU_REQUEST,
380             _child_: AttExchangeMtuRequestBuilder { mtu: 64 }.into(),
381         };
382 
383         let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
384 
385         assert!(out.is_none());
386     }
387 
388     #[test]
test_packet_bypass_when_not_isolated()389     fn test_packet_bypass_when_not_isolated() {
390         let mut arbiter = Arbiter::new();
391         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
392         arbiter.on_le_connect(TCB_IDX, ANOTHER_ADVERTISER_ID);
393         let packet = AttBuilder {
394             opcode: AttOpcode::READ_REQUEST,
395             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
396         };
397 
398         let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
399 
400         assert!(out.is_none());
401     }
402 
403     #[test]
test_packet_bypass_when_different_connection()404     fn test_packet_bypass_when_different_connection() {
405         let mut arbiter = Arbiter::new();
406         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
407         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
408         arbiter.on_le_connect(ANOTHER_TCB_IDX, ANOTHER_ADVERTISER_ID);
409         let packet = AttBuilder {
410             opcode: AttOpcode::READ_REQUEST,
411             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
412         };
413 
414         let out =
415             arbiter.try_parse_att_server_packet(ANOTHER_TCB_IDX, packet.to_vec().unwrap().into());
416 
417         assert!(out.is_none());
418     }
419 
420     #[test]
test_packet_capture_when_isolated_after_advertiser_closes()421     fn test_packet_capture_when_isolated_after_advertiser_closes() {
422         let mut arbiter = Arbiter::new();
423         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
424         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
425         let packet = AttBuilder {
426             opcode: AttOpcode::READ_REQUEST,
427             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
428         };
429         arbiter.clear_advertiser(ADVERTISER_ID);
430 
431         let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
432 
433         assert!(out.is_some());
434     }
435 
436     #[test]
test_packet_capture_when_isolated_after_server_closes()437     fn test_packet_capture_when_isolated_after_server_closes() {
438         let mut arbiter = Arbiter::new();
439         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
440         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
441         let packet = AttBuilder {
442             opcode: AttOpcode::READ_REQUEST,
443             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
444         };
445         arbiter.clear_server(SERVER_ID);
446 
447         let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
448 
449         assert!(out.is_some());
450     }
451 
452     #[test]
test_not_isolated_after_disconnection()453     fn test_not_isolated_after_disconnection() {
454         let mut arbiter = Arbiter::new();
455         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
456         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
457 
458         arbiter.on_le_disconnect(TCB_IDX);
459         let is_isolated = arbiter.is_connection_isolated(CONN_ID);
460 
461         assert!(!is_isolated);
462     }
463 
464     #[test]
test_tcb_idx_reuse_after_isolated()465     fn test_tcb_idx_reuse_after_isolated() {
466         let mut arbiter = Arbiter::new();
467         arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
468         arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
469         arbiter.clear_advertiser(ADVERTISER_ID);
470         arbiter.on_le_disconnect(TCB_IDX);
471 
472         let conn_id = arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
473 
474         assert!(conn_id.is_none());
475         assert!(!arbiter.is_connection_isolated(CONN_ID));
476     }
477 }
478