• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from boto.compat import http_client
2from tests.compat import mock, unittest
3
4
5class AWSMockServiceTestCase(unittest.TestCase):
6    """Base class for mocking aws services."""
7    # This param is used by the unittest module to display a full
8    # diff when assert*Equal methods produce an error message.
9    maxDiff = None
10    connection_class = None
11
12    def setUp(self):
13        self.https_connection = mock.Mock(spec=http_client.HTTPSConnection)
14        self.https_connection.debuglevel = 0
15        self.https_connection_factory = (
16            mock.Mock(return_value=self.https_connection), ())
17        self.service_connection = self.create_service_connection(
18            https_connection_factory=self.https_connection_factory,
19            aws_access_key_id='aws_access_key_id',
20            aws_secret_access_key='aws_secret_access_key')
21        self.initialize_service_connection()
22
23    def initialize_service_connection(self):
24        self.actual_request = None
25        self.original_mexe = self.service_connection._mexe
26        self.service_connection._mexe = self._mexe_spy
27        self.proxy = None
28        self.use_proxy = False
29
30    def create_service_connection(self, **kwargs):
31        if self.connection_class is None:
32            raise ValueError("The connection_class class attribute must be "
33                             "set to a non-None value.")
34        return self.connection_class(**kwargs)
35
36    def _mexe_spy(self, request, *args, **kwargs):
37        self.actual_request = request
38        return self.original_mexe(request, *args, **kwargs)
39
40    def create_response(self, status_code, reason='', header=[], body=None):
41        if body is None:
42            body = self.default_body()
43        response = mock.Mock(spec=http_client.HTTPResponse)
44        response.status = status_code
45        response.read.return_value = body
46        response.reason = reason
47
48        response.getheaders.return_value = header
49        response.msg = dict(header)
50
51        def overwrite_header(arg, default=None):
52            header_dict = dict(header)
53            if arg in header_dict:
54                return header_dict[arg]
55            else:
56                return default
57        response.getheader.side_effect = overwrite_header
58
59        return response
60
61    def assert_request_parameters(self, params, ignore_params_values=None):
62        """Verify the actual parameters sent to the service API."""
63        request_params = self.actual_request.params.copy()
64        if ignore_params_values is not None:
65            for param in ignore_params_values:
66                try:
67                    del request_params[param]
68                except KeyError:
69                    pass
70        self.assertDictEqual(request_params, params)
71
72    def set_http_response(self, status_code, reason='', header=[], body=None):
73        http_response = self.create_response(status_code, reason, header, body)
74        self.https_connection.getresponse.return_value = http_response
75
76    def default_body(self):
77        return ''
78
79
80class MockServiceWithConfigTestCase(AWSMockServiceTestCase):
81    def setUp(self):
82        super(MockServiceWithConfigTestCase, self).setUp()
83        self.environ = {}
84        self.config = {}
85        self.config_patch = mock.patch('boto.provider.config.get',
86                                       self.get_config)
87        self.has_config_patch = mock.patch('boto.provider.config.has_option',
88                                           self.has_config)
89        self.environ_patch = mock.patch('os.environ', self.environ)
90        self.config_patch.start()
91        self.has_config_patch.start()
92        self.environ_patch.start()
93
94    def tearDown(self):
95        self.config_patch.stop()
96        self.has_config_patch.stop()
97        self.environ_patch.stop()
98
99    def has_config(self, section_name, key):
100        try:
101            self.config[section_name][key]
102            return True
103        except KeyError:
104            return False
105
106    def get_config(self, section_name, key, default=None):
107        try:
108            return self.config[section_name][key]
109        except KeyError:
110            return None
111