# Copyright (c) 2013 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. import collections import dpkt import logging import socket import time DnsRecord = collections.namedtuple('DnsResult', ['rrname', 'rrtype', 'data', 'ts']) MDNS_IP_ADDR = '224.0.0.251' MDNS_PORT = 5353 # Value to | to a class value to signal cache flush. DNS_CACHE_FLUSH = 0x8000 # When considering SRV records, clients are supposed to unilaterally prefer # numerically lower priorities, then pick probabilistically by weight. # See RFC2782. # An arbitrary number that will fit in 16 bits. DEFAULT_PRIORITY = 500 # An arbitrary number that will fit in 16 bits. DEFAULT_WEIGHT = 500 def _RR_equals(rra, rrb): """Returns whether the two dpkt.dns.DNS.RR objects are equal.""" # Compare all the members present in either object and on any RR object. keys = set(rra.__dict__.keys() + rrb.__dict__.keys() + dpkt.dns.DNS.RR.__slots__) # On RR objects, rdata is packed based on the other members and the final # packed string depends on other RR and Q elements on the same DNS/mDNS # packet. keys.discard('rdata') for key in keys: if hasattr(rra, key) != hasattr(rrb, key): return False if not hasattr(rra, key): continue if key == 'cls': # cls attribute should be masked for the cache flush bit. if (getattr(rra, key) & ~DNS_CACHE_FLUSH != getattr(rrb, key) & ~DNS_CACHE_FLUSH): return False else: if getattr(rra, key) != getattr(rrb, key): return False return True class ZeroconfDaemon(object): """Implements a simulated Zeroconf daemon running on the given host. This class implements part of the Multicast DNS RFC 6762 able to simulate a host exposing services or consuming services over mDNS. """ def __init__(self, host, hostname, domain='local'): """Initializes the ZeroconfDameon running on the given host. For the purposes of the Zeroconf implementation, a host must have a hostname and a domain that defaults to 'local'. The ZeroconfDaemon will by default advertise the host address it is running on, which is required by some services. @param host: The Host instance where this daemon runs on. @param hostname: A string representing the hostname """ self._host = host self._hostname = hostname self._domain = domain self._response_ttl = 60 # Default TTL in seconds. self._a_records = {} # Local A records. self._srv_records = {} # Local SRV records. self._ptr_records = {} # Local PTR records. self._txt_records = {} # Local TXT records. # dict() of name --> (dict() of type --> (dict() of data --> timeout)) # For example: _peer_records['somehost.local'][dpkt.dns.DNS_A] \ # ['192.168.0.1'] = time.time() + 3600 self._peer_records = {} # Register the host address locally. self.register_A(self.full_hostname, host.ip_addr) # Attend all the traffic to the mDNS port (unicast, multicast or # broadcast). self._sock = host.socket(socket.AF_INET, socket.SOCK_DGRAM) self._sock.listen(MDNS_IP_ADDR, MDNS_PORT, self._mdns_request) # Observer list for new responses. self._answer_callbacks = [] def __del__(self): self._sock.close() @property def host(self): """The Host object where this daemon is running.""" return self._host @property def hostname(self): """The hostname part within a domain.""" return self._hostname @property def domain(self): """The domain where the given hostname is running.""" return self._domain @property def full_hostname(self): """The full hostname designation including host and domain name.""" return self._hostname + '.' + self._domain def _mdns_request(self, data, addr, port): """Handles a mDNS multicast packet. This method will generate and send a mDNS response to any query for which it has new authoritative information. Called by the Simulator as a callback for every mDNS received packet. @param data: The string contained on the UDP message. @param addr: The address where the message comes from. @param port: The port number where the message comes from. """ # Parse the mDNS request using dpkt's DNS module. mdns = dpkt.dns.DNS(data) if mdns.op == 0x0000: # Query QUERY_HANDLERS = { dpkt.dns.DNS_A: self._process_A, dpkt.dns.DNS_PTR: self._process_PTR, dpkt.dns.DNS_TXT: self._process_TXT, dpkt.dns.DNS_SRV: self._process_SRV, } answers = [] for q in mdns.qd: # Query entries if q.type in QUERY_HANDLERS: answers += QUERY_HANDLERS[q.type](q) elif q.type == dpkt.dns.DNS_ANY: # Special type matching any known type. for _, handler in QUERY_HANDLERS.iteritems(): answers += handler(q) # Remove all the already known answers from the list. answers = [ans for ans in answers if not any(True for known_ans in mdns.an if _RR_equals(known_ans, ans))] self._send_answers(answers) # Always process the received authoritative answers. answers = mdns.ns # Process the answers for response packets. if mdns.op == 0x8400: # Standard response answers.extend(mdns.an) if answers: cur_time = time.time() new_answers = [] for rr in answers: # Answers RRs # dpkt decodes the information on different fields depending on # the response type. if rr.type == dpkt.dns.DNS_A: data = socket.inet_ntoa(rr.ip) elif rr.type == dpkt.dns.DNS_PTR: data = rr.ptrname elif rr.type == dpkt.dns.DNS_TXT: data = tuple(rr.text) # Convert the list to a hashable tuple elif rr.type == dpkt.dns.DNS_SRV: data = rr.srvname, rr.priority, rr.weight, rr.port else: continue # Ignore unsupported records. if not rr.name in self._peer_records: self._peer_records[rr.name] = {} # Start a new cache or clear the existing if required. if not rr.type in self._peer_records[rr.name] or ( rr.cls & DNS_CACHE_FLUSH): self._peer_records[rr.name][rr.type] = {} new_answers.append((rr.type, rr.name, data)) cached_ans = self._peer_records[rr.name][rr.type] rr_timeout = cur_time + rr.ttl # Update the answer timeout if already cached. if data in cached_ans: cached_ans[data] = max(cached_ans[data], rr_timeout) else: cached_ans[data] = rr_timeout if new_answers: for cbk in self._answer_callbacks: cbk(new_answers) def clear_cache(self): """Discards all the cached records.""" self._peer_records = {} def _send_answers(self, answers): """Send a mDNS reply with the provided answers. This method uses the undelying Host to send an IP packet with a mDNS response containing the list of answers of the type dpkt.dns.DNS.RR. If the list is empty, no packet is sent. @param answers: The list of answers to send. """ if not answers: return logging.debug('Sending response with answers: %r.', answers) resp_dns = dpkt.dns.DNS( op = dpkt.dns.DNS_AA, # Authoritative Answer. rcode = dpkt.dns.DNS_RCODE_NOERR, an = answers) # This property modifies the "op" field: resp_dns.qr = dpkt.dns.DNS_R, # Response. self._sock.send(str(resp_dns), MDNS_IP_ADDR, MDNS_PORT) ### RFC 2782 - RR for specifying the location of services (DNS SRV). def register_SRV(self, service, proto, priority, weight, port): """Publishes the SRV specified record. A SRV record defines a service on a port of a host with given properties like priority and weight. The service has a name of the form "service.proto.domain". The target host, this is, the host where the announced service is running on is set to the host where this zeroconf daemon is running, "hostname.domain". @param service: A string with the service name. @param proto: A string with the protocol name, for example "_tcp". @param priority: The service priority number as defined by RFC2782. @param weight: The service weight number as defined by RFC2782. @param port: The port number where the service is running on. """ srvname = service + '.' + proto + '.' + self._domain self._srv_records[srvname] = priority, weight, port def _process_SRV(self, q): """Process a SRV query provided in |q|. @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_SRV. @return: A list of dns.DNS.RR responses to the provided query that can be empty. """ if not q.name in self._srv_records: return [] priority, weight, port = self._srv_records[q.name] full_hostname = self._hostname + '.' + self._domain ans = dpkt.dns.DNS.RR( type = dpkt.dns.DNS_SRV, cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH, ttl = self._response_ttl, name = q.name, srvname = full_hostname, priority = priority, weight = weight, port = port) # The target host (srvname) requires to send an A record with its IP # address. We do this as if a query for it was sent. a_qry = dpkt.dns.DNS.Q(name=full_hostname, type=dpkt.dns.DNS_A) return [ans] + self._process_A(a_qry) ### RFC 1035 - 3.4.1, Domains Names - A (IPv4 address). def register_A(self, hostname, ip_addr): """Registers an Address record (A) pointing to the given IP addres. Records registered with method are assumed authoritative. @param hostname: The full host name, for example, "somehost.local". @param ip_addr: The IPv4 address of the host, for example, "192.0.1.1". """ if not hostname in self._a_records: self._a_records[hostname] = [] self._a_records[hostname].append(socket.inet_aton(ip_addr)) def _process_A(self, q): """Process an A query provided in |q|. @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_A. @return: A list of dns.DNS.RR responses to the provided query that can be empty. """ if not q.name in self._a_records: return [] answers = [] for ip_addr in self._a_records[q.name]: answers.append(dpkt.dns.DNS.RR( type = dpkt.dns.DNS_A, cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH, ttl = self._response_ttl, name = q.name, ip = ip_addr)) return answers ### RFC 1035 - 3.3.12, Domain names - PTR (domain name pointer). def register_PTR(self, domain, destination): """Register a domain pointer record. A domain pointer record is simply a pointer to a hostname on the domain. @param domain: A domain name that can include a proto name, for example, "_workstation._tcp.local". @param destination: The hostname inside the given domain, for example, "my-desktop". """ if not domain in self._ptr_records: self._ptr_records[domain] = [] self._ptr_records[domain].append(destination) def _process_PTR(self, q): """Process a PTR query provided in |q|. @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_PTR. @return: A list of dns.DNS.RR responses to the provided query that can be empty. """ if not q.name in self._ptr_records: return [] answers = [] for dest in self._ptr_records[q.name]: answers.append(dpkt.dns.DNS.RR( type = dpkt.dns.DNS_PTR, cls = dpkt.dns.DNS_IN, # Don't cache flush for PTR records. ttl = self._response_ttl, name = q.name, ptrname = dest + '.' + q.name)) return answers ### RFC 1035 - 3.3.14, Domain names - TXT (descriptive text). def register_TXT(self, domain, txt_list, announce=False): """Register a TXT record on a domain with given list of strings. A TXT record can hold any list of text entries whos format depends on the domain. This method replaces any previous TXT record previously registered for the given domain. @param domain: A domain name that normally can include a proto name and a service or host name. @param txt_list: A list of strings. @param announce: If True, the method will also announce the changes on the network. """ self._txt_records[domain] = txt_list if announce: self._send_answers(self._process_TXT(dpkt.dns.DNS.Q(name=domain))) def _process_TXT(self, q): """Process a TXT query provided in |q|. @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_TXT. @return: A list of dns.DNS.RR responses to the provided query that can be empty. """ if not q.name in self._txt_records: return [] text_list = self._txt_records[q.name] answer = dpkt.dns.DNS.RR( type = dpkt.dns.DNS_TXT, cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH, ttl = self._response_ttl, name = q.name, text = text_list) return [answer] def register_service(self, unique_prefix, service_type, protocol, port, txt_list): """Register a service in the Avahi style. Avahi exposes a convenient set of methods for manipulating "services" which are a trio of PTR, SRV, and TXT records. This is a similar helper method for our daemon. @param unique_prefix: string unique prefix of service (part of the canonical name). @param service_type: string type of service (e.g. '_privet'). @param protocol: string protocol to use for service (e.g. '_tcp'). @param port: IP port of service (e.g. 53). @param txt_list: list of txt records (e.g. ['vers=1.0', 'foo']). """ service_name = '.'.join([unique_prefix, service_type]) fq_service_name = '.'.join([service_name, protocol, self._domain]) logging.debug('Registering service=%s on port=%d with txt records=%r', fq_service_name, port, txt_list) self.register_SRV( service_name, protocol, DEFAULT_PRIORITY, DEFAULT_WEIGHT, port) self.register_PTR('.'.join([service_type, protocol, self._domain]), unique_prefix) self.register_TXT(fq_service_name, txt_list) def cached_results(self, rrname, rrtype, timestamp=None): """Return all the cached results for the requested rrname and rrtype. This method is used to request all the received mDNS answers present on the cache that were valid at the provided timestamp or later. Answers received before this timestamp whose TTL isn't long enough to make them valid at the timestamp aren't returned. On the other hand, answers received *after* the provided timestamp will always be considered, even if they weren't known at the provided timestamp point. A timestamp of None will return them all. This method allows to retrieve "volatile" answers with a TTL of zero. According to the RFC, these answers should be only considered for the "ongoing" request. To do this, call this method after a few seconds (the request timeout) after calling the send_request() method, passing to this method the returned timestamp. @param rrname: The requested domain name. @param rrtype: The DNS record type. For example, dpkt.dns.DNS_TXT. @param timestamp: The request timestamp. See description. @return: The list of matching records of the form (rrname, rrtype, data, timeout). """ if timestamp is None: timestamp = 0 if not rrname in self._peer_records: return [] if not rrtype in self._peer_records[rrname]: return [] res = [] for data, data_ts in self._peer_records[rrname][rrtype].iteritems(): if data_ts >= timestamp: res.append(DnsRecord(rrname, rrtype, data, data_ts)) return res def send_request(self, queries): """Sends a request for the provided rrname and rrtype. All the known and valid answers for this request will be included in the non authoritative list of known answers together with the request. This is recommended by the RFC and avoid unnecessary responses. @param queries: A list of pairs (rrname, rrtype) where rrname is the domain name you are requesting for and the rrtype is the DNS record type. For example, ('somehost.local', dpkt.dns.DNS_ANY). @return: The timestamp where this request is sent. See cached_results(). """ queries = [dpkt.dns.DNS.Q(name=rrname, type=rrtype) for rrname, rrtype in queries] # TODO(deymo): Inlcude the already known answers on the request. answers = [] mdns = dpkt.dns.DNS( op = dpkt.dns.DNS_QUERY, qd = queries, an = answers) self._sock.send(str(mdns), MDNS_IP_ADDR, MDNS_PORT) return time.time() def add_answer_observer(self, callback): """Adds the callback to the list of observers for new answers. @param callback: A callable object accepting a list of tuples (rrname, rrtype, data) where rrname is the domain name, rrtype the DNS record type and data is the information associated with the answers, similar to what cached_results() returns. """ self._answer_callbacks.append(callback)