• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2010 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import daemonserver
17import errno
18import logging
19import socket
20import SocketServer
21import threading
22import time
23
24from third_party.dns import flags
25from third_party.dns import message
26from third_party.dns import rcode
27from third_party.dns import resolver
28from third_party.dns import rdatatype
29from third_party import ipaddr
30
31
32
33class DnsProxyException(Exception):
34  pass
35
36
37class RealDnsLookup(object):
38  def __init__(self, name_servers):
39    if '127.0.0.1' in name_servers:
40      raise DnsProxyException(
41          'Invalid nameserver: 127.0.0.1 (causes an infinte loop)')
42    self.resolver = resolver.get_default_resolver()
43    self.resolver.nameservers = name_servers
44    self.dns_cache_lock = threading.Lock()
45    self.dns_cache = {}
46
47  @staticmethod
48  def _IsIPAddress(hostname):
49    try:
50      socket.inet_aton(hostname)
51      return True
52    except socket.error:
53      return False
54
55  def __call__(self, hostname, rdtype=rdatatype.A):
56    """Return real IP for a host.
57
58    Args:
59      host: a hostname ending with a period (e.g. "www.google.com.")
60      rdtype: the query type (1 for 'A', 28 for 'AAAA')
61    Returns:
62      the IP address as a string (e.g. "192.168.25.2")
63    """
64    if self._IsIPAddress(hostname):
65      return hostname
66    self.dns_cache_lock.acquire()
67    ip = self.dns_cache.get(hostname)
68    self.dns_cache_lock.release()
69    if ip:
70      return ip
71    try:
72      answers = self.resolver.query(hostname, rdtype)
73    except resolver.NXDOMAIN:
74      return None
75    except resolver.NoNameservers:
76      logging.debug('_real_dns_lookup(%s) -> No nameserver.',
77                    hostname)
78      return None
79    except (resolver.NoAnswer, resolver.Timeout) as ex:
80      logging.debug('_real_dns_lookup(%s) -> None (%s)',
81                    hostname, ex.__class__.__name__)
82      return None
83    if answers:
84      ip = str(answers[0])
85    self.dns_cache_lock.acquire()
86    self.dns_cache[hostname] = ip
87    self.dns_cache_lock.release()
88    return ip
89
90  def ClearCache(self):
91    """Clear the dns cache."""
92    self.dns_cache_lock.acquire()
93    self.dns_cache.clear()
94    self.dns_cache_lock.release()
95
96
97class ReplayDnsLookup(object):
98  """Resolve DNS requests to replay host."""
99  def __init__(self, replay_ip, filters=None):
100    self.replay_ip = replay_ip
101    self.filters = filters or []
102
103  def __call__(self, hostname):
104    ip = self.replay_ip
105    for f in self.filters:
106      ip = f(hostname, default_ip=ip)
107    return ip
108
109
110class PrivateIpFilter(object):
111  """Resolve private hosts to their real IPs and others to the Web proxy IP.
112
113  Hosts in the given http_archive will resolve to the Web proxy IP without
114  checking the real IP.
115
116  This only supports IPv4 lookups.
117  """
118  def __init__(self, real_dns_lookup, http_archive):
119    """Initialize PrivateIpDnsLookup.
120
121    Args:
122      real_dns_lookup: a function that resolves a host to an IP.
123      http_archive: an instance of a HttpArchive
124        Hosts is in the archive will always resolve to the web_proxy_ip
125    """
126    self.real_dns_lookup = real_dns_lookup
127    self.http_archive = http_archive
128    self.InitializeArchiveHosts()
129
130  def __call__(self, host, default_ip):
131    """Return real IPv4 for private hosts and Web proxy IP otherwise.
132
133    Args:
134      host: a hostname ending with a period (e.g. "www.google.com.")
135    Returns:
136      IP address as a string or None (if lookup fails)
137    """
138    ip = default_ip
139    if host not in self.archive_hosts:
140      real_ip = self.real_dns_lookup(host)
141      if real_ip:
142        if ipaddr.IPAddress(real_ip).is_private:
143          ip = real_ip
144      else:
145        ip = None
146    return ip
147
148  def InitializeArchiveHosts(self):
149    """Recompute the archive_hosts from the http_archive."""
150    self.archive_hosts = set('%s.' % req.host.split(':')[0]
151                             for req in self.http_archive)
152
153
154class DelayFilter(object):
155  """Add a delay to replayed lookups."""
156
157  def __init__(self, is_record_mode, delay_ms):
158    self.is_record_mode = is_record_mode
159    self.delay_ms = int(delay_ms)
160
161  def __call__(self, host, default_ip):
162    if not self.is_record_mode:
163      time.sleep(self.delay_ms * 1000.0)
164    return default_ip
165
166  def SetRecordMode(self):
167    self.is_record_mode = True
168
169  def SetReplayMode(self):
170    self.is_record_mode = False
171
172
173class UdpDnsHandler(SocketServer.DatagramRequestHandler):
174  """Resolve DNS queries to localhost.
175
176  Possible alternative implementation:
177  http://howl.play-bow.org/pipermail/dnspython-users/2010-February/000119.html
178  """
179
180  STANDARD_QUERY_OPERATION_CODE = 0
181
182  def handle(self):
183    """Handle a DNS query.
184
185    IPv6 requests (with rdtype AAAA) receive mismatched IPv4 responses
186    (with rdtype A). To properly support IPv6, the http proxy would
187    need both types of addresses. By default, Windows XP does not
188    support IPv6.
189    """
190    self.data = self.rfile.read()
191    self.transaction_id = self.data[0]
192    self.flags = self.data[1]
193    self.qa_counts = self.data[4:6]
194    self.domain = ''
195    operation_code = (ord(self.data[2]) >> 3) & 15
196    if operation_code == self.STANDARD_QUERY_OPERATION_CODE:
197      self.wire_domain = self.data[12:]
198      self.domain = self._domain(self.wire_domain)
199    else:
200      logging.debug("DNS request with non-zero operation code: %s",
201                    operation_code)
202    ip = self.server.dns_lookup(self.domain)
203    if ip is None:
204      logging.debug('dnsproxy: %s -> NXDOMAIN', self.domain)
205      response = self.get_dns_no_such_name_response()
206    else:
207      if ip == self.server.server_address[0]:
208        logging.debug('dnsproxy: %s -> %s (replay web proxy)', self.domain, ip)
209      else:
210        logging.debug('dnsproxy: %s -> %s', self.domain, ip)
211      response = self.get_dns_response(ip)
212    self.wfile.write(response)
213
214  @classmethod
215  def _domain(cls, wire_domain):
216    domain = ''
217    index = 0
218    length = ord(wire_domain[index])
219    while length:
220      domain += wire_domain[index + 1:index + length + 1] + '.'
221      index += length + 1
222      length = ord(wire_domain[index])
223    return domain
224
225  def get_dns_response(self, ip):
226    packet = ''
227    if self.domain:
228      packet = (
229          self.transaction_id +
230          self.flags +
231          '\x81\x80' +        # standard query response, no error
232          self.qa_counts * 2 + '\x00\x00\x00\x00' +  # Q&A counts
233          self.wire_domain +
234          '\xc0\x0c'          # pointer to domain name
235          '\x00\x01'          # resource record type ("A" host address)
236          '\x00\x01'          # class of the data
237          '\x00\x00\x00\x3c'  # ttl (seconds)
238          '\x00\x04' +        # resource data length (4 bytes for ip)
239          socket.inet_aton(ip)
240          )
241    return packet
242
243  def get_dns_no_such_name_response(self):
244    query_message = message.from_wire(self.data)
245    response_message = message.make_response(query_message)
246    response_message.flags |= flags.AA | flags.RA
247    response_message.set_rcode(rcode.NXDOMAIN)
248    return response_message.to_wire()
249
250
251class DnsProxyServer(SocketServer.ThreadingUDPServer,
252                     daemonserver.DaemonServer):
253  # Increase the request queue size. The default value, 5, is set in
254  # SocketServer.TCPServer (the parent of BaseHTTPServer.HTTPServer).
255  # Since we're intercepting many domains through this single server,
256  # it is quite possible to get more than 5 concurrent requests.
257  request_queue_size = 256
258
259  # Allow sockets to be reused. See
260  # http://svn.python.org/projects/python/trunk/Lib/SocketServer.py for more
261  # details.
262  allow_reuse_address = True
263
264  # Don't prevent python from exiting when there is thread activity.
265  daemon_threads = True
266
267  def __init__(self, host='', port=53, dns_lookup=None):
268    """Initialize DnsProxyServer.
269
270    Args:
271      host: a host string (name or IP) to bind the dns proxy and to which
272        DNS requests will be resolved.
273      port: an integer port on which to bind the proxy.
274      dns_lookup: a list of filters to apply to lookup.
275    """
276    try:
277      SocketServer.ThreadingUDPServer.__init__(
278          self, (host, port), UdpDnsHandler)
279    except socket.error, (error_number, msg):
280      if error_number == errno.EACCES:
281        raise DnsProxyException(
282            'Unable to bind DNS server on (%s:%s)' % (host, port))
283      raise
284    self.dns_lookup = dns_lookup or (lambda host: self.server_address[0])
285    self.server_port = self.server_address[1]
286    logging.warning('DNS server started on %s:%d', self.server_address[0],
287                                                   self.server_address[1])
288
289  def cleanup(self):
290    try:
291      self.shutdown()
292      self.server_close()
293    except KeyboardInterrupt, e:
294      pass
295    logging.info('Stopped DNS server')
296