1 // Copyright 2013 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_DNS_MDNS_CLIENT_IMPL_H_ 6 #define NET_DNS_MDNS_CLIENT_IMPL_H_ 7 8 #include <stdint.h> 9 10 #include <map> 11 #include <memory> 12 #include <string> 13 #include <utility> 14 #include <vector> 15 16 #include "base/cancelable_callback.h" 17 #include "base/containers/queue.h" 18 #include "base/gtest_prod_util.h" 19 #include "base/memory/raw_ptr.h" 20 #include "base/memory/weak_ptr.h" 21 #include "base/observer_list.h" 22 #include "base/time/time.h" 23 #include "net/base/io_buffer.h" 24 #include "net/base/ip_endpoint.h" 25 #include "net/base/net_export.h" 26 #include "net/dns/mdns_cache.h" 27 #include "net/dns/mdns_client.h" 28 #include "net/socket/datagram_server_socket.h" 29 #include "net/socket/udp_server_socket.h" 30 #include "net/socket/udp_socket.h" 31 32 namespace base { 33 class Clock; 34 class OneShotTimer; 35 } // namespace base 36 37 namespace net { 38 39 class NetLog; 40 41 class MDnsSocketFactoryImpl : public MDnsSocketFactory { 42 public: MDnsSocketFactoryImpl()43 MDnsSocketFactoryImpl() : net_log_(nullptr) {} MDnsSocketFactoryImpl(NetLog * net_log)44 explicit MDnsSocketFactoryImpl(NetLog* net_log) : net_log_(net_log) {} 45 46 MDnsSocketFactoryImpl(const MDnsSocketFactoryImpl&) = delete; 47 MDnsSocketFactoryImpl& operator=(const MDnsSocketFactoryImpl&) = delete; 48 49 ~MDnsSocketFactoryImpl() override = default; 50 51 void CreateSockets( 52 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override; 53 54 private: 55 const raw_ptr<NetLog> net_log_; 56 }; 57 58 // A connection to the network for multicast DNS clients. It reads data into 59 // DnsResponse objects and alerts the delegate that a packet has been received. 60 class NET_EXPORT_PRIVATE MDnsConnection { 61 public: 62 class Delegate { 63 public: 64 // Handle an mDNS packet buffered in |response| with a size of |bytes_read|. 65 virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0; 66 virtual void OnConnectionError(int error) = 0; 67 virtual ~Delegate() = default; 68 }; 69 70 explicit MDnsConnection(MDnsConnection::Delegate* delegate); 71 72 MDnsConnection(const MDnsConnection&) = delete; 73 MDnsConnection& operator=(const MDnsConnection&) = delete; 74 75 virtual ~MDnsConnection(); 76 77 // Succeeds if at least one of the socket handlers succeeded. 78 int Init(MDnsSocketFactory* socket_factory); 79 void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size); 80 81 private: 82 class SocketHandler { 83 public: 84 SocketHandler(std::unique_ptr<DatagramServerSocket> socket, 85 MDnsConnection* connection); 86 87 SocketHandler(const SocketHandler&) = delete; 88 SocketHandler& operator=(const SocketHandler&) = delete; 89 90 ~SocketHandler(); 91 92 int Start(); 93 void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size); 94 95 private: 96 int DoLoop(int rv); 97 void OnDatagramReceived(int rv); 98 99 // Callback for when sending a query has finished. 100 void SendDone(int rv); 101 102 std::unique_ptr<DatagramServerSocket> socket_; 103 raw_ptr<MDnsConnection> connection_; 104 IPEndPoint recv_addr_; 105 DnsResponse response_; 106 IPEndPoint multicast_addr_; 107 bool send_in_progress_ = false; 108 base::queue<std::pair<scoped_refptr<IOBuffer>, unsigned>> send_queue_; 109 }; 110 111 // Callback for handling a datagram being received on either ipv4 or ipv6. 112 void OnDatagramReceived(DnsResponse* response, 113 const IPEndPoint& recv_addr, 114 int bytes_read); 115 116 void PostOnError(SocketHandler* loop, int rv); 117 void OnError(int rv); 118 119 // Only socket handlers which successfully bound and started are kept. 120 std::vector<std::unique_ptr<SocketHandler>> socket_handlers_; 121 122 raw_ptr<Delegate> delegate_; 123 124 base::WeakPtrFactory<MDnsConnection> weak_ptr_factory_{this}; 125 }; 126 127 class MDnsListenerImpl; 128 129 class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient { 130 public: 131 // The core object exists while the MDnsClient is listening, and is deleted 132 // whenever the number of listeners reaches zero. The deletion happens 133 // asychronously, so destroying the last listener does not immediately 134 // invalidate the core. 135 class Core final : public MDnsConnection::Delegate { 136 public: 137 Core(base::Clock* clock, base::OneShotTimer* timer); 138 139 Core(const Core&) = delete; 140 Core& operator=(const Core&) = delete; 141 142 ~Core() override; 143 144 // Initialize the core. 145 int Init(MDnsSocketFactory* socket_factory); 146 147 // Send a query with a specific rrtype and name. Returns true on success. 148 bool SendQuery(uint16_t rrtype, const std::string& name); 149 150 // Add/remove a listener to the list of listeners. 151 void AddListener(MDnsListenerImpl* listener); 152 void RemoveListener(MDnsListenerImpl* listener); 153 154 // Query the cache for records of a specific type and name. 155 void QueryCache(uint16_t rrtype, 156 const std::string& name, 157 std::vector<const RecordParsed*>* records) const; 158 159 // Parse the response and alert relevant listeners. 160 void HandlePacket(DnsResponse* response, int bytes_read) override; 161 162 void OnConnectionError(int error) override; 163 cache_for_testing()164 MDnsCache* cache_for_testing() { return &cache_; } 165 166 private: 167 FRIEND_TEST_ALL_PREFIXES(MDnsTest, CacheCleanupWithShortTTL); 168 169 class ListenerKey { 170 public: 171 ListenerKey(const std::string& name, uint16_t type); 172 ListenerKey(const ListenerKey&) = default; 173 ListenerKey(ListenerKey&&) = default; 174 bool operator<(const ListenerKey& key) const; name_lowercase()175 const std::string& name_lowercase() const { return name_lowercase_; } type()176 uint16_t type() const { return type_; } 177 178 private: 179 std::string name_lowercase_; 180 uint16_t type_; 181 }; 182 typedef base::ObserverList<MDnsListenerImpl>::Unchecked ObserverListType; 183 typedef std::map<ListenerKey, std::unique_ptr<ObserverListType>> 184 ListenerMap; 185 186 // Alert listeners of an update to the cache. 187 void AlertListeners(MDnsCache::UpdateType update_type, 188 const ListenerKey& key, const RecordParsed* record); 189 190 // Schedule a cache cleanup to a specific time, cancelling other cleanups. 191 void ScheduleCleanup(base::Time cleanup); 192 193 // Clean up the cache and schedule a new cleanup. 194 void DoCleanup(); 195 196 // Callback for when a record is removed from the cache. 197 void OnRecordRemoved(const RecordParsed* record); 198 199 void NotifyNsecRecord(const RecordParsed* record); 200 201 // Delete and erase the observer list for |key|. Only deletes the observer 202 // list if is empty. 203 void CleanupObserverList(const ListenerKey& key); 204 205 ListenerMap listeners_; 206 207 MDnsCache cache_; 208 209 raw_ptr<base::Clock> clock_; 210 raw_ptr<base::OneShotTimer> cleanup_timer_; 211 base::Time scheduled_cleanup_; 212 213 std::unique_ptr<MDnsConnection> connection_; 214 base::WeakPtrFactory<Core> weak_ptr_factory_{this}; 215 }; 216 217 MDnsClientImpl(); 218 219 // Test constructor, takes a mock clock and mock timer. 220 MDnsClientImpl(base::Clock* clock, 221 std::unique_ptr<base::OneShotTimer> cleanup_timer); 222 223 MDnsClientImpl(const MDnsClientImpl&) = delete; 224 MDnsClientImpl& operator=(const MDnsClientImpl&) = delete; 225 226 ~MDnsClientImpl() override; 227 228 // MDnsClient implementation: 229 std::unique_ptr<MDnsListener> CreateListener( 230 uint16_t rrtype, 231 const std::string& name, 232 MDnsListener::Delegate* delegate) override; 233 234 std::unique_ptr<MDnsTransaction> CreateTransaction( 235 uint16_t rrtype, 236 const std::string& name, 237 int flags, 238 const MDnsTransaction::ResultCallback& callback) override; 239 240 int StartListening(MDnsSocketFactory* socket_factory) override; 241 void StopListening() override; 242 bool IsListening() const override; 243 core()244 Core* core() { return core_.get(); } 245 246 private: 247 raw_ptr<base::Clock> clock_; 248 std::unique_ptr<base::OneShotTimer> cleanup_timer_; 249 250 std::unique_ptr<Core> core_; 251 }; 252 253 class MDnsListenerImpl final : public MDnsListener { 254 public: 255 MDnsListenerImpl(uint16_t rrtype, 256 const std::string& name, 257 base::Clock* clock, 258 MDnsListener::Delegate* delegate, 259 MDnsClientImpl* client); 260 261 MDnsListenerImpl(const MDnsListenerImpl&) = delete; 262 MDnsListenerImpl& operator=(const MDnsListenerImpl&) = delete; 263 264 ~MDnsListenerImpl() override; 265 266 // MDnsListener implementation: 267 bool Start() override; 268 269 // Actively refresh any received records. 270 void SetActiveRefresh(bool active_refresh) override; 271 272 const std::string& GetName() const override; 273 274 uint16_t GetType() const override; 275 delegate()276 MDnsListener::Delegate* delegate() { return delegate_; } 277 278 // Alert the delegate of a record update. 279 void HandleRecordUpdate(MDnsCache::UpdateType update_type, 280 const RecordParsed* record_parsed); 281 282 // Alert the delegate of the existence of an Nsec record. 283 void AlertNsecRecord(); 284 285 private: 286 void ScheduleNextRefresh(); 287 void DoRefresh(); 288 289 uint16_t rrtype_; 290 std::string name_; 291 raw_ptr<base::Clock> clock_; 292 raw_ptr<MDnsClientImpl> client_; 293 raw_ptr<MDnsListener::Delegate> delegate_; 294 295 base::Time last_update_; 296 uint32_t ttl_; 297 bool started_ = false; 298 bool active_refresh_ = false; 299 300 base::CancelableRepeatingClosure next_refresh_; 301 base::WeakPtrFactory<MDnsListenerImpl> weak_ptr_factory_{this}; 302 }; 303 304 class MDnsTransactionImpl final : public MDnsTransaction, 305 public MDnsListener::Delegate { 306 public: 307 MDnsTransactionImpl(uint16_t rrtype, 308 const std::string& name, 309 int flags, 310 const MDnsTransaction::ResultCallback& callback, 311 MDnsClientImpl* client); 312 313 MDnsTransactionImpl(const MDnsTransactionImpl&) = delete; 314 MDnsTransactionImpl& operator=(const MDnsTransactionImpl&) = delete; 315 316 ~MDnsTransactionImpl() override; 317 318 // MDnsTransaction implementation: 319 bool Start() override; 320 321 const std::string& GetName() const override; 322 uint16_t GetType() const override; 323 324 // MDnsListener::Delegate implementation: 325 void OnRecordUpdate(MDnsListener::UpdateType update, 326 const RecordParsed* record) override; 327 void OnNsecRecord(const std::string& name, unsigned type) override; 328 329 void OnCachePurged() override; 330 331 private: is_active()332 bool is_active() { return !callback_.is_null(); } 333 334 void Reset(); 335 336 // Trigger the callback and reset all related variables. 337 void TriggerCallback(MDnsTransaction::Result result, 338 const RecordParsed* record); 339 340 // Internal callback for when a cache record is found. 341 void CacheRecordFound(const RecordParsed* record); 342 343 // Signal the transactionis over and release all related resources. 344 void SignalTransactionOver(); 345 346 // Reads records from the cache and calls the callback for every 347 // record read. 348 void ServeRecordsFromCache(); 349 350 // Send a query to the network and set up a timeout to time out the 351 // transaction. Returns false if it fails to start listening on the network 352 // or if it fails to send a query. 353 bool QueryAndListen(); 354 355 uint16_t rrtype_; 356 std::string name_; 357 MDnsTransaction::ResultCallback callback_; 358 359 std::unique_ptr<MDnsListener> listener_; 360 base::CancelableOnceCallback<void()> timeout_; 361 362 raw_ptr<MDnsClientImpl> client_; 363 364 bool started_ = false; 365 int flags_; 366 base::WeakPtrFactory<MDnsTransactionImpl> weak_ptr_factory_{this}; 367 }; 368 369 } // namespace net 370 #endif // NET_DNS_MDNS_CLIENT_IMPL_H_ 371