• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import dbus
6import logging
7import socket
8import time
9import urllib2
10
11import common
12
13from autotest_lib.client.bin import utils
14from autotest_lib.client.common_lib import error
15from autotest_lib.client.cros import routing
16
17
18class IpTablesContext(object):
19    """Context manager that manages iptables rules."""
20    IPTABLES = '/sbin/iptables'
21
22    def __init__(self, initial_allowed_host=None):
23        self.initial_allowed_host = initial_allowed_host
24        self.rules = []
25
26    def _IpTables(self, command):
27        # Run, log, return output
28        return utils.system_output('%s %s' % (self.IPTABLES, command),
29                                   retain_output=True)
30
31    def _RemoveRule(self, rule):
32        self._IpTables('-D ' + rule)
33        self.rules.remove(rule)
34
35    def AllowHost(self, host):
36        """
37        Allows the specified host through the firewall.
38
39        @param host: Name of host to allow
40        """
41        for proto in ['tcp', 'udp']:
42            rule = 'INPUT -s %s/32 -p %s -m %s -j ACCEPT' % (host, proto, proto)
43            output = self._IpTables('-S INPUT')
44            current = [x.rstrip() for x in output.splitlines()]
45            logging.error('current: %s', current)
46            if '-A ' + rule in current:
47                # Already have the rule
48                logging.info('Not adding redundant %s', rule)
49                continue
50            self._IpTables('-A '+ rule)
51            self.rules.append(rule)
52
53    def _CleanupRules(self):
54        for rule in self.rules:
55            self._RemoveRule(rule)
56
57    def __enter__(self):
58        if self.initial_allowed_host:
59            self.AllowHost(self.initial_allowed_host)
60        return self
61
62    def __exit__(self, exception, value, traceback):
63        self._CleanupRules()
64        return False
65
66
67def NameServersForService(conn_mgr, service):
68    """
69    Returns the list of name servers used by a connected service.
70
71    @param conn_mgr: Connection manager (shill)
72    @param service: Name of the connected service
73    @return: List of name servers used by |service|
74    """
75    service_properties = service.GetProperties(utf8_strings=True)
76    device_path = service_properties['Device']
77    device = conn_mgr.GetObjectInterface('Device', device_path)
78    if device is None:
79        logging.error('No device for service %s', service)
80        return []
81
82    properties = device.GetProperties(utf8_strings=True)
83
84    hosts = []
85    for path in properties['IPConfigs']:
86        ipconfig = conn_mgr.GetObjectInterface('IPConfig', path)
87        ipconfig_properties = ipconfig.GetProperties(utf8_strings=True)
88        hosts += ipconfig_properties['NameServers']
89
90    logging.info('Name servers: %s', ', '.join(hosts))
91
92    return hosts
93
94
95def CheckInterfaceForDestination(host, expected_interface):
96    """
97    Checks that routes for host go through a given interface.
98
99    The concern here is that our network setup may have gone wrong
100    and our test connections may go over some other network than
101    the one we're trying to test.  So we take all the IP addresses
102    for the supplied host and make sure they go through the given
103    network interface.
104
105    @param host: Destination host
106    @param expected_interface: Expected interface name
107    @raises: error.TestFail if the routes for the given host go through
108            a different interface than the expected one.
109
110    """
111    # addrinfo records: (family, type, proto, canonname, (addr, port))
112    server_addresses = [record[4][0]
113                        for record in socket.getaddrinfo(host, 80)]
114
115    route_found = False
116    routes = routing.NetworkRoutes()
117    for address in server_addresses:
118        route = routes.getRouteFor(address)
119        if not route:
120            continue
121
122        route_found = True
123
124        interface = route.interface
125        logging.info('interface for %s: %s', address, interface)
126        if interface != expected_interface:
127            raise error.TestFail('Target server %s uses interface %s'
128                                 '(%s expected).' %
129                                 (address, interface, expected_interface))
130
131    if not route_found:
132        raise error.TestFail('No route found for "%s".' % host)
133
134FETCH_URL_PATTERN_FOR_TEST = \
135    'http://testing-chargen.appspot.com/download?size=%d'
136
137def FetchUrl(url_pattern, bytes_to_fetch=10, fetch_timeout=10):
138    """
139    Fetches a specified number of bytes from a URL.
140
141    @param url_pattern: URL pattern for fetching a specified number of bytes.
142            %d in the pattern is to be filled in with the number of bytes to
143            fetch.
144    @param bytes_to_fetch: Number of bytes to fetch.
145    @param fetch_timeout: Number of seconds to wait for the fetch to complete
146            before it times out.
147    @return: The time in seconds spent for fetching the specified number of
148            bytes.
149    @raises: error.TestError if one of the following happens:
150            - The fetch takes no time.
151            - The number of bytes fetched differs from the specified
152              number.
153
154    """
155    # Limit the amount of bytes to read at a time.
156    _MAX_FETCH_READ_BYTES = 1024 * 1024
157
158    url = url_pattern % bytes_to_fetch
159    logging.info('FetchUrl %s', url)
160    start_time = time.time()
161    result = urllib2.urlopen(url, timeout=fetch_timeout)
162    bytes_fetched = 0
163    while bytes_fetched < bytes_to_fetch:
164        bytes_left = bytes_to_fetch - bytes_fetched
165        bytes_to_read = min(bytes_left, _MAX_FETCH_READ_BYTES)
166        bytes_read = len(result.read(bytes_to_read))
167        bytes_fetched += bytes_read
168        if bytes_read != bytes_to_read:
169            raise error.TestError('FetchUrl tried to read %d bytes, but got '
170                                  '%d bytes instead.' %
171                                  (bytes_to_read, bytes_read))
172        fetch_time = time.time() - start_time
173        if fetch_time > fetch_timeout:
174            raise error.TestError('FetchUrl exceeded timeout.')
175
176    return fetch_time
177