• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import collections
6import dpkt
7import logging
8import six
9import socket
10import time
11
12
13DnsRecord = collections.namedtuple('DnsResult', ['rrname', 'rrtype', 'data', 'ts'])
14
15MDNS_IP_ADDR = '224.0.0.251'
16MDNS_PORT = 5353
17
18# Value to | to a class value to signal cache flush.
19DNS_CACHE_FLUSH = 0x8000
20
21# When considering SRV records, clients are supposed to unilaterally prefer
22# numerically lower priorities, then pick probabilistically by weight.
23# See RFC2782.
24# An arbitrary number that will fit in 16 bits.
25DEFAULT_PRIORITY = 500
26# An arbitrary number that will fit in 16 bits.
27DEFAULT_WEIGHT = 500
28
29def _RR_equals(rra, rrb):
30    """Returns whether the two dpkt.dns.DNS.RR objects are equal."""
31    # Compare all the members present in either object and on any RR object.
32    keys = set(rra.__dict__.keys() + rrb.__dict__.keys() +
33               dpkt.dns.DNS.RR.__slots__)
34    # On RR objects, rdata is packed based on the other members and the final
35    # packed string depends on other RR and Q elements on the same DNS/mDNS
36    # packet.
37    keys.discard('rdata')
38    for key in keys:
39        if hasattr(rra, key) != hasattr(rrb, key):
40            return False
41        if not hasattr(rra, key):
42            continue
43        if key == 'cls':
44            # cls attribute should be masked for the cache flush bit.
45            if (getattr(rra, key) & ~DNS_CACHE_FLUSH !=
46                        getattr(rrb, key) & ~DNS_CACHE_FLUSH):
47                return False
48        else:
49            if getattr(rra, key) != getattr(rrb, key):
50                return False
51    return True
52
53
54class ZeroconfDaemon(object):
55    """Implements a simulated Zeroconf daemon running on the given host.
56
57    This class implements part of the Multicast DNS RFC 6762 able to simulate
58    a host exposing services or consuming services over mDNS.
59    """
60    def __init__(self, host, hostname, domain='local'):
61        """Initializes the ZeroconfDameon running on the given host.
62
63        For the purposes of the Zeroconf implementation, a host must have a
64        hostname and a domain that defaults to 'local'. The ZeroconfDaemon will
65        by default advertise the host address it is running on, which is
66        required by some services.
67
68        @param host: The Host instance where this daemon runs on.
69        @param hostname: A string representing the hostname
70        """
71        self._host = host
72        self._hostname = hostname
73        self._domain = domain
74        self._response_ttl = 60 # Default TTL in seconds.
75
76        self._a_records = {} # Local A records.
77        self._srv_records = {} # Local SRV records.
78        self._ptr_records = {} # Local PTR records.
79        self._txt_records = {} # Local TXT records.
80
81        # dict() of name --> (dict() of type --> (dict() of data --> timeout))
82        # For example: _peer_records['somehost.local'][dpkt.dns.DNS_A] \
83        #     ['192.168.0.1'] = time.time() + 3600
84        self._peer_records = {}
85
86        # Register the host address locally.
87        self.register_A(self.full_hostname, host.ip_addr)
88
89        # Attend all the traffic to the mDNS port (unicast, multicast or
90        # broadcast).
91        self._sock = host.socket(socket.AF_INET, socket.SOCK_DGRAM)
92        self._sock.listen(MDNS_IP_ADDR, MDNS_PORT, self._mdns_request)
93
94        # Observer list for new responses.
95        self._answer_callbacks = []
96
97
98    def __del__(self):
99        self._sock.close()
100
101
102    @property
103    def host(self):
104        """The Host object where this daemon is running."""
105        return self._host
106
107
108    @property
109    def hostname(self):
110        """The hostname part within a domain."""
111        return self._hostname
112
113
114    @property
115    def domain(self):
116        """The domain where the given hostname is running."""
117        return self._domain
118
119
120    @property
121    def full_hostname(self):
122        """The full hostname designation including host and domain name."""
123        return self._hostname + '.' + self._domain
124
125
126    def _mdns_request(self, data, addr, port):
127        """Handles a mDNS multicast packet.
128
129        This method will generate and send a mDNS response to any query
130        for which it has new authoritative information. Called by the Simulator
131        as a callback for every mDNS received packet.
132
133        @param data: The string contained on the UDP message.
134        @param addr: The address where the message comes from.
135        @param port: The port number where the message comes from.
136        """
137        # Parse the mDNS request using dpkt's DNS module.
138        mdns = dpkt.dns.DNS(data)
139        if mdns.op == 0x0000: # Query
140            QUERY_HANDLERS = {
141                dpkt.dns.DNS_A: self._process_A,
142                dpkt.dns.DNS_PTR: self._process_PTR,
143                dpkt.dns.DNS_TXT: self._process_TXT,
144                dpkt.dns.DNS_SRV: self._process_SRV,
145            }
146
147            answers = []
148            for q in mdns.qd: # Query entries
149                if q.type in QUERY_HANDLERS:
150                    answers += QUERY_HANDLERS[q.type](q)
151                elif q.type == dpkt.dns.DNS_ANY:
152                    # Special type matching any known type.
153                    for _, handler in QUERY_HANDLERS.iteritems():
154                        answers += handler(q)
155            # Remove all the already known answers from the list.
156            answers = [ans for ans in answers if not any(True
157                for known_ans in mdns.an if _RR_equals(known_ans, ans))]
158
159            self._send_answers(answers)
160
161        # Always process the received authoritative answers.
162        answers = mdns.ns
163
164        # Process the answers for response packets.
165        if mdns.op == 0x8400: # Standard response
166            answers.extend(mdns.an)
167
168        if answers:
169            cur_time = time.time()
170            new_answers = []
171            for rr in answers: # Answers RRs
172                # dpkt decodes the information on different fields depending on
173                # the response type.
174                if rr.type == dpkt.dns.DNS_A:
175                    data = socket.inet_ntoa(rr.ip)
176                elif rr.type == dpkt.dns.DNS_PTR:
177                    data = rr.ptrname
178                elif rr.type == dpkt.dns.DNS_TXT:
179                    data = tuple(rr.text) # Convert the list to a hashable tuple
180                elif rr.type == dpkt.dns.DNS_SRV:
181                    data = rr.srvname, rr.priority, rr.weight, rr.port
182                else:
183                    continue # Ignore unsupported records.
184                if not rr.name in self._peer_records:
185                    self._peer_records[rr.name] = {}
186                # Start a new cache or clear the existing if required.
187                if not rr.type in self._peer_records[rr.name] or (
188                        rr.cls & DNS_CACHE_FLUSH):
189                    self._peer_records[rr.name][rr.type] = {}
190
191                new_answers.append((rr.type, rr.name, data))
192                cached_ans = self._peer_records[rr.name][rr.type]
193                rr_timeout = cur_time + rr.ttl
194                # Update the answer timeout if already cached.
195                if data in cached_ans:
196                    cached_ans[data] = max(cached_ans[data], rr_timeout)
197                else:
198                    cached_ans[data] = rr_timeout
199            if new_answers:
200                for cbk in self._answer_callbacks:
201                    cbk(new_answers)
202
203
204    def clear_cache(self):
205        """Discards all the cached records."""
206        self._peer_records = {}
207
208
209    def _send_answers(self, answers):
210        """Send a mDNS reply with the provided answers.
211
212        This method uses the undelying Host to send an IP packet with a mDNS
213        response containing the list of answers of the type dpkt.dns.DNS.RR.
214        If the list is empty, no packet is sent.
215
216        @param answers: The list of answers to send.
217        """
218        if not answers:
219            return
220        logging.debug('Sending response with answers: %r.', answers)
221        resp_dns = dpkt.dns.DNS(
222            op = dpkt.dns.DNS_AA, # Authoritative Answer.
223            rcode = dpkt.dns.DNS_RCODE_NOERR,
224            an = answers)
225        # This property modifies the "op" field:
226        resp_dns.qr = dpkt.dns.DNS_R, # Response.
227        self._sock.send(str(resp_dns), MDNS_IP_ADDR, MDNS_PORT)
228
229
230    ### RFC 2782 - RR for specifying the location of services (DNS SRV).
231    def register_SRV(self, service, proto, priority, weight, port):
232        """Publishes the SRV specified record.
233
234        A SRV record defines a service on a port of a host with given properties
235        like priority and weight. The service has a name of the form
236        "service.proto.domain". The target host, this is, the host where the
237        announced service is running on is set to the host where this zeroconf
238        daemon is running, "hostname.domain".
239
240        @param service: A string with the service name.
241        @param proto: A string with the protocol name, for example "_tcp".
242        @param priority: The service priority number as defined by RFC2782.
243        @param weight: The service weight number as defined by RFC2782.
244        @param port: The port number where the service is running on.
245        """
246        srvname = service + '.' + proto + '.' + self._domain
247        self._srv_records[srvname] = priority, weight, port
248
249
250    def _process_SRV(self, q):
251        """Process a SRV query provided in |q|.
252
253        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_SRV.
254        @return: A list of dns.DNS.RR responses to the provided query that can
255        be empty.
256        """
257        if not q.name in self._srv_records:
258            return []
259        priority, weight, port = self._srv_records[q.name]
260        full_hostname = self._hostname + '.' + self._domain
261        ans = dpkt.dns.DNS.RR(
262            type = dpkt.dns.DNS_SRV,
263            cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
264            ttl = self._response_ttl,
265            name = q.name,
266            srvname = full_hostname,
267            priority = priority,
268            weight = weight,
269            port = port)
270        # The target host (srvname) requires to send an A record with its IP
271        # address. We do this as if a query for it was sent.
272        a_qry = dpkt.dns.DNS.Q(name=full_hostname, type=dpkt.dns.DNS_A)
273        return [ans] + self._process_A(a_qry)
274
275
276    ### RFC 1035 - 3.4.1, Domains Names - A (IPv4 address).
277    def register_A(self, hostname, ip_addr):
278        """Registers an Address record (A) pointing to the given IP addres.
279
280        Records registered with method are assumed authoritative.
281
282        @param hostname: The full host name, for example, "somehost.local".
283        @param ip_addr: The IPv4 address of the host, for example, "192.0.1.1".
284        """
285        if not hostname in self._a_records:
286            self._a_records[hostname] = []
287        self._a_records[hostname].append(socket.inet_aton(ip_addr))
288
289
290    def _process_A(self, q):
291        """Process an A query provided in |q|.
292
293        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_A.
294        @return: A list of dns.DNS.RR responses to the provided query that can
295        be empty.
296        """
297        if not q.name in self._a_records:
298            return []
299        answers = []
300        for ip_addr in self._a_records[q.name]:
301            answers.append(dpkt.dns.DNS.RR(
302                type = dpkt.dns.DNS_A,
303                cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
304                ttl = self._response_ttl,
305                name = q.name,
306                ip = ip_addr))
307        return answers
308
309
310    ### RFC 1035 - 3.3.12, Domain names - PTR (domain name pointer).
311    def register_PTR(self, domain, destination):
312        """Register a domain pointer record.
313
314        A domain pointer record is simply a pointer to a hostname on the domain.
315
316        @param domain: A domain name that can include a proto name, for
317        example, "_workstation._tcp.local".
318        @param destination: The hostname inside the given domain, for example,
319        "my-desktop".
320        """
321        if not domain in self._ptr_records:
322            self._ptr_records[domain] = []
323        self._ptr_records[domain].append(destination)
324
325
326    def _process_PTR(self, q):
327        """Process a PTR query provided in |q|.
328
329        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_PTR.
330        @return: A list of dns.DNS.RR responses to the provided query that can
331        be empty.
332        """
333        if not q.name in self._ptr_records:
334            return []
335        answers = []
336        for dest in self._ptr_records[q.name]:
337            answers.append(dpkt.dns.DNS.RR(
338                type = dpkt.dns.DNS_PTR,
339                cls = dpkt.dns.DNS_IN, # Don't cache flush for PTR records.
340                ttl = self._response_ttl,
341                name = q.name,
342                ptrname = dest + '.' + q.name))
343        return answers
344
345
346    ### RFC 1035 - 3.3.14, Domain names - TXT (descriptive text).
347    def register_TXT(self, domain, txt_list, announce=False):
348        """Register a TXT record on a domain with given list of strings.
349
350        A TXT record can hold any list of text entries whos format depends on
351        the domain. This method replaces any previous TXT record previously
352        registered for the given domain.
353
354        @param domain: A domain name that normally can include a proto name and
355        a service or host name.
356        @param txt_list: A list of strings.
357        @param announce: If True, the method will also announce the changes
358        on the network.
359        """
360        self._txt_records[domain] = txt_list
361        if announce:
362            self._send_answers(self._process_TXT(dpkt.dns.DNS.Q(name=domain)))
363
364
365    def _process_TXT(self, q):
366        """Process a TXT query provided in |q|.
367
368        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_TXT.
369        @return: A list of dns.DNS.RR responses to the provided query that can
370        be empty.
371        """
372        if not q.name in self._txt_records:
373            return []
374        text_list = self._txt_records[q.name]
375        answer = dpkt.dns.DNS.RR(
376            type = dpkt.dns.DNS_TXT,
377            cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
378            ttl = self._response_ttl,
379            name = q.name,
380            text = text_list)
381        return [answer]
382
383
384    def register_service(self, unique_prefix, service_type,
385                         protocol, port, txt_list):
386        """Register a service in the Avahi style.
387
388        Avahi exposes a convenient set of methods for manipulating "services"
389        which are a trio of PTR, SRV, and TXT records.  This is a similar
390        helper method for our daemon.
391
392        @param unique_prefix: string unique prefix of service (part of the
393                              canonical name).
394        @param service_type: string type of service (e.g. '_privet').
395        @param protocol: string protocol to use for service (e.g. '_tcp').
396        @param port: IP port of service (e.g. 53).
397        @param txt_list: list of txt records (e.g. ['vers=1.0', 'foo']).
398        """
399        service_name = '.'.join([unique_prefix, service_type])
400        fq_service_name = '.'.join([service_name, protocol, self._domain])
401        logging.debug('Registering service=%s on port=%d with txt records=%r',
402                      fq_service_name, port, txt_list)
403        self.register_SRV(
404                service_name, protocol, DEFAULT_PRIORITY, DEFAULT_WEIGHT, port)
405        self.register_PTR('.'.join([service_type, protocol, self._domain]),
406                          unique_prefix)
407        self.register_TXT(fq_service_name, txt_list)
408
409
410    def cached_results(self, rrname, rrtype, timestamp=None):
411        """Return all the cached results for the requested rrname and rrtype.
412
413        This method is used to request all the received mDNS answers present
414        on the cache that were valid at the provided timestamp or later.
415        Answers received before this timestamp whose TTL isn't long enough to
416        make them valid at the timestamp aren't returned. On the other hand,
417        answers received *after* the provided timestamp will always be
418        considered, even if they weren't known at the provided timestamp point.
419        A timestamp of None will return them all.
420
421        This method allows to retrieve "volatile" answers with a TTL of zero.
422        According to the RFC, these answers should be only considered for the
423        "ongoing" request. To do this, call this method after a few seconds (the
424        request timeout) after calling the send_request() method, passing to
425        this method the returned timestamp.
426
427        @param rrname: The requested domain name.
428        @param rrtype: The DNS record type. For example, dpkt.dns.DNS_TXT.
429        @param timestamp: The request timestamp. See description.
430        @return: The list of matching records of the form (rrname, rrtype, data,
431                 timeout).
432        """
433        if timestamp is None:
434            timestamp = 0
435        if not rrname in self._peer_records:
436            return []
437        if not rrtype in self._peer_records[rrname]:
438            return []
439        res = []
440        for data, data_ts in six.iteritems(self._peer_records[rrname][rrtype]):
441            if data_ts >= timestamp:
442                res.append(DnsRecord(rrname, rrtype, data, data_ts))
443        return res
444
445
446    def send_request(self, queries):
447        """Sends a request for the provided rrname and rrtype.
448
449        All the known and valid answers for this request will be included in the
450        non authoritative list of known answers together with the request. This
451        is recommended by the RFC and avoid unnecessary responses.
452
453        @param queries: A list of pairs (rrname, rrtype) where rrname is the
454        domain name you are requesting for and the rrtype is the DNS record
455        type. For example, ('somehost.local', dpkt.dns.DNS_ANY).
456        @return: The timestamp where this request is sent. See cached_results().
457        """
458        queries = [dpkt.dns.DNS.Q(name=rrname, type=rrtype)
459                for rrname, rrtype in queries]
460        # TODO(deymo): Inlcude the already known answers on the request.
461        answers = []
462        mdns = dpkt.dns.DNS(
463            op = dpkt.dns.DNS_QUERY,
464            qd = queries,
465            an = answers)
466        self._sock.send(str(mdns), MDNS_IP_ADDR, MDNS_PORT)
467        return time.time()
468
469
470    def add_answer_observer(self, callback):
471        """Adds the callback to the list of observers for new answers.
472
473        @param callback: A callable object accepting a list of tuples (rrname,
474        rrtype, data) where rrname is the domain name, rrtype the DNS record
475        type and data is the information associated with the answers, similar to
476        what cached_results() returns.
477        """
478        self._answer_callbacks.append(callback)
479