1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Tests for protorpc.remote.""" 19 20__author__ = 'rafek@google.com (Rafe Kaplan)' 21 22 23import sys 24import types 25import unittest 26from wsgiref import headers 27 28from protorpc import descriptor 29from protorpc import message_types 30from protorpc import messages 31from protorpc import protobuf 32from protorpc import protojson 33from protorpc import remote 34from protorpc import test_util 35from protorpc import transport 36 37import mox 38 39 40class ModuleInterfaceTest(test_util.ModuleInterfaceTest, 41 test_util.TestCase): 42 43 MODULE = remote 44 45 46class Request(messages.Message): 47 """Test request message.""" 48 49 value = messages.StringField(1) 50 51 52class Response(messages.Message): 53 """Test response message.""" 54 55 value = messages.StringField(1) 56 57 58class MyService(remote.Service): 59 60 @remote.method(Request, Response) 61 def remote_method(self, request): 62 response = Response() 63 response.value = request.value 64 return response 65 66 67class SimpleRequest(messages.Message): 68 """Simple request message type used for tests.""" 69 70 param1 = messages.StringField(1) 71 param2 = messages.StringField(2) 72 73 74class SimpleResponse(messages.Message): 75 """Simple response message type used for tests.""" 76 77 78class BasicService(remote.Service): 79 """A basic service with decorated remote method.""" 80 81 def __init__(self): 82 self.request_ids = [] 83 84 @remote.method(SimpleRequest, SimpleResponse) 85 def remote_method(self, request): 86 self.request_ids.append(id(request)) 87 return SimpleResponse() 88 89 90class RpcErrorTest(test_util.TestCase): 91 92 def testFromStatus(self): 93 for state in remote.RpcState: 94 exception = remote.RpcError.from_state 95 self.assertEquals(remote.ServerError, 96 remote.RpcError.from_state('SERVER_ERROR')) 97 98 99class ApplicationErrorTest(test_util.TestCase): 100 101 def testErrorCode(self): 102 self.assertEquals('blam', 103 remote.ApplicationError('an error', 'blam').error_name) 104 105 def testStr(self): 106 self.assertEquals('an error', str(remote.ApplicationError('an error', 1))) 107 108 def testRepr(self): 109 self.assertEquals("ApplicationError('an error', 1)", 110 repr(remote.ApplicationError('an error', 1))) 111 112 self.assertEquals("ApplicationError('an error')", 113 repr(remote.ApplicationError('an error'))) 114 115 116class MethodTest(test_util.TestCase): 117 """Test remote method decorator.""" 118 119 def testMethod(self): 120 """Test use of remote decorator.""" 121 self.assertEquals(SimpleRequest, 122 BasicService.remote_method.remote.request_type) 123 self.assertEquals(SimpleResponse, 124 BasicService.remote_method.remote.response_type) 125 self.assertTrue(isinstance(BasicService.remote_method.remote.method, 126 types.FunctionType)) 127 128 def testMethodMessageResolution(self): 129 """Test use of remote decorator to resolve message types by name.""" 130 class OtherService(remote.Service): 131 132 @remote.method('SimpleRequest', 'SimpleResponse') 133 def remote_method(self, request): 134 pass 135 136 self.assertEquals(SimpleRequest, 137 OtherService.remote_method.remote.request_type) 138 self.assertEquals(SimpleResponse, 139 OtherService.remote_method.remote.response_type) 140 141 def testMethodMessageResolution_NotFound(self): 142 """Test failure to find message types.""" 143 class OtherService(remote.Service): 144 145 @remote.method('NoSuchRequest', 'NoSuchResponse') 146 def remote_method(self, request): 147 pass 148 149 self.assertRaisesWithRegexpMatch( 150 messages.DefinitionNotFoundError, 151 'Could not find definition for NoSuchRequest', 152 getattr, 153 OtherService.remote_method.remote, 154 'request_type') 155 156 self.assertRaisesWithRegexpMatch( 157 messages.DefinitionNotFoundError, 158 'Could not find definition for NoSuchResponse', 159 getattr, 160 OtherService.remote_method.remote, 161 'response_type') 162 163 def testInvocation(self): 164 """Test that invocation passes request through properly.""" 165 service = BasicService() 166 request = SimpleRequest() 167 self.assertEquals(SimpleResponse(), service.remote_method(request)) 168 self.assertEquals([id(request)], service.request_ids) 169 170 def testInvocation_WrongRequestType(self): 171 """Wrong request type passed to remote method.""" 172 service = BasicService() 173 174 self.assertRaises(remote.RequestError, 175 service.remote_method, 176 'wrong') 177 178 self.assertRaises(remote.RequestError, 179 service.remote_method, 180 None) 181 182 self.assertRaises(remote.RequestError, 183 service.remote_method, 184 SimpleResponse()) 185 186 def testInvocation_WrongResponseType(self): 187 """Wrong response type returned from remote method.""" 188 189 class AnotherService(object): 190 191 @remote.method(SimpleRequest, SimpleResponse) 192 def remote_method(self, unused_request): 193 return self.return_this 194 195 service = AnotherService() 196 197 service.return_this = 'wrong' 198 self.assertRaises(remote.ServerError, 199 service.remote_method, 200 SimpleRequest()) 201 service.return_this = None 202 self.assertRaises(remote.ServerError, 203 service.remote_method, 204 SimpleRequest()) 205 service.return_this = SimpleRequest() 206 self.assertRaises(remote.ServerError, 207 service.remote_method, 208 SimpleRequest()) 209 210 def testBadRequestType(self): 211 """Test bad request types used in remote definition.""" 212 213 for request_type in (None, 1020, messages.Message, str): 214 215 def declare(): 216 class BadService(object): 217 218 @remote.method(request_type, SimpleResponse) 219 def remote_method(self, request): 220 pass 221 222 self.assertRaises(TypeError, declare) 223 224 def testBadResponseType(self): 225 """Test bad response types used in remote definition.""" 226 227 for response_type in (None, 1020, messages.Message, str): 228 229 def declare(): 230 class BadService(object): 231 232 @remote.method(SimpleRequest, response_type) 233 def remote_method(self, request): 234 pass 235 236 self.assertRaises(TypeError, declare) 237 238 239class GetRemoteMethodTest(test_util.TestCase): 240 """Test for is_remote_method.""" 241 242 def testGetRemoteMethod(self): 243 """Test valid remote method detection.""" 244 245 class Service(object): 246 247 @remote.method(Request, Response) 248 def remote_method(self, request): 249 pass 250 251 self.assertEquals(Service.remote_method.remote, 252 remote.get_remote_method_info(Service.remote_method)) 253 self.assertTrue(Service.remote_method.remote, 254 remote.get_remote_method_info(Service().remote_method)) 255 256 def testGetNotRemoteMethod(self): 257 """Test positive result on a remote method.""" 258 259 class NotService(object): 260 261 def not_remote_method(self, request): 262 pass 263 264 def fn(self): 265 pass 266 267 class NotReallyRemote(object): 268 """Test negative result on many bad values for remote methods.""" 269 270 def not_really(self, request): 271 pass 272 273 not_really.remote = 'something else' 274 275 for not_remote in [NotService.not_remote_method, 276 NotService().not_remote_method, 277 NotReallyRemote.not_really, 278 NotReallyRemote().not_really, 279 None, 280 1, 281 'a string', 282 fn]: 283 self.assertEquals(None, remote.get_remote_method_info(not_remote)) 284 285 286class RequestStateTest(test_util.TestCase): 287 """Test request state.""" 288 289 STATE_CLASS = remote.RequestState 290 291 def testConstructor(self): 292 """Test constructor.""" 293 state = self.STATE_CLASS(remote_host='remote-host', 294 remote_address='remote-address', 295 server_host='server-host', 296 server_port=10) 297 self.assertEquals('remote-host', state.remote_host) 298 self.assertEquals('remote-address', state.remote_address) 299 self.assertEquals('server-host', state.server_host) 300 self.assertEquals(10, state.server_port) 301 302 state = self.STATE_CLASS() 303 self.assertEquals(None, state.remote_host) 304 self.assertEquals(None, state.remote_address) 305 self.assertEquals(None, state.server_host) 306 self.assertEquals(None, state.server_port) 307 308 def testConstructorError(self): 309 """Test unexpected keyword argument.""" 310 self.assertRaises(TypeError, 311 self.STATE_CLASS, 312 x=10) 313 314 def testRepr(self): 315 """Test string representation.""" 316 self.assertEquals('<%s>' % self.STATE_CLASS.__name__, 317 repr(self.STATE_CLASS())) 318 self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__, 319 repr(self.STATE_CLASS(remote_host='abc'))) 320 self.assertEquals("<%s remote_host='abc' " 321 "remote_address='def'>" % self.STATE_CLASS.__name__, 322 repr(self.STATE_CLASS(remote_host='abc', 323 remote_address='def'))) 324 self.assertEquals("<%s remote_host='abc' " 325 "remote_address='def' " 326 "server_host='ghi'>" % self.STATE_CLASS.__name__, 327 repr(self.STATE_CLASS(remote_host='abc', 328 remote_address='def', 329 server_host='ghi'))) 330 self.assertEquals("<%s remote_host='abc' " 331 "remote_address='def' " 332 "server_host='ghi' " 333 'server_port=102>' % self.STATE_CLASS.__name__, 334 repr(self.STATE_CLASS(remote_host='abc', 335 remote_address='def', 336 server_host='ghi', 337 server_port=102))) 338 339 340class HttpRequestStateTest(RequestStateTest): 341 342 STATE_CLASS = remote.HttpRequestState 343 344 def testHttpMethod(self): 345 state = remote.HttpRequestState(http_method='GET') 346 self.assertEquals('GET', state.http_method) 347 348 def testHttpMethod(self): 349 state = remote.HttpRequestState(service_path='/bar') 350 self.assertEquals('/bar', state.service_path) 351 352 def testHeadersList(self): 353 state = remote.HttpRequestState( 354 headers=[('a', 'b'), ('c', 'd'), ('c', 'e')]) 355 356 self.assertEquals(['a', 'c', 'c'], list(state.headers.keys())) 357 self.assertEquals(['b'], state.headers.get_all('a')) 358 self.assertEquals(['d', 'e'], state.headers.get_all('c')) 359 360 def testHeadersDict(self): 361 state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']}) 362 363 self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys())) 364 self.assertEquals(['b'], state.headers.get_all('a')) 365 self.assertEquals(['d', 'e'], state.headers.get_all('c')) 366 367 def testRepr(self): 368 super(HttpRequestStateTest, self).testRepr() 369 370 self.assertEquals("<%s remote_host='abc' " 371 "remote_address='def' " 372 "server_host='ghi' " 373 'server_port=102 ' 374 "http_method='POST' " 375 "service_path='/bar' " 376 "headers=[('a', 'b'), ('c', 'd')]>" % 377 self.STATE_CLASS.__name__, 378 repr(self.STATE_CLASS(remote_host='abc', 379 remote_address='def', 380 server_host='ghi', 381 server_port=102, 382 http_method='POST', 383 service_path='/bar', 384 headers={'a': 'b', 'c': 'd'}, 385 ))) 386 387 388class ServiceTest(test_util.TestCase): 389 """Test Service class.""" 390 391 def testServiceBase_AllRemoteMethods(self): 392 """Test that service base class has no remote methods.""" 393 self.assertEquals({}, remote.Service.all_remote_methods()) 394 395 def testAllRemoteMethods(self): 396 """Test all_remote_methods with properly Service subclass.""" 397 self.assertEquals({'remote_method': MyService.remote_method}, 398 MyService.all_remote_methods()) 399 400 def testAllRemoteMethods_SubClass(self): 401 """Test all_remote_methods on a sub-class of a service.""" 402 class SubClass(MyService): 403 404 @remote.method(Request, Response) 405 def sub_class_method(self, request): 406 pass 407 408 self.assertEquals({'remote_method': SubClass.remote_method, 409 'sub_class_method': SubClass.sub_class_method, 410 }, 411 SubClass.all_remote_methods()) 412 413 def testOverrideMethod(self): 414 """Test that trying to override a remote method with remote decorator.""" 415 class SubClass(MyService): 416 417 def remote_method(self, request): 418 response = super(SubClass, self).remote_method(request) 419 response.value = '(%s)' % response.value 420 return response 421 422 self.assertEquals({'remote_method': SubClass.remote_method, 423 }, 424 SubClass.all_remote_methods()) 425 426 instance = SubClass() 427 self.assertEquals('(Hello)', 428 instance.remote_method(Request(value='Hello')).value) 429 self.assertEquals(Request, SubClass.remote_method.remote.request_type) 430 self.assertEquals(Response, SubClass.remote_method.remote.response_type) 431 432 def testOverrideMethodWithRemote(self): 433 """Test trying to override a remote method with remote decorator.""" 434 def do_override(): 435 class SubClass(MyService): 436 437 @remote.method(Request, Response) 438 def remote_method(self, request): 439 pass 440 441 self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, 442 'Do not use method decorator when ' 443 'overloading remote method remote_method ' 444 'on service SubClass', 445 do_override) 446 447 def testOverrideMethodWithInvalidValue(self): 448 """Test trying to override a remote method with remote decorator.""" 449 def do_override(bad_value): 450 class SubClass(MyService): 451 452 remote_method = bad_value 453 454 for bad_value in [None, 1, 'string', {}]: 455 self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, 456 'Must override remote_method in ' 457 'SubClass with a method', 458 do_override, bad_value) 459 460 def testCallingRemoteMethod(self): 461 """Test invoking a remote method.""" 462 expected = Response() 463 expected.value = 'what was passed in' 464 465 request = Request() 466 request.value = 'what was passed in' 467 468 service = MyService() 469 self.assertEquals(expected, service.remote_method(request)) 470 471 def testFactory(self): 472 """Test using factory to pass in state.""" 473 class StatefulService(remote.Service): 474 475 def __init__(self, a, b, c=None): 476 self.a = a 477 self.b = b 478 self.c = c 479 480 state = [1, 2, 3] 481 482 factory = StatefulService.new_factory(1, state) 483 484 module_name = ServiceTest.__module__ 485 pattern = ('Creates new instances of service StatefulService.\n\n' 486 'Returns:\n' 487 ' New instance of %s.StatefulService.' % module_name) 488 self.assertEqual(pattern, factory.__doc__) 489 self.assertEquals('StatefulService_service_factory', factory.__name__) 490 self.assertEquals(StatefulService, factory.service_class) 491 492 service = factory() 493 self.assertEquals(1, service.a) 494 self.assertEquals(id(state), id(service.b)) 495 self.assertEquals(None, service.c) 496 497 factory = StatefulService.new_factory(2, b=3, c=4) 498 service = factory() 499 self.assertEquals(2, service.a) 500 self.assertEquals(3, service.b) 501 self.assertEquals(4, service.c) 502 503 def testFactoryError(self): 504 """Test misusing a factory.""" 505 # Passing positional argument that is not accepted by class. 506 self.assertRaises(TypeError, remote.Service.new_factory(1)) 507 508 # Passing keyword argument that is not accepted by class. 509 self.assertRaises(TypeError, remote.Service.new_factory(x=1)) 510 511 class StatefulService(remote.Service): 512 513 def __init__(self, a): 514 pass 515 516 # Missing required parameter. 517 self.assertRaises(TypeError, StatefulService.new_factory()) 518 519 def testDefinitionName(self): 520 """Test getting service definition name.""" 521 class TheService(remote.Service): 522 pass 523 524 module_name = test_util.get_module_name(ServiceTest) 525 self.assertEqual(TheService.definition_name(), 526 '%s.TheService' % module_name) 527 self.assertTrue(TheService.outer_definition_name(), 528 module_name) 529 self.assertTrue(TheService.definition_package(), 530 module_name) 531 532 def testDefinitionNameWithPackage(self): 533 """Test getting service definition name when package defined.""" 534 global package 535 package = 'my.package' 536 try: 537 class TheService(remote.Service): 538 pass 539 540 self.assertEquals('my.package.TheService', TheService.definition_name()) 541 self.assertEquals('my.package', TheService.outer_definition_name()) 542 self.assertEquals('my.package', TheService.definition_package()) 543 finally: 544 del package 545 546 def testDefinitionNameWithNoModule(self): 547 """Test getting service definition name when package defined.""" 548 module = sys.modules[__name__] 549 try: 550 del sys.modules[__name__] 551 class TheService(remote.Service): 552 pass 553 554 self.assertEquals('TheService', TheService.definition_name()) 555 self.assertEquals(None, TheService.outer_definition_name()) 556 self.assertEquals(None, TheService.definition_package()) 557 finally: 558 sys.modules[__name__] = module 559 560 561class StubTest(test_util.TestCase): 562 563 def setUp(self): 564 self.mox = mox.Mox() 565 self.transport = self.mox.CreateMockAnything() 566 567 def testDefinitionName(self): 568 self.assertEquals(BasicService.definition_name(), 569 BasicService.Stub.definition_name()) 570 self.assertEquals(BasicService.outer_definition_name(), 571 BasicService.Stub.outer_definition_name()) 572 self.assertEquals(BasicService.definition_package(), 573 BasicService.Stub.definition_package()) 574 575 def testRemoteMethods(self): 576 self.assertEquals(BasicService.all_remote_methods(), 577 BasicService.Stub.all_remote_methods()) 578 579 def testSync_WithRequest(self): 580 stub = BasicService.Stub(self.transport) 581 582 request = SimpleRequest() 583 request.param1 = 'val1' 584 request.param2 = 'val2' 585 response = SimpleResponse() 586 587 rpc = transport.Rpc(request) 588 rpc.set_response(response) 589 self.transport.send_rpc(BasicService.remote_method.remote, 590 request).AndReturn(rpc) 591 592 self.mox.ReplayAll() 593 594 self.assertEquals(SimpleResponse(), stub.remote_method(request)) 595 596 self.mox.VerifyAll() 597 598 def testSync_WithKwargs(self): 599 stub = BasicService.Stub(self.transport) 600 601 602 request = SimpleRequest() 603 request.param1 = 'val1' 604 request.param2 = 'val2' 605 response = SimpleResponse() 606 607 rpc = transport.Rpc(request) 608 rpc.set_response(response) 609 self.transport.send_rpc(BasicService.remote_method.remote, 610 request).AndReturn(rpc) 611 612 self.mox.ReplayAll() 613 614 self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1', 615 param2='val2')) 616 617 self.mox.VerifyAll() 618 619 def testAsync_WithRequest(self): 620 stub = BasicService.Stub(self.transport) 621 622 request = SimpleRequest() 623 request.param1 = 'val1' 624 request.param2 = 'val2' 625 response = SimpleResponse() 626 627 rpc = transport.Rpc(request) 628 629 self.transport.send_rpc(BasicService.remote_method.remote, 630 request).AndReturn(rpc) 631 632 self.mox.ReplayAll() 633 634 self.assertEquals(rpc, stub.async.remote_method(request)) 635 636 self.mox.VerifyAll() 637 638 def testAsync_WithKwargs(self): 639 stub = BasicService.Stub(self.transport) 640 641 request = SimpleRequest() 642 request.param1 = 'val1' 643 request.param2 = 'val2' 644 response = SimpleResponse() 645 646 rpc = transport.Rpc(request) 647 648 self.transport.send_rpc(BasicService.remote_method.remote, 649 request).AndReturn(rpc) 650 651 self.mox.ReplayAll() 652 653 self.assertEquals(rpc, stub.async.remote_method(param1='val1', 654 param2='val2')) 655 656 self.mox.VerifyAll() 657 658 def testAsync_WithRequestAndKwargs(self): 659 stub = BasicService.Stub(self.transport) 660 661 request = SimpleRequest() 662 request.param1 = 'val1' 663 request.param2 = 'val2' 664 response = SimpleResponse() 665 666 self.mox.ReplayAll() 667 668 self.assertRaisesWithRegexpMatch( 669 TypeError, 670 r'May not provide both args and kwargs', 671 stub.async.remote_method, 672 request, 673 param1='val1', 674 param2='val2') 675 676 self.mox.VerifyAll() 677 678 def testAsync_WithTooManyPositionals(self): 679 stub = BasicService.Stub(self.transport) 680 681 request = SimpleRequest() 682 request.param1 = 'val1' 683 request.param2 = 'val2' 684 response = SimpleResponse() 685 686 self.mox.ReplayAll() 687 688 self.assertRaisesWithRegexpMatch( 689 TypeError, 690 r'remote_method\(\) takes at most 2 positional arguments \(3 given\)', 691 stub.async.remote_method, 692 request, 'another value') 693 694 self.mox.VerifyAll() 695 696 697class IsErrorStatusTest(test_util.TestCase): 698 699 def testIsError(self): 700 for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING): 701 status = remote.RpcStatus(state=state) 702 self.assertTrue(remote.is_error_status(status)) 703 704 def testIsNotError(self): 705 for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING): 706 status = remote.RpcStatus(state=state) 707 self.assertFalse(remote.is_error_status(status)) 708 709 def testStateNone(self): 710 self.assertRaises(messages.ValidationError, 711 remote.is_error_status, remote.RpcStatus()) 712 713 714class CheckRpcStatusTest(test_util.TestCase): 715 716 def testStateNone(self): 717 self.assertRaises(messages.ValidationError, 718 remote.check_rpc_status, remote.RpcStatus()) 719 720 def testNoError(self): 721 for state in (remote.RpcState.OK, remote.RpcState.RUNNING): 722 remote.check_rpc_status(remote.RpcStatus(state=state)) 723 724 def testErrorState(self): 725 status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR, 726 error_message='a request error') 727 self.assertRaisesWithRegexpMatch(remote.RequestError, 728 'a request error', 729 remote.check_rpc_status, status) 730 731 def testApplicationErrorState(self): 732 status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, 733 error_message='an application error', 734 error_name='blam') 735 try: 736 remote.check_rpc_status(status) 737 self.fail('Should have raised application error.') 738 except remote.ApplicationError as err: 739 self.assertEquals('an application error', str(err)) 740 self.assertEquals('blam', err.error_name) 741 742 743class ProtocolConfigTest(test_util.TestCase): 744 745 def testConstructor(self): 746 config = remote.ProtocolConfig( 747 protojson, 748 'proto1', 749 'application/X-Json', 750 iter(['text/Json', 'text/JavaScript'])) 751 self.assertEquals(protojson, config.protocol) 752 self.assertEquals('proto1', config.name) 753 self.assertEquals('application/x-json', config.default_content_type) 754 self.assertEquals(('text/json', 'text/javascript'), 755 config.alternate_content_types) 756 self.assertEquals(('application/x-json', 'text/json', 'text/javascript'), 757 config.content_types) 758 759 def testConstructorDefaults(self): 760 config = remote.ProtocolConfig(protojson, 'proto2') 761 self.assertEquals(protojson, config.protocol) 762 self.assertEquals('proto2', config.name) 763 self.assertEquals('application/json', config.default_content_type) 764 self.assertEquals(('application/x-javascript', 765 'text/javascript', 766 'text/x-javascript', 767 'text/x-json', 768 'text/json'), 769 config.alternate_content_types) 770 self.assertEquals(('application/json', 771 'application/x-javascript', 772 'text/javascript', 773 'text/x-javascript', 774 'text/x-json', 775 'text/json'), config.content_types) 776 777 def testEmptyAlternativeTypes(self): 778 config = remote.ProtocolConfig(protojson, 'proto2', 779 alternative_content_types=()) 780 self.assertEquals(protojson, config.protocol) 781 self.assertEquals('proto2', config.name) 782 self.assertEquals('application/json', config.default_content_type) 783 self.assertEquals((), config.alternate_content_types) 784 self.assertEquals(('application/json',), config.content_types) 785 786 def testDuplicateContentTypes(self): 787 self.assertRaises(remote.ServiceConfigurationError, 788 remote.ProtocolConfig, 789 protojson, 790 'json', 791 'text/plain', 792 ('text/plain',)) 793 794 self.assertRaises(remote.ServiceConfigurationError, 795 remote.ProtocolConfig, 796 protojson, 797 'json', 798 'text/plain', 799 ('text/html', 'text/html')) 800 801 def testEncodeMessage(self): 802 config = remote.ProtocolConfig(protojson, 'proto2') 803 encoded_message = config.encode_message( 804 remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, 805 error_message='bad error')) 806 807 # Convert back to a dictionary from JSON. 808 dict_message = protojson.json.loads(encoded_message) 809 self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'}, 810 dict_message) 811 812 def testDecodeMessage(self): 813 config = remote.ProtocolConfig(protojson, 'proto2') 814 self.assertEquals( 815 remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, 816 error_message="bad error"), 817 config.decode_message( 818 remote.RpcStatus, 819 '{"state": "SERVER_ERROR", "error_message": "bad error"}')) 820 821 822class ProtocolsTest(test_util.TestCase): 823 824 def setUp(self): 825 self.protocols = remote.Protocols() 826 827 def testEmpty(self): 828 self.assertEquals((), self.protocols.names) 829 self.assertEquals((), self.protocols.content_types) 830 831 def testAddProtocolAllDefaults(self): 832 self.protocols.add_protocol(protojson, 'json') 833 self.assertEquals(('json',), self.protocols.names) 834 self.assertEquals(('application/json', 835 'application/x-javascript', 836 'text/javascript', 837 'text/json', 838 'text/x-javascript', 839 'text/x-json'), 840 self.protocols.content_types) 841 842 def testAddProtocolNoDefaultAlternatives(self): 843 class Protocol(object): 844 CONTENT_TYPE = 'text/plain' 845 846 self.protocols.add_protocol(Protocol, 'text') 847 self.assertEquals(('text',), self.protocols.names) 848 self.assertEquals(('text/plain',), self.protocols.content_types) 849 850 def testAddProtocolOverrideDefaults(self): 851 self.protocols.add_protocol(protojson, 'json', 852 default_content_type='text/blar', 853 alternative_content_types=('text/blam', 854 'text/blim')) 855 self.assertEquals(('json',), self.protocols.names) 856 self.assertEquals(('text/blam', 'text/blar', 'text/blim'), 857 self.protocols.content_types) 858 859 def testLookupByName(self): 860 self.protocols.add_protocol(protojson, 'json') 861 self.protocols.add_protocol(protojson, 'json2', 862 default_content_type='text/plain', 863 alternative_content_types=()) 864 865 self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name) 866 self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name) 867 868 def testLookupByContentType(self): 869 self.protocols.add_protocol(protojson, 'json') 870 self.protocols.add_protocol(protojson, 'json2', 871 default_content_type='text/plain', 872 alternative_content_types=()) 873 874 self.assertEquals( 875 'json', 876 self.protocols.lookup_by_content_type('AppliCation/Json').name) 877 878 self.assertEquals( 879 'json', 880 self.protocols.lookup_by_content_type('text/x-Json').name) 881 882 self.assertEquals( 883 'json2', 884 self.protocols.lookup_by_content_type('text/Plain').name) 885 886 def testNewDefault(self): 887 protocols = remote.Protocols.new_default() 888 self.assertEquals(('protobuf', 'protojson'), protocols.names) 889 890 protobuf_protocol = protocols.lookup_by_name('protobuf') 891 self.assertEquals(protobuf, protobuf_protocol.protocol) 892 893 protojson_protocol = protocols.lookup_by_name('protojson') 894 self.assertEquals(protojson.ProtoJson.get_default(), 895 protojson_protocol.protocol) 896 897 def testGetDefaultProtocols(self): 898 protocols = remote.Protocols.get_default() 899 self.assertEquals(('protobuf', 'protojson'), protocols.names) 900 901 protobuf_protocol = protocols.lookup_by_name('protobuf') 902 self.assertEquals(protobuf, protobuf_protocol.protocol) 903 904 protojson_protocol = protocols.lookup_by_name('protojson') 905 self.assertEquals(protojson.ProtoJson.get_default(), 906 protojson_protocol.protocol) 907 908 self.assertTrue(protocols is remote.Protocols.get_default()) 909 910 def testSetDefaultProtocols(self): 911 protocols = remote.Protocols() 912 remote.Protocols.set_default(protocols) 913 self.assertTrue(protocols is remote.Protocols.get_default()) 914 915 def testSetDefaultWithoutProtocols(self): 916 self.assertRaises(TypeError, remote.Protocols.set_default, None) 917 self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols') 918 self.assertRaises(TypeError, remote.Protocols.set_default, {}) 919 920 921def main(): 922 unittest.main() 923 924 925if __name__ == '__main__': 926 main() 927