1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Test utilities for message testing. 19 20Includes module interface test to ensure that public parts of module are 21correctly declared in __all__. 22 23Includes message types that correspond to those defined in 24services_test.proto. 25 26Includes additional test utilities to make sure encoding/decoding libraries 27conform. 28""" 29import cgi 30import datetime 31import inspect 32import os 33import re 34import socket 35import types 36import unittest 37 38import six 39from six.moves import range # pylint: disable=redefined-builtin 40 41from apitools.base.protorpclite import message_types 42from apitools.base.protorpclite import messages 43from apitools.base.protorpclite import util 44 45# Unicode of the word "Russian" in cyrillic. 46RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439' 47 48# All characters binary value interspersed with nulls. 49BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256)) 50 51 52class TestCase(unittest.TestCase): 53 54 def assertRaisesWithRegexpMatch(self, 55 exception, 56 regexp, 57 function, 58 *params, 59 **kwargs): 60 """Check that exception is raised and text matches regular expression. 61 62 Args: 63 exception: Exception type that is expected. 64 regexp: String regular expression that is expected in error message. 65 function: Callable to test. 66 params: Parameters to forward to function. 67 kwargs: Keyword arguments to forward to function. 68 """ 69 try: 70 function(*params, **kwargs) 71 self.fail('Expected exception %s was not raised' % 72 exception.__name__) 73 except exception as err: 74 match = bool(re.match(regexp, str(err))) 75 self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp, 76 err)) 77 78 def assertHeaderSame(self, header1, header2): 79 """Check that two HTTP headers are the same. 80 81 Args: 82 header1: Header value string 1. 83 header2: header value string 2. 84 """ 85 value1, params1 = cgi.parse_header(header1) 86 value2, params2 = cgi.parse_header(header2) 87 self.assertEqual(value1, value2) 88 self.assertEqual(params1, params2) 89 90 def assertIterEqual(self, iter1, iter2): 91 """Check two iterators or iterables are equal independent of order. 92 93 Similar to Python 2.7 assertItemsEqual. Named differently in order to 94 avoid potential conflict. 95 96 Args: 97 iter1: An iterator or iterable. 98 iter2: An iterator or iterable. 99 """ 100 list1 = list(iter1) 101 list2 = list(iter2) 102 103 unmatched1 = list() 104 105 while list1: 106 item1 = list1[0] 107 del list1[0] 108 for index in range(len(list2)): 109 if item1 == list2[index]: 110 del list2[index] 111 break 112 else: 113 unmatched1.append(item1) 114 115 error_message = [] 116 for item in unmatched1: 117 error_message.append( 118 ' Item from iter1 not found in iter2: %r' % item) 119 for item in list2: 120 error_message.append( 121 ' Item from iter2 not found in iter1: %r' % item) 122 if error_message: 123 self.fail('Collections not equivalent:\n' + 124 '\n'.join(error_message)) 125 126 127class ModuleInterfaceTest(object): 128 """Test to ensure module interface is carefully constructed. 129 130 A module interface is the set of public objects listed in the 131 module __all__ attribute. Modules that that are considered public 132 should have this interface carefully declared. At all times, the 133 __all__ attribute should have objects intended to be publically 134 used and all other objects in the module should be considered 135 unused. 136 137 Protected attributes (those beginning with '_') and other imported 138 modules should not be part of this set of variables. An exception 139 is for variables that begin and end with '__' which are implicitly 140 part of the interface (eg. __name__, __file__, __all__ itself, 141 etc.). 142 143 Modules that are imported in to the tested modules are an 144 exception and may be left out of the __all__ definition. The test 145 is done by checking the value of what would otherwise be a public 146 name and not allowing it to be exported if it is an instance of a 147 module. Modules that are explicitly exported are for the time 148 being not permitted. 149 150 To use this test class a module should define a new class that 151 inherits first from ModuleInterfaceTest and then from 152 test_util.TestCase. No other tests should be added to this test 153 case, making the order of inheritance less important, but if setUp 154 for some reason is overidden, it is important that 155 ModuleInterfaceTest is first in the list so that its setUp method 156 is invoked. 157 158 Multiple inheritance is required so that ModuleInterfaceTest is 159 not itself a test, and is not itself executed as one. 160 161 The test class is expected to have the following class attributes 162 defined: 163 164 MODULE: A reference to the module that is being validated for interface 165 correctness. 166 167 Example: 168 Module definition (hello.py): 169 170 import sys 171 172 __all__ = ['hello'] 173 174 def _get_outputter(): 175 return sys.stdout 176 177 def hello(): 178 _get_outputter().write('Hello\n') 179 180 Test definition: 181 182 import unittest 183 from protorpc import test_util 184 185 import hello 186 187 class ModuleInterfaceTest(test_util.ModuleInterfaceTest, 188 test_util.TestCase): 189 190 MODULE = hello 191 192 193 class HelloTest(test_util.TestCase): 194 ... Test 'hello' module ... 195 196 197 if __name__ == '__main__': 198 unittest.main() 199 200 """ 201 202 def setUp(self): 203 """Set up makes sure that MODULE and IMPORTED_MODULES is defined. 204 205 This is a basic configuration test for the test itself so does not 206 get it's own test case. 207 """ 208 if not hasattr(self, 'MODULE'): 209 self.fail( 210 "You must define 'MODULE' on ModuleInterfaceTest sub-class " 211 "%s." % type(self).__name__) 212 213 def testAllExist(self): 214 """Test that all attributes defined in __all__ exist.""" 215 missing_attributes = [] 216 for attribute in self.MODULE.__all__: 217 if not hasattr(self.MODULE, attribute): 218 missing_attributes.append(attribute) 219 if missing_attributes: 220 self.fail('%s of __all__ are not defined in module.' % 221 missing_attributes) 222 223 def testAllExported(self): 224 """Test that all public attributes not imported are in __all__.""" 225 missing_attributes = [] 226 for attribute in dir(self.MODULE): 227 if not attribute.startswith('_'): 228 if (attribute not in self.MODULE.__all__ and 229 not isinstance(getattr(self.MODULE, attribute), 230 types.ModuleType) and 231 attribute != 'with_statement'): 232 missing_attributes.append(attribute) 233 if missing_attributes: 234 self.fail('%s are not modules and not defined in __all__.' % 235 missing_attributes) 236 237 def testNoExportedProtectedVariables(self): 238 """Test that there are no protected variables listed in __all__.""" 239 protected_variables = [] 240 for attribute in self.MODULE.__all__: 241 if attribute.startswith('_'): 242 protected_variables.append(attribute) 243 if protected_variables: 244 self.fail('%s are protected variables and may not be exported.' % 245 protected_variables) 246 247 def testNoExportedModules(self): 248 """Test that no modules exist in __all__.""" 249 exported_modules = [] 250 for attribute in self.MODULE.__all__: 251 try: 252 value = getattr(self.MODULE, attribute) 253 except AttributeError: 254 # This is a different error case tested for in testAllExist. 255 pass 256 else: 257 if isinstance(value, types.ModuleType): 258 exported_modules.append(attribute) 259 if exported_modules: 260 self.fail('%s are modules and may not be exported.' % 261 exported_modules) 262 263 264class NestedMessage(messages.Message): 265 """Simple message that gets nested in another message.""" 266 267 a_value = messages.StringField(1, required=True) 268 269 270class HasNestedMessage(messages.Message): 271 """Message that has another message nested in it.""" 272 273 nested = messages.MessageField(NestedMessage, 1) 274 repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True) 275 276 277class HasDefault(messages.Message): 278 """Has a default value.""" 279 280 a_value = messages.StringField(1, default=u'a default') 281 282 283class OptionalMessage(messages.Message): 284 """Contains all message types.""" 285 286 class SimpleEnum(messages.Enum): 287 """Simple enumeration type.""" 288 VAL1 = 1 289 VAL2 = 2 290 291 double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE) 292 float_value = messages.FloatField(2, variant=messages.Variant.FLOAT) 293 int64_value = messages.IntegerField(3, variant=messages.Variant.INT64) 294 uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64) 295 int32_value = messages.IntegerField(5, variant=messages.Variant.INT32) 296 bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL) 297 string_value = messages.StringField(7, variant=messages.Variant.STRING) 298 bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES) 299 enum_value = messages.EnumField(SimpleEnum, 10) 300 301 302class RepeatedMessage(messages.Message): 303 """Contains all message types as repeated fields.""" 304 305 class SimpleEnum(messages.Enum): 306 """Simple enumeration type.""" 307 VAL1 = 1 308 VAL2 = 2 309 310 double_value = messages.FloatField(1, 311 variant=messages.Variant.DOUBLE, 312 repeated=True) 313 float_value = messages.FloatField(2, 314 variant=messages.Variant.FLOAT, 315 repeated=True) 316 int64_value = messages.IntegerField(3, 317 variant=messages.Variant.INT64, 318 repeated=True) 319 uint64_value = messages.IntegerField(4, 320 variant=messages.Variant.UINT64, 321 repeated=True) 322 int32_value = messages.IntegerField(5, 323 variant=messages.Variant.INT32, 324 repeated=True) 325 bool_value = messages.BooleanField(6, 326 variant=messages.Variant.BOOL, 327 repeated=True) 328 string_value = messages.StringField(7, 329 variant=messages.Variant.STRING, 330 repeated=True) 331 bytes_value = messages.BytesField(8, 332 variant=messages.Variant.BYTES, 333 repeated=True) 334 enum_value = messages.EnumField(SimpleEnum, 335 10, 336 repeated=True) 337 338 339class HasOptionalNestedMessage(messages.Message): 340 341 nested = messages.MessageField(OptionalMessage, 1) 342 repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True) 343 344 345# pylint:disable=anomalous-unicode-escape-in-string 346class ProtoConformanceTestBase(object): 347 """Protocol conformance test base class. 348 349 Each supported protocol should implement two methods that support encoding 350 and decoding of Message objects in that format: 351 352 encode_message(message) - Serialize to encoding. 353 encode_message(message, encoded_message) - Deserialize from encoding. 354 355 Tests for the modules where these functions are implemented should extend 356 this class in order to support basic behavioral expectations. This ensures 357 that protocols correctly encode and decode message transparently to the 358 caller. 359 360 In order to support these test, the base class should also extend 361 the TestCase class and implement the following class attributes 362 which define the encoded version of certain protocol buffers: 363 364 encoded_partial: 365 <OptionalMessage 366 double_value: 1.23 367 int64_value: -100000000000 368 string_value: u"a string" 369 enum_value: OptionalMessage.SimpleEnum.VAL2 370 > 371 372 encoded_full: 373 <OptionalMessage 374 double_value: 1.23 375 float_value: -2.5 376 int64_value: -100000000000 377 uint64_value: 102020202020 378 int32_value: 1020 379 bool_value: true 380 string_value: u"a string\u044f" 381 bytes_value: b"a bytes\xff\xfe" 382 enum_value: OptionalMessage.SimpleEnum.VAL2 383 > 384 385 encoded_repeated: 386 <RepeatedMessage 387 double_value: [1.23, 2.3] 388 float_value: [-2.5, 0.5] 389 int64_value: [-100000000000, 20] 390 uint64_value: [102020202020, 10] 391 int32_value: [1020, 718] 392 bool_value: [true, false] 393 string_value: [u"a string\u044f", u"another string"] 394 bytes_value: [b"a bytes\xff\xfe", b"another bytes"] 395 enum_value: [OptionalMessage.SimpleEnum.VAL2, 396 OptionalMessage.SimpleEnum.VAL 1] 397 > 398 399 encoded_nested: 400 <HasNestedMessage 401 nested: <NestedMessage 402 a_value: "a string" 403 > 404 > 405 406 encoded_repeated_nested: 407 <HasNestedMessage 408 repeated_nested: [ 409 <NestedMessage a_value: "a string">, 410 <NestedMessage a_value: "another string"> 411 ] 412 > 413 414 unexpected_tag_message: 415 An encoded message that has an undefined tag or number in the stream. 416 417 encoded_default_assigned: 418 <HasDefault 419 a_value: "a default" 420 > 421 422 encoded_nested_empty: 423 <HasOptionalNestedMessage 424 nested: <OptionalMessage> 425 > 426 427 encoded_invalid_enum: 428 <OptionalMessage 429 enum_value: (invalid value for serialization type) 430 > 431 """ 432 433 encoded_empty_message = '' 434 435 def testEncodeInvalidMessage(self): 436 message = NestedMessage() 437 self.assertRaises(messages.ValidationError, 438 self.PROTOLIB.encode_message, message) 439 440 def CompareEncoded(self, expected_encoded, actual_encoded): 441 """Compare two encoded protocol values. 442 443 Can be overridden by sub-classes to special case comparison. 444 For example, to eliminate white space from output that is not 445 relevant to encoding. 446 447 Args: 448 expected_encoded: Expected string encoded value. 449 actual_encoded: Actual string encoded value. 450 """ 451 self.assertEquals(expected_encoded, actual_encoded) 452 453 def EncodeDecode(self, encoded, expected_message): 454 message = self.PROTOLIB.decode_message(type(expected_message), encoded) 455 self.assertEquals(expected_message, message) 456 self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message)) 457 458 def testEmptyMessage(self): 459 self.EncodeDecode(self.encoded_empty_message, OptionalMessage()) 460 461 def testPartial(self): 462 """Test message with a few values set.""" 463 message = OptionalMessage() 464 message.double_value = 1.23 465 message.int64_value = -100000000000 466 message.int32_value = 1020 467 message.string_value = u'a string' 468 message.enum_value = OptionalMessage.SimpleEnum.VAL2 469 470 self.EncodeDecode(self.encoded_partial, message) 471 472 def testFull(self): 473 """Test all types.""" 474 message = OptionalMessage() 475 message.double_value = 1.23 476 message.float_value = -2.5 477 message.int64_value = -100000000000 478 message.uint64_value = 102020202020 479 message.int32_value = 1020 480 message.bool_value = True 481 message.string_value = u'a string\u044f' 482 message.bytes_value = b'a bytes\xff\xfe' 483 message.enum_value = OptionalMessage.SimpleEnum.VAL2 484 485 self.EncodeDecode(self.encoded_full, message) 486 487 def testRepeated(self): 488 """Test repeated fields.""" 489 message = RepeatedMessage() 490 message.double_value = [1.23, 2.3] 491 message.float_value = [-2.5, 0.5] 492 message.int64_value = [-100000000000, 20] 493 message.uint64_value = [102020202020, 10] 494 message.int32_value = [1020, 718] 495 message.bool_value = [True, False] 496 message.string_value = [u'a string\u044f', u'another string'] 497 message.bytes_value = [b'a bytes\xff\xfe', b'another bytes'] 498 message.enum_value = [RepeatedMessage.SimpleEnum.VAL2, 499 RepeatedMessage.SimpleEnum.VAL1] 500 501 self.EncodeDecode(self.encoded_repeated, message) 502 503 def testNested(self): 504 """Test nested messages.""" 505 nested_message = NestedMessage() 506 nested_message.a_value = u'a string' 507 508 message = HasNestedMessage() 509 message.nested = nested_message 510 511 self.EncodeDecode(self.encoded_nested, message) 512 513 def testRepeatedNested(self): 514 """Test repeated nested messages.""" 515 nested_message1 = NestedMessage() 516 nested_message1.a_value = u'a string' 517 nested_message2 = NestedMessage() 518 nested_message2.a_value = u'another string' 519 520 message = HasNestedMessage() 521 message.repeated_nested = [nested_message1, nested_message2] 522 523 self.EncodeDecode(self.encoded_repeated_nested, message) 524 525 def testStringTypes(self): 526 """Test that encoding str on StringField works.""" 527 message = OptionalMessage() 528 message.string_value = 'Latin' 529 self.EncodeDecode(self.encoded_string_types, message) 530 531 def testEncodeUninitialized(self): 532 """Test that cannot encode uninitialized message.""" 533 required = NestedMessage() 534 self.assertRaisesWithRegexpMatch(messages.ValidationError, 535 "Message NestedMessage is missing " 536 "required field a_value", 537 self.PROTOLIB.encode_message, 538 required) 539 540 def testUnexpectedField(self): 541 """Test decoding and encoding unexpected fields.""" 542 loaded_message = self.PROTOLIB.decode_message( 543 OptionalMessage, self.unexpected_tag_message) 544 # Message should be equal to an empty message, since unknown 545 # values aren't included in equality. 546 self.assertEquals(OptionalMessage(), loaded_message) 547 # Verify that the encoded message matches the source, including the 548 # unknown value. 549 self.assertEquals(self.unexpected_tag_message, 550 self.PROTOLIB.encode_message(loaded_message)) 551 552 def testDoNotSendDefault(self): 553 """Test that default is not sent when nothing is assigned.""" 554 self.EncodeDecode(self.encoded_empty_message, HasDefault()) 555 556 def testSendDefaultExplicitlyAssigned(self): 557 """Test that default is sent when explcitly assigned.""" 558 message = HasDefault() 559 560 message.a_value = HasDefault.a_value.default 561 562 self.EncodeDecode(self.encoded_default_assigned, message) 563 564 def testEncodingNestedEmptyMessage(self): 565 """Test encoding a nested empty message.""" 566 message = HasOptionalNestedMessage() 567 message.nested = OptionalMessage() 568 569 self.EncodeDecode(self.encoded_nested_empty, message) 570 571 def testEncodingRepeatedNestedEmptyMessage(self): 572 """Test encoding a nested empty message.""" 573 message = HasOptionalNestedMessage() 574 message.repeated_nested = [OptionalMessage(), OptionalMessage()] 575 576 self.EncodeDecode(self.encoded_repeated_nested_empty, message) 577 578 def testContentType(self): 579 self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str)) 580 581 def testDecodeInvalidEnumType(self): 582 # Since protos need to be able to add new enums, a message should be 583 # successfully decoded even if the enum value is invalid. Encoding the 584 # decoded message should result in equivalence with the original 585 # encoded message containing an invalid enum. 586 decoded = self.PROTOLIB.decode_message(OptionalMessage, 587 self.encoded_invalid_enum) 588 message = OptionalMessage() 589 self.assertEqual(message, decoded) 590 encoded = self.PROTOLIB.encode_message(decoded) 591 self.assertEqual(self.encoded_invalid_enum, encoded) 592 593 def testDateTimeNoTimeZone(self): 594 """Test that DateTimeFields are encoded/decoded correctly.""" 595 596 class MyMessage(messages.Message): 597 value = message_types.DateTimeField(1) 598 599 value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000) 600 message = MyMessage(value=value) 601 decoded = self.PROTOLIB.decode_message( 602 MyMessage, self.PROTOLIB.encode_message(message)) 603 self.assertEquals(decoded.value, value) 604 605 def testDateTimeWithTimeZone(self): 606 """Test DateTimeFields with time zones.""" 607 608 class MyMessage(messages.Message): 609 value = message_types.DateTimeField(1) 610 611 value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000, 612 util.TimeZoneOffset(8 * 60)) 613 message = MyMessage(value=value) 614 decoded = self.PROTOLIB.decode_message( 615 MyMessage, self.PROTOLIB.encode_message(message)) 616 self.assertEquals(decoded.value, value) 617 618 619def pick_unused_port(): 620 """Find an unused port to use in tests. 621 622 Derived from Damon Kohlers example: 623 624 http://code.activestate.com/recipes/531822-pick-unused-port 625 """ 626 temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 627 try: 628 temp.bind(('localhost', 0)) 629 port = temp.getsockname()[1] 630 finally: 631 temp.close() 632 return port 633 634 635def get_module_name(module_attribute): 636 """Get the module name. 637 638 Args: 639 module_attribute: An attribute of the module. 640 641 Returns: 642 The fully qualified module name or simple module name where 643 'module_attribute' is defined if the module name is "__main__". 644 """ 645 if module_attribute.__module__ == '__main__': 646 module_file = inspect.getfile(module_attribute) 647 default = os.path.basename(module_file).split('.')[0] 648 return default 649 return module_attribute.__module__ 650