1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Test utilities for message testing. 19 20Includes module interface test to ensure that public parts of module are 21correctly declared in __all__. 22 23Includes message types that correspond to those defined in 24services_test.proto. 25 26Includes additional test utilities to make sure encoding/decoding libraries 27conform. 28""" 29from six.moves import range 30 31__author__ = 'rafek@google.com (Rafe Kaplan)' 32 33import cgi 34import datetime 35import inspect 36import os 37import re 38import socket 39import types 40import unittest2 as unittest 41 42import six 43 44from . import message_types 45from . import messages 46from . import util 47 48# Unicode of the word "Russian" in cyrillic. 49RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439' 50 51# All characters binary value interspersed with nulls. 52BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256)) 53 54 55class TestCase(unittest.TestCase): 56 57 def assertRaisesWithRegexpMatch(self, 58 exception, 59 regexp, 60 function, 61 *params, 62 **kwargs): 63 """Check that exception is raised and text matches regular expression. 64 65 Args: 66 exception: Exception type that is expected. 67 regexp: String regular expression that is expected in error message. 68 function: Callable to test. 69 params: Parameters to forward to function. 70 kwargs: Keyword arguments to forward to function. 71 """ 72 try: 73 function(*params, **kwargs) 74 self.fail('Expected exception %s was not raised' % exception.__name__) 75 except exception as err: 76 match = bool(re.match(regexp, str(err))) 77 self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp, 78 err)) 79 80 def assertHeaderSame(self, header1, header2): 81 """Check that two HTTP headers are the same. 82 83 Args: 84 header1: Header value string 1. 85 header2: header value string 2. 86 """ 87 value1, params1 = cgi.parse_header(header1) 88 value2, params2 = cgi.parse_header(header2) 89 self.assertEqual(value1, value2) 90 self.assertEqual(params1, params2) 91 92 def assertIterEqual(self, iter1, iter2): 93 """Check that two iterators or iterables are equal independent of order. 94 95 Similar to Python 2.7 assertItemsEqual. Named differently in order to 96 avoid potential conflict. 97 98 Args: 99 iter1: An iterator or iterable. 100 iter2: An iterator or iterable. 101 """ 102 list1 = list(iter1) 103 list2 = list(iter2) 104 105 unmatched1 = list() 106 107 while list1: 108 item1 = list1[0] 109 del list1[0] 110 for index in range(len(list2)): 111 if item1 == list2[index]: 112 del list2[index] 113 break 114 else: 115 unmatched1.append(item1) 116 117 error_message = [] 118 for item in unmatched1: 119 error_message.append( 120 ' Item from iter1 not found in iter2: %r' % item) 121 for item in list2: 122 error_message.append( 123 ' Item from iter2 not found in iter1: %r' % item) 124 if error_message: 125 self.fail('Collections not equivalent:\n' + '\n'.join(error_message)) 126 127 128class ModuleInterfaceTest(object): 129 """Test to ensure module interface is carefully constructed. 130 131 A module interface is the set of public objects listed in the module __all__ 132 attribute. Modules that that are considered public should have this interface 133 carefully declared. At all times, the __all__ attribute should have objects 134 intended to be publically used and all other objects in the module should be 135 considered unused. 136 137 Protected attributes (those beginning with '_') and other imported modules 138 should not be part of this set of variables. An exception is for variables 139 that begin and end with '__' which are implicitly part of the interface 140 (eg. __name__, __file__, __all__ itself, etc.). 141 142 Modules that are imported in to the tested modules are an exception and may 143 be left out of the __all__ definition. The test is done by checking the value 144 of what would otherwise be a public name and not allowing it to be exported 145 if it is an instance of a module. Modules that are explicitly exported are 146 for the time being not permitted. 147 148 To use this test class a module should define a new class that inherits first 149 from ModuleInterfaceTest and then from test_util.TestCase. No other tests 150 should be added to this test case, making the order of inheritance less 151 important, but if setUp for some reason is overidden, it is important that 152 ModuleInterfaceTest is first in the list so that its setUp method is 153 invoked. 154 155 Multiple inheretance is required so that ModuleInterfaceTest is not itself 156 a test, and is not itself executed as one. 157 158 The test class is expected to have the following class attributes defined: 159 160 MODULE: A reference to the module that is being validated for interface 161 correctness. 162 163 Example: 164 Module definition (hello.py): 165 166 import sys 167 168 __all__ = ['hello'] 169 170 def _get_outputter(): 171 return sys.stdout 172 173 def hello(): 174 _get_outputter().write('Hello\n') 175 176 Test definition: 177 178 import unittest 179 from protorpc import test_util 180 181 import hello 182 183 class ModuleInterfaceTest(test_util.ModuleInterfaceTest, 184 test_util.TestCase): 185 186 MODULE = hello 187 188 189 class HelloTest(test_util.TestCase): 190 ... Test 'hello' module ... 191 192 193 if __name__ == '__main__': 194 unittest.main() 195 """ 196 197 def setUp(self): 198 """Set up makes sure that MODULE and IMPORTED_MODULES is defined. 199 200 This is a basic configuration test for the test itself so does not 201 get it's own test case. 202 """ 203 if not hasattr(self, 'MODULE'): 204 self.fail( 205 "You must define 'MODULE' on ModuleInterfaceTest sub-class %s." % 206 type(self).__name__) 207 208 def testAllExist(self): 209 """Test that all attributes defined in __all__ exist.""" 210 missing_attributes = [] 211 for attribute in self.MODULE.__all__: 212 if not hasattr(self.MODULE, attribute): 213 missing_attributes.append(attribute) 214 if missing_attributes: 215 self.fail('%s of __all__ are not defined in module.' % 216 missing_attributes) 217 218 def testAllExported(self): 219 """Test that all public attributes not imported are in __all__.""" 220 missing_attributes = [] 221 for attribute in dir(self.MODULE): 222 if not attribute.startswith('_'): 223 if (attribute not in self.MODULE.__all__ and 224 not isinstance(getattr(self.MODULE, attribute), 225 types.ModuleType) and 226 attribute != 'with_statement'): 227 missing_attributes.append(attribute) 228 if missing_attributes: 229 self.fail('%s are not modules and not defined in __all__.' % 230 missing_attributes) 231 232 def testNoExportedProtectedVariables(self): 233 """Test that there are no protected variables listed in __all__.""" 234 protected_variables = [] 235 for attribute in self.MODULE.__all__: 236 if attribute.startswith('_'): 237 protected_variables.append(attribute) 238 if protected_variables: 239 self.fail('%s are protected variables and may not be exported.' % 240 protected_variables) 241 242 def testNoExportedModules(self): 243 """Test that no modules exist in __all__.""" 244 exported_modules = [] 245 for attribute in self.MODULE.__all__: 246 try: 247 value = getattr(self.MODULE, attribute) 248 except AttributeError: 249 # This is a different error case tested for in testAllExist. 250 pass 251 else: 252 if isinstance(value, types.ModuleType): 253 exported_modules.append(attribute) 254 if exported_modules: 255 self.fail('%s are modules and may not be exported.' % exported_modules) 256 257 258class NestedMessage(messages.Message): 259 """Simple message that gets nested in another message.""" 260 261 a_value = messages.StringField(1, required=True) 262 263 264class HasNestedMessage(messages.Message): 265 """Message that has another message nested in it.""" 266 267 nested = messages.MessageField(NestedMessage, 1) 268 repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True) 269 270 271class HasDefault(messages.Message): 272 """Has a default value.""" 273 274 a_value = messages.StringField(1, default=u'a default') 275 276 277class OptionalMessage(messages.Message): 278 """Contains all message types.""" 279 280 class SimpleEnum(messages.Enum): 281 """Simple enumeration type.""" 282 VAL1 = 1 283 VAL2 = 2 284 285 double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE) 286 float_value = messages.FloatField(2, variant=messages.Variant.FLOAT) 287 int64_value = messages.IntegerField(3, variant=messages.Variant.INT64) 288 uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64) 289 int32_value = messages.IntegerField(5, variant=messages.Variant.INT32) 290 bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL) 291 string_value = messages.StringField(7, variant=messages.Variant.STRING) 292 bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES) 293 enum_value = messages.EnumField(SimpleEnum, 10) 294 295 # TODO(rafek): Add support for these variants. 296 # uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) 297 # sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) 298 # sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) 299 300 301class RepeatedMessage(messages.Message): 302 """Contains all message types as repeated fields.""" 303 304 class SimpleEnum(messages.Enum): 305 """Simple enumeration type.""" 306 VAL1 = 1 307 VAL2 = 2 308 309 double_value = messages.FloatField(1, 310 variant=messages.Variant.DOUBLE, 311 repeated=True) 312 float_value = messages.FloatField(2, 313 variant=messages.Variant.FLOAT, 314 repeated=True) 315 int64_value = messages.IntegerField(3, 316 variant=messages.Variant.INT64, 317 repeated=True) 318 uint64_value = messages.IntegerField(4, 319 variant=messages.Variant.UINT64, 320 repeated=True) 321 int32_value = messages.IntegerField(5, 322 variant=messages.Variant.INT32, 323 repeated=True) 324 bool_value = messages.BooleanField(6, 325 variant=messages.Variant.BOOL, 326 repeated=True) 327 string_value = messages.StringField(7, 328 variant=messages.Variant.STRING, 329 repeated=True) 330 bytes_value = messages.BytesField(8, 331 variant=messages.Variant.BYTES, 332 repeated=True) 333 #uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) 334 enum_value = messages.EnumField(SimpleEnum, 335 10, 336 repeated=True) 337 #sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) 338 #sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) 339 340 341class HasOptionalNestedMessage(messages.Message): 342 343 nested = messages.MessageField(OptionalMessage, 1) 344 repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True) 345 346 347class ProtoConformanceTestBase(object): 348 """Protocol conformance test base class. 349 350 Each supported protocol should implement two methods that support encoding 351 and decoding of Message objects in that format: 352 353 encode_message(message) - Serialize to encoding. 354 encode_message(message, encoded_message) - Deserialize from encoding. 355 356 Tests for the modules where these functions are implemented should extend 357 this class in order to support basic behavioral expectations. This ensures 358 that protocols correctly encode and decode message transparently to the 359 caller. 360 361 In order to support these test, the base class should also extend the TestCase 362 class and implement the following class attributes which define the encoded 363 version of certain protocol buffers: 364 365 encoded_partial: 366 <OptionalMessage 367 double_value: 1.23 368 int64_value: -100000000000 369 string_value: u"a string" 370 enum_value: OptionalMessage.SimpleEnum.VAL2 371 > 372 373 encoded_full: 374 <OptionalMessage 375 double_value: 1.23 376 float_value: -2.5 377 int64_value: -100000000000 378 uint64_value: 102020202020 379 int32_value: 1020 380 bool_value: true 381 string_value: u"a string\u044f" 382 bytes_value: b"a bytes\xff\xfe" 383 enum_value: OptionalMessage.SimpleEnum.VAL2 384 > 385 386 encoded_repeated: 387 <RepeatedMessage 388 double_value: [1.23, 2.3] 389 float_value: [-2.5, 0.5] 390 int64_value: [-100000000000, 20] 391 uint64_value: [102020202020, 10] 392 int32_value: [1020, 718] 393 bool_value: [true, false] 394 string_value: [u"a string\u044f", u"another string"] 395 bytes_value: [b"a bytes\xff\xfe", b"another bytes"] 396 enum_value: [OptionalMessage.SimpleEnum.VAL2, 397 OptionalMessage.SimpleEnum.VAL 1] 398 > 399 400 encoded_nested: 401 <HasNestedMessage 402 nested: <NestedMessage 403 a_value: "a string" 404 > 405 > 406 407 encoded_repeated_nested: 408 <HasNestedMessage 409 repeated_nested: [ 410 <NestedMessage a_value: "a string">, 411 <NestedMessage a_value: "another string"> 412 ] 413 > 414 415 unexpected_tag_message: 416 An encoded message that has an undefined tag or number in the stream. 417 418 encoded_default_assigned: 419 <HasDefault 420 a_value: "a default" 421 > 422 423 encoded_nested_empty: 424 <HasOptionalNestedMessage 425 nested: <OptionalMessage> 426 > 427 428 encoded_invalid_enum: 429 <OptionalMessage 430 enum_value: (invalid value for serialization type) 431 > 432 """ 433 434 encoded_empty_message = '' 435 436 def testEncodeInvalidMessage(self): 437 message = NestedMessage() 438 self.assertRaises(messages.ValidationError, 439 self.PROTOLIB.encode_message, message) 440 441 def CompareEncoded(self, expected_encoded, actual_encoded): 442 """Compare two encoded protocol values. 443 444 Can be overridden by sub-classes to special case comparison. 445 For example, to eliminate white space from output that is not 446 relevant to encoding. 447 448 Args: 449 expected_encoded: Expected string encoded value. 450 actual_encoded: Actual string encoded value. 451 """ 452 self.assertEquals(expected_encoded, actual_encoded) 453 454 def EncodeDecode(self, encoded, expected_message): 455 message = self.PROTOLIB.decode_message(type(expected_message), encoded) 456 self.assertEquals(expected_message, message) 457 self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message)) 458 459 def testEmptyMessage(self): 460 self.EncodeDecode(self.encoded_empty_message, OptionalMessage()) 461 462 def testPartial(self): 463 """Test message with a few values set.""" 464 message = OptionalMessage() 465 message.double_value = 1.23 466 message.int64_value = -100000000000 467 message.int32_value = 1020 468 message.string_value = u'a string' 469 message.enum_value = OptionalMessage.SimpleEnum.VAL2 470 471 self.EncodeDecode(self.encoded_partial, message) 472 473 def testFull(self): 474 """Test all types.""" 475 message = OptionalMessage() 476 message.double_value = 1.23 477 message.float_value = -2.5 478 message.int64_value = -100000000000 479 message.uint64_value = 102020202020 480 message.int32_value = 1020 481 message.bool_value = True 482 message.string_value = u'a string\u044f' 483 message.bytes_value = b'a bytes\xff\xfe' 484 message.enum_value = OptionalMessage.SimpleEnum.VAL2 485 486 self.EncodeDecode(self.encoded_full, message) 487 488 def testRepeated(self): 489 """Test repeated fields.""" 490 message = RepeatedMessage() 491 message.double_value = [1.23, 2.3] 492 message.float_value = [-2.5, 0.5] 493 message.int64_value = [-100000000000, 20] 494 message.uint64_value = [102020202020, 10] 495 message.int32_value = [1020, 718] 496 message.bool_value = [True, False] 497 message.string_value = [u'a string\u044f', u'another string'] 498 message.bytes_value = [b'a bytes\xff\xfe', b'another bytes'] 499 message.enum_value = [RepeatedMessage.SimpleEnum.VAL2, 500 RepeatedMessage.SimpleEnum.VAL1] 501 502 self.EncodeDecode(self.encoded_repeated, message) 503 504 def testNested(self): 505 """Test nested messages.""" 506 nested_message = NestedMessage() 507 nested_message.a_value = u'a string' 508 509 message = HasNestedMessage() 510 message.nested = nested_message 511 512 self.EncodeDecode(self.encoded_nested, message) 513 514 def testRepeatedNested(self): 515 """Test repeated nested messages.""" 516 nested_message1 = NestedMessage() 517 nested_message1.a_value = u'a string' 518 nested_message2 = NestedMessage() 519 nested_message2.a_value = u'another string' 520 521 message = HasNestedMessage() 522 message.repeated_nested = [nested_message1, nested_message2] 523 524 self.EncodeDecode(self.encoded_repeated_nested, message) 525 526 def testStringTypes(self): 527 """Test that encoding str on StringField works.""" 528 message = OptionalMessage() 529 message.string_value = 'Latin' 530 self.EncodeDecode(self.encoded_string_types, message) 531 532 def testEncodeUninitialized(self): 533 """Test that cannot encode uninitialized message.""" 534 required = NestedMessage() 535 self.assertRaisesWithRegexpMatch(messages.ValidationError, 536 "Message NestedMessage is missing " 537 "required field a_value", 538 self.PROTOLIB.encode_message, 539 required) 540 541 def testUnexpectedField(self): 542 """Test decoding and encoding unexpected fields.""" 543 loaded_message = self.PROTOLIB.decode_message(OptionalMessage, 544 self.unexpected_tag_message) 545 # Message should be equal to an empty message, since unknown values aren't 546 # included in equality. 547 self.assertEquals(OptionalMessage(), loaded_message) 548 # Verify that the encoded message matches the source, including the 549 # unknown value. 550 self.assertEquals(self.unexpected_tag_message, 551 self.PROTOLIB.encode_message(loaded_message)) 552 553 def testDoNotSendDefault(self): 554 """Test that default is not sent when nothing is assigned.""" 555 self.EncodeDecode(self.encoded_empty_message, HasDefault()) 556 557 def testSendDefaultExplicitlyAssigned(self): 558 """Test that default is sent when explcitly assigned.""" 559 message = HasDefault() 560 561 message.a_value = HasDefault.a_value.default 562 563 self.EncodeDecode(self.encoded_default_assigned, message) 564 565 def testEncodingNestedEmptyMessage(self): 566 """Test encoding a nested empty message.""" 567 message = HasOptionalNestedMessage() 568 message.nested = OptionalMessage() 569 570 self.EncodeDecode(self.encoded_nested_empty, message) 571 572 def testEncodingRepeatedNestedEmptyMessage(self): 573 """Test encoding a nested empty message.""" 574 message = HasOptionalNestedMessage() 575 message.repeated_nested = [OptionalMessage(), OptionalMessage()] 576 577 self.EncodeDecode(self.encoded_repeated_nested_empty, message) 578 579 def testContentType(self): 580 self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str)) 581 582 def testDecodeInvalidEnumType(self): 583 self.assertRaisesWithRegexpMatch(messages.DecodeError, 584 'Invalid enum value ', 585 self.PROTOLIB.decode_message, 586 OptionalMessage, 587 self.encoded_invalid_enum) 588 589 def testDateTimeNoTimeZone(self): 590 """Test that DateTimeFields are encoded/decoded correctly.""" 591 592 class MyMessage(messages.Message): 593 value = message_types.DateTimeField(1) 594 595 value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000) 596 message = MyMessage(value=value) 597 decoded = self.PROTOLIB.decode_message( 598 MyMessage, self.PROTOLIB.encode_message(message)) 599 self.assertEquals(decoded.value, value) 600 601 def testDateTimeWithTimeZone(self): 602 """Test DateTimeFields with time zones.""" 603 604 class MyMessage(messages.Message): 605 value = message_types.DateTimeField(1) 606 607 value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000, 608 util.TimeZoneOffset(8 * 60)) 609 message = MyMessage(value=value) 610 decoded = self.PROTOLIB.decode_message( 611 MyMessage, self.PROTOLIB.encode_message(message)) 612 self.assertEquals(decoded.value, value) 613 614 615def do_with(context, function, *args, **kwargs): 616 """Simulate a with statement. 617 618 Avoids need to import with from future. 619 620 Does not support simulation of 'as'. 621 622 Args: 623 context: Context object normally used with 'with'. 624 function: Callable to evoke. Replaces with-block. 625 """ 626 context.__enter__() 627 try: 628 function(*args, **kwargs) 629 except: 630 context.__exit__(*sys.exc_info()) 631 finally: 632 context.__exit__(None, None, None) 633 634 635def pick_unused_port(): 636 """Find an unused port to use in tests. 637 638 Derived from Damon Kohlers example: 639 640 http://code.activestate.com/recipes/531822-pick-unused-port 641 """ 642 temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 643 try: 644 temp.bind(('localhost', 0)) 645 port = temp.getsockname()[1] 646 finally: 647 temp.close() 648 return port 649 650 651def get_module_name(module_attribute): 652 """Get the module name. 653 654 Args: 655 module_attribute: An attribute of the module. 656 657 Returns: 658 The fully qualified module name or simple module name where 659 'module_attribute' is defined if the module name is "__main__". 660 """ 661 if module_attribute.__module__ == '__main__': 662 module_file = inspect.getfile(module_attribute) 663 default = os.path.basename(module_file).split('.')[0] 664 return default 665 else: 666 return module_attribute.__module__ 667