1#!/usr/bin/python3 2# Copyright (c) 2014 The Chromium OS Authors. All rights reserved. 3# Use of this source code is governed by a BSD-style license that can be 4# found in the LICENSE file. 5 6import unittest 7 8import dpkt 9import socket 10 11import common 12 13from autotest_lib.client.cros.netprotos import fake_host 14from autotest_lib.client.cros.netprotos import zeroconf 15 16 17FAKE_HOSTNAME = 'fakehost1' 18 19FAKE_IPADDR = '192.168.11.22' 20 21 22class TestZeroconfDaemon(unittest.TestCase): 23 """Test class for ZeroconfDaemon.""" 24 25 def setUp(self): 26 self._host = fake_host.FakeHost(FAKE_IPADDR) 27 self._zero = zeroconf.ZeroconfDaemon(self._host, FAKE_HOSTNAME) 28 29 30 def _query_A(self, name): 31 """Returns the list of A records matching the given name. 32 33 @param name: A domain name. 34 @return a list of dpkt.dns.DNS.RR objects, one for each matching record. 35 """ 36 q = dpkt.dns.DNS.Q(name=name, type=dpkt.dns.DNS_A) 37 return self._zero._process_A(q) 38 39 40 def testRegisterService(self): 41 """Tests that we get appropriate records after registering a service.""" 42 SERVICE_PORT = 9 43 SERVICE_TXT_LIST = ['lies=lies'] 44 self._zero.register_service('unique_prefix', '_service_type', 45 '_tcp', SERVICE_PORT, SERVICE_TXT_LIST) 46 name = '_service_type._tcp.local' 47 fq_name = 'unique_prefix.' + name 48 # Issue SRV, PTR, and TXT queries 49 q_srv = dpkt.dns.DNS.Q(name=fq_name, type=dpkt.dns.DNS_SRV) 50 q_txt = dpkt.dns.DNS.Q(name=fq_name, type=dpkt.dns.DNS_TXT) 51 q_ptr = dpkt.dns.DNS.Q(name=name, type=dpkt.dns.DNS_PTR) 52 ptr_responses = self._zero._process_PTR(q_ptr) 53 srv_responses = self._zero._process_SRV(q_srv) 54 txt_responses = self._zero._process_TXT(q_txt) 55 self.assertTrue(ptr_responses) 56 self.assertTrue(srv_responses) 57 self.assertTrue(txt_responses) 58 ptr_resp = ptr_responses[0] 59 srv_resp = [resp for resp in srv_responses 60 if resp.type == dpkt.dns.DNS_SRV][0] 61 txt_resp = txt_responses[0] 62 # Check that basic things are right. 63 self.assertEqual(fq_name, ptr_resp.ptrname) 64 self.assertEqual(FAKE_HOSTNAME + '.' + self._zero.domain, 65 srv_resp.srvname) 66 self.assertEqual(SERVICE_PORT, srv_resp.port) 67 self.assertEqual(SERVICE_TXT_LIST, txt_resp.text) 68 69 70 def testProperties(self): 71 """Test the initial properties set by the constructor.""" 72 self.assertEqual(self._zero.host, self._host) 73 self.assertEqual(self._zero.hostname, FAKE_HOSTNAME) 74 self.assertEqual(self._zero.domain, 'local') # Default domain 75 self.assertEqual(self._zero.full_hostname, FAKE_HOSTNAME + '.local') 76 77 78 def testSocketInit(self): 79 """Test that the constructor listens for mDNS traffic.""" 80 81 # Should create an UDP socket and bind it to the mDNS address and port. 82 self.assertEqual(len(self._host._sockets), 1) 83 sock = self._host._sockets[0] 84 85 self.assertEqual(sock._family, socket.AF_INET) # IPv4 86 self.assertEqual(sock._sock_type, socket.SOCK_DGRAM) # UDP 87 88 # Check it is listening for UDP packets on the mDNS address and port. 89 self.assertTrue(sock._bound) 90 self.assertEqual(sock._bind_ip_addr, '224.0.0.251') # mDNS address 91 self.assertEqual(sock._bind_port, 5353) # mDNS port 92 self.assertTrue(callable(sock._bind_recv_callback)) 93 94 95 def testRecordsInit(self): 96 """Test the A record of the host is registered.""" 97 host_A = self._query_A(self._zero.full_hostname) 98 self.assertGreater(len(host_A), 0) 99 100 record = host_A[0] 101 # Check the hostname and the packed IP address. 102 self.assertEqual(record.name, self._zero.full_hostname) 103 self.assertEqual(record.ip, socket.inet_aton(self._host.ip_addr)) 104 105 106 def testDoubleTXTProcessing(self): 107 """Test when more than one TXT record is present in a packet. 108 109 A mDNS packet can include several answer records for several domains and 110 record type. A corner case found on the field presents a mDNS packet 111 with two TXT records for the same domain name on the same packet on its 112 authoritative answers section while the packet itself is a query. 113 """ 114 # Build the mDNS packet with two TXT records. 115 domain_name = 'other_host.local' 116 answers = [ 117 dpkt.dns.DNS.RR(type=dpkt.dns.DNS_TXT, 118 cls=dpkt.dns.DNS_IN, 119 ttl=120, 120 name=domain_name, 121 text=['one'.encode(), 'two'.encode()]), 122 dpkt.dns.DNS.RR(type=dpkt.dns.DNS_TXT, 123 cls=dpkt.dns.DNS_IN, 124 ttl=120, 125 name=domain_name, 126 text=['two'.encode()]) 127 ] 128 # The packet is a query packet, with extra answers on the autoritative 129 # section. 130 mdns = dpkt.dns.DNS( 131 op = dpkt.dns.DNS_QUERY, # Standard query 132 rcode = dpkt.dns.DNS_RCODE_NOERR, 133 q = [], 134 an = [], 135 ns = answers) 136 137 # Record the new answers received on the answer_calls list. 138 answer_calls = [] 139 self._zero.add_answer_observer(lambda args: answer_calls.extend(args)) 140 141 # Send the packet to the registered callback. 142 sock = self._host._sockets[0] 143 cbk = sock._bind_recv_callback 144 cbk(bytes(mdns), 1234, 5353) 145 146 # Check that the answers callback is called with all the answers in the 147 # received order. 148 self.assertEqual(len(answer_calls), 2) 149 ans1, ans2 = answer_calls # Each ans is a (rrtype, rrname, data) 150 self.assertEqual(ans1[2], ('one', 'two')) 151 self.assertEqual(ans2[2], ('two',)) 152 153 # Check that the two records were cached. 154 records = self._zero.cached_results(domain_name, dpkt.dns.DNS_TXT) 155 self.assertEqual(len(records), 2) 156 157 158if __name__ == '__main__': 159 unittest.main() 160