1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Test utilities for message testing. 19 20Includes module interface test to ensure that public parts of module are 21correctly declared in __all__. 22 23Includes message types that correspond to those defined in 24services_test.proto. 25 26Includes additional test utilities to make sure encoding/decoding libraries 27conform. 28""" 29import cgi 30import datetime 31import inspect 32import os 33import re 34import socket 35import types 36import unittest 37 38import six 39from six.moves import range # pylint: disable=redefined-builtin 40 41from apitools.base.protorpclite import message_types 42from apitools.base.protorpclite import messages 43from apitools.base.protorpclite import util 44 45# Unicode of the word "Russian" in cyrillic. 46RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439' 47 48# All characters binary value interspersed with nulls. 49BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256)) 50 51 52class TestCase(unittest.TestCase): 53 54 def assertRaisesWithRegexpMatch(self, 55 exception, 56 regexp, 57 function, 58 *params, 59 **kwargs): 60 """Check that exception is raised and text matches regular expression. 61 62 Args: 63 exception: Exception type that is expected. 64 regexp: String regular expression that is expected in error message. 65 function: Callable to test. 66 params: Parameters to forward to function. 67 kwargs: Keyword arguments to forward to function. 68 """ 69 try: 70 function(*params, **kwargs) 71 self.fail('Expected exception %s was not raised' % 72 exception.__name__) 73 except exception as err: 74 match = bool(re.match(regexp, str(err))) 75 self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp, 76 err)) 77 78 def assertHeaderSame(self, header1, header2): 79 """Check that two HTTP headers are the same. 80 81 Args: 82 header1: Header value string 1. 83 header2: header value string 2. 84 """ 85 value1, params1 = cgi.parse_header(header1) 86 value2, params2 = cgi.parse_header(header2) 87 self.assertEqual(value1, value2) 88 self.assertEqual(params1, params2) 89 90 def assertIterEqual(self, iter1, iter2): 91 """Check two iterators or iterables are equal independent of order. 92 93 Similar to Python 2.7 assertItemsEqual. Named differently in order to 94 avoid potential conflict. 95 96 Args: 97 iter1: An iterator or iterable. 98 iter2: An iterator or iterable. 99 """ 100 list1 = list(iter1) 101 list2 = list(iter2) 102 103 unmatched1 = list() 104 105 while list1: 106 item1 = list1[0] 107 del list1[0] 108 for index in range(len(list2)): 109 if item1 == list2[index]: 110 del list2[index] 111 break 112 else: 113 unmatched1.append(item1) 114 115 error_message = [] 116 for item in unmatched1: 117 error_message.append( 118 ' Item from iter1 not found in iter2: %r' % item) 119 for item in list2: 120 error_message.append( 121 ' Item from iter2 not found in iter1: %r' % item) 122 if error_message: 123 self.fail('Collections not equivalent:\n' + 124 '\n'.join(error_message)) 125 126 127class ModuleInterfaceTest(object): 128 """Test to ensure module interface is carefully constructed. 129 130 A module interface is the set of public objects listed in the 131 module __all__ attribute. Modules that that are considered public 132 should have this interface carefully declared. At all times, the 133 __all__ attribute should have objects intended to be publically 134 used and all other objects in the module should be considered 135 unused. 136 137 Protected attributes (those beginning with '_') and other imported 138 modules should not be part of this set of variables. An exception 139 is for variables that begin and end with '__' which are implicitly 140 part of the interface (eg. __name__, __file__, __all__ itself, 141 etc.). 142 143 Modules that are imported in to the tested modules are an 144 exception and may be left out of the __all__ definition. The test 145 is done by checking the value of what would otherwise be a public 146 name and not allowing it to be exported if it is an instance of a 147 module. Modules that are explicitly exported are for the time 148 being not permitted. 149 150 To use this test class a module should define a new class that 151 inherits first from ModuleInterfaceTest and then from 152 test_util.TestCase. No other tests should be added to this test 153 case, making the order of inheritance less important, but if setUp 154 for some reason is overidden, it is important that 155 ModuleInterfaceTest is first in the list so that its setUp method 156 is invoked. 157 158 Multiple inheritance is required so that ModuleInterfaceTest is 159 not itself a test, and is not itself executed as one. 160 161 The test class is expected to have the following class attributes 162 defined: 163 164 MODULE: A reference to the module that is being validated for interface 165 correctness. 166 167 Example: 168 Module definition (hello.py): 169 170 import sys 171 172 __all__ = ['hello'] 173 174 def _get_outputter(): 175 return sys.stdout 176 177 def hello(): 178 _get_outputter().write('Hello\n') 179 180 Test definition: 181 182 import unittest 183 from protorpc import test_util 184 185 import hello 186 187 class ModuleInterfaceTest(test_util.ModuleInterfaceTest, 188 test_util.TestCase): 189 190 MODULE = hello 191 192 193 class HelloTest(test_util.TestCase): 194 ... Test 'hello' module ... 195 196 197 if __name__ == '__main__': 198 unittest.main() 199 200 """ 201 202 def setUp(self): 203 """Set up makes sure that MODULE and IMPORTED_MODULES is defined. 204 205 This is a basic configuration test for the test itself so does not 206 get it's own test case. 207 """ 208 if not hasattr(self, 'MODULE'): 209 self.fail( 210 "You must define 'MODULE' on ModuleInterfaceTest sub-class " 211 "%s." % type(self).__name__) 212 213 def testAllExist(self): 214 """Test that all attributes defined in __all__ exist.""" 215 missing_attributes = [] 216 for attribute in self.MODULE.__all__: 217 if not hasattr(self.MODULE, attribute): 218 missing_attributes.append(attribute) 219 if missing_attributes: 220 self.fail('%s of __all__ are not defined in module.' % 221 missing_attributes) 222 223 def testAllExported(self): 224 """Test that all public attributes not imported are in __all__.""" 225 missing_attributes = [] 226 for attribute in dir(self.MODULE): 227 if not attribute.startswith('_'): 228 if (attribute not in self.MODULE.__all__ and 229 not isinstance(getattr(self.MODULE, attribute), 230 types.ModuleType) and 231 attribute != 'with_statement'): 232 missing_attributes.append(attribute) 233 if missing_attributes: 234 self.fail('%s are not modules and not defined in __all__.' % 235 missing_attributes) 236 237 def testNoExportedProtectedVariables(self): 238 """Test that there are no protected variables listed in __all__.""" 239 protected_variables = [] 240 for attribute in self.MODULE.__all__: 241 if attribute.startswith('_'): 242 protected_variables.append(attribute) 243 if protected_variables: 244 self.fail('%s are protected variables and may not be exported.' % 245 protected_variables) 246 247 def testNoExportedModules(self): 248 """Test that no modules exist in __all__.""" 249 exported_modules = [] 250 for attribute in self.MODULE.__all__: 251 try: 252 value = getattr(self.MODULE, attribute) 253 except AttributeError: 254 # This is a different error case tested for in testAllExist. 255 pass 256 else: 257 if isinstance(value, types.ModuleType): 258 exported_modules.append(attribute) 259 if exported_modules: 260 self.fail('%s are modules and may not be exported.' % 261 exported_modules) 262 263 264class NestedMessage(messages.Message): 265 """Simple message that gets nested in another message.""" 266 267 a_value = messages.StringField(1, required=True) 268 269 270class HasNestedMessage(messages.Message): 271 """Message that has another message nested in it.""" 272 273 nested = messages.MessageField(NestedMessage, 1) 274 repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True) 275 276 277class HasDefault(messages.Message): 278 """Has a default value.""" 279 280 a_value = messages.StringField(1, default=u'a default') 281 282 283class OptionalMessage(messages.Message): 284 """Contains all message types.""" 285 286 class SimpleEnum(messages.Enum): 287 """Simple enumeration type.""" 288 VAL1 = 1 289 VAL2 = 2 290 291 double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE) 292 float_value = messages.FloatField(2, variant=messages.Variant.FLOAT) 293 int64_value = messages.IntegerField(3, variant=messages.Variant.INT64) 294 uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64) 295 int32_value = messages.IntegerField(5, variant=messages.Variant.INT32) 296 bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL) 297 string_value = messages.StringField(7, variant=messages.Variant.STRING) 298 bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES) 299 enum_value = messages.EnumField(SimpleEnum, 10) 300 301 302class RepeatedMessage(messages.Message): 303 """Contains all message types as repeated fields.""" 304 305 class SimpleEnum(messages.Enum): 306 """Simple enumeration type.""" 307 VAL1 = 1 308 VAL2 = 2 309 310 double_value = messages.FloatField(1, 311 variant=messages.Variant.DOUBLE, 312 repeated=True) 313 float_value = messages.FloatField(2, 314 variant=messages.Variant.FLOAT, 315 repeated=True) 316 int64_value = messages.IntegerField(3, 317 variant=messages.Variant.INT64, 318 repeated=True) 319 uint64_value = messages.IntegerField(4, 320 variant=messages.Variant.UINT64, 321 repeated=True) 322 int32_value = messages.IntegerField(5, 323 variant=messages.Variant.INT32, 324 repeated=True) 325 bool_value = messages.BooleanField(6, 326 variant=messages.Variant.BOOL, 327 repeated=True) 328 string_value = messages.StringField(7, 329 variant=messages.Variant.STRING, 330 repeated=True) 331 bytes_value = messages.BytesField(8, 332 variant=messages.Variant.BYTES, 333 repeated=True) 334 enum_value = messages.EnumField(SimpleEnum, 335 10, 336 repeated=True) 337 338 339class HasOptionalNestedMessage(messages.Message): 340 341 nested = messages.MessageField(OptionalMessage, 1) 342 repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True) 343 344 345# pylint:disable=anomalous-unicode-escape-in-string 346class ProtoConformanceTestBase(object): 347 """Protocol conformance test base class. 348 349 Each supported protocol should implement two methods that support encoding 350 and decoding of Message objects in that format: 351 352 encode_message(message) - Serialize to encoding. 353 encode_message(message, encoded_message) - Deserialize from encoding. 354 355 Tests for the modules where these functions are implemented should extend 356 this class in order to support basic behavioral expectations. This ensures 357 that protocols correctly encode and decode message transparently to the 358 caller. 359 360 In order to support these test, the base class should also extend 361 the TestCase class and implement the following class attributes 362 which define the encoded version of certain protocol buffers: 363 364 encoded_partial: 365 <OptionalMessage 366 double_value: 1.23 367 int64_value: -100000000000 368 string_value: u"a string" 369 enum_value: OptionalMessage.SimpleEnum.VAL2 370 > 371 372 encoded_full: 373 <OptionalMessage 374 double_value: 1.23 375 float_value: -2.5 376 int64_value: -100000000000 377 uint64_value: 102020202020 378 int32_value: 1020 379 bool_value: true 380 string_value: u"a string\u044f" 381 bytes_value: b"a bytes\xff\xfe" 382 enum_value: OptionalMessage.SimpleEnum.VAL2 383 > 384 385 encoded_repeated: 386 <RepeatedMessage 387 double_value: [1.23, 2.3] 388 float_value: [-2.5, 0.5] 389 int64_value: [-100000000000, 20] 390 uint64_value: [102020202020, 10] 391 int32_value: [1020, 718] 392 bool_value: [true, false] 393 string_value: [u"a string\u044f", u"another string"] 394 bytes_value: [b"a bytes\xff\xfe", b"another bytes"] 395 enum_value: [OptionalMessage.SimpleEnum.VAL2, 396 OptionalMessage.SimpleEnum.VAL 1] 397 > 398 399 encoded_nested: 400 <HasNestedMessage 401 nested: <NestedMessage 402 a_value: "a string" 403 > 404 > 405 406 encoded_repeated_nested: 407 <HasNestedMessage 408 repeated_nested: [ 409 <NestedMessage a_value: "a string">, 410 <NestedMessage a_value: "another string"> 411 ] 412 > 413 414 unexpected_tag_message: 415 An encoded message that has an undefined tag or number in the stream. 416 417 encoded_default_assigned: 418 <HasDefault 419 a_value: "a default" 420 > 421 422 encoded_nested_empty: 423 <HasOptionalNestedMessage 424 nested: <OptionalMessage> 425 > 426 427 encoded_invalid_enum: 428 <OptionalMessage 429 enum_value: (invalid value for serialization type) 430 > 431 432 encoded_invalid_repeated_enum: 433 <RepeatedMessage 434 enum_value: (invalid value for serialization type) 435 > 436 """ 437 438 encoded_empty_message = '' 439 440 def testEncodeInvalidMessage(self): 441 message = NestedMessage() 442 self.assertRaises(messages.ValidationError, 443 self.PROTOLIB.encode_message, message) 444 445 def CompareEncoded(self, expected_encoded, actual_encoded): 446 """Compare two encoded protocol values. 447 448 Can be overridden by sub-classes to special case comparison. 449 For example, to eliminate white space from output that is not 450 relevant to encoding. 451 452 Args: 453 expected_encoded: Expected string encoded value. 454 actual_encoded: Actual string encoded value. 455 """ 456 self.assertEquals(expected_encoded, actual_encoded) 457 458 def EncodeDecode(self, encoded, expected_message): 459 message = self.PROTOLIB.decode_message(type(expected_message), encoded) 460 self.assertEquals(expected_message, message) 461 self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message)) 462 463 def testEmptyMessage(self): 464 self.EncodeDecode(self.encoded_empty_message, OptionalMessage()) 465 466 def testPartial(self): 467 """Test message with a few values set.""" 468 message = OptionalMessage() 469 message.double_value = 1.23 470 message.int64_value = -100000000000 471 message.int32_value = 1020 472 message.string_value = u'a string' 473 message.enum_value = OptionalMessage.SimpleEnum.VAL2 474 475 self.EncodeDecode(self.encoded_partial, message) 476 477 def testFull(self): 478 """Test all types.""" 479 message = OptionalMessage() 480 message.double_value = 1.23 481 message.float_value = -2.5 482 message.int64_value = -100000000000 483 message.uint64_value = 102020202020 484 message.int32_value = 1020 485 message.bool_value = True 486 message.string_value = u'a string\u044f' 487 message.bytes_value = b'a bytes\xff\xfe' 488 message.enum_value = OptionalMessage.SimpleEnum.VAL2 489 490 self.EncodeDecode(self.encoded_full, message) 491 492 def testRepeated(self): 493 """Test repeated fields.""" 494 message = RepeatedMessage() 495 message.double_value = [1.23, 2.3] 496 message.float_value = [-2.5, 0.5] 497 message.int64_value = [-100000000000, 20] 498 message.uint64_value = [102020202020, 10] 499 message.int32_value = [1020, 718] 500 message.bool_value = [True, False] 501 message.string_value = [u'a string\u044f', u'another string'] 502 message.bytes_value = [b'a bytes\xff\xfe', b'another bytes'] 503 message.enum_value = [RepeatedMessage.SimpleEnum.VAL2, 504 RepeatedMessage.SimpleEnum.VAL1] 505 506 self.EncodeDecode(self.encoded_repeated, message) 507 508 def testNested(self): 509 """Test nested messages.""" 510 nested_message = NestedMessage() 511 nested_message.a_value = u'a string' 512 513 message = HasNestedMessage() 514 message.nested = nested_message 515 516 self.EncodeDecode(self.encoded_nested, message) 517 518 def testRepeatedNested(self): 519 """Test repeated nested messages.""" 520 nested_message1 = NestedMessage() 521 nested_message1.a_value = u'a string' 522 nested_message2 = NestedMessage() 523 nested_message2.a_value = u'another string' 524 525 message = HasNestedMessage() 526 message.repeated_nested = [nested_message1, nested_message2] 527 528 self.EncodeDecode(self.encoded_repeated_nested, message) 529 530 def testStringTypes(self): 531 """Test that encoding str on StringField works.""" 532 message = OptionalMessage() 533 message.string_value = 'Latin' 534 self.EncodeDecode(self.encoded_string_types, message) 535 536 def testEncodeUninitialized(self): 537 """Test that cannot encode uninitialized message.""" 538 required = NestedMessage() 539 self.assertRaisesWithRegexpMatch(messages.ValidationError, 540 "Message NestedMessage is missing " 541 "required field a_value", 542 self.PROTOLIB.encode_message, 543 required) 544 545 def testUnexpectedField(self): 546 """Test decoding and encoding unexpected fields.""" 547 loaded_message = self.PROTOLIB.decode_message( 548 OptionalMessage, self.unexpected_tag_message) 549 # Message should be equal to an empty message, since unknown 550 # values aren't included in equality. 551 self.assertEquals(OptionalMessage(), loaded_message) 552 # Verify that the encoded message matches the source, including the 553 # unknown value. 554 self.assertEquals(self.unexpected_tag_message, 555 self.PROTOLIB.encode_message(loaded_message)) 556 557 def testDoNotSendDefault(self): 558 """Test that default is not sent when nothing is assigned.""" 559 self.EncodeDecode(self.encoded_empty_message, HasDefault()) 560 561 def testSendDefaultExplicitlyAssigned(self): 562 """Test that default is sent when explcitly assigned.""" 563 message = HasDefault() 564 565 message.a_value = HasDefault.a_value.default 566 567 self.EncodeDecode(self.encoded_default_assigned, message) 568 569 def testEncodingNestedEmptyMessage(self): 570 """Test encoding a nested empty message.""" 571 message = HasOptionalNestedMessage() 572 message.nested = OptionalMessage() 573 574 self.EncodeDecode(self.encoded_nested_empty, message) 575 576 def testEncodingRepeatedNestedEmptyMessage(self): 577 """Test encoding a nested empty message.""" 578 message = HasOptionalNestedMessage() 579 message.repeated_nested = [OptionalMessage(), OptionalMessage()] 580 581 self.EncodeDecode(self.encoded_repeated_nested_empty, message) 582 583 def testContentType(self): 584 self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str)) 585 586 def testDecodeInvalidEnumType(self): 587 # Since protos need to be able to add new enums, a message should be 588 # successfully decoded even if the enum value is invalid. Encoding the 589 # decoded message should result in equivalence with the original 590 # encoded message containing an invalid enum. 591 decoded = self.PROTOLIB.decode_message(OptionalMessage, 592 self.encoded_invalid_enum) 593 message = OptionalMessage() 594 self.assertEqual(message, decoded) 595 encoded = self.PROTOLIB.encode_message(decoded) 596 self.assertEqual(self.encoded_invalid_enum, encoded) 597 598 def testDecodeInvalidRepeatedEnumType(self): 599 # Since protos need to be able to add new enums, a message should be 600 # successfully decoded even if the enum value is invalid. Encoding the 601 # decoded message should result in equivalence with the original 602 # encoded message containing an invalid enum. 603 decoded = self.PROTOLIB.decode_message(RepeatedMessage, 604 self.encoded_invalid_repeated_enum) 605 message = RepeatedMessage() 606 message.enum_value = [RepeatedMessage.SimpleEnum.VAL1] 607 self.assertEqual(message, decoded) 608 encoded = self.PROTOLIB.encode_message(decoded) 609 self.assertEqual(self.encoded_invalid_repeated_enum, encoded) 610 611 def testDateTimeNoTimeZone(self): 612 """Test that DateTimeFields are encoded/decoded correctly.""" 613 614 class MyMessage(messages.Message): 615 value = message_types.DateTimeField(1) 616 617 value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000) 618 message = MyMessage(value=value) 619 decoded = self.PROTOLIB.decode_message( 620 MyMessage, self.PROTOLIB.encode_message(message)) 621 self.assertEquals(decoded.value, value) 622 623 def testDateTimeWithTimeZone(self): 624 """Test DateTimeFields with time zones.""" 625 626 class MyMessage(messages.Message): 627 value = message_types.DateTimeField(1) 628 629 value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000, 630 util.TimeZoneOffset(8 * 60)) 631 message = MyMessage(value=value) 632 decoded = self.PROTOLIB.decode_message( 633 MyMessage, self.PROTOLIB.encode_message(message)) 634 self.assertEquals(decoded.value, value) 635 636 637def pick_unused_port(): 638 """Find an unused port to use in tests. 639 640 Derived from Damon Kohlers example: 641 642 http://code.activestate.com/recipes/531822-pick-unused-port 643 """ 644 temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 645 try: 646 temp.bind(('localhost', 0)) 647 port = temp.getsockname()[1] 648 finally: 649 temp.close() 650 return port 651 652 653def get_module_name(module_attribute): 654 """Get the module name. 655 656 Args: 657 module_attribute: An attribute of the module. 658 659 Returns: 660 The fully qualified module name or simple module name where 661 'module_attribute' is defined if the module name is "__main__". 662 """ 663 if module_attribute.__module__ == '__main__': 664 module_file = inspect.getfile(module_attribute) 665 default = os.path.basename(module_file).split('.')[0] 666 return default 667 return module_attribute.__module__ 668