• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for protorpc.remote."""
19
20__author__ = 'rafek@google.com (Rafe Kaplan)'
21
22
23import sys
24import types
25import unittest
26from wsgiref import headers
27
28from protorpc import descriptor
29from protorpc import message_types
30from protorpc import messages
31from protorpc import protobuf
32from protorpc import protojson
33from protorpc import remote
34from protorpc import test_util
35from protorpc import transport
36
37import mox
38
39
40class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
41                          test_util.TestCase):
42
43  MODULE = remote
44
45
46class Request(messages.Message):
47  """Test request message."""
48
49  value = messages.StringField(1)
50
51
52class Response(messages.Message):
53  """Test response message."""
54
55  value = messages.StringField(1)
56
57
58class MyService(remote.Service):
59
60  @remote.method(Request, Response)
61  def remote_method(self, request):
62    response = Response()
63    response.value = request.value
64    return response
65
66
67class SimpleRequest(messages.Message):
68  """Simple request message type used for tests."""
69
70  param1 = messages.StringField(1)
71  param2 = messages.StringField(2)
72
73
74class SimpleResponse(messages.Message):
75  """Simple response message type used for tests."""
76
77
78class BasicService(remote.Service):
79  """A basic service with decorated remote method."""
80
81  def __init__(self):
82    self.request_ids = []
83
84  @remote.method(SimpleRequest, SimpleResponse)
85  def remote_method(self, request):
86    self.request_ids.append(id(request))
87    return SimpleResponse()
88
89
90class RpcErrorTest(test_util.TestCase):
91
92  def testFromStatus(self):
93    for state in remote.RpcState:
94      exception = remote.RpcError.from_state
95    self.assertEquals(remote.ServerError,
96                      remote.RpcError.from_state('SERVER_ERROR'))
97
98
99class ApplicationErrorTest(test_util.TestCase):
100
101  def testErrorCode(self):
102    self.assertEquals('blam',
103                      remote.ApplicationError('an error', 'blam').error_name)
104
105  def testStr(self):
106    self.assertEquals('an error', str(remote.ApplicationError('an error', 1)))
107
108  def testRepr(self):
109    self.assertEquals("ApplicationError('an error', 1)",
110                      repr(remote.ApplicationError('an error', 1)))
111
112    self.assertEquals("ApplicationError('an error')",
113                      repr(remote.ApplicationError('an error')))
114
115
116class MethodTest(test_util.TestCase):
117  """Test remote method decorator."""
118
119  def testMethod(self):
120    """Test use of remote decorator."""
121    self.assertEquals(SimpleRequest,
122                      BasicService.remote_method.remote.request_type)
123    self.assertEquals(SimpleResponse,
124                      BasicService.remote_method.remote.response_type)
125    self.assertTrue(isinstance(BasicService.remote_method.remote.method,
126                               types.FunctionType))
127
128  def testMethodMessageResolution(self):
129    """Test use of remote decorator to resolve message types by name."""
130    class OtherService(remote.Service):
131
132      @remote.method('SimpleRequest', 'SimpleResponse')
133      def remote_method(self, request):
134        pass
135
136    self.assertEquals(SimpleRequest,
137                      OtherService.remote_method.remote.request_type)
138    self.assertEquals(SimpleResponse,
139                      OtherService.remote_method.remote.response_type)
140
141  def testMethodMessageResolution_NotFound(self):
142    """Test failure to find message types."""
143    class OtherService(remote.Service):
144
145      @remote.method('NoSuchRequest', 'NoSuchResponse')
146      def remote_method(self, request):
147        pass
148
149    self.assertRaisesWithRegexpMatch(
150      messages.DefinitionNotFoundError,
151      'Could not find definition for NoSuchRequest',
152      getattr,
153      OtherService.remote_method.remote,
154      'request_type')
155
156    self.assertRaisesWithRegexpMatch(
157      messages.DefinitionNotFoundError,
158      'Could not find definition for NoSuchResponse',
159      getattr,
160      OtherService.remote_method.remote,
161      'response_type')
162
163  def testInvocation(self):
164    """Test that invocation passes request through properly."""
165    service = BasicService()
166    request = SimpleRequest()
167    self.assertEquals(SimpleResponse(), service.remote_method(request))
168    self.assertEquals([id(request)], service.request_ids)
169
170  def testInvocation_WrongRequestType(self):
171    """Wrong request type passed to remote method."""
172    service = BasicService()
173
174    self.assertRaises(remote.RequestError,
175                      service.remote_method,
176                      'wrong')
177
178    self.assertRaises(remote.RequestError,
179                      service.remote_method,
180                      None)
181
182    self.assertRaises(remote.RequestError,
183                      service.remote_method,
184                      SimpleResponse())
185
186  def testInvocation_WrongResponseType(self):
187    """Wrong response type returned from remote method."""
188
189    class AnotherService(object):
190
191      @remote.method(SimpleRequest, SimpleResponse)
192      def remote_method(self, unused_request):
193        return self.return_this
194
195    service = AnotherService()
196
197    service.return_this = 'wrong'
198    self.assertRaises(remote.ServerError,
199                      service.remote_method,
200                      SimpleRequest())
201    service.return_this = None
202    self.assertRaises(remote.ServerError,
203                      service.remote_method,
204                      SimpleRequest())
205    service.return_this = SimpleRequest()
206    self.assertRaises(remote.ServerError,
207                      service.remote_method,
208                      SimpleRequest())
209
210  def testBadRequestType(self):
211    """Test bad request types used in remote definition."""
212
213    for request_type in (None, 1020, messages.Message, str):
214
215      def declare():
216        class BadService(object):
217
218          @remote.method(request_type, SimpleResponse)
219          def remote_method(self, request):
220            pass
221
222      self.assertRaises(TypeError, declare)
223
224  def testBadResponseType(self):
225    """Test bad response types used in remote definition."""
226
227    for response_type in (None, 1020, messages.Message, str):
228
229      def declare():
230        class BadService(object):
231
232          @remote.method(SimpleRequest, response_type)
233          def remote_method(self, request):
234            pass
235
236      self.assertRaises(TypeError, declare)
237
238
239class GetRemoteMethodTest(test_util.TestCase):
240  """Test for is_remote_method."""
241
242  def testGetRemoteMethod(self):
243    """Test valid remote method detection."""
244
245    class Service(object):
246
247      @remote.method(Request, Response)
248      def remote_method(self, request):
249        pass
250
251    self.assertEquals(Service.remote_method.remote,
252                      remote.get_remote_method_info(Service.remote_method))
253    self.assertTrue(Service.remote_method.remote,
254                    remote.get_remote_method_info(Service().remote_method))
255
256  def testGetNotRemoteMethod(self):
257    """Test positive result on a remote method."""
258
259    class NotService(object):
260
261      def not_remote_method(self, request):
262        pass
263
264    def fn(self):
265      pass
266
267    class NotReallyRemote(object):
268      """Test negative result on many bad values for remote methods."""
269
270      def not_really(self, request):
271        pass
272
273      not_really.remote = 'something else'
274
275    for not_remote in [NotService.not_remote_method,
276                       NotService().not_remote_method,
277                       NotReallyRemote.not_really,
278                       NotReallyRemote().not_really,
279                       None,
280                       1,
281                       'a string',
282                       fn]:
283      self.assertEquals(None, remote.get_remote_method_info(not_remote))
284
285
286class RequestStateTest(test_util.TestCase):
287  """Test request state."""
288
289  STATE_CLASS = remote.RequestState
290
291  def testConstructor(self):
292    """Test constructor."""
293    state = self.STATE_CLASS(remote_host='remote-host',
294                             remote_address='remote-address',
295                             server_host='server-host',
296                             server_port=10)
297    self.assertEquals('remote-host', state.remote_host)
298    self.assertEquals('remote-address', state.remote_address)
299    self.assertEquals('server-host', state.server_host)
300    self.assertEquals(10, state.server_port)
301
302    state = self.STATE_CLASS()
303    self.assertEquals(None, state.remote_host)
304    self.assertEquals(None, state.remote_address)
305    self.assertEquals(None, state.server_host)
306    self.assertEquals(None, state.server_port)
307
308  def testConstructorError(self):
309    """Test unexpected keyword argument."""
310    self.assertRaises(TypeError,
311                      self.STATE_CLASS,
312                      x=10)
313
314  def testRepr(self):
315    """Test string representation."""
316    self.assertEquals('<%s>' % self.STATE_CLASS.__name__,
317                      repr(self.STATE_CLASS()))
318    self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__,
319                      repr(self.STATE_CLASS(remote_host='abc')))
320    self.assertEquals("<%s remote_host='abc' "
321                      "remote_address='def'>" % self.STATE_CLASS.__name__,
322                      repr(self.STATE_CLASS(remote_host='abc',
323                                               remote_address='def')))
324    self.assertEquals("<%s remote_host='abc' "
325                      "remote_address='def' "
326                      "server_host='ghi'>" % self.STATE_CLASS.__name__,
327                      repr(self.STATE_CLASS(remote_host='abc',
328                                            remote_address='def',
329                                            server_host='ghi')))
330    self.assertEquals("<%s remote_host='abc' "
331                      "remote_address='def' "
332                      "server_host='ghi' "
333                      'server_port=102>' % self.STATE_CLASS.__name__,
334                      repr(self.STATE_CLASS(remote_host='abc',
335                                            remote_address='def',
336                                            server_host='ghi',
337                                            server_port=102)))
338
339
340class HttpRequestStateTest(RequestStateTest):
341
342  STATE_CLASS = remote.HttpRequestState
343
344  def testHttpMethod(self):
345    state = remote.HttpRequestState(http_method='GET')
346    self.assertEquals('GET', state.http_method)
347
348  def testHttpMethod(self):
349    state = remote.HttpRequestState(service_path='/bar')
350    self.assertEquals('/bar', state.service_path)
351
352  def testHeadersList(self):
353    state = remote.HttpRequestState(
354      headers=[('a', 'b'), ('c', 'd'), ('c', 'e')])
355
356    self.assertEquals(['a', 'c', 'c'], list(state.headers.keys()))
357    self.assertEquals(['b'], state.headers.get_all('a'))
358    self.assertEquals(['d', 'e'], state.headers.get_all('c'))
359
360  def testHeadersDict(self):
361    state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']})
362
363    self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys()))
364    self.assertEquals(['b'], state.headers.get_all('a'))
365    self.assertEquals(['d', 'e'], state.headers.get_all('c'))
366
367  def testRepr(self):
368    super(HttpRequestStateTest, self).testRepr()
369
370    self.assertEquals("<%s remote_host='abc' "
371                      "remote_address='def' "
372                      "server_host='ghi' "
373                      'server_port=102 '
374                      "http_method='POST' "
375                      "service_path='/bar' "
376                      "headers=[('a', 'b'), ('c', 'd')]>" %
377                      self.STATE_CLASS.__name__,
378                      repr(self.STATE_CLASS(remote_host='abc',
379                                            remote_address='def',
380                                            server_host='ghi',
381                                            server_port=102,
382                                            http_method='POST',
383                                            service_path='/bar',
384                                            headers={'a': 'b', 'c': 'd'},
385                                            )))
386
387
388class ServiceTest(test_util.TestCase):
389  """Test Service class."""
390
391  def testServiceBase_AllRemoteMethods(self):
392    """Test that service base class has no remote methods."""
393    self.assertEquals({}, remote.Service.all_remote_methods())
394
395  def testAllRemoteMethods(self):
396    """Test all_remote_methods with properly Service subclass."""
397    self.assertEquals({'remote_method': MyService.remote_method},
398                      MyService.all_remote_methods())
399
400  def testAllRemoteMethods_SubClass(self):
401    """Test all_remote_methods on a sub-class of a service."""
402    class SubClass(MyService):
403
404      @remote.method(Request, Response)
405      def sub_class_method(self, request):
406        pass
407
408    self.assertEquals({'remote_method': SubClass.remote_method,
409                       'sub_class_method': SubClass.sub_class_method,
410                      },
411                      SubClass.all_remote_methods())
412
413  def testOverrideMethod(self):
414    """Test that trying to override a remote method with remote decorator."""
415    class SubClass(MyService):
416
417      def remote_method(self, request):
418        response = super(SubClass, self).remote_method(request)
419        response.value = '(%s)' % response.value
420        return response
421
422    self.assertEquals({'remote_method': SubClass.remote_method,
423                      },
424                      SubClass.all_remote_methods())
425
426    instance = SubClass()
427    self.assertEquals('(Hello)',
428                      instance.remote_method(Request(value='Hello')).value)
429    self.assertEquals(Request, SubClass.remote_method.remote.request_type)
430    self.assertEquals(Response, SubClass.remote_method.remote.response_type)
431
432  def testOverrideMethodWithRemote(self):
433    """Test trying to override a remote method with remote decorator."""
434    def do_override():
435      class SubClass(MyService):
436
437        @remote.method(Request, Response)
438        def remote_method(self, request):
439          pass
440
441    self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
442                                     'Do not use method decorator when '
443                                     'overloading remote method remote_method '
444                                     'on service SubClass',
445                                     do_override)
446
447  def testOverrideMethodWithInvalidValue(self):
448    """Test trying to override a remote method with remote decorator."""
449    def do_override(bad_value):
450      class SubClass(MyService):
451
452        remote_method = bad_value
453
454    for bad_value in [None, 1, 'string', {}]:
455      self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
456                                       'Must override remote_method in '
457                                       'SubClass with a method',
458                                       do_override, bad_value)
459
460  def testCallingRemoteMethod(self):
461    """Test invoking a remote method."""
462    expected = Response()
463    expected.value = 'what was passed in'
464
465    request = Request()
466    request.value = 'what was passed in'
467
468    service = MyService()
469    self.assertEquals(expected, service.remote_method(request))
470
471  def testFactory(self):
472    """Test using factory to pass in state."""
473    class StatefulService(remote.Service):
474
475      def __init__(self, a, b, c=None):
476        self.a = a
477        self.b = b
478        self.c = c
479
480    state = [1, 2, 3]
481
482    factory = StatefulService.new_factory(1, state)
483
484    module_name = ServiceTest.__module__
485    pattern = ('Creates new instances of service StatefulService.\n\n'
486               'Returns:\n'
487               '  New instance of %s.StatefulService.' % module_name)
488    self.assertEqual(pattern, factory.__doc__)
489    self.assertEquals('StatefulService_service_factory', factory.__name__)
490    self.assertEquals(StatefulService, factory.service_class)
491
492    service = factory()
493    self.assertEquals(1, service.a)
494    self.assertEquals(id(state), id(service.b))
495    self.assertEquals(None, service.c)
496
497    factory = StatefulService.new_factory(2, b=3, c=4)
498    service = factory()
499    self.assertEquals(2, service.a)
500    self.assertEquals(3, service.b)
501    self.assertEquals(4, service.c)
502
503  def testFactoryError(self):
504    """Test misusing a factory."""
505    # Passing positional argument that is not accepted by class.
506    self.assertRaises(TypeError, remote.Service.new_factory(1))
507
508    # Passing keyword argument that is not accepted by class.
509    self.assertRaises(TypeError, remote.Service.new_factory(x=1))
510
511    class StatefulService(remote.Service):
512
513      def __init__(self, a):
514        pass
515
516    # Missing required parameter.
517    self.assertRaises(TypeError, StatefulService.new_factory())
518
519  def testDefinitionName(self):
520    """Test getting service definition name."""
521    class TheService(remote.Service):
522      pass
523
524    module_name = test_util.get_module_name(ServiceTest)
525    self.assertEqual(TheService.definition_name(),
526                     '%s.TheService' % module_name)
527    self.assertTrue(TheService.outer_definition_name(),
528                    module_name)
529    self.assertTrue(TheService.definition_package(),
530                    module_name)
531
532  def testDefinitionNameWithPackage(self):
533    """Test getting service definition name when package defined."""
534    global package
535    package = 'my.package'
536    try:
537      class TheService(remote.Service):
538        pass
539
540      self.assertEquals('my.package.TheService', TheService.definition_name())
541      self.assertEquals('my.package', TheService.outer_definition_name())
542      self.assertEquals('my.package', TheService.definition_package())
543    finally:
544      del package
545
546  def testDefinitionNameWithNoModule(self):
547    """Test getting service definition name when package defined."""
548    module = sys.modules[__name__]
549    try:
550      del sys.modules[__name__]
551      class TheService(remote.Service):
552        pass
553
554      self.assertEquals('TheService', TheService.definition_name())
555      self.assertEquals(None, TheService.outer_definition_name())
556      self.assertEquals(None, TheService.definition_package())
557    finally:
558      sys.modules[__name__] = module
559
560
561class StubTest(test_util.TestCase):
562
563  def setUp(self):
564    self.mox = mox.Mox()
565    self.transport = self.mox.CreateMockAnything()
566
567  def testDefinitionName(self):
568    self.assertEquals(BasicService.definition_name(),
569                      BasicService.Stub.definition_name())
570    self.assertEquals(BasicService.outer_definition_name(),
571                      BasicService.Stub.outer_definition_name())
572    self.assertEquals(BasicService.definition_package(),
573                      BasicService.Stub.definition_package())
574
575  def testRemoteMethods(self):
576    self.assertEquals(BasicService.all_remote_methods(),
577                      BasicService.Stub.all_remote_methods())
578
579  def testSync_WithRequest(self):
580    stub = BasicService.Stub(self.transport)
581
582    request = SimpleRequest()
583    request.param1 = 'val1'
584    request.param2 = 'val2'
585    response = SimpleResponse()
586
587    rpc = transport.Rpc(request)
588    rpc.set_response(response)
589    self.transport.send_rpc(BasicService.remote_method.remote,
590                            request).AndReturn(rpc)
591
592    self.mox.ReplayAll()
593
594    self.assertEquals(SimpleResponse(), stub.remote_method(request))
595
596    self.mox.VerifyAll()
597
598  def testSync_WithKwargs(self):
599    stub = BasicService.Stub(self.transport)
600
601
602    request = SimpleRequest()
603    request.param1 = 'val1'
604    request.param2 = 'val2'
605    response = SimpleResponse()
606
607    rpc = transport.Rpc(request)
608    rpc.set_response(response)
609    self.transport.send_rpc(BasicService.remote_method.remote,
610                            request).AndReturn(rpc)
611
612    self.mox.ReplayAll()
613
614    self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1',
615                                                           param2='val2'))
616
617    self.mox.VerifyAll()
618
619  def testAsync_WithRequest(self):
620    stub = BasicService.Stub(self.transport)
621
622    request = SimpleRequest()
623    request.param1 = 'val1'
624    request.param2 = 'val2'
625    response = SimpleResponse()
626
627    rpc = transport.Rpc(request)
628
629    self.transport.send_rpc(BasicService.remote_method.remote,
630                            request).AndReturn(rpc)
631
632    self.mox.ReplayAll()
633
634    self.assertEquals(rpc, stub.async.remote_method(request))
635
636    self.mox.VerifyAll()
637
638  def testAsync_WithKwargs(self):
639    stub = BasicService.Stub(self.transport)
640
641    request = SimpleRequest()
642    request.param1 = 'val1'
643    request.param2 = 'val2'
644    response = SimpleResponse()
645
646    rpc = transport.Rpc(request)
647
648    self.transport.send_rpc(BasicService.remote_method.remote,
649                            request).AndReturn(rpc)
650
651    self.mox.ReplayAll()
652
653    self.assertEquals(rpc, stub.async.remote_method(param1='val1',
654                                                    param2='val2'))
655
656    self.mox.VerifyAll()
657
658  def testAsync_WithRequestAndKwargs(self):
659    stub = BasicService.Stub(self.transport)
660
661    request = SimpleRequest()
662    request.param1 = 'val1'
663    request.param2 = 'val2'
664    response = SimpleResponse()
665
666    self.mox.ReplayAll()
667
668    self.assertRaisesWithRegexpMatch(
669      TypeError,
670      r'May not provide both args and kwargs',
671      stub.async.remote_method,
672      request,
673      param1='val1',
674      param2='val2')
675
676    self.mox.VerifyAll()
677
678  def testAsync_WithTooManyPositionals(self):
679    stub = BasicService.Stub(self.transport)
680
681    request = SimpleRequest()
682    request.param1 = 'val1'
683    request.param2 = 'val2'
684    response = SimpleResponse()
685
686    self.mox.ReplayAll()
687
688    self.assertRaisesWithRegexpMatch(
689      TypeError,
690      r'remote_method\(\) takes at most 2 positional arguments \(3 given\)',
691      stub.async.remote_method,
692      request, 'another value')
693
694    self.mox.VerifyAll()
695
696
697class IsErrorStatusTest(test_util.TestCase):
698
699  def testIsError(self):
700    for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING):
701      status = remote.RpcStatus(state=state)
702      self.assertTrue(remote.is_error_status(status))
703
704  def testIsNotError(self):
705    for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING):
706      status = remote.RpcStatus(state=state)
707      self.assertFalse(remote.is_error_status(status))
708
709  def testStateNone(self):
710    self.assertRaises(messages.ValidationError,
711                      remote.is_error_status, remote.RpcStatus())
712
713
714class CheckRpcStatusTest(test_util.TestCase):
715
716  def testStateNone(self):
717    self.assertRaises(messages.ValidationError,
718                      remote.check_rpc_status, remote.RpcStatus())
719
720  def testNoError(self):
721    for state in (remote.RpcState.OK, remote.RpcState.RUNNING):
722      remote.check_rpc_status(remote.RpcStatus(state=state))
723
724  def testErrorState(self):
725    status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR,
726                              error_message='a request error')
727    self.assertRaisesWithRegexpMatch(remote.RequestError,
728                                     'a request error',
729                                     remote.check_rpc_status, status)
730
731  def testApplicationErrorState(self):
732    status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR,
733                              error_message='an application error',
734                              error_name='blam')
735    try:
736      remote.check_rpc_status(status)
737      self.fail('Should have raised application error.')
738    except remote.ApplicationError as err:
739      self.assertEquals('an application error', str(err))
740      self.assertEquals('blam', err.error_name)
741
742
743class ProtocolConfigTest(test_util.TestCase):
744
745  def testConstructor(self):
746    config = remote.ProtocolConfig(
747      protojson,
748      'proto1',
749      'application/X-Json',
750      iter(['text/Json', 'text/JavaScript']))
751    self.assertEquals(protojson, config.protocol)
752    self.assertEquals('proto1', config.name)
753    self.assertEquals('application/x-json', config.default_content_type)
754    self.assertEquals(('text/json', 'text/javascript'),
755                      config.alternate_content_types)
756    self.assertEquals(('application/x-json', 'text/json', 'text/javascript'),
757                      config.content_types)
758
759  def testConstructorDefaults(self):
760    config = remote.ProtocolConfig(protojson, 'proto2')
761    self.assertEquals(protojson, config.protocol)
762    self.assertEquals('proto2', config.name)
763    self.assertEquals('application/json', config.default_content_type)
764    self.assertEquals(('application/x-javascript',
765                       'text/javascript',
766                       'text/x-javascript',
767                       'text/x-json',
768                       'text/json'),
769                      config.alternate_content_types)
770    self.assertEquals(('application/json',
771                       'application/x-javascript',
772                       'text/javascript',
773                       'text/x-javascript',
774                       'text/x-json',
775                       'text/json'), config.content_types)
776
777  def testEmptyAlternativeTypes(self):
778    config = remote.ProtocolConfig(protojson, 'proto2',
779                                   alternative_content_types=())
780    self.assertEquals(protojson, config.protocol)
781    self.assertEquals('proto2', config.name)
782    self.assertEquals('application/json', config.default_content_type)
783    self.assertEquals((), config.alternate_content_types)
784    self.assertEquals(('application/json',), config.content_types)
785
786  def testDuplicateContentTypes(self):
787    self.assertRaises(remote.ServiceConfigurationError,
788                      remote.ProtocolConfig,
789                      protojson,
790                      'json',
791                      'text/plain',
792                      ('text/plain',))
793
794    self.assertRaises(remote.ServiceConfigurationError,
795                      remote.ProtocolConfig,
796                      protojson,
797                      'json',
798                      'text/plain',
799                      ('text/html', 'text/html'))
800
801  def testEncodeMessage(self):
802    config = remote.ProtocolConfig(protojson, 'proto2')
803    encoded_message = config.encode_message(
804        remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
805                         error_message='bad error'))
806
807    # Convert back to a dictionary from JSON.
808    dict_message = protojson.json.loads(encoded_message)
809    self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'},
810                      dict_message)
811
812  def testDecodeMessage(self):
813    config = remote.ProtocolConfig(protojson, 'proto2')
814    self.assertEquals(
815      remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
816                       error_message="bad error"),
817      config.decode_message(
818        remote.RpcStatus,
819        '{"state": "SERVER_ERROR", "error_message": "bad error"}'))
820
821
822class ProtocolsTest(test_util.TestCase):
823
824  def setUp(self):
825    self.protocols = remote.Protocols()
826
827  def testEmpty(self):
828    self.assertEquals((), self.protocols.names)
829    self.assertEquals((), self.protocols.content_types)
830
831  def testAddProtocolAllDefaults(self):
832    self.protocols.add_protocol(protojson, 'json')
833    self.assertEquals(('json',), self.protocols.names)
834    self.assertEquals(('application/json',
835                       'application/x-javascript',
836                       'text/javascript',
837                       'text/json',
838                       'text/x-javascript',
839                       'text/x-json'),
840                      self.protocols.content_types)
841
842  def testAddProtocolNoDefaultAlternatives(self):
843    class Protocol(object):
844      CONTENT_TYPE = 'text/plain'
845
846    self.protocols.add_protocol(Protocol, 'text')
847    self.assertEquals(('text',), self.protocols.names)
848    self.assertEquals(('text/plain',), self.protocols.content_types)
849
850  def testAddProtocolOverrideDefaults(self):
851    self.protocols.add_protocol(protojson, 'json',
852                                default_content_type='text/blar',
853                                alternative_content_types=('text/blam',
854                                                           'text/blim'))
855    self.assertEquals(('json',), self.protocols.names)
856    self.assertEquals(('text/blam', 'text/blar', 'text/blim'),
857                      self.protocols.content_types)
858
859  def testLookupByName(self):
860    self.protocols.add_protocol(protojson, 'json')
861    self.protocols.add_protocol(protojson, 'json2',
862                                default_content_type='text/plain',
863                                alternative_content_types=())
864
865    self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name)
866    self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name)
867
868  def testLookupByContentType(self):
869    self.protocols.add_protocol(protojson, 'json')
870    self.protocols.add_protocol(protojson, 'json2',
871                                default_content_type='text/plain',
872                                alternative_content_types=())
873
874    self.assertEquals(
875      'json',
876      self.protocols.lookup_by_content_type('AppliCation/Json').name)
877
878    self.assertEquals(
879      'json',
880      self.protocols.lookup_by_content_type('text/x-Json').name)
881
882    self.assertEquals(
883      'json2',
884      self.protocols.lookup_by_content_type('text/Plain').name)
885
886  def testNewDefault(self):
887    protocols = remote.Protocols.new_default()
888    self.assertEquals(('protobuf', 'protojson'), protocols.names)
889
890    protobuf_protocol = protocols.lookup_by_name('protobuf')
891    self.assertEquals(protobuf, protobuf_protocol.protocol)
892
893    protojson_protocol = protocols.lookup_by_name('protojson')
894    self.assertEquals(protojson.ProtoJson.get_default(),
895                      protojson_protocol.protocol)
896
897  def testGetDefaultProtocols(self):
898    protocols = remote.Protocols.get_default()
899    self.assertEquals(('protobuf', 'protojson'), protocols.names)
900
901    protobuf_protocol = protocols.lookup_by_name('protobuf')
902    self.assertEquals(protobuf, protobuf_protocol.protocol)
903
904    protojson_protocol = protocols.lookup_by_name('protojson')
905    self.assertEquals(protojson.ProtoJson.get_default(),
906                      protojson_protocol.protocol)
907
908    self.assertTrue(protocols is remote.Protocols.get_default())
909
910  def testSetDefaultProtocols(self):
911    protocols = remote.Protocols()
912    remote.Protocols.set_default(protocols)
913    self.assertTrue(protocols is remote.Protocols.get_default())
914
915  def testSetDefaultWithoutProtocols(self):
916    self.assertRaises(TypeError, remote.Protocols.set_default, None)
917    self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols')
918    self.assertRaises(TypeError, remote.Protocols.set_default, {})
919
920
921def main():
922  unittest.main()
923
924
925if __name__ == '__main__':
926  main()
927