import operator, unittest import json from django.test import client from autotest_lib.frontend.afe import frontend_test_utils, models as afe_models class ResourceTestCase(unittest.TestCase, frontend_test_utils.FrontendTestMixin): URI_PREFIX = None # subclasses may override this to use partial URIs def setUp(self): super(ResourceTestCase, self).setUp() self._frontend_common_setup() self._setup_debug_user() self.client = client.Client() def tearDown(self): super(ResourceTestCase, self).tearDown() self._frontend_common_teardown() def _setup_debug_user(self): user = afe_models.User.objects.create(login='debug_user') acl = afe_models.AclGroup.objects.get(name='my_acl') user.aclgroup_set.add(acl) def _expected_status(self, method): if method == 'post': return 201 if method == 'delete': return 204 return 200 def raw_request(self, method, uri, **kwargs): method = method.lower() if method == 'put': # the put() implementation in Django's test client is poorly # implemented and only supports url-encoded keyvals for the data. # the post() implementation is correct, though, so use that, with a # trick to override the method. method = 'post' kwargs['REQUEST_METHOD'] = 'PUT' client_method = getattr(self.client, method) return client_method(uri, **kwargs) def request(self, method, uri, encode_body=True, **kwargs): expected_status = self._expected_status(method) if 'data' in kwargs: kwargs.setdefault('content_type', 'application/json') if kwargs['content_type'] == 'application/json': kwargs['data'] = json.dumps(kwargs['data']) if uri.startswith('http://'): full_uri = uri else: assert self.URI_PREFIX full_uri = self.URI_PREFIX + '/' + uri response = self.raw_request(method, full_uri, **kwargs) self.assertEquals( response.status_code, expected_status, 'Requesting %s\nExpected %s, got %s: %s (headers: %s)' % (full_uri, expected_status, response.status_code, response.content, response._headers)) if response['content-type'] != 'application/json': return response.content try: return json.loads(response.content) except ValueError: self.fail('Invalid reponse body: %s' % response.content) def sorted_by(self, collection, attribute): return sorted(collection, key=operator.itemgetter(attribute)) def _read_attribute(self, item, attribute_or_list): if isinstance(attribute_or_list, basestring): attribute_or_list = [attribute_or_list] for attribute in attribute_or_list: item = item[attribute] return item def check_collection(self, collection, attribute_or_list, expected_list, length=None, check_number=None): """Check the members of a collection of dicts. @param collection: an iterable of dicts @param attribute_or_list: an attribute or list of attributes to read. the results will be sorted and compared with expected_list. if a list of attributes is given, the attributes will be read hierarchically, i.e. item[attribute1][attribute2]... @param expected_list: list of expected values @param check_number: if given, only check this number of entries @param length: expected length of list, only necessary if check_number is given """ actual_list = sorted(self._read_attribute(item, attribute_or_list) for item in collection['members']) if length is None and check_number is None: length = len(expected_list) if length is not None: self.assertEquals(len(actual_list), length, 'Expected %s, got %s: %s' % (length, len(actual_list), ', '.join(str(item) for item in actual_list))) if check_number: actual_list = actual_list[:check_number] self.assertEquals(actual_list, expected_list) def check_relationship(self, resource_uri, relationship_name, other_entry_name, field, expected_values, length=None, check_number=None): """Check the members of a relationship collection. @param resource_uri: URI of base resource @param relationship_name: name of relationship attribute on base resource @param other_entry_name: name of other entry in relationship @param field: name of field to grab on other entry @param expected values: list of expected values for the given field """ response = self.request('get', resource_uri) relationship_uri = response[relationship_name]['href'] relationships = self.request('get', relationship_uri) self.check_collection(relationships, [other_entry_name, field], expected_values, length, check_number)