• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Test utilities for message testing.
19
20Includes module interface test to ensure that public parts of module are
21correctly declared in __all__.
22
23Includes message types that correspond to those defined in
24services_test.proto.
25
26Includes additional test utilities to make sure encoding/decoding libraries
27conform.
28"""
29import cgi
30import datetime
31import inspect
32import os
33import re
34import socket
35import types
36import unittest
37
38import six
39from six.moves import range  # pylint: disable=redefined-builtin
40
41from apitools.base.protorpclite import message_types
42from apitools.base.protorpclite import messages
43from apitools.base.protorpclite import util
44
45# Unicode of the word "Russian" in cyrillic.
46RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439'
47
48# All characters binary value interspersed with nulls.
49BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256))
50
51
52class TestCase(unittest.TestCase):
53
54    def assertRaisesWithRegexpMatch(self,
55                                    exception,
56                                    regexp,
57                                    function,
58                                    *params,
59                                    **kwargs):
60        """Check that exception is raised and text matches regular expression.
61
62        Args:
63          exception: Exception type that is expected.
64          regexp: String regular expression that is expected in error message.
65          function: Callable to test.
66          params: Parameters to forward to function.
67          kwargs: Keyword arguments to forward to function.
68        """
69        try:
70            function(*params, **kwargs)
71            self.fail('Expected exception %s was not raised' %
72                      exception.__name__)
73        except exception as err:
74            match = bool(re.match(regexp, str(err)))
75            self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp,
76                                                                        err))
77
78    def assertHeaderSame(self, header1, header2):
79        """Check that two HTTP headers are the same.
80
81        Args:
82          header1: Header value string 1.
83          header2: header value string 2.
84        """
85        value1, params1 = cgi.parse_header(header1)
86        value2, params2 = cgi.parse_header(header2)
87        self.assertEqual(value1, value2)
88        self.assertEqual(params1, params2)
89
90    def assertIterEqual(self, iter1, iter2):
91        """Check two iterators or iterables are equal independent of order.
92
93        Similar to Python 2.7 assertItemsEqual.  Named differently in order to
94        avoid potential conflict.
95
96        Args:
97          iter1: An iterator or iterable.
98          iter2: An iterator or iterable.
99        """
100        list1 = list(iter1)
101        list2 = list(iter2)
102
103        unmatched1 = list()
104
105        while list1:
106            item1 = list1[0]
107            del list1[0]
108            for index in range(len(list2)):
109                if item1 == list2[index]:
110                    del list2[index]
111                    break
112            else:
113                unmatched1.append(item1)
114
115        error_message = []
116        for item in unmatched1:
117            error_message.append(
118                '  Item from iter1 not found in iter2: %r' % item)
119        for item in list2:
120            error_message.append(
121                '  Item from iter2 not found in iter1: %r' % item)
122        if error_message:
123            self.fail('Collections not equivalent:\n' +
124                      '\n'.join(error_message))
125
126
127class ModuleInterfaceTest(object):
128    """Test to ensure module interface is carefully constructed.
129
130    A module interface is the set of public objects listed in the
131    module __all__ attribute. Modules that that are considered public
132    should have this interface carefully declared. At all times, the
133    __all__ attribute should have objects intended to be publically
134    used and all other objects in the module should be considered
135    unused.
136
137    Protected attributes (those beginning with '_') and other imported
138    modules should not be part of this set of variables. An exception
139    is for variables that begin and end with '__' which are implicitly
140    part of the interface (eg. __name__, __file__, __all__ itself,
141    etc.).
142
143    Modules that are imported in to the tested modules are an
144    exception and may be left out of the __all__ definition. The test
145    is done by checking the value of what would otherwise be a public
146    name and not allowing it to be exported if it is an instance of a
147    module. Modules that are explicitly exported are for the time
148    being not permitted.
149
150    To use this test class a module should define a new class that
151    inherits first from ModuleInterfaceTest and then from
152    test_util.TestCase. No other tests should be added to this test
153    case, making the order of inheritance less important, but if setUp
154    for some reason is overidden, it is important that
155    ModuleInterfaceTest is first in the list so that its setUp method
156    is invoked.
157
158    Multiple inheritance is required so that ModuleInterfaceTest is
159    not itself a test, and is not itself executed as one.
160
161    The test class is expected to have the following class attributes
162    defined:
163
164      MODULE: A reference to the module that is being validated for interface
165        correctness.
166
167    Example:
168      Module definition (hello.py):
169
170        import sys
171
172        __all__ = ['hello']
173
174        def _get_outputter():
175          return sys.stdout
176
177        def hello():
178          _get_outputter().write('Hello\n')
179
180      Test definition:
181
182        import unittest
183        from protorpc import test_util
184
185        import hello
186
187        class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
188                                  test_util.TestCase):
189
190          MODULE = hello
191
192
193        class HelloTest(test_util.TestCase):
194          ... Test 'hello' module ...
195
196
197        if __name__ == '__main__':
198          unittest.main()
199
200    """
201
202    def setUp(self):
203        """Set up makes sure that MODULE and IMPORTED_MODULES is defined.
204
205        This is a basic configuration test for the test itself so does not
206        get it's own test case.
207        """
208        if not hasattr(self, 'MODULE'):
209            self.fail(
210                "You must define 'MODULE' on ModuleInterfaceTest sub-class "
211                "%s." % type(self).__name__)
212
213    def testAllExist(self):
214        """Test that all attributes defined in __all__ exist."""
215        missing_attributes = []
216        for attribute in self.MODULE.__all__:
217            if not hasattr(self.MODULE, attribute):
218                missing_attributes.append(attribute)
219        if missing_attributes:
220            self.fail('%s of __all__ are not defined in module.' %
221                      missing_attributes)
222
223    def testAllExported(self):
224        """Test that all public attributes not imported are in __all__."""
225        missing_attributes = []
226        for attribute in dir(self.MODULE):
227            if not attribute.startswith('_'):
228                if (attribute not in self.MODULE.__all__ and
229                        not isinstance(getattr(self.MODULE, attribute),
230                                       types.ModuleType) and
231                        attribute != 'with_statement'):
232                    missing_attributes.append(attribute)
233        if missing_attributes:
234            self.fail('%s are not modules and not defined in __all__.' %
235                      missing_attributes)
236
237    def testNoExportedProtectedVariables(self):
238        """Test that there are no protected variables listed in __all__."""
239        protected_variables = []
240        for attribute in self.MODULE.__all__:
241            if attribute.startswith('_'):
242                protected_variables.append(attribute)
243        if protected_variables:
244            self.fail('%s are protected variables and may not be exported.' %
245                      protected_variables)
246
247    def testNoExportedModules(self):
248        """Test that no modules exist in __all__."""
249        exported_modules = []
250        for attribute in self.MODULE.__all__:
251            try:
252                value = getattr(self.MODULE, attribute)
253            except AttributeError:
254                # This is a different error case tested for in testAllExist.
255                pass
256            else:
257                if isinstance(value, types.ModuleType):
258                    exported_modules.append(attribute)
259        if exported_modules:
260            self.fail('%s are modules and may not be exported.' %
261                      exported_modules)
262
263
264class NestedMessage(messages.Message):
265    """Simple message that gets nested in another message."""
266
267    a_value = messages.StringField(1, required=True)
268
269
270class HasNestedMessage(messages.Message):
271    """Message that has another message nested in it."""
272
273    nested = messages.MessageField(NestedMessage, 1)
274    repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True)
275
276
277class HasDefault(messages.Message):
278    """Has a default value."""
279
280    a_value = messages.StringField(1, default=u'a default')
281
282
283class OptionalMessage(messages.Message):
284    """Contains all message types."""
285
286    class SimpleEnum(messages.Enum):
287        """Simple enumeration type."""
288        VAL1 = 1
289        VAL2 = 2
290
291    double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE)
292    float_value = messages.FloatField(2, variant=messages.Variant.FLOAT)
293    int64_value = messages.IntegerField(3, variant=messages.Variant.INT64)
294    uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64)
295    int32_value = messages.IntegerField(5, variant=messages.Variant.INT32)
296    bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL)
297    string_value = messages.StringField(7, variant=messages.Variant.STRING)
298    bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES)
299    enum_value = messages.EnumField(SimpleEnum, 10)
300
301
302class RepeatedMessage(messages.Message):
303    """Contains all message types as repeated fields."""
304
305    class SimpleEnum(messages.Enum):
306        """Simple enumeration type."""
307        VAL1 = 1
308        VAL2 = 2
309
310    double_value = messages.FloatField(1,
311                                       variant=messages.Variant.DOUBLE,
312                                       repeated=True)
313    float_value = messages.FloatField(2,
314                                      variant=messages.Variant.FLOAT,
315                                      repeated=True)
316    int64_value = messages.IntegerField(3,
317                                        variant=messages.Variant.INT64,
318                                        repeated=True)
319    uint64_value = messages.IntegerField(4,
320                                         variant=messages.Variant.UINT64,
321                                         repeated=True)
322    int32_value = messages.IntegerField(5,
323                                        variant=messages.Variant.INT32,
324                                        repeated=True)
325    bool_value = messages.BooleanField(6,
326                                       variant=messages.Variant.BOOL,
327                                       repeated=True)
328    string_value = messages.StringField(7,
329                                        variant=messages.Variant.STRING,
330                                        repeated=True)
331    bytes_value = messages.BytesField(8,
332                                      variant=messages.Variant.BYTES,
333                                      repeated=True)
334    enum_value = messages.EnumField(SimpleEnum,
335                                    10,
336                                    repeated=True)
337
338
339class HasOptionalNestedMessage(messages.Message):
340
341    nested = messages.MessageField(OptionalMessage, 1)
342    repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True)
343
344
345# pylint:disable=anomalous-unicode-escape-in-string
346class ProtoConformanceTestBase(object):
347    """Protocol conformance test base class.
348
349    Each supported protocol should implement two methods that support encoding
350    and decoding of Message objects in that format:
351
352      encode_message(message) - Serialize to encoding.
353      encode_message(message, encoded_message) - Deserialize from encoding.
354
355    Tests for the modules where these functions are implemented should extend
356    this class in order to support basic behavioral expectations.  This ensures
357    that protocols correctly encode and decode message transparently to the
358    caller.
359
360    In order to support these test, the base class should also extend
361    the TestCase class and implement the following class attributes
362    which define the encoded version of certain protocol buffers:
363
364      encoded_partial:
365        <OptionalMessage
366          double_value: 1.23
367          int64_value: -100000000000
368          string_value: u"a string"
369          enum_value: OptionalMessage.SimpleEnum.VAL2
370          >
371
372      encoded_full:
373        <OptionalMessage
374          double_value: 1.23
375          float_value: -2.5
376          int64_value: -100000000000
377          uint64_value: 102020202020
378          int32_value: 1020
379          bool_value: true
380          string_value: u"a string\u044f"
381          bytes_value: b"a bytes\xff\xfe"
382          enum_value: OptionalMessage.SimpleEnum.VAL2
383          >
384
385      encoded_repeated:
386        <RepeatedMessage
387          double_value: [1.23, 2.3]
388          float_value: [-2.5, 0.5]
389          int64_value: [-100000000000, 20]
390          uint64_value: [102020202020, 10]
391          int32_value: [1020, 718]
392          bool_value: [true, false]
393          string_value: [u"a string\u044f", u"another string"]
394          bytes_value: [b"a bytes\xff\xfe", b"another bytes"]
395          enum_value: [OptionalMessage.SimpleEnum.VAL2,
396                       OptionalMessage.SimpleEnum.VAL 1]
397          >
398
399      encoded_nested:
400        <HasNestedMessage
401          nested: <NestedMessage
402            a_value: "a string"
403            >
404          >
405
406      encoded_repeated_nested:
407        <HasNestedMessage
408          repeated_nested: [
409              <NestedMessage a_value: "a string">,
410              <NestedMessage a_value: "another string">
411            ]
412          >
413
414      unexpected_tag_message:
415        An encoded message that has an undefined tag or number in the stream.
416
417      encoded_default_assigned:
418        <HasDefault
419          a_value: "a default"
420          >
421
422      encoded_nested_empty:
423        <HasOptionalNestedMessage
424          nested: <OptionalMessage>
425          >
426
427      encoded_invalid_enum:
428        <OptionalMessage
429          enum_value: (invalid value for serialization type)
430          >
431
432      encoded_invalid_repeated_enum:
433        <RepeatedMessage
434          enum_value: (invalid value for serialization type)
435          >
436    """
437
438    encoded_empty_message = ''
439
440    def testEncodeInvalidMessage(self):
441        message = NestedMessage()
442        self.assertRaises(messages.ValidationError,
443                          self.PROTOLIB.encode_message, message)
444
445    def CompareEncoded(self, expected_encoded, actual_encoded):
446        """Compare two encoded protocol values.
447
448        Can be overridden by sub-classes to special case comparison.
449        For example, to eliminate white space from output that is not
450        relevant to encoding.
451
452        Args:
453          expected_encoded: Expected string encoded value.
454          actual_encoded: Actual string encoded value.
455        """
456        self.assertEquals(expected_encoded, actual_encoded)
457
458    def EncodeDecode(self, encoded, expected_message):
459        message = self.PROTOLIB.decode_message(type(expected_message), encoded)
460        self.assertEquals(expected_message, message)
461        self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message))
462
463    def testEmptyMessage(self):
464        self.EncodeDecode(self.encoded_empty_message, OptionalMessage())
465
466    def testPartial(self):
467        """Test message with a few values set."""
468        message = OptionalMessage()
469        message.double_value = 1.23
470        message.int64_value = -100000000000
471        message.int32_value = 1020
472        message.string_value = u'a string'
473        message.enum_value = OptionalMessage.SimpleEnum.VAL2
474
475        self.EncodeDecode(self.encoded_partial, message)
476
477    def testFull(self):
478        """Test all types."""
479        message = OptionalMessage()
480        message.double_value = 1.23
481        message.float_value = -2.5
482        message.int64_value = -100000000000
483        message.uint64_value = 102020202020
484        message.int32_value = 1020
485        message.bool_value = True
486        message.string_value = u'a string\u044f'
487        message.bytes_value = b'a bytes\xff\xfe'
488        message.enum_value = OptionalMessage.SimpleEnum.VAL2
489
490        self.EncodeDecode(self.encoded_full, message)
491
492    def testRepeated(self):
493        """Test repeated fields."""
494        message = RepeatedMessage()
495        message.double_value = [1.23, 2.3]
496        message.float_value = [-2.5, 0.5]
497        message.int64_value = [-100000000000, 20]
498        message.uint64_value = [102020202020, 10]
499        message.int32_value = [1020, 718]
500        message.bool_value = [True, False]
501        message.string_value = [u'a string\u044f', u'another string']
502        message.bytes_value = [b'a bytes\xff\xfe', b'another bytes']
503        message.enum_value = [RepeatedMessage.SimpleEnum.VAL2,
504                              RepeatedMessage.SimpleEnum.VAL1]
505
506        self.EncodeDecode(self.encoded_repeated, message)
507
508    def testNested(self):
509        """Test nested messages."""
510        nested_message = NestedMessage()
511        nested_message.a_value = u'a string'
512
513        message = HasNestedMessage()
514        message.nested = nested_message
515
516        self.EncodeDecode(self.encoded_nested, message)
517
518    def testRepeatedNested(self):
519        """Test repeated nested messages."""
520        nested_message1 = NestedMessage()
521        nested_message1.a_value = u'a string'
522        nested_message2 = NestedMessage()
523        nested_message2.a_value = u'another string'
524
525        message = HasNestedMessage()
526        message.repeated_nested = [nested_message1, nested_message2]
527
528        self.EncodeDecode(self.encoded_repeated_nested, message)
529
530    def testStringTypes(self):
531        """Test that encoding str on StringField works."""
532        message = OptionalMessage()
533        message.string_value = 'Latin'
534        self.EncodeDecode(self.encoded_string_types, message)
535
536    def testEncodeUninitialized(self):
537        """Test that cannot encode uninitialized message."""
538        required = NestedMessage()
539        self.assertRaisesWithRegexpMatch(messages.ValidationError,
540                                         "Message NestedMessage is missing "
541                                         "required field a_value",
542                                         self.PROTOLIB.encode_message,
543                                         required)
544
545    def testUnexpectedField(self):
546        """Test decoding and encoding unexpected fields."""
547        loaded_message = self.PROTOLIB.decode_message(
548            OptionalMessage, self.unexpected_tag_message)
549        # Message should be equal to an empty message, since unknown
550        # values aren't included in equality.
551        self.assertEquals(OptionalMessage(), loaded_message)
552        # Verify that the encoded message matches the source, including the
553        # unknown value.
554        self.assertEquals(self.unexpected_tag_message,
555                          self.PROTOLIB.encode_message(loaded_message))
556
557    def testDoNotSendDefault(self):
558        """Test that default is not sent when nothing is assigned."""
559        self.EncodeDecode(self.encoded_empty_message, HasDefault())
560
561    def testSendDefaultExplicitlyAssigned(self):
562        """Test that default is sent when explcitly assigned."""
563        message = HasDefault()
564
565        message.a_value = HasDefault.a_value.default
566
567        self.EncodeDecode(self.encoded_default_assigned, message)
568
569    def testEncodingNestedEmptyMessage(self):
570        """Test encoding a nested empty message."""
571        message = HasOptionalNestedMessage()
572        message.nested = OptionalMessage()
573
574        self.EncodeDecode(self.encoded_nested_empty, message)
575
576    def testEncodingRepeatedNestedEmptyMessage(self):
577        """Test encoding a nested empty message."""
578        message = HasOptionalNestedMessage()
579        message.repeated_nested = [OptionalMessage(), OptionalMessage()]
580
581        self.EncodeDecode(self.encoded_repeated_nested_empty, message)
582
583    def testContentType(self):
584        self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str))
585
586    def testDecodeInvalidEnumType(self):
587        # Since protos need to be able to add new enums, a message should be
588        # successfully decoded even if the enum value is invalid. Encoding the
589        # decoded message should result in equivalence with the original
590        # encoded message containing an invalid enum.
591        decoded = self.PROTOLIB.decode_message(OptionalMessage,
592                                               self.encoded_invalid_enum)
593        message = OptionalMessage()
594        self.assertEqual(message, decoded)
595        encoded = self.PROTOLIB.encode_message(decoded)
596        self.assertEqual(self.encoded_invalid_enum, encoded)
597
598    def testDecodeInvalidRepeatedEnumType(self):
599        # Since protos need to be able to add new enums, a message should be
600        # successfully decoded even if the enum value is invalid. Encoding the
601        # decoded message should result in equivalence with the original
602        # encoded message containing an invalid enum.
603        decoded = self.PROTOLIB.decode_message(RepeatedMessage,
604                                               self.encoded_invalid_repeated_enum)
605        message = RepeatedMessage()
606        message.enum_value = [RepeatedMessage.SimpleEnum.VAL1]
607        self.assertEqual(message, decoded)
608        encoded = self.PROTOLIB.encode_message(decoded)
609        self.assertEqual(self.encoded_invalid_repeated_enum, encoded)
610
611    def testDateTimeNoTimeZone(self):
612        """Test that DateTimeFields are encoded/decoded correctly."""
613
614        class MyMessage(messages.Message):
615            value = message_types.DateTimeField(1)
616
617        value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000)
618        message = MyMessage(value=value)
619        decoded = self.PROTOLIB.decode_message(
620            MyMessage, self.PROTOLIB.encode_message(message))
621        self.assertEquals(decoded.value, value)
622
623    def testDateTimeWithTimeZone(self):
624        """Test DateTimeFields with time zones."""
625
626        class MyMessage(messages.Message):
627            value = message_types.DateTimeField(1)
628
629        value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000,
630                                  util.TimeZoneOffset(8 * 60))
631        message = MyMessage(value=value)
632        decoded = self.PROTOLIB.decode_message(
633            MyMessage, self.PROTOLIB.encode_message(message))
634        self.assertEquals(decoded.value, value)
635
636
637def pick_unused_port():
638    """Find an unused port to use in tests.
639
640      Derived from Damon Kohlers example:
641
642        http://code.activestate.com/recipes/531822-pick-unused-port
643    """
644    temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
645    try:
646        temp.bind(('localhost', 0))
647        port = temp.getsockname()[1]
648    finally:
649        temp.close()
650    return port
651
652
653def get_module_name(module_attribute):
654    """Get the module name.
655
656    Args:
657      module_attribute: An attribute of the module.
658
659    Returns:
660      The fully qualified module name or simple module name where
661      'module_attribute' is defined if the module name is "__main__".
662    """
663    if module_attribute.__module__ == '__main__':
664        module_file = inspect.getfile(module_attribute)
665        default = os.path.basename(module_file).split('.')[0]
666        return default
667    return module_attribute.__module__
668