• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_DNS_MDNS_CLIENT_IMPL_H_
6 #define NET_DNS_MDNS_CLIENT_IMPL_H_
7 
8 #include <stdint.h>
9 
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 #include "base/cancelable_callback.h"
17 #include "base/containers/queue.h"
18 #include "base/gtest_prod_util.h"
19 #include "base/memory/raw_ptr.h"
20 #include "base/memory/weak_ptr.h"
21 #include "base/observer_list.h"
22 #include "base/time/time.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/net_export.h"
26 #include "net/dns/mdns_cache.h"
27 #include "net/dns/mdns_client.h"
28 #include "net/socket/datagram_server_socket.h"
29 #include "net/socket/udp_server_socket.h"
30 #include "net/socket/udp_socket.h"
31 
32 namespace base {
33 class Clock;
34 class OneShotTimer;
35 }  // namespace base
36 
37 namespace net {
38 
39 class NetLog;
40 
41 class MDnsSocketFactoryImpl : public MDnsSocketFactory {
42  public:
MDnsSocketFactoryImpl()43   MDnsSocketFactoryImpl() : net_log_(nullptr) {}
MDnsSocketFactoryImpl(NetLog * net_log)44   explicit MDnsSocketFactoryImpl(NetLog* net_log) : net_log_(net_log) {}
45 
46   MDnsSocketFactoryImpl(const MDnsSocketFactoryImpl&) = delete;
47   MDnsSocketFactoryImpl& operator=(const MDnsSocketFactoryImpl&) = delete;
48 
49   ~MDnsSocketFactoryImpl() override = default;
50 
51   void CreateSockets(
52       std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override;
53 
54  private:
55   const raw_ptr<NetLog> net_log_;
56 };
57 
58 // A connection to the network for multicast DNS clients. It reads data into
59 // DnsResponse objects and alerts the delegate that a packet has been received.
60 class NET_EXPORT_PRIVATE MDnsConnection {
61  public:
62   class Delegate {
63    public:
64     // Handle an mDNS packet buffered in |response| with a size of |bytes_read|.
65     virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0;
66     virtual void OnConnectionError(int error) = 0;
67     virtual ~Delegate() = default;
68   };
69 
70   explicit MDnsConnection(MDnsConnection::Delegate* delegate);
71 
72   MDnsConnection(const MDnsConnection&) = delete;
73   MDnsConnection& operator=(const MDnsConnection&) = delete;
74 
75   virtual ~MDnsConnection();
76 
77   // Succeeds if at least one of the socket handlers succeeded.
78   int Init(MDnsSocketFactory* socket_factory);
79   void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size);
80 
81  private:
82   class SocketHandler {
83    public:
84     SocketHandler(std::unique_ptr<DatagramServerSocket> socket,
85                   MDnsConnection* connection);
86 
87     SocketHandler(const SocketHandler&) = delete;
88     SocketHandler& operator=(const SocketHandler&) = delete;
89 
90     ~SocketHandler();
91 
92     int Start();
93     void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size);
94 
95    private:
96     int DoLoop(int rv);
97     void OnDatagramReceived(int rv);
98 
99     // Callback for when sending a query has finished.
100     void SendDone(int rv);
101 
102     std::unique_ptr<DatagramServerSocket> socket_;
103     raw_ptr<MDnsConnection> connection_;
104     IPEndPoint recv_addr_;
105     DnsResponse response_;
106     IPEndPoint multicast_addr_;
107     bool send_in_progress_ = false;
108     base::queue<std::pair<scoped_refptr<IOBuffer>, unsigned>> send_queue_;
109   };
110 
111   // Callback for handling a datagram being received on either ipv4 or ipv6.
112   void OnDatagramReceived(DnsResponse* response,
113                           const IPEndPoint& recv_addr,
114                           int bytes_read);
115 
116   void PostOnError(SocketHandler* loop, int rv);
117   void OnError(int rv);
118 
119   // Only socket handlers which successfully bound and started are kept.
120   std::vector<std::unique_ptr<SocketHandler>> socket_handlers_;
121 
122   raw_ptr<Delegate> delegate_;
123 
124   base::WeakPtrFactory<MDnsConnection> weak_ptr_factory_{this};
125 };
126 
127 class MDnsListenerImpl;
128 
129 class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient {
130  public:
131   // The core object exists while the MDnsClient is listening, and is deleted
132   // whenever the number of listeners reaches zero. The deletion happens
133   // asychronously, so destroying the last listener does not immediately
134   // invalidate the core.
135   class Core final : public MDnsConnection::Delegate {
136    public:
137     Core(base::Clock* clock, base::OneShotTimer* timer);
138 
139     Core(const Core&) = delete;
140     Core& operator=(const Core&) = delete;
141 
142     ~Core() override;
143 
144     // Initialize the core.
145     int Init(MDnsSocketFactory* socket_factory);
146 
147     // Send a query with a specific rrtype and name. Returns true on success.
148     bool SendQuery(uint16_t rrtype, const std::string& name);
149 
150     // Add/remove a listener to the list of listeners.
151     void AddListener(MDnsListenerImpl* listener);
152     void RemoveListener(MDnsListenerImpl* listener);
153 
154     // Query the cache for records of a specific type and name.
155     void QueryCache(uint16_t rrtype,
156                     const std::string& name,
157                     std::vector<const RecordParsed*>* records) const;
158 
159     // Parse the response and alert relevant listeners.
160     void HandlePacket(DnsResponse* response, int bytes_read) override;
161 
162     void OnConnectionError(int error) override;
163 
cache_for_testing()164     MDnsCache* cache_for_testing() { return &cache_; }
165 
166    private:
167     FRIEND_TEST_ALL_PREFIXES(MDnsTest, CacheCleanupWithShortTTL);
168 
169     class ListenerKey {
170      public:
171       ListenerKey(const std::string& name, uint16_t type);
172       ListenerKey(const ListenerKey&) = default;
173       ListenerKey(ListenerKey&&) = default;
174       bool operator<(const ListenerKey& key) const;
name_lowercase()175       const std::string& name_lowercase() const { return name_lowercase_; }
type()176       uint16_t type() const { return type_; }
177 
178      private:
179       std::string name_lowercase_;
180       uint16_t type_;
181     };
182     typedef base::ObserverList<MDnsListenerImpl>::Unchecked ObserverListType;
183     typedef std::map<ListenerKey, std::unique_ptr<ObserverListType>>
184         ListenerMap;
185 
186     // Alert listeners of an update to the cache.
187     void AlertListeners(MDnsCache::UpdateType update_type,
188                         const ListenerKey& key, const RecordParsed* record);
189 
190     // Schedule a cache cleanup to a specific time, cancelling other cleanups.
191     void ScheduleCleanup(base::Time cleanup);
192 
193     // Clean up the cache and schedule a new cleanup.
194     void DoCleanup();
195 
196     // Callback for when a record is removed from the cache.
197     void OnRecordRemoved(const RecordParsed* record);
198 
199     void NotifyNsecRecord(const RecordParsed* record);
200 
201     // Delete and erase the observer list for |key|. Only deletes the observer
202     // list if is empty.
203     void CleanupObserverList(const ListenerKey& key);
204 
205     ListenerMap listeners_;
206 
207     MDnsCache cache_;
208 
209     raw_ptr<base::Clock> clock_;
210     raw_ptr<base::OneShotTimer> cleanup_timer_;
211     base::Time scheduled_cleanup_;
212 
213     std::unique_ptr<MDnsConnection> connection_;
214     base::WeakPtrFactory<Core> weak_ptr_factory_{this};
215   };
216 
217   MDnsClientImpl();
218 
219   // Test constructor, takes a mock clock and mock timer.
220   MDnsClientImpl(base::Clock* clock,
221                  std::unique_ptr<base::OneShotTimer> cleanup_timer);
222 
223   MDnsClientImpl(const MDnsClientImpl&) = delete;
224   MDnsClientImpl& operator=(const MDnsClientImpl&) = delete;
225 
226   ~MDnsClientImpl() override;
227 
228   // MDnsClient implementation:
229   std::unique_ptr<MDnsListener> CreateListener(
230       uint16_t rrtype,
231       const std::string& name,
232       MDnsListener::Delegate* delegate) override;
233 
234   std::unique_ptr<MDnsTransaction> CreateTransaction(
235       uint16_t rrtype,
236       const std::string& name,
237       int flags,
238       const MDnsTransaction::ResultCallback& callback) override;
239 
240   int StartListening(MDnsSocketFactory* socket_factory) override;
241   void StopListening() override;
242   bool IsListening() const override;
243 
core()244   Core* core() { return core_.get(); }
245 
246  private:
247   raw_ptr<base::Clock> clock_;
248   std::unique_ptr<base::OneShotTimer> cleanup_timer_;
249 
250   std::unique_ptr<Core> core_;
251 };
252 
253 class MDnsListenerImpl final : public MDnsListener {
254  public:
255   MDnsListenerImpl(uint16_t rrtype,
256                    const std::string& name,
257                    base::Clock* clock,
258                    MDnsListener::Delegate* delegate,
259                    MDnsClientImpl* client);
260 
261   MDnsListenerImpl(const MDnsListenerImpl&) = delete;
262   MDnsListenerImpl& operator=(const MDnsListenerImpl&) = delete;
263 
264   ~MDnsListenerImpl() override;
265 
266   // MDnsListener implementation:
267   bool Start() override;
268 
269   // Actively refresh any received records.
270   void SetActiveRefresh(bool active_refresh) override;
271 
272   const std::string& GetName() const override;
273 
274   uint16_t GetType() const override;
275 
delegate()276   MDnsListener::Delegate* delegate() { return delegate_; }
277 
278   // Alert the delegate of a record update.
279   void HandleRecordUpdate(MDnsCache::UpdateType update_type,
280                           const RecordParsed* record_parsed);
281 
282   // Alert the delegate of the existence of an Nsec record.
283   void AlertNsecRecord();
284 
285  private:
286   void ScheduleNextRefresh();
287   void DoRefresh();
288 
289   uint16_t rrtype_;
290   std::string name_;
291   raw_ptr<base::Clock> clock_;
292   raw_ptr<MDnsClientImpl> client_;
293   raw_ptr<MDnsListener::Delegate> delegate_;
294 
295   base::Time last_update_;
296   uint32_t ttl_;
297   bool started_ = false;
298   bool active_refresh_ = false;
299 
300   base::CancelableRepeatingClosure next_refresh_;
301   base::WeakPtrFactory<MDnsListenerImpl> weak_ptr_factory_{this};
302 };
303 
304 class MDnsTransactionImpl final : public MDnsTransaction,
305                                   public MDnsListener::Delegate {
306  public:
307   MDnsTransactionImpl(uint16_t rrtype,
308                       const std::string& name,
309                       int flags,
310                       const MDnsTransaction::ResultCallback& callback,
311                       MDnsClientImpl* client);
312 
313   MDnsTransactionImpl(const MDnsTransactionImpl&) = delete;
314   MDnsTransactionImpl& operator=(const MDnsTransactionImpl&) = delete;
315 
316   ~MDnsTransactionImpl() override;
317 
318   // MDnsTransaction implementation:
319   bool Start() override;
320 
321   const std::string& GetName() const override;
322   uint16_t GetType() const override;
323 
324   // MDnsListener::Delegate implementation:
325   void OnRecordUpdate(MDnsListener::UpdateType update,
326                       const RecordParsed* record) override;
327   void OnNsecRecord(const std::string& name, unsigned type) override;
328 
329   void OnCachePurged() override;
330 
331  private:
is_active()332   bool is_active() { return !callback_.is_null(); }
333 
334   void Reset();
335 
336   // Trigger the callback and reset all related variables.
337   void TriggerCallback(MDnsTransaction::Result result,
338                        const RecordParsed* record);
339 
340   // Internal callback for when a cache record is found.
341   void CacheRecordFound(const RecordParsed* record);
342 
343   // Signal the transactionis over and release all related resources.
344   void SignalTransactionOver();
345 
346   // Reads records from the cache and calls the callback for every
347   // record read.
348   void ServeRecordsFromCache();
349 
350   // Send a query to the network and set up a timeout to time out the
351   // transaction. Returns false if it fails to start listening on the network
352   // or if it fails to send a query.
353   bool QueryAndListen();
354 
355   uint16_t rrtype_;
356   std::string name_;
357   MDnsTransaction::ResultCallback callback_;
358 
359   std::unique_ptr<MDnsListener> listener_;
360   base::CancelableOnceCallback<void()> timeout_;
361 
362   raw_ptr<MDnsClientImpl> client_;
363 
364   bool started_ = false;
365   int flags_;
366   base::WeakPtrFactory<MDnsTransactionImpl> weak_ptr_factory_{this};
367 };
368 
369 }  // namespace net
370 #endif  // NET_DNS_MDNS_CLIENT_IMPL_H_
371