• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import cgi, datetime, re, time, urllib
2from django import http
3import django.core.exceptions
4from django.core import urlresolvers
5from django.utils import datastructures
6import json
7from autotest_lib.frontend.shared import exceptions, query_lib
8from autotest_lib.frontend.afe import model_logic
9
10
11_JSON_CONTENT_TYPE = 'application/json'
12
13
14def _resolve_class_path(class_path):
15    module_path, class_name = class_path.rsplit('.', 1)
16    module = __import__(module_path, {}, {}, [''])
17    return getattr(module, class_name)
18
19
20_NO_VALUE_SPECIFIED = object()
21
22class _InputDict(dict):
23    def get(self, key, default=_NO_VALUE_SPECIFIED):
24        return super(_InputDict, self).get(key, default)
25
26
27    @classmethod
28    def remove_unspecified_fields(cls, field_dict):
29        return dict((key, value) for key, value in field_dict.iteritems()
30                    if value is not _NO_VALUE_SPECIFIED)
31
32
33class Resource(object):
34    _permitted_methods = None # subclasses must override this
35
36
37    def __init__(self, request):
38        assert self._permitted_methods
39        # this request should be used for global environment info, like
40        # constructing absolute URIs.  it should not be used for query
41        # parameters, because the request may not have been for this particular
42        # resource.
43        self._request = request
44        # this dict will contain the applicable query parameters
45        self._query_params = datastructures.MultiValueDict()
46
47
48    @classmethod
49    def dispatch_request(cls, request, *args, **kwargs):
50        # handle a request directly
51        try:
52            try:
53                instance = cls.from_uri_args(request, **kwargs)
54            except django.core.exceptions.ObjectDoesNotExist, exc:
55                raise http.Http404(exc)
56
57            instance.read_query_parameters(request.GET)
58            return instance.handle_request()
59        except exceptions.RequestError, exc:
60            return exc.response
61
62
63    def handle_request(self):
64        if self._request.method.upper() not in self._permitted_methods:
65            return http.HttpResponseNotAllowed(self._permitted_methods)
66
67        handler = getattr(self, self._request.method.lower())
68        return handler()
69
70
71    # the handler methods below only need to be overridden if the resource
72    # supports the method
73
74    def get(self):
75        """Handle a GET request.
76
77        @returns an HttpResponse
78        """
79        raise NotImplementedError
80
81
82    def post(self):
83        """Handle a POST request.
84
85        @returns an HttpResponse
86        """
87        raise NotImplementedError
88
89
90    def put(self):
91        """Handle a PUT request.
92
93        @returns an HttpResponse
94        """
95        raise NotImplementedError
96
97
98    def delete(self):
99        """Handle a DELETE request.
100
101        @returns an HttpResponse
102        """
103        raise NotImplementedError
104
105
106    @classmethod
107    def from_uri_args(cls, request, **kwargs):
108        """Construct an instance from URI args.
109
110        Default implementation for resources with no URI args.
111        """
112        return cls(request)
113
114
115    def _uri_args(self):
116        """Return kwargs for a URI reference to this resource.
117
118        Default implementation for resources with no URI args.
119        """
120        return {}
121
122
123    def _query_parameters_accepted(self):
124        """Return sequence of tuples (name, description) for query parameters.
125
126        Documents the available query parameters for GETting this resource.
127        Default implementation for resources with no parameters.
128        """
129        return ()
130
131
132    def read_query_parameters(self, parameters):
133        """Read relevant query parameters from a Django MultiValueDict."""
134        params_acccepted = set(param_name for param_name, _
135                               in self._query_parameters_accepted())
136        for name, values in parameters.iterlists():
137            base_name = name.split(':', 1)[0]
138            if base_name in params_acccepted:
139                self._query_params.setlist(name, values)
140
141
142    def set_query_parameters(self, **parameters):
143        """Set query parameters programmatically."""
144        self._query_params.update(parameters)
145
146
147    def href(self, query_params=None):
148        """Return URI to this resource."""
149        kwargs = self._uri_args()
150        path = urlresolvers.reverse(self.dispatch_request, kwargs=kwargs)
151        full_query_params = datastructures.MultiValueDict(self._query_params)
152        if query_params:
153            full_query_params.update(query_params)
154        if full_query_params:
155            path += '?' + urllib.urlencode(full_query_params.lists(),
156                                           doseq=True)
157        return self._request.build_absolute_uri(path)
158
159
160    def resolve_uri(self, uri):
161        # check for absolute URIs
162        match = re.match(r'(?P<root>https?://[^/]+)(?P<path>/.*)', uri)
163        if match:
164            # is this URI for a different host?
165            my_root = self._request.build_absolute_uri('/')
166            request_root = match.group('root') + '/'
167            if my_root != request_root:
168                # might support this in the future, but not now
169                raise exceptions.BadRequest('Unable to resolve remote URI %s'
170                                            % uri)
171            uri = match.group('path')
172
173        try:
174            view_method, args, kwargs = urlresolvers.resolve(uri)
175        except http.Http404:
176            raise exceptions.BadRequest('Unable to resolve URI %s' % uri)
177        resource_class = view_method.im_self # class owning this classmethod
178        return resource_class.from_uri_args(self._request, **kwargs)
179
180
181    def resolve_link(self, link):
182        if isinstance(link, dict):
183            uri = link['href']
184        elif isinstance(link, basestring):
185            uri = link
186        else:
187            raise exceptions.BadRequest('Unable to understand link %s' % link)
188        return self.resolve_uri(uri)
189
190
191    def link(self, query_params=None):
192        return {'href': self.href(query_params=query_params)}
193
194
195    def _query_parameters_response(self):
196        return dict((name, description)
197                    for name, description in self._query_parameters_accepted())
198
199
200    def _basic_response(self, content):
201        """Construct and return a simple 200 response."""
202        assert isinstance(content, dict)
203        query_parameters = self._query_parameters_response()
204        if query_parameters:
205            content['query_parameters'] = query_parameters
206        encoded_content = json.dumps(content)
207        return http.HttpResponse(encoded_content,
208                                 content_type=_JSON_CONTENT_TYPE)
209
210
211    def _decoded_input(self):
212        content_type = self._request.META.get('CONTENT_TYPE',
213                                              _JSON_CONTENT_TYPE)
214        raw_data = self._request.raw_post_data
215        if content_type == _JSON_CONTENT_TYPE:
216            try:
217                raw_dict = json.loads(raw_data)
218            except ValueError, exc:
219                raise exceptions.BadRequest('Error decoding request body: '
220                                            '%s\n%r' % (exc, raw_data))
221            if not isinstance(raw_dict, dict):
222                raise exceptions.BadRequest('Expected dict input, got %s: %r' %
223                                            (type(raw_dict), raw_dict))
224        elif content_type == 'application/x-www-form-urlencoded':
225            cgi_dict = cgi.parse_qs(raw_data) # django won't do this for PUT
226            raw_dict = {}
227            for key, values in cgi_dict.items():
228                value = values[-1] # take last value if multiple were given
229                try:
230                    # attempt to parse numbers, booleans and nulls
231                    raw_dict[key] = json.loads(value)
232                except ValueError:
233                    # otherwise, leave it as a string
234                    raw_dict[key] = value
235        else:
236            raise exceptions.RequestError(415, 'Unsupported media type: %s'
237                                          % content_type)
238
239        return _InputDict(raw_dict)
240
241
242    def _format_datetime(self, date_time):
243        """Return ISO 8601 string for the given datetime"""
244        if date_time is None:
245            return None
246        timezone_hrs = time.timezone / 60 / 60  # convert seconds to hours
247        if timezone_hrs >= 0:
248            timezone_join = '+'
249        else:
250            timezone_join = '' # minus sign comes from number itself
251        timezone_spec = '%s%s:00' % (timezone_join, timezone_hrs)
252        return date_time.strftime('%Y-%m-%dT%H:%M:%S') + timezone_spec
253
254
255    @classmethod
256    def _check_for_required_fields(cls, input_dict, fields):
257        assert isinstance(fields, (list, tuple)), fields
258        missing_fields = ', '.join(field for field in fields
259                                   if field not in input_dict)
260        if missing_fields:
261            raise exceptions.BadRequest('Missing input: ' + missing_fields)
262
263
264class Entry(Resource):
265    @classmethod
266    def add_query_selectors(cls, query_processor):
267        """Sbuclasses may override this to support querying."""
268        pass
269
270
271    def short_representation(self):
272        return self.link()
273
274
275    def full_representation(self):
276        return self.short_representation()
277
278
279    def get(self):
280        return self._basic_response(self.full_representation())
281
282
283    def put(self):
284        try:
285            self.update(self._decoded_input())
286        except model_logic.ValidationError, exc:
287            raise exceptions.BadRequest('Invalid input: %s' % exc)
288        return self._basic_response(self.full_representation())
289
290
291    def _delete_entry(self):
292        raise NotImplementedError
293
294
295    def delete(self):
296        self._delete_entry()
297        return http.HttpResponse(status=204) # No content
298
299
300    def create_instance(self, input_dict, containing_collection):
301        raise NotImplementedError
302
303
304    def update(self, input_dict):
305        raise NotImplementedError
306
307
308class InstanceEntry(Entry):
309    class NullEntry(object):
310        def link(self):
311            return None
312
313
314        def short_representation(self):
315            return None
316
317
318    _null_entry = NullEntry()
319    _permitted_methods = ('GET', 'PUT', 'DELETE')
320    model = None # subclasses must override this with a Django model class
321
322
323    def __init__(self, request, instance):
324        assert self.model is not None
325        super(InstanceEntry, self).__init__(request)
326        self.instance = instance
327        self._is_prepared_for_full_representation = False
328
329
330    @classmethod
331    def from_optional_instance(cls, request, instance):
332        if instance is None:
333            return cls._null_entry
334        return cls(request, instance)
335
336
337    def _delete_entry(self):
338        self.instance.delete()
339
340
341    def full_representation(self):
342        self.prepare_for_full_representation([self])
343        return super(InstanceEntry, self).full_representation()
344
345
346    @classmethod
347    def prepare_for_full_representation(cls, entries):
348        """
349        Prepare the given list of entries to generate full representations.
350
351        This method delegates to _do_prepare_for_full_representation(), which
352        subclasses may override as necessary to do the actual processing.  This
353        method also marks the instance as prepared, so it's safe to call this
354        multiple times with the same instance(s) without wasting work.
355        """
356        not_prepared = [entry for entry in entries
357                        if not entry._is_prepared_for_full_representation]
358        cls._do_prepare_for_full_representation([entry.instance
359                                                 for entry in not_prepared])
360        for entry in not_prepared:
361            entry._is_prepared_for_full_representation = True
362
363
364    @classmethod
365    def _do_prepare_for_full_representation(cls, instances):
366        """
367        Subclasses may override this to gather data as needed for full
368        representations of the given model instances.  Typically, this involves
369        querying over related objects, and this method offers a chance to query
370        for many instances at once, which can provide a great performance
371        benefit.
372        """
373        pass
374
375
376class Collection(Resource):
377    _DEFAULT_ITEMS_PER_PAGE = 50
378
379    _permitted_methods=('GET', 'POST')
380
381    # subclasses must override these
382    queryset = None # or override _fresh_queryset() directly
383    entry_class = None
384
385
386    def __init__(self, request):
387        super(Collection, self).__init__(request)
388        assert self.entry_class is not None
389        if isinstance(self.entry_class, basestring):
390            type(self).entry_class = _resolve_class_path(self.entry_class)
391
392        self._query_processor = query_lib.QueryProcessor()
393        self.entry_class.add_query_selectors(self._query_processor)
394
395
396    def _query_parameters_accepted(self):
397        params = [('start_index', 'Index of first member to include'),
398                  ('items_per_page', 'Number of members to include'),
399                  ('full_representations',
400                   'True to include full representations of members')]
401        for selector in self._query_processor.selectors():
402            params.append((selector.name, selector.doc))
403        return params
404
405
406    def _fresh_queryset(self):
407        assert self.queryset is not None
408        # always copy the queryset before using it to avoid caching
409        return self.queryset.all()
410
411
412    def _entry_from_instance(self, instance):
413        return self.entry_class(self._request, instance)
414
415
416    def _representation(self, entry_instances):
417        entries = [self._entry_from_instance(instance)
418                   for instance in entry_instances]
419
420        want_full_representation = self._read_bool_parameter(
421                'full_representations')
422        if want_full_representation:
423            self.entry_class.prepare_for_full_representation(entries)
424
425        members = []
426        for entry in entries:
427            if want_full_representation:
428                rep = entry.full_representation()
429            else:
430                rep = entry.short_representation()
431            members.append(rep)
432
433        rep = self.link()
434        rep.update({'members': members})
435        return rep
436
437
438    def _read_bool_parameter(self, name):
439        if name not in self._query_params:
440            return False
441        return (self._query_params[name].lower() == 'true')
442
443
444    def _read_int_parameter(self, name, default):
445        if name not in self._query_params:
446            return default
447        input_value = self._query_params[name]
448        try:
449            return int(input_value)
450        except ValueError:
451            raise exceptions.BadRequest('Invalid non-numeric value for %s: %r'
452                                        % (name, input_value))
453
454
455    def _apply_form_query(self, queryset):
456        """Apply any query selectors passed as form variables."""
457        for parameter, values in self._query_params.lists():
458            if ':' in parameter:
459                parameter, comparison_type = parameter.split(':', 1)
460            else:
461                comparison_type = None
462
463            if not self._query_processor.has_selector(parameter):
464                continue
465            for value in values: # forms keys can have multiple values
466                queryset = self._query_processor.apply_selector(
467                        queryset, parameter, value,
468                        comparison_type=comparison_type)
469        return queryset
470
471
472    def _filtered_queryset(self):
473        return self._apply_form_query(self._fresh_queryset())
474
475
476    def get(self):
477        queryset = self._filtered_queryset()
478
479        items_per_page = self._read_int_parameter('items_per_page',
480                                                  self._DEFAULT_ITEMS_PER_PAGE)
481        start_index = self._read_int_parameter('start_index', 0)
482        page = queryset[start_index:(start_index + items_per_page)]
483
484        rep = self._representation(page)
485        rep.update({'total_results': len(queryset),
486                    'start_index': start_index,
487                    'items_per_page': items_per_page})
488        return self._basic_response(rep)
489
490
491    def full_representation(self):
492        # careful, this rep can be huge for large collections
493        return self._representation(self._fresh_queryset())
494
495
496    def post(self):
497        input_dict = self._decoded_input()
498        try:
499            instance = self.entry_class.create_instance(input_dict, self)
500            entry = self._entry_from_instance(instance)
501            entry.update(input_dict)
502        except model_logic.ValidationError, exc:
503            raise exceptions.BadRequest('Invalid input: %s' % exc)
504        # RFC 2616 specifies that we provide the new URI in both the Location
505        # header and the body
506        response = http.HttpResponse(status=201, # Created
507                                     content=entry.href())
508        response['Location'] = entry.href()
509        return response
510
511
512class Relationship(Entry):
513    _permitted_methods = ('GET', 'DELETE')
514
515    # subclasses must override this with a dict mapping name to entry class
516    related_classes = None
517
518
519    def __init__(self, **kwargs):
520        assert len(self.related_classes) == 2
521        self.entries = dict((name, kwargs[name])
522                            for name in self.related_classes)
523        for name in self.related_classes: # sanity check
524            assert isinstance(self.entries[name], self.related_classes[name])
525
526        # just grab the request from one of the entries
527        some_entry = self.entries.itervalues().next()
528        super(Relationship, self).__init__(some_entry._request)
529
530
531    @classmethod
532    def from_uri_args(cls, request, **kwargs):
533        # kwargs contains URI args for each entry
534        entries = {}
535        for name, entry_class in cls.related_classes.iteritems():
536            entries[name] = entry_class.from_uri_args(request, **kwargs)
537        return cls(**entries)
538
539
540    def _uri_args(self):
541        kwargs = {}
542        for name, entry in self.entries.iteritems():
543            kwargs.update(entry._uri_args())
544        return kwargs
545
546
547    def short_representation(self):
548        rep = self.link()
549        for name, entry in self.entries.iteritems():
550            rep[name] = entry.short_representation()
551        return rep
552
553
554    @classmethod
555    def _get_related_manager(cls, instance):
556        """Get the related objects manager for the given instance.
557
558        The instance must be one of the related classes.  This method will
559        return the related manager from that instance to instances of the other
560        related class.
561        """
562        this_model = type(instance)
563        models = [entry_class.model for entry_class
564                  in cls.related_classes.values()]
565        if isinstance(instance, models[0]):
566            this_model, other_model = models
567        else:
568            other_model, this_model = models
569
570        _, field = this_model.objects.determine_relationship(other_model)
571        this_models_fields = (this_model._meta.fields
572                              + this_model._meta.many_to_many)
573        if field in this_models_fields:
574            manager_name = field.attname
575        else:
576            # related manager is on other_model, get name of reverse related
577            # manager on this_model
578            manager_name = field.related.get_accessor_name()
579
580        return getattr(instance, manager_name)
581
582
583    def _delete_entry(self):
584        # choose order arbitrarily
585        entry, other_entry = self.entries.itervalues()
586        related_manager = self._get_related_manager(entry.instance)
587        related_manager.remove(other_entry.instance)
588
589
590    @classmethod
591    def create_instance(cls, input_dict, containing_collection):
592        other_name = containing_collection.unfixed_name
593        cls._check_for_required_fields(input_dict, (other_name,))
594        entry = containing_collection.fixed_entry
595        other_entry = containing_collection.resolve_link(input_dict[other_name])
596        related_manager = cls._get_related_manager(entry.instance)
597        related_manager.add(other_entry.instance)
598        return other_entry.instance
599
600
601    def update(self, input_dict):
602        pass
603
604
605class RelationshipCollection(Collection):
606    def __init__(self, request=None, fixed_entry=None):
607        if request is None:
608            request = fixed_entry._request
609        super(RelationshipCollection, self).__init__(request)
610
611        assert issubclass(self.entry_class, Relationship)
612        self.related_classes = self.entry_class.related_classes
613        self.fixed_name = None
614        self.fixed_entry = None
615        self.unfixed_name = None
616        self.related_manager = None
617
618        if fixed_entry is not None:
619            self._set_fixed_entry(fixed_entry)
620            entry_uri_arg = self.fixed_entry._uri_args().values()[0]
621            self._query_params[self.fixed_name] = entry_uri_arg
622
623
624    def _set_fixed_entry(self, entry):
625        """Set the fixed entry for this collection.
626
627        The entry must be an instance of one of the related entry classes.  This
628        method must be called before a relationship is used.  It gets called
629        either from the constructor (when collections are instantiated from
630        other resource handling code) or from read_query_parameters() (when a
631        request is made directly for the collection.
632        """
633        names = self.related_classes.keys()
634        if isinstance(entry, self.related_classes[names[0]]):
635            self.fixed_name, self.unfixed_name = names
636        else:
637            assert isinstance(entry, self.related_classes[names[1]])
638            self.unfixed_name, self.fixed_name = names
639        self.fixed_entry = entry
640        self.unfixed_class = self.related_classes[self.unfixed_name]
641        self.related_manager = self.entry_class._get_related_manager(
642                entry.instance)
643
644
645    def _query_parameters_accepted(self):
646        return [(name, 'Show relationships for this %s' % entry_class.__name__)
647                for name, entry_class
648                in self.related_classes.iteritems()]
649
650
651    def _resolve_query_param(self, name, uri_arg):
652        entry_class = self.related_classes[name]
653        return entry_class.from_uri_args(self._request, uri_arg)
654
655
656    def read_query_parameters(self, query_params):
657        super(RelationshipCollection, self).read_query_parameters(query_params)
658        if not self._query_params:
659            raise exceptions.BadRequest(
660                    'You must specify one of the parameters %s and %s'
661                    % tuple(self.related_classes.keys()))
662        query_items = self._query_params.items()
663        fixed_entry = self._resolve_query_param(*query_items[0])
664        self._set_fixed_entry(fixed_entry)
665
666        if len(query_items) > 1:
667            other_fixed_entry = self._resolve_query_param(*query_items[1])
668            self.related_manager = self.related_manager.filter(
669                    pk=other_fixed_entry.instance.id)
670
671
672    def _entry_from_instance(self, instance):
673        unfixed_entry = self.unfixed_class(self._request, instance)
674        entries = {self.fixed_name: self.fixed_entry,
675                   self.unfixed_name: unfixed_entry}
676        return self.entry_class(**entries)
677
678
679    def _fresh_queryset(self):
680        return self.related_manager.all()
681