• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "stack/arbiter/acl_arbiter.h"
18 
19 #include <base/functional/bind.h>
20 
21 #include <iterator>
22 #include <unordered_map>
23 
24 #include "common/init_flags.h"
25 #include "os/log.h"
26 #include "osi/include/allocator.h"
27 #include "stack/gatt/gatt_int.h"
28 #include "stack/include/btu.h"  // do_in_main_thread
29 #include "stack/include/l2c_api.h"
30 
31 namespace bluetooth {
32 namespace shim {
33 namespace arbiter {
34 
35 class PassthroughAclArbiter : public AclArbiter {
36  public:
OnLeConnect(uint8_t tcb_idx,uint16_t advertiser_id)37   virtual void OnLeConnect(uint8_t tcb_idx, uint16_t advertiser_id) override {
38     // no-op
39   }
40 
OnLeDisconnect(uint8_t tcb_idx)41   virtual void OnLeDisconnect(uint8_t tcb_idx) override {
42     // no-op
43   }
44 
InterceptAttPacket(uint8_t tcb_idx,const BT_HDR * packet)45   virtual InterceptAction InterceptAttPacket(uint8_t tcb_idx,
46                                              const BT_HDR* packet) override {
47     return InterceptAction::FORWARD;
48   }
49 
OnOutgoingMtuReq(uint8_t tcb_idx)50   virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override {
51     // no-op
52   }
53 
OnIncomingMtuResp(uint8_t tcb_idx,size_t mtu)54   virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) {
55     // no-op
56   }
57 
OnIncomingMtuReq(uint8_t tcb_idx,size_t mtu)58   virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) {
59     // no-op
60   }
61 
Get()62   static PassthroughAclArbiter& Get() {
63     static auto singleton = PassthroughAclArbiter();
64     return singleton;
65   }
66 };
67 
68 namespace {
69 struct RustArbiterCallbacks {
70   ::rust::Fn<void(uint8_t tcb_idx, uint8_t advertiser)> on_le_connect;
71   ::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect;
72   ::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)>
73       intercept_packet;
74   ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req;
75   ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp;
76   ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req;
77 };
78 
79 RustArbiterCallbacks callbacks_{};
80 }  // namespace
81 
82 class RustGattAclArbiter : public AclArbiter {
83  public:
OnLeConnect(uint8_t tcb_idx,uint16_t advertiser_id)84   virtual void OnLeConnect(uint8_t tcb_idx, uint16_t advertiser_id) override {
85     LOG_INFO("Notifying Rust of LE connection");
86     callbacks_.on_le_connect(tcb_idx, advertiser_id);
87   }
88 
OnLeDisconnect(uint8_t tcb_idx)89   virtual void OnLeDisconnect(uint8_t tcb_idx) override {
90     LOG_INFO("Notifying Rust of LE disconnection");
91     callbacks_.on_le_disconnect(tcb_idx);
92   }
93 
InterceptAttPacket(uint8_t tcb_idx,const BT_HDR * packet)94   virtual InterceptAction InterceptAttPacket(uint8_t tcb_idx,
95                                              const BT_HDR* packet) override {
96     LOG_DEBUG("Intercepting ATT packet and forwarding to Rust");
97 
98     uint8_t* packet_start = (uint8_t*)(packet + 1) + packet->offset;
99     uint8_t* packet_end = packet_start + packet->len;
100 
101     auto vec = ::rust::Vec<uint8_t>();
102     std::copy(packet_start, packet_end, std::back_inserter(vec));
103     return callbacks_.intercept_packet(tcb_idx, std::move(vec));
104   }
105 
OnOutgoingMtuReq(uint8_t tcb_idx)106   virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override {
107     LOG_DEBUG("Notifying Rust of outgoing MTU request");
108     callbacks_.on_outgoing_mtu_req(tcb_idx);
109   }
110 
OnIncomingMtuResp(uint8_t tcb_idx,size_t mtu)111   virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) {
112     LOG_DEBUG("Notifying Rust of incoming MTU response %zu", mtu);
113     callbacks_.on_incoming_mtu_resp(tcb_idx, mtu);
114   }
115 
OnIncomingMtuReq(uint8_t tcb_idx,size_t mtu)116   virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) {
117     LOG_DEBUG("Notifying Rust of incoming MTU request %zu", mtu);
118     callbacks_.on_incoming_mtu_req(tcb_idx, mtu);
119   }
120 
SendPacketToPeer(uint8_t tcb_idx,::rust::Vec<uint8_t> buffer)121   void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer) {
122     tGATT_TCB* p_tcb = gatt_get_tcb_by_idx(tcb_idx);
123     if (p_tcb != nullptr) {
124       BT_HDR* p_buf = (BT_HDR*)osi_malloc(sizeof(BT_HDR) + buffer.size() +
125                                           L2CAP_MIN_OFFSET);
126       if (p_buf == nullptr) {
127         LOG_ALWAYS_FATAL("OOM when sending packet");
128       }
129       auto p = (uint8_t*)(p_buf + 1) + L2CAP_MIN_OFFSET;
130       std::copy(buffer.begin(), buffer.end(), p);
131       p_buf->offset = L2CAP_MIN_OFFSET;
132       p_buf->len = buffer.size();
133       L2CA_SendFixedChnlData(L2CAP_ATT_CID, p_tcb->peer_bda, p_buf);
134     } else {
135       LOG_ERROR("Dropping packet since connection no longer exists");
136     }
137   }
138 
Get()139   static RustGattAclArbiter& Get() {
140     static auto singleton = RustGattAclArbiter();
141     return singleton;
142   }
143 };
144 
StoreCallbacksFromRust(::rust::Fn<void (uint8_t tcb_idx,uint8_t advertiser)> on_le_connect,::rust::Fn<void (uint8_t tcb_idx)> on_le_disconnect,::rust::Fn<InterceptAction (uint8_t tcb_idx,::rust::Vec<uint8_t> buffer)> intercept_packet,::rust::Fn<void (uint8_t tcb_idx)> on_outgoing_mtu_req,::rust::Fn<void (uint8_t tcb_idx,size_t mtu)> on_incoming_mtu_resp,::rust::Fn<void (uint8_t tcb_idx,size_t mtu)> on_incoming_mtu_req)145 void StoreCallbacksFromRust(
146     ::rust::Fn<void(uint8_t tcb_idx, uint8_t advertiser)> on_le_connect,
147     ::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect,
148     ::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)>
149         intercept_packet,
150     ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req,
151     ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp,
152     ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req) {
153   LOG_INFO("Received callbacks from Rust, registering in Arbiter");
154   callbacks_ = {on_le_connect,       on_le_disconnect,     intercept_packet,
155                 on_outgoing_mtu_req, on_incoming_mtu_resp, on_incoming_mtu_req};
156 }
157 
SendPacketToPeer(uint8_t tcb_idx,::rust::Vec<uint8_t> buffer)158 void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer) {
159   do_in_main_thread(FROM_HERE,
160                     base::Bind(&RustGattAclArbiter::SendPacketToPeer,
161                                base::Unretained(&RustGattAclArbiter::Get()),
162                                tcb_idx, std::move(buffer)));
163 }
164 
GetArbiter()165 AclArbiter& GetArbiter() {
166   return common::init_flags::private_gatt_is_enabled()
167              ? static_cast<AclArbiter&>(RustGattAclArbiter::Get())
168              : static_cast<AclArbiter&>(PassthroughAclArbiter::Get());
169 }
170 
171 }  // namespace arbiter
172 }  // namespace shim
173 }  // namespace bluetooth
174