• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for protorpc.service_handlers."""
19
20__author__ = 'rafek@google.com (Rafe Kaplan)'
21
22
23import cgi
24import cStringIO
25import os
26import re
27import sys
28import unittest
29import urllib
30
31from protorpc import messages
32from protorpc import protobuf
33from protorpc import protojson
34from protorpc import protourlencode
35from protorpc import message_types
36from protorpc import registry
37from protorpc import remote
38from protorpc import test_util
39from protorpc import util
40from protorpc import webapp_test_util
41from protorpc.webapp import forms
42from protorpc.webapp import service_handlers
43from protorpc.webapp.google_imports import webapp
44
45import mox
46
47package = 'test_package'
48
49
50class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
51                          test_util.TestCase):
52
53  MODULE = service_handlers
54
55
56class Enum1(messages.Enum):
57  """A test enum class."""
58
59  VAL1 = 1
60  VAL2 = 2
61  VAL3 = 3
62
63
64class Request1(messages.Message):
65  """A test request message type."""
66
67  integer_field = messages.IntegerField(1)
68  string_field = messages.StringField(2)
69  enum_field = messages.EnumField(Enum1, 3)
70
71
72class Response1(messages.Message):
73  """A test response message type."""
74
75  integer_field = messages.IntegerField(1)
76  string_field = messages.StringField(2)
77  enum_field = messages.EnumField(Enum1, 3)
78
79
80class SuperMessage(messages.Message):
81  """A test message with a nested message field."""
82
83  sub_message = messages.MessageField(Request1, 1)
84  sub_messages = messages.MessageField(Request1, 2, repeated=True)
85
86
87class SuperSuperMessage(messages.Message):
88  """A test message with two levels of nested."""
89
90  sub_message = messages.MessageField(SuperMessage, 1)
91  sub_messages = messages.MessageField(Request1, 2, repeated=True)
92
93
94class RepeatedMessage(messages.Message):
95  """A test message with a repeated field."""
96
97  ints = messages.IntegerField(1, repeated=True)
98  strings = messages.StringField(2, repeated=True)
99  enums = messages.EnumField(Enum1, 3, repeated=True)
100
101
102class Service(object):
103  """A simple service that takes a Request1 and returns Request2."""
104
105  @remote.method(Request1, Response1)
106  def method1(self, request):
107    response = Response1()
108    if hasattr(request, 'integer_field'):
109      response.integer_field = request.integer_field
110    if hasattr(request, 'string_field'):
111      response.string_field = request.string_field
112    if hasattr(request, 'enum_field'):
113      response.enum_field = request.enum_field
114    return response
115
116  @remote.method(RepeatedMessage, RepeatedMessage)
117  def repeated_method(self, request):
118    response = RepeatedMessage()
119    if hasattr(request, 'ints'):
120      response = request.ints
121    return response
122
123  def not_remote(self):
124    pass
125
126
127def VerifyResponse(test,
128                   response,
129                   expected_status,
130                   expected_status_message,
131                   expected_content,
132                   expected_content_type='application/x-www-form-urlencoded'):
133  def write(content):
134    if expected_content == '':
135      test.assertEquals(util.pad_string(''), content)
136    else:
137      test.assertNotEquals(-1, content.find(expected_content),
138                           'Expected to find:\n%s\n\nActual content: \n%s' % (
139                             expected_content, content))
140
141  def start_response(response, headers):
142    status, message = response.split(' ', 1)
143    test.assertEquals(expected_status, status)
144    test.assertEquals(expected_status_message, message)
145    for name, value in headers:
146      if name.lower() == 'content-type':
147        test.assertEquals(expected_content_type, value)
148    for name, value in headers:
149      if name.lower() == 'x-content-type-options':
150        test.assertEquals('nosniff', value)
151      elif name.lower() == 'content-type':
152        test.assertFalse(value.lower().startswith('text/html'))
153    return write
154
155  response.wsgi_write(start_response)
156
157
158class ServiceHandlerFactoryTest(test_util.TestCase):
159  """Tests for the service handler factory."""
160
161  def testAllRequestMappers(self):
162    """Test all_request_mappers method."""
163    configuration = service_handlers.ServiceHandlerFactory(Service)
164    mapper1 = service_handlers.RPCMapper(['whatever'], 'whatever', None)
165    mapper2 = service_handlers.RPCMapper(['whatever'], 'whatever', None)
166
167    configuration.add_request_mapper(mapper1)
168    self.assertEquals([mapper1], list(configuration.all_request_mappers()))
169
170    configuration.add_request_mapper(mapper2)
171    self.assertEquals([mapper1, mapper2],
172                      list(configuration.all_request_mappers()))
173
174  def testServiceFactory(self):
175    """Test that service_factory attribute is set."""
176    handler_factory = service_handlers.ServiceHandlerFactory(Service)
177    self.assertEquals(Service, handler_factory.service_factory)
178
179  def testFactoryMethod(self):
180    """Test that factory creates correct instance of class."""
181    factory = service_handlers.ServiceHandlerFactory(Service)
182    handler = factory()
183
184    self.assertTrue(isinstance(handler, service_handlers.ServiceHandler))
185    self.assertTrue(isinstance(handler.service, Service))
186
187  def testMapping(self):
188    """Test the mapping method."""
189    factory = service_handlers.ServiceHandlerFactory(Service)
190    path, mapped_factory = factory.mapping('/my_service')
191
192    self.assertEquals(r'(/my_service)' + service_handlers._METHOD_PATTERN, path)
193    self.assertEquals(id(factory), id(mapped_factory))
194    match = re.match(path, '/my_service.my_method')
195    self.assertEquals('/my_service', match.group(1))
196    self.assertEquals('my_method', match.group(2))
197
198    path, mapped_factory = factory.mapping('/my_service/nested')
199    self.assertEquals('(/my_service/nested)' +
200                      service_handlers._METHOD_PATTERN, path)
201    match = re.match(path, '/my_service/nested.my_method')
202    self.assertEquals('/my_service/nested', match.group(1))
203    self.assertEquals('my_method', match.group(2))
204
205  def testRegexMapping(self):
206    """Test the mapping method using a regex."""
207    factory = service_handlers.ServiceHandlerFactory(Service)
208    path, mapped_factory = factory.mapping('.*/my_service')
209
210    self.assertEquals(r'(.*/my_service)' + service_handlers._METHOD_PATTERN, path)
211    self.assertEquals(id(factory), id(mapped_factory))
212    match = re.match(path, '/whatever_preceeds/my_service.my_method')
213    self.assertEquals('/whatever_preceeds/my_service', match.group(1))
214    self.assertEquals('my_method', match.group(2))
215    match = re.match(path, '/something_else/my_service.my_other_method')
216    self.assertEquals('/something_else/my_service', match.group(1))
217    self.assertEquals('my_other_method', match.group(2))
218
219  def testMapping_BadPath(self):
220    """Test bad parameterse to the mapping method."""
221    factory = service_handlers.ServiceHandlerFactory(Service)
222    self.assertRaises(ValueError, factory.mapping, '/my_service/')
223
224  def testDefault(self):
225    """Test the default factory convenience method."""
226    handler_factory = service_handlers.ServiceHandlerFactory.default(
227        Service,
228        parameter_prefix='my_prefix.')
229
230    self.assertEquals(Service, handler_factory.service_factory)
231
232    mappers = handler_factory.all_request_mappers()
233
234    # Verify Protobuf encoded mapper.
235    protobuf_mapper = next(mappers)
236    self.assertTrue(isinstance(protobuf_mapper,
237                               service_handlers.ProtobufRPCMapper))
238
239    # Verify JSON encoded mapper.
240    json_mapper = next(mappers)
241    self.assertTrue(isinstance(json_mapper,
242                               service_handlers.JSONRPCMapper))
243
244    # Should have no more mappers.
245    self.assertRaises(StopIteration, mappers.next)
246
247
248class ServiceHandlerTest(webapp_test_util.RequestHandlerTestBase):
249  """Test the ServiceHandler class."""
250
251  def setUp(self):
252    self.mox = mox.Mox()
253    self.service_factory = Service
254    self.remote_host = 'remote.host.com'
255    self.server_host = 'server.host.com'
256    self.ResetRequestHandler()
257
258    self.request = Request1()
259    self.request.integer_field = 1
260    self.request.string_field = 'a'
261    self.request.enum_field = Enum1.VAL1
262
263  def ResetRequestHandler(self):
264    super(ServiceHandlerTest, self).setUp()
265
266  def CreateService(self):
267    return self.service_factory()
268
269  def CreateRequestHandler(self):
270    self.rpc_mapper1 = self.mox.CreateMock(service_handlers.RPCMapper)
271    self.rpc_mapper1.http_methods = set(['POST'])
272    self.rpc_mapper1.content_types = set(['application/x-www-form-urlencoded'])
273    self.rpc_mapper1.default_content_type = 'application/x-www-form-urlencoded'
274    self.rpc_mapper2 = self.mox.CreateMock(service_handlers.RPCMapper)
275    self.rpc_mapper2.http_methods = set(['GET'])
276    self.rpc_mapper2.content_types = set(['application/json'])
277    self.rpc_mapper2.default_content_type = 'application/json'
278    self.factory = service_handlers.ServiceHandlerFactory(
279        self.CreateService)
280    self.factory.add_request_mapper(self.rpc_mapper1)
281    self.factory.add_request_mapper(self.rpc_mapper2)
282    return self.factory()
283
284  def GetEnvironment(self):
285    """Create handler to test."""
286    environ = super(ServiceHandlerTest, self).GetEnvironment()
287    if self.remote_host:
288      environ['REMOTE_HOST'] = self.remote_host
289    if self.server_host:
290      environ['SERVER_HOST'] = self.server_host
291    return environ
292
293  def VerifyResponse(self, *args, **kwargs):
294    VerifyResponse(self,
295                   self.response,
296                   *args, **kwargs)
297
298  def ExpectRpcError(self, mapper, state, error_message, error_name=None):
299    mapper.build_response(self.handler,
300                          remote.RpcStatus(state=state,
301                                           error_message=error_message,
302                                           error_name=error_name))
303
304  def testRedirect(self):
305    """Test that redirection is disabled."""
306    self.assertRaises(NotImplementedError, self.handler.redirect, '/')
307
308  def testFirstMapper(self):
309    """Make sure service attribute works when matches first RPCMapper."""
310    self.rpc_mapper1.build_request(
311        self.handler, Request1).AndReturn(self.request)
312
313    def build_response(handler, response):
314      output = '%s %s %s' % (response.integer_field,
315                             response.string_field,
316                             response.enum_field)
317      handler.response.headers['content-type'] = (
318        'application/x-www-form-urlencoded')
319      handler.response.out.write(output)
320    self.rpc_mapper1.build_response(
321        self.handler, mox.IsA(Response1)).WithSideEffects(build_response)
322
323    self.mox.ReplayAll()
324
325    self.handler.handle('POST', '/my_service', 'method1')
326
327    self.VerifyResponse('200', 'OK', '1 a VAL1')
328
329    self.mox.VerifyAll()
330
331  def testSecondMapper(self):
332    """Make sure service attribute works when matches first RPCMapper.
333
334    Demonstrates the multiplicity of the RPCMapper configuration.
335    """
336    self.rpc_mapper2.build_request(
337        self.handler, Request1).AndReturn(self.request)
338
339    def build_response(handler, response):
340      output = '%s %s %s' % (response.integer_field,
341                             response.string_field,
342                             response.enum_field)
343      handler.response.headers['content-type'] = (
344        'application/x-www-form-urlencoded')
345      handler.response.out.write(output)
346    self.rpc_mapper2.build_response(
347        self.handler, mox.IsA(Response1)).WithSideEffects(build_response)
348
349    self.mox.ReplayAll()
350
351    self.handler.request.headers['Content-Type'] = 'application/json'
352    self.handler.handle('GET', '/my_service', 'method1')
353
354    self.VerifyResponse('200', 'OK', '1 a VAL1')
355
356    self.mox.VerifyAll()
357
358  def testCaseInsensitiveContentType(self):
359    """Ensure that matching content-type is case insensitive."""
360    request = Request1()
361    request.integer_field = 1
362    request.string_field = 'a'
363    request.enum_field = Enum1.VAL1
364    self.rpc_mapper1.build_request(self.handler,
365                                   Request1).AndReturn(self.request)
366
367    def build_response(handler, response):
368      output = '%s %s %s' % (response.integer_field,
369                             response.string_field,
370                             response.enum_field)
371      handler.response.out.write(output)
372      handler.response.headers['content-type'] = 'text/plain'
373    self.rpc_mapper1.build_response(
374        self.handler, mox.IsA(Response1)).WithSideEffects(build_response)
375
376    self.mox.ReplayAll()
377
378    self.handler.request.headers['Content-Type'] = ('ApPlIcAtIoN/'
379                                                    'X-wWw-FoRm-UrLeNcOdEd')
380
381    self.handler.handle('POST', '/my_service', 'method1')
382
383    self.VerifyResponse('200', 'OK', '1 a VAL1', 'text/plain')
384
385    self.mox.VerifyAll()
386
387  def testContentTypeWithParameters(self):
388    """Test that content types have parameters parsed out."""
389    request = Request1()
390    request.integer_field = 1
391    request.string_field = 'a'
392    request.enum_field = Enum1.VAL1
393    self.rpc_mapper1.build_request(self.handler,
394                                   Request1).AndReturn(self.request)
395
396    def build_response(handler, response):
397      output = '%s %s %s' % (response.integer_field,
398                             response.string_field,
399                             response.enum_field)
400      handler.response.headers['content-type'] = (
401        'application/x-www-form-urlencoded')
402      handler.response.out.write(output)
403    self.rpc_mapper1.build_response(
404        self.handler, mox.IsA(Response1)).WithSideEffects(build_response)
405
406    self.mox.ReplayAll()
407
408    self.handler.request.headers['Content-Type'] = ('application/'
409                                                    'x-www-form-urlencoded' +
410                                                    '; a=b; c=d')
411
412    self.handler.handle('POST', '/my_service', 'method1')
413
414    self.VerifyResponse('200', 'OK', '1 a VAL1')
415
416    self.mox.VerifyAll()
417
418  def testContentFromHeaderOnly(self):
419    """Test getting content-type from HTTP_CONTENT_TYPE directly.
420
421    Some bad web server implementations might decide not to set CONTENT_TYPE for
422    POST requests where there is an empty body.  In these cases, need to get
423    content-type directly from webob environ key HTTP_CONTENT_TYPE.
424    """
425    request = Request1()
426    request.integer_field = 1
427    request.string_field = 'a'
428    request.enum_field = Enum1.VAL1
429    self.rpc_mapper1.build_request(self.handler,
430                                   Request1).AndReturn(self.request)
431
432    def build_response(handler, response):
433      output = '%s %s %s' % (response.integer_field,
434                             response.string_field,
435                             response.enum_field)
436      handler.response.headers['Content-Type'] = (
437        'application/x-www-form-urlencoded')
438      handler.response.out.write(output)
439    self.rpc_mapper1.build_response(
440        self.handler, mox.IsA(Response1)).WithSideEffects(build_response)
441
442    self.mox.ReplayAll()
443
444    self.handler.request.headers['Content-Type'] = None
445    self.handler.request.environ['HTTP_CONTENT_TYPE'] = (
446      'application/x-www-form-urlencoded')
447
448    self.handler.handle('POST', '/my_service', 'method1')
449
450    self.VerifyResponse('200', 'OK', '1 a VAL1',
451                        'application/x-www-form-urlencoded')
452
453    self.mox.VerifyAll()
454
455  def testRequestState(self):
456    """Make sure request state is passed in to handler that supports it."""
457    class ServiceWithState(object):
458
459      initialize_request_state = self.mox.CreateMockAnything()
460
461      @remote.method(Request1, Response1)
462      def method1(self, request):
463        return Response1()
464
465    self.service_factory = ServiceWithState
466
467    # Reset handler with new service type.
468    self.ResetRequestHandler()
469
470    self.rpc_mapper1.build_request(
471        self.handler, Request1).AndReturn(Request1())
472
473    def build_response(handler, response):
474      handler.response.headers['Content-Type'] = (
475        'application/x-www-form-urlencoded')
476      handler.response.out.write('whatever')
477    self.rpc_mapper1.build_response(
478        self.handler, mox.IsA(Response1)).WithSideEffects(build_response)
479
480    def verify_state(state):
481      return (
482        'remote.host.com' ==  state.remote_host and
483        '127.0.0.1' == state.remote_address and
484        'server.host.com' == state.server_host and
485        8080 == state.server_port and
486        'POST' == state.http_method and
487        '/my_service' == state.service_path and
488        'application/x-www-form-urlencoded' == state.headers['content-type'] and
489        'dev_appserver_login="test:test@example.com:True"' ==
490        state.headers['cookie'])
491    ServiceWithState.initialize_request_state(mox.Func(verify_state))
492
493    self.mox.ReplayAll()
494
495    self.handler.handle('POST', '/my_service', 'method1')
496
497    self.VerifyResponse('200', 'OK', 'whatever')
498
499    self.mox.VerifyAll()
500
501  def testRequestState_MissingHosts(self):
502    """Make sure missing state environment values are handled gracefully."""
503    class ServiceWithState(object):
504
505      initialize_request_state = self.mox.CreateMockAnything()
506
507      @remote.method(Request1, Response1)
508      def method1(self, request):
509        return Response1()
510
511    self.service_factory = ServiceWithState
512    self.remote_host = None
513    self.server_host = None
514
515    # Reset handler with new service type.
516    self.ResetRequestHandler()
517
518    self.rpc_mapper1.build_request(
519        self.handler, Request1).AndReturn(Request1())
520
521    def build_response(handler, response):
522      handler.response.headers['Content-Type'] = (
523        'application/x-www-form-urlencoded')
524      handler.response.out.write('whatever')
525    self.rpc_mapper1.build_response(
526        self.handler, mox.IsA(Response1)).WithSideEffects(build_response)
527
528    def verify_state(state):
529      return (None is state.remote_host and
530              '127.0.0.1' == state.remote_address and
531              None is state.server_host and
532              8080 == state.server_port)
533    ServiceWithState.initialize_request_state(mox.Func(verify_state))
534
535    self.mox.ReplayAll()
536
537    self.handler.handle('POST', '/my_service', 'method1')
538
539    self.VerifyResponse('200', 'OK', 'whatever')
540
541    self.mox.VerifyAll()
542
543  def testNoMatch_UnknownHTTPMethod(self):
544    """Test what happens when no RPCMapper matches."""
545    self.mox.ReplayAll()
546
547    self.handler.handle('UNKNOWN', '/my_service', 'does_not_matter')
548
549    self.VerifyResponse('405',
550                        'Unsupported HTTP method: UNKNOWN',
551                        'Method Not Allowed',
552                        'text/plain; charset=utf-8')
553
554    self.mox.VerifyAll()
555
556  def testNoMatch_GetNotSupported(self):
557    """Test what happens when GET is not supported."""
558    self.mox.ReplayAll()
559
560    self.handler.handle('GET', '/my_service', 'method1')
561
562    self.VerifyResponse('405',
563                        'Method Not Allowed',
564                        '/my_service.method1 is a ProtoRPC method.\n\n'
565                        'Service %s.Service\n\n'
566                        'More about ProtoRPC: '
567                        'http://code.google.com/p/google-protorpc' %
568                        (__name__,),
569                        'text/plain; charset=utf-8')
570
571    self.mox.VerifyAll()
572
573  def testNoMatch_UnknownContentType(self):
574    """Test what happens when no RPCMapper matches."""
575    self.mox.ReplayAll()
576
577    self.handler.request.headers['Content-Type'] = 'image/png'
578    self.handler.handle('POST', '/my_service', 'method1')
579
580    self.VerifyResponse('415',
581                        'Unsupported content-type: image/png',
582                        'Unsupported Media Type',
583                        'text/plain; charset=utf-8')
584
585    self.mox.VerifyAll()
586
587  def testNoMatch_NoContentType(self):
588    """Test what happens when no RPCMapper matches.."""
589    self.mox.ReplayAll()
590
591    self.handler.request.environ.pop('HTTP_CONTENT_TYPE', None)
592    self.handler.request.headers.pop('Content-Type', None)
593    self.handler.handle('/my_service', 'POST', 'method1')
594
595    self.VerifyResponse('400', 'Invalid RPC request: missing content-type',
596                        'Bad Request',
597                        'text/plain; charset=utf-8')
598
599    self.mox.VerifyAll()
600
601  def testNoSuchMethod(self):
602    """When service method not found."""
603    self.ExpectRpcError(self.rpc_mapper1,
604                        remote.RpcState.METHOD_NOT_FOUND_ERROR,
605                        'Unrecognized RPC method: no_such_method')
606
607    self.mox.ReplayAll()
608
609    self.handler.handle('POST', '/my_service', 'no_such_method')
610
611    self.VerifyResponse('400', 'Unrecognized RPC method: no_such_method', '')
612
613    self.mox.VerifyAll()
614
615  def testNoSuchRemoteMethod(self):
616    """When service method exists but is not remote."""
617    self.ExpectRpcError(self.rpc_mapper1,
618                        remote.RpcState.METHOD_NOT_FOUND_ERROR,
619                        'Unrecognized RPC method: not_remote')
620
621    self.mox.ReplayAll()
622
623    self.handler.handle('POST', '/my_service', 'not_remote')
624
625    self.VerifyResponse('400', 'Unrecognized RPC method: not_remote', '')
626
627    self.mox.VerifyAll()
628
629  def testRequestError(self):
630    """RequestError handling."""
631    def build_request(handler, request):
632      raise service_handlers.RequestError('This is a request error')
633    self.rpc_mapper1.build_request(
634        self.handler, Request1).WithSideEffects(build_request)
635
636    self.ExpectRpcError(self.rpc_mapper1,
637                        remote.RpcState.REQUEST_ERROR,
638                        'Error parsing ProtoRPC request '
639                        '(This is a request error)')
640
641    self.mox.ReplayAll()
642
643    self.handler.handle('POST', '/my_service', 'method1')
644
645    self.VerifyResponse('400',
646                        'Error parsing ProtoRPC request '
647                        '(This is a request error)',
648                        '')
649
650
651    self.mox.VerifyAll()
652
653  def testDecodeError(self):
654    """DecodeError handling."""
655    def build_request(handler, request):
656      raise messages.DecodeError('This is a decode error')
657    self.rpc_mapper1.build_request(
658        self.handler, Request1).WithSideEffects(build_request)
659
660    self.ExpectRpcError(self.rpc_mapper1,
661                        remote.RpcState.REQUEST_ERROR,
662                        r'Error parsing ProtoRPC request '
663                        r'(This is a decode error)')
664
665    self.mox.ReplayAll()
666
667    self.handler.handle('POST', '/my_service', 'method1')
668
669    self.VerifyResponse('400',
670                        'Error parsing ProtoRPC request '
671                        '(This is a decode error)',
672                        '')
673
674    self.mox.VerifyAll()
675
676  def testResponseException(self):
677    """Test what happens when build_response raises ResponseError."""
678    self.rpc_mapper1.build_request(
679        self.handler, Request1).AndReturn(self.request)
680
681    self.rpc_mapper1.build_response(
682        self.handler, mox.IsA(Response1)).AndRaise(
683        service_handlers.ResponseError)
684
685    self.ExpectRpcError(self.rpc_mapper1,
686                        remote.RpcState.SERVER_ERROR,
687                        'Internal Server Error')
688
689    self.mox.ReplayAll()
690
691    self.handler.handle('POST', '/my_service', 'method1')
692
693    self.VerifyResponse('500', 'Internal Server Error', '')
694
695    self.mox.VerifyAll()
696
697  def testGet(self):
698    """Test that GET goes to 'handle' properly."""
699    self.handler.handle = self.mox.CreateMockAnything()
700    self.handler.handle('GET', '/my_service', 'method1')
701    self.handler.handle('GET', '/my_other_service', 'method2')
702
703    self.mox.ReplayAll()
704
705    self.handler.get('/my_service', 'method1')
706    self.handler.get('/my_other_service', 'method2')
707
708    self.mox.VerifyAll()
709
710  def testPost(self):
711    """Test that POST goes to 'handle' properly."""
712    self.handler.handle = self.mox.CreateMockAnything()
713    self.handler.handle('POST', '/my_service', 'method1')
714    self.handler.handle('POST', '/my_other_service', 'method2')
715
716    self.mox.ReplayAll()
717
718    self.handler.post('/my_service', 'method1')
719    self.handler.post('/my_other_service', 'method2')
720
721    self.mox.VerifyAll()
722
723  def testGetNoMethod(self):
724    self.handler.get('/my_service', '')
725    self.assertEquals(405, self.handler.response.status)
726    self.assertEquals(
727      util.pad_string('/my_service is a ProtoRPC service.\n\n'
728                      'Service %s.Service\n\n'
729                      'More about ProtoRPC: '
730                      'http://code.google.com/p/google-protorpc\n' %
731                      __name__),
732      self.handler.response.out.getvalue())
733    self.assertEquals(
734        'nosniff',
735        self.handler.response.headers['x-content-type-options'])
736
737  def testGetNotSupported(self):
738    self.handler.get('/my_service', 'method1')
739    self.assertEquals(405, self.handler.response.status)
740    expected_message = ('/my_service.method1 is a ProtoRPC method.\n\n'
741                        'Service %s.Service\n\n'
742                        'More about ProtoRPC: '
743                        'http://code.google.com/p/google-protorpc\n' %
744                        __name__)
745    self.assertEquals(util.pad_string(expected_message),
746                      self.handler.response.out.getvalue())
747    self.assertEquals(
748        'nosniff',
749        self.handler.response.headers['x-content-type-options'])
750
751  def testGetUnknownContentType(self):
752    self.handler.request.headers['content-type'] = 'image/png'
753    self.handler.get('/my_service', 'method1')
754    self.assertEquals(415, self.handler.response.status)
755    self.assertEquals(
756      util.pad_string('/my_service.method1 is a ProtoRPC method.\n\n'
757                      'Service %s.Service\n\n'
758                      'More about ProtoRPC: '
759                      'http://code.google.com/p/google-protorpc\n' %
760                      __name__),
761      self.handler.response.out.getvalue())
762    self.assertEquals(
763        'nosniff',
764        self.handler.response.headers['x-content-type-options'])
765
766
767class MissingContentLengthTests(ServiceHandlerTest):
768  """Test for when content-length is not set in the environment.
769
770  This test moves CONTENT_LENGTH from the environment to the
771  content-length header.
772  """
773
774  def GetEnvironment(self):
775    environment = super(MissingContentLengthTests, self).GetEnvironment()
776    content_length = str(environment.pop('CONTENT_LENGTH', '0'))
777    environment['HTTP_CONTENT_LENGTH'] = content_length
778    return environment
779
780
781class MissingContentTypeTests(ServiceHandlerTest):
782  """Test for when content-type is not set in the environment.
783
784  This test moves CONTENT_TYPE from the environment to the
785  content-type header.
786  """
787
788  def GetEnvironment(self):
789    environment = super(MissingContentTypeTests, self).GetEnvironment()
790    content_type = str(environment.pop('CONTENT_TYPE', ''))
791    environment['HTTP_CONTENT_TYPE'] = content_type
792    return environment
793
794
795class RPCMapperTestBase(test_util.TestCase):
796
797  def setUp(self):
798    """Set up test framework."""
799    self.Reinitialize()
800
801  def Reinitialize(self, input='',
802                   get=False,
803                   path_method='method1',
804                   content_type='text/plain'):
805    """Allows reinitialization of test with custom input values and POST.
806
807    Args:
808      input: Query string or POST input.
809      get: Use GET method if True.  Use POST if False.
810    """
811    self.factory = service_handlers.ServiceHandlerFactory(Service)
812
813    self.service_handler = service_handlers.ServiceHandler(self.factory,
814                                                           Service())
815    self.service_handler.remote_method = path_method
816    request_path = '/servicepath'
817    if path_method:
818      request_path += '/' + path_method
819    if get:
820      request_path += '?' + input
821
822    if get:
823      environ = {'wsgi.input': cStringIO.StringIO(''),
824                 'CONTENT_LENGTH': '0',
825                 'QUERY_STRING': input,
826                 'REQUEST_METHOD': 'GET',
827                 'PATH_INFO': request_path,
828                }
829      self.service_handler.method = 'GET'
830    else:
831      environ = {'wsgi.input': cStringIO.StringIO(input),
832                 'CONTENT_LENGTH': str(len(input)),
833                 'QUERY_STRING': '',
834                 'REQUEST_METHOD': 'POST',
835                 'PATH_INFO': request_path,
836                }
837      self.service_handler.method = 'POST'
838
839    self.request = webapp.Request(environ)
840
841    self.response = webapp.Response()
842
843    self.service_handler.initialize(self.request, self.response)
844
845    self.service_handler.request.headers['Content-Type'] = content_type
846
847
848class RPCMapperTest(RPCMapperTestBase, webapp_test_util.RequestHandlerTestBase):
849  """Test the RPCMapper base class."""
850
851  def setUp(self):
852    RPCMapperTestBase.setUp(self)
853    webapp_test_util.RequestHandlerTestBase.setUp(self)
854    self.mox = mox.Mox()
855    self.protocol = self.mox.CreateMockAnything()
856
857  def GetEnvironment(self):
858    """Get environment.
859
860    Return bogus content in body.
861
862    Returns:
863      dict of CGI environment.
864    """
865    environment = super(RPCMapperTest, self).GetEnvironment()
866    environment['wsgi.input'] = cStringIO.StringIO('my body')
867    environment['CONTENT_LENGTH'] = len('my body')
868    return environment
869
870  def testContentTypes_JustDefault(self):
871    """Test content type attributes."""
872    self.mox.ReplayAll()
873
874    mapper = service_handlers.RPCMapper(['GET', 'POST'],
875                                        'my-content-type',
876                                        self.protocol)
877
878    self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods)
879    self.assertEquals('my-content-type', mapper.default_content_type)
880    self.assertEquals(frozenset(['my-content-type']),
881                                mapper.content_types)
882
883    self.mox.VerifyAll()
884
885  def testContentTypes_Extended(self):
886    """Test content type attributes."""
887    self.mox.ReplayAll()
888
889    mapper = service_handlers.RPCMapper(['GET', 'POST'],
890                                        'my-content-type',
891                                        self.protocol,
892                                        content_types=['a', 'b'])
893
894    self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods)
895    self.assertEquals('my-content-type', mapper.default_content_type)
896    self.assertEquals(frozenset(['my-content-type', 'a', 'b']),
897                                mapper.content_types)
898
899    self.mox.VerifyAll()
900
901  def testBuildRequest(self):
902    """Test building a request."""
903    expected_request = Request1()
904    self.protocol.decode_message(Request1,
905                                 'my body').AndReturn(expected_request)
906
907    self.mox.ReplayAll()
908
909    mapper = service_handlers.RPCMapper(['POST'],
910                                        'my-content-type',
911                                        self.protocol)
912
913    request = mapper.build_request(self.handler, Request1)
914
915    self.assertTrue(expected_request is request)
916
917  def testBuildRequest_ValidationError(self):
918    """Test building a request generating a validation error."""
919    expected_request = Request1()
920    self.protocol.decode_message(
921        Request1, 'my body').AndRaise(messages.ValidationError('xyz'))
922
923    self.mox.ReplayAll()
924
925    mapper = service_handlers.RPCMapper(['POST'],
926                                        'my-content-type',
927                                        self.protocol)
928
929    self.assertRaisesWithRegexpMatch(
930        service_handlers.RequestError,
931        'Unable to parse request content: xyz',
932        mapper.build_request,
933        self.handler,
934        Request1)
935
936  def testBuildRequest_DecodeError(self):
937    """Test building a request generating a decode error."""
938    expected_request = Request1()
939    self.protocol.decode_message(
940        Request1, 'my body').AndRaise(messages.DecodeError('xyz'))
941
942    self.mox.ReplayAll()
943
944    mapper = service_handlers.RPCMapper(['POST'],
945                                        'my-content-type',
946                                        self.protocol)
947
948    self.assertRaisesWithRegexpMatch(
949        service_handlers.RequestError,
950        'Unable to parse request content: xyz',
951        mapper.build_request,
952        self.handler,
953        Request1)
954
955  def testBuildResponse(self):
956    """Test building a response."""
957    response = Response1()
958    self.protocol.encode_message(response).AndReturn('encoded')
959
960    self.mox.ReplayAll()
961
962    mapper = service_handlers.RPCMapper(['POST'],
963                                        'my-content-type',
964                                        self.protocol)
965
966    request = mapper.build_response(self.handler, response)
967
968    self.assertEquals('my-content-type',
969                      self.handler.response.headers['Content-Type'])
970    self.assertEquals('encoded', self.handler.response.out.getvalue())
971
972  def testBuildResponse(self):
973    """Test building a response."""
974    response = Response1()
975    self.protocol.encode_message(response).AndRaise(
976        messages.ValidationError('xyz'))
977
978    self.mox.ReplayAll()
979
980    mapper = service_handlers.RPCMapper(['POST'],
981                                        'my-content-type',
982                                        self.protocol)
983
984    self.assertRaisesWithRegexpMatch(service_handlers.ResponseError,
985                                     'Unable to encode message: xyz',
986                                     mapper.build_response,
987                                     self.handler,
988                                     response)
989
990
991class ProtocolMapperTestBase(object):
992  """Base class for basic protocol mapper tests."""
993
994  def setUp(self):
995    """Reinitialize test specifically for protocol buffer mapper."""
996    super(ProtocolMapperTestBase, self).setUp()
997    self.Reinitialize(path_method='my_method',
998                      content_type='application/x-google-protobuf')
999
1000    self.request_message = Request1()
1001    self.request_message.integer_field = 1
1002    self.request_message.string_field = u'something'
1003    self.request_message.enum_field = Enum1.VAL1
1004
1005    self.response_message = Response1()
1006    self.response_message.integer_field = 1
1007    self.response_message.string_field = u'something'
1008    self.response_message.enum_field = Enum1.VAL1
1009
1010  def testBuildRequest(self):
1011    """Test request building."""
1012    self.Reinitialize(self.protocol.encode_message(self.request_message),
1013                      content_type=self.content_type)
1014
1015    mapper = self.mapper()
1016    parsed_request = mapper.build_request(self.service_handler,
1017                                          Request1)
1018    self.assertEquals(self.request_message, parsed_request)
1019
1020  def testBuildResponse(self):
1021    """Test response building."""
1022
1023    mapper = self.mapper()
1024    mapper.build_response(self.service_handler, self.response_message)
1025    self.assertEquals(self.protocol.encode_message(self.response_message),
1026                      self.service_handler.response.out.getvalue())
1027
1028  def testWholeRequest(self):
1029    """Test the basic flow of a request with mapper class."""
1030    body = self.protocol.encode_message(self.request_message)
1031    self.Reinitialize(input=body,
1032                      content_type=self.content_type)
1033    self.factory.add_request_mapper(self.mapper())
1034    self.service_handler.handle('POST', '/my_service', 'method1')
1035    VerifyResponse(self,
1036                   self.service_handler.response,
1037                   '200',
1038                   'OK',
1039                   self.protocol.encode_message(self.response_message),
1040                   self.content_type)
1041
1042
1043class URLEncodedRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase):
1044  """Test the URL encoded RPC mapper."""
1045
1046  content_type = 'application/x-www-form-urlencoded'
1047  protocol = protourlencode
1048  mapper = service_handlers.URLEncodedRPCMapper
1049
1050  def testBuildRequest_Prefix(self):
1051    """Test building request with parameter prefix."""
1052    self.Reinitialize(urllib.urlencode([('prefix_integer_field', '10'),
1053                                        ('prefix_string_field', 'a string'),
1054                                        ('prefix_enum_field', 'VAL1'),
1055                                       ]),
1056                      self.content_type)
1057
1058    url_encoded_mapper = service_handlers.URLEncodedRPCMapper(
1059        parameter_prefix='prefix_')
1060    request = url_encoded_mapper.build_request(self.service_handler,
1061                                               Request1)
1062    self.assertEquals(10, request.integer_field)
1063    self.assertEquals('a string', request.string_field)
1064    self.assertEquals(Enum1.VAL1, request.enum_field)
1065
1066  def testBuildRequest_DecodeError(self):
1067    """Test trying to build request that causes a decode error."""
1068    self.Reinitialize(urllib.urlencode((('integer_field', '10'),
1069                                        ('integer_field', '20'),
1070                                        )),
1071                      content_type=self.content_type)
1072
1073    url_encoded_mapper = service_handlers.URLEncodedRPCMapper()
1074
1075    self.assertRaises(service_handlers.RequestError,
1076                      url_encoded_mapper.build_request,
1077                      self.service_handler,
1078                      Service.method1.remote.request_type)
1079
1080  def testBuildResponse_Prefix(self):
1081    """Test building a response with parameter prefix."""
1082    response = Response1()
1083    response.integer_field = 10
1084    response.string_field = u'a string'
1085    response.enum_field = Enum1.VAL3
1086
1087    url_encoded_mapper = service_handlers.URLEncodedRPCMapper(
1088        parameter_prefix='prefix_')
1089
1090    url_encoded_mapper.build_response(self.service_handler, response)
1091    self.assertEquals('application/x-www-form-urlencoded',
1092                      self.response.headers['content-type'])
1093    self.assertEquals(cgi.parse_qs(self.response.out.getvalue(), True, True),
1094                      {'prefix_integer_field': ['10'],
1095                       'prefix_string_field': [u'a string'],
1096                       'prefix_enum_field': ['VAL3'],
1097                      })
1098
1099
1100class ProtobufRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase):
1101  """Test the protobuf encoded RPC mapper."""
1102
1103  content_type = 'application/octet-stream'
1104  protocol = protobuf
1105  mapper = service_handlers.ProtobufRPCMapper
1106
1107
1108class JSONRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase):
1109  """Test the URL encoded RPC mapper."""
1110
1111  content_type = 'application/json'
1112  protocol = protojson
1113  mapper = service_handlers.JSONRPCMapper
1114
1115
1116class MyService(remote.Service):
1117
1118  def __init__(self, value='default'):
1119    self.value = value
1120
1121
1122class ServiceMappingTest(test_util.TestCase):
1123
1124  def CheckFormMappings(self, mapping, registry_path='/protorpc'):
1125    """Check to make sure that form mapping is configured as expected.
1126
1127    Args:
1128      mapping: Mapping that should contain forms handlers.
1129    """
1130    pattern, factory = mapping[0]
1131    self.assertEquals('%s/form(?:/)?' % registry_path, pattern)
1132    handler = factory()
1133    self.assertTrue(isinstance(handler, forms.FormsHandler))
1134    self.assertEquals(registry_path, handler.registry_path)
1135
1136    pattern, factory = mapping[1]
1137    self.assertEquals('%s/form/(.+)' % registry_path, pattern)
1138    self.assertEquals(forms.ResourceHandler, factory)
1139
1140
1141  def DoMappingTest(self,
1142                    services,
1143                    registry_path='/myreg',
1144                    expected_paths=None):
1145    mapped_services = mapping = service_handlers.service_mapping(services,
1146                                                                 registry_path)
1147    if registry_path:
1148      form_mapping = mapping[:2]
1149      mapped_registry_path, mapped_registry_factory = mapping[-1]
1150      mapped_services = mapping[2:-1]
1151      self.CheckFormMappings(form_mapping, registry_path=registry_path)
1152
1153      self.assertEquals(r'(%s)%s' % (registry_path,
1154                                     service_handlers._METHOD_PATTERN),
1155                        mapped_registry_path)
1156      self.assertEquals(registry.RegistryService,
1157                        mapped_registry_factory.service_factory.service_class)
1158
1159      # Verify registry knows about other services.
1160      expected_registry = {registry_path: registry.RegistryService}
1161      for path, factory in dict(services).items():
1162        if isinstance(factory, type) and issubclass(factory, remote.Service):
1163          expected_registry[path] = factory
1164        else:
1165          expected_registry[path] = factory.service_class
1166      self.assertEquals(expected_registry,
1167                        mapped_registry_factory().service.registry)
1168
1169    # Verify that services are mapped to URL.
1170    self.assertEquals(len(services), len(mapped_services))
1171    for path, service in dict(services).items():
1172      mapped_path = r'(%s)%s' %  (path, service_handlers._METHOD_PATTERN)
1173      mapped_factory = dict(mapped_services)[mapped_path]
1174      self.assertEquals(service, mapped_factory.service_factory)
1175
1176  def testServiceMapping_Empty(self):
1177    """Test an empty service mapping."""
1178    self.DoMappingTest({})
1179
1180  def testServiceMapping_ByClass(self):
1181    """Test mapping a service by class."""
1182    self.DoMappingTest({'/my-service': MyService})
1183
1184  def testServiceMapping_ByFactory(self):
1185    """Test mapping a service by factory."""
1186    self.DoMappingTest({'/my-service': MyService.new_factory('new-value')})
1187
1188  def testServiceMapping_ByList(self):
1189    """Test mapping a service by factory."""
1190    self.DoMappingTest(
1191      [('/my-service1', MyService.new_factory('service1')),
1192       ('/my-service2', MyService.new_factory('service2')),
1193      ])
1194
1195  def testServiceMapping_NoRegistry(self):
1196    """Test mapping a service by class."""
1197    mapping = self.DoMappingTest({'/my-service': MyService}, None)
1198
1199  def testDefaultMappingWithClass(self):
1200    """Test setting path just from the class.
1201
1202    Path of the mapping will be the fully qualified ProtoRPC service name with
1203    '.' replaced with '/'.  For example:
1204
1205      com.nowhere.service.TheService -> /com/nowhere/service/TheService
1206    """
1207    mapping = service_handlers.service_mapping([MyService])
1208    mapped_services = mapping[2:-1]
1209    self.assertEquals(1, len(mapped_services))
1210    path, factory = mapped_services[0]
1211
1212    self.assertEquals(
1213      r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN,
1214      path)
1215    self.assertEquals(MyService, factory.service_factory)
1216
1217  def testDefaultMappingWithFactory(self):
1218    mapping = service_handlers.service_mapping(
1219      [MyService.new_factory('service1')])
1220    mapped_services = mapping[2:-1]
1221    self.assertEquals(1, len(mapped_services))
1222    path, factory = mapped_services[0]
1223
1224    self.assertEquals(
1225      r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN,
1226      path)
1227    self.assertEquals(MyService, factory.service_factory.service_class)
1228
1229  def testMappingDuplicateExplicitServiceName(self):
1230    self.assertRaisesWithRegexpMatch(
1231      service_handlers.ServiceConfigurationError,
1232      "Path '/my_path' is already defined in service mapping",
1233      service_handlers.service_mapping,
1234      [('/my_path', MyService),
1235       ('/my_path', MyService),
1236       ])
1237
1238  def testMappingDuplicateServiceName(self):
1239    self.assertRaisesWithRegexpMatch(
1240      service_handlers.ServiceConfigurationError,
1241      "Path '/test_package/MyService' is already defined in service mapping",
1242      service_handlers.service_mapping,
1243      [MyService, MyService])
1244
1245
1246class GetCalled(remote.Service):
1247
1248  def __init__(self, test):
1249    self.test = test
1250
1251  @remote.method(Request1, Response1)
1252  def my_method(self, request):
1253    self.test.request = request
1254    return Response1(string_field='a response')
1255
1256
1257class TestRunServices(test_util.TestCase):
1258
1259  def DoRequest(self,
1260                path,
1261                request,
1262                response_type,
1263                reg_path='/protorpc'):
1264    stdin = sys.stdin
1265    stdout = sys.stdout
1266    environ = os.environ
1267    try:
1268      sys.stdin = cStringIO.StringIO(protojson.encode_message(request))
1269      sys.stdout = cStringIO.StringIO()
1270
1271      os.environ = webapp_test_util.GetDefaultEnvironment()
1272      os.environ['PATH_INFO'] = path
1273      os.environ['REQUEST_METHOD'] = 'POST'
1274      os.environ['CONTENT_TYPE'] = 'application/json'
1275      os.environ['wsgi.input'] = sys.stdin
1276      os.environ['wsgi.output'] = sys.stdout
1277      os.environ['CONTENT_LENGTH'] = len(sys.stdin.getvalue())
1278
1279      service_handlers.run_services(
1280        [('/my_service', GetCalled.new_factory(self))], reg_path)
1281
1282      header, body = sys.stdout.getvalue().split('\n\n', 1)
1283
1284      return (header.split('\n')[0],
1285              protojson.decode_message(response_type, body))
1286    finally:
1287      sys.stdin = stdin
1288      sys.stdout = stdout
1289      os.environ = environ
1290
1291  def testRequest(self):
1292    request = Request1(string_field='request value')
1293
1294    status, response = self.DoRequest('/my_service.my_method',
1295                                      request,
1296                                      Response1)
1297    self.assertEquals('Status: 200 OK', status)
1298    self.assertEquals(request, self.request)
1299    self.assertEquals(Response1(string_field='a response'), response)
1300
1301  def testRegistry(self):
1302    request = Request1(string_field='request value')
1303    status, response = self.DoRequest('/protorpc.services',
1304                              message_types.VoidMessage(),
1305                              registry.ServicesResponse)
1306
1307    self.assertEquals('Status: 200 OK', status)
1308    self.assertIterEqual([
1309        registry.ServiceMapping(
1310            name='/protorpc',
1311            definition='protorpc.registry.RegistryService'),
1312        registry.ServiceMapping(
1313            name='/my_service',
1314            definition='test_package.GetCalled'),
1315        ], response.services)
1316
1317  def testRunServicesWithOutRegistry(self):
1318    request = Request1(string_field='request value')
1319
1320    status, response = self.DoRequest('/protorpc.services',
1321                                      message_types.VoidMessage(),
1322                                      registry.ServicesResponse,
1323                                      reg_path=None)
1324    self.assertEquals('Status: 404 Not Found', status)
1325
1326
1327def main():
1328  unittest.main()
1329
1330
1331if __name__ == '__main__':
1332  main()
1333