• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Test utilities for message testing.
19
20Includes module interface test to ensure that public parts of module are
21correctly declared in __all__.
22
23Includes message types that correspond to those defined in
24services_test.proto.
25
26Includes additional test utilities to make sure encoding/decoding libraries
27conform.
28"""
29import cgi
30import datetime
31import inspect
32import os
33import re
34import socket
35import types
36import unittest
37
38import six
39from six.moves import range  # pylint: disable=redefined-builtin
40
41from apitools.base.protorpclite import message_types
42from apitools.base.protorpclite import messages
43from apitools.base.protorpclite import util
44
45# Unicode of the word "Russian" in cyrillic.
46RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439'
47
48# All characters binary value interspersed with nulls.
49BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256))
50
51
52class TestCase(unittest.TestCase):
53
54    def assertRaisesWithRegexpMatch(self,
55                                    exception,
56                                    regexp,
57                                    function,
58                                    *params,
59                                    **kwargs):
60        """Check that exception is raised and text matches regular expression.
61
62        Args:
63          exception: Exception type that is expected.
64          regexp: String regular expression that is expected in error message.
65          function: Callable to test.
66          params: Parameters to forward to function.
67          kwargs: Keyword arguments to forward to function.
68        """
69        try:
70            function(*params, **kwargs)
71            self.fail('Expected exception %s was not raised' %
72                      exception.__name__)
73        except exception as err:
74            match = bool(re.match(regexp, str(err)))
75            self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp,
76                                                                        err))
77
78    def assertHeaderSame(self, header1, header2):
79        """Check that two HTTP headers are the same.
80
81        Args:
82          header1: Header value string 1.
83          header2: header value string 2.
84        """
85        value1, params1 = cgi.parse_header(header1)
86        value2, params2 = cgi.parse_header(header2)
87        self.assertEqual(value1, value2)
88        self.assertEqual(params1, params2)
89
90    def assertIterEqual(self, iter1, iter2):
91        """Check two iterators or iterables are equal independent of order.
92
93        Similar to Python 2.7 assertItemsEqual.  Named differently in order to
94        avoid potential conflict.
95
96        Args:
97          iter1: An iterator or iterable.
98          iter2: An iterator or iterable.
99        """
100        list1 = list(iter1)
101        list2 = list(iter2)
102
103        unmatched1 = list()
104
105        while list1:
106            item1 = list1[0]
107            del list1[0]
108            for index in range(len(list2)):
109                if item1 == list2[index]:
110                    del list2[index]
111                    break
112            else:
113                unmatched1.append(item1)
114
115        error_message = []
116        for item in unmatched1:
117            error_message.append(
118                '  Item from iter1 not found in iter2: %r' % item)
119        for item in list2:
120            error_message.append(
121                '  Item from iter2 not found in iter1: %r' % item)
122        if error_message:
123            self.fail('Collections not equivalent:\n' +
124                      '\n'.join(error_message))
125
126
127class ModuleInterfaceTest(object):
128    """Test to ensure module interface is carefully constructed.
129
130    A module interface is the set of public objects listed in the
131    module __all__ attribute. Modules that that are considered public
132    should have this interface carefully declared. At all times, the
133    __all__ attribute should have objects intended to be publically
134    used and all other objects in the module should be considered
135    unused.
136
137    Protected attributes (those beginning with '_') and other imported
138    modules should not be part of this set of variables. An exception
139    is for variables that begin and end with '__' which are implicitly
140    part of the interface (eg. __name__, __file__, __all__ itself,
141    etc.).
142
143    Modules that are imported in to the tested modules are an
144    exception and may be left out of the __all__ definition. The test
145    is done by checking the value of what would otherwise be a public
146    name and not allowing it to be exported if it is an instance of a
147    module. Modules that are explicitly exported are for the time
148    being not permitted.
149
150    To use this test class a module should define a new class that
151    inherits first from ModuleInterfaceTest and then from
152    test_util.TestCase. No other tests should be added to this test
153    case, making the order of inheritance less important, but if setUp
154    for some reason is overidden, it is important that
155    ModuleInterfaceTest is first in the list so that its setUp method
156    is invoked.
157
158    Multiple inheritance is required so that ModuleInterfaceTest is
159    not itself a test, and is not itself executed as one.
160
161    The test class is expected to have the following class attributes
162    defined:
163
164      MODULE: A reference to the module that is being validated for interface
165        correctness.
166
167    Example:
168      Module definition (hello.py):
169
170        import sys
171
172        __all__ = ['hello']
173
174        def _get_outputter():
175          return sys.stdout
176
177        def hello():
178          _get_outputter().write('Hello\n')
179
180      Test definition:
181
182        import unittest
183        from protorpc import test_util
184
185        import hello
186
187        class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
188                                  test_util.TestCase):
189
190          MODULE = hello
191
192
193        class HelloTest(test_util.TestCase):
194          ... Test 'hello' module ...
195
196
197        if __name__ == '__main__':
198          unittest.main()
199
200    """
201
202    def setUp(self):
203        """Set up makes sure that MODULE and IMPORTED_MODULES is defined.
204
205        This is a basic configuration test for the test itself so does not
206        get it's own test case.
207        """
208        if not hasattr(self, 'MODULE'):
209            self.fail(
210                "You must define 'MODULE' on ModuleInterfaceTest sub-class "
211                "%s." % type(self).__name__)
212
213    def testAllExist(self):
214        """Test that all attributes defined in __all__ exist."""
215        missing_attributes = []
216        for attribute in self.MODULE.__all__:
217            if not hasattr(self.MODULE, attribute):
218                missing_attributes.append(attribute)
219        if missing_attributes:
220            self.fail('%s of __all__ are not defined in module.' %
221                      missing_attributes)
222
223    def testAllExported(self):
224        """Test that all public attributes not imported are in __all__."""
225        missing_attributes = []
226        for attribute in dir(self.MODULE):
227            if not attribute.startswith('_'):
228                if (attribute not in self.MODULE.__all__ and
229                        not isinstance(getattr(self.MODULE, attribute),
230                                       types.ModuleType) and
231                        attribute != 'with_statement'):
232                    missing_attributes.append(attribute)
233        if missing_attributes:
234            self.fail('%s are not modules and not defined in __all__.' %
235                      missing_attributes)
236
237    def testNoExportedProtectedVariables(self):
238        """Test that there are no protected variables listed in __all__."""
239        protected_variables = []
240        for attribute in self.MODULE.__all__:
241            if attribute.startswith('_'):
242                protected_variables.append(attribute)
243        if protected_variables:
244            self.fail('%s are protected variables and may not be exported.' %
245                      protected_variables)
246
247    def testNoExportedModules(self):
248        """Test that no modules exist in __all__."""
249        exported_modules = []
250        for attribute in self.MODULE.__all__:
251            try:
252                value = getattr(self.MODULE, attribute)
253            except AttributeError:
254                # This is a different error case tested for in testAllExist.
255                pass
256            else:
257                if isinstance(value, types.ModuleType):
258                    exported_modules.append(attribute)
259        if exported_modules:
260            self.fail('%s are modules and may not be exported.' %
261                      exported_modules)
262
263
264class NestedMessage(messages.Message):
265    """Simple message that gets nested in another message."""
266
267    a_value = messages.StringField(1, required=True)
268
269
270class HasNestedMessage(messages.Message):
271    """Message that has another message nested in it."""
272
273    nested = messages.MessageField(NestedMessage, 1)
274    repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True)
275
276
277class HasDefault(messages.Message):
278    """Has a default value."""
279
280    a_value = messages.StringField(1, default=u'a default')
281
282
283class OptionalMessage(messages.Message):
284    """Contains all message types."""
285
286    class SimpleEnum(messages.Enum):
287        """Simple enumeration type."""
288        VAL1 = 1
289        VAL2 = 2
290
291    double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE)
292    float_value = messages.FloatField(2, variant=messages.Variant.FLOAT)
293    int64_value = messages.IntegerField(3, variant=messages.Variant.INT64)
294    uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64)
295    int32_value = messages.IntegerField(5, variant=messages.Variant.INT32)
296    bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL)
297    string_value = messages.StringField(7, variant=messages.Variant.STRING)
298    bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES)
299    enum_value = messages.EnumField(SimpleEnum, 10)
300
301
302class RepeatedMessage(messages.Message):
303    """Contains all message types as repeated fields."""
304
305    class SimpleEnum(messages.Enum):
306        """Simple enumeration type."""
307        VAL1 = 1
308        VAL2 = 2
309
310    double_value = messages.FloatField(1,
311                                       variant=messages.Variant.DOUBLE,
312                                       repeated=True)
313    float_value = messages.FloatField(2,
314                                      variant=messages.Variant.FLOAT,
315                                      repeated=True)
316    int64_value = messages.IntegerField(3,
317                                        variant=messages.Variant.INT64,
318                                        repeated=True)
319    uint64_value = messages.IntegerField(4,
320                                         variant=messages.Variant.UINT64,
321                                         repeated=True)
322    int32_value = messages.IntegerField(5,
323                                        variant=messages.Variant.INT32,
324                                        repeated=True)
325    bool_value = messages.BooleanField(6,
326                                       variant=messages.Variant.BOOL,
327                                       repeated=True)
328    string_value = messages.StringField(7,
329                                        variant=messages.Variant.STRING,
330                                        repeated=True)
331    bytes_value = messages.BytesField(8,
332                                      variant=messages.Variant.BYTES,
333                                      repeated=True)
334    enum_value = messages.EnumField(SimpleEnum,
335                                    10,
336                                    repeated=True)
337
338
339class HasOptionalNestedMessage(messages.Message):
340
341    nested = messages.MessageField(OptionalMessage, 1)
342    repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True)
343
344
345# pylint:disable=anomalous-unicode-escape-in-string
346class ProtoConformanceTestBase(object):
347    """Protocol conformance test base class.
348
349    Each supported protocol should implement two methods that support encoding
350    and decoding of Message objects in that format:
351
352      encode_message(message) - Serialize to encoding.
353      encode_message(message, encoded_message) - Deserialize from encoding.
354
355    Tests for the modules where these functions are implemented should extend
356    this class in order to support basic behavioral expectations.  This ensures
357    that protocols correctly encode and decode message transparently to the
358    caller.
359
360    In order to support these test, the base class should also extend
361    the TestCase class and implement the following class attributes
362    which define the encoded version of certain protocol buffers:
363
364      encoded_partial:
365        <OptionalMessage
366          double_value: 1.23
367          int64_value: -100000000000
368          string_value: u"a string"
369          enum_value: OptionalMessage.SimpleEnum.VAL2
370          >
371
372      encoded_full:
373        <OptionalMessage
374          double_value: 1.23
375          float_value: -2.5
376          int64_value: -100000000000
377          uint64_value: 102020202020
378          int32_value: 1020
379          bool_value: true
380          string_value: u"a string\u044f"
381          bytes_value: b"a bytes\xff\xfe"
382          enum_value: OptionalMessage.SimpleEnum.VAL2
383          >
384
385      encoded_repeated:
386        <RepeatedMessage
387          double_value: [1.23, 2.3]
388          float_value: [-2.5, 0.5]
389          int64_value: [-100000000000, 20]
390          uint64_value: [102020202020, 10]
391          int32_value: [1020, 718]
392          bool_value: [true, false]
393          string_value: [u"a string\u044f", u"another string"]
394          bytes_value: [b"a bytes\xff\xfe", b"another bytes"]
395          enum_value: [OptionalMessage.SimpleEnum.VAL2,
396                       OptionalMessage.SimpleEnum.VAL 1]
397          >
398
399      encoded_nested:
400        <HasNestedMessage
401          nested: <NestedMessage
402            a_value: "a string"
403            >
404          >
405
406      encoded_repeated_nested:
407        <HasNestedMessage
408          repeated_nested: [
409              <NestedMessage a_value: "a string">,
410              <NestedMessage a_value: "another string">
411            ]
412          >
413
414      unexpected_tag_message:
415        An encoded message that has an undefined tag or number in the stream.
416
417      encoded_default_assigned:
418        <HasDefault
419          a_value: "a default"
420          >
421
422      encoded_nested_empty:
423        <HasOptionalNestedMessage
424          nested: <OptionalMessage>
425          >
426
427      encoded_invalid_enum:
428        <OptionalMessage
429          enum_value: (invalid value for serialization type)
430          >
431    """
432
433    encoded_empty_message = ''
434
435    def testEncodeInvalidMessage(self):
436        message = NestedMessage()
437        self.assertRaises(messages.ValidationError,
438                          self.PROTOLIB.encode_message, message)
439
440    def CompareEncoded(self, expected_encoded, actual_encoded):
441        """Compare two encoded protocol values.
442
443        Can be overridden by sub-classes to special case comparison.
444        For example, to eliminate white space from output that is not
445        relevant to encoding.
446
447        Args:
448          expected_encoded: Expected string encoded value.
449          actual_encoded: Actual string encoded value.
450        """
451        self.assertEquals(expected_encoded, actual_encoded)
452
453    def EncodeDecode(self, encoded, expected_message):
454        message = self.PROTOLIB.decode_message(type(expected_message), encoded)
455        self.assertEquals(expected_message, message)
456        self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message))
457
458    def testEmptyMessage(self):
459        self.EncodeDecode(self.encoded_empty_message, OptionalMessage())
460
461    def testPartial(self):
462        """Test message with a few values set."""
463        message = OptionalMessage()
464        message.double_value = 1.23
465        message.int64_value = -100000000000
466        message.int32_value = 1020
467        message.string_value = u'a string'
468        message.enum_value = OptionalMessage.SimpleEnum.VAL2
469
470        self.EncodeDecode(self.encoded_partial, message)
471
472    def testFull(self):
473        """Test all types."""
474        message = OptionalMessage()
475        message.double_value = 1.23
476        message.float_value = -2.5
477        message.int64_value = -100000000000
478        message.uint64_value = 102020202020
479        message.int32_value = 1020
480        message.bool_value = True
481        message.string_value = u'a string\u044f'
482        message.bytes_value = b'a bytes\xff\xfe'
483        message.enum_value = OptionalMessage.SimpleEnum.VAL2
484
485        self.EncodeDecode(self.encoded_full, message)
486
487    def testRepeated(self):
488        """Test repeated fields."""
489        message = RepeatedMessage()
490        message.double_value = [1.23, 2.3]
491        message.float_value = [-2.5, 0.5]
492        message.int64_value = [-100000000000, 20]
493        message.uint64_value = [102020202020, 10]
494        message.int32_value = [1020, 718]
495        message.bool_value = [True, False]
496        message.string_value = [u'a string\u044f', u'another string']
497        message.bytes_value = [b'a bytes\xff\xfe', b'another bytes']
498        message.enum_value = [RepeatedMessage.SimpleEnum.VAL2,
499                              RepeatedMessage.SimpleEnum.VAL1]
500
501        self.EncodeDecode(self.encoded_repeated, message)
502
503    def testNested(self):
504        """Test nested messages."""
505        nested_message = NestedMessage()
506        nested_message.a_value = u'a string'
507
508        message = HasNestedMessage()
509        message.nested = nested_message
510
511        self.EncodeDecode(self.encoded_nested, message)
512
513    def testRepeatedNested(self):
514        """Test repeated nested messages."""
515        nested_message1 = NestedMessage()
516        nested_message1.a_value = u'a string'
517        nested_message2 = NestedMessage()
518        nested_message2.a_value = u'another string'
519
520        message = HasNestedMessage()
521        message.repeated_nested = [nested_message1, nested_message2]
522
523        self.EncodeDecode(self.encoded_repeated_nested, message)
524
525    def testStringTypes(self):
526        """Test that encoding str on StringField works."""
527        message = OptionalMessage()
528        message.string_value = 'Latin'
529        self.EncodeDecode(self.encoded_string_types, message)
530
531    def testEncodeUninitialized(self):
532        """Test that cannot encode uninitialized message."""
533        required = NestedMessage()
534        self.assertRaisesWithRegexpMatch(messages.ValidationError,
535                                         "Message NestedMessage is missing "
536                                         "required field a_value",
537                                         self.PROTOLIB.encode_message,
538                                         required)
539
540    def testUnexpectedField(self):
541        """Test decoding and encoding unexpected fields."""
542        loaded_message = self.PROTOLIB.decode_message(
543            OptionalMessage, self.unexpected_tag_message)
544        # Message should be equal to an empty message, since unknown
545        # values aren't included in equality.
546        self.assertEquals(OptionalMessage(), loaded_message)
547        # Verify that the encoded message matches the source, including the
548        # unknown value.
549        self.assertEquals(self.unexpected_tag_message,
550                          self.PROTOLIB.encode_message(loaded_message))
551
552    def testDoNotSendDefault(self):
553        """Test that default is not sent when nothing is assigned."""
554        self.EncodeDecode(self.encoded_empty_message, HasDefault())
555
556    def testSendDefaultExplicitlyAssigned(self):
557        """Test that default is sent when explcitly assigned."""
558        message = HasDefault()
559
560        message.a_value = HasDefault.a_value.default
561
562        self.EncodeDecode(self.encoded_default_assigned, message)
563
564    def testEncodingNestedEmptyMessage(self):
565        """Test encoding a nested empty message."""
566        message = HasOptionalNestedMessage()
567        message.nested = OptionalMessage()
568
569        self.EncodeDecode(self.encoded_nested_empty, message)
570
571    def testEncodingRepeatedNestedEmptyMessage(self):
572        """Test encoding a nested empty message."""
573        message = HasOptionalNestedMessage()
574        message.repeated_nested = [OptionalMessage(), OptionalMessage()]
575
576        self.EncodeDecode(self.encoded_repeated_nested_empty, message)
577
578    def testContentType(self):
579        self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str))
580
581    def testDecodeInvalidEnumType(self):
582        # Since protos need to be able to add new enums, a message should be
583        # successfully decoded even if the enum value is invalid. Encoding the
584        # decoded message should result in equivalence with the original
585        # encoded message containing an invalid enum.
586        decoded = self.PROTOLIB.decode_message(OptionalMessage,
587                                               self.encoded_invalid_enum)
588        message = OptionalMessage()
589        self.assertEqual(message, decoded)
590        encoded = self.PROTOLIB.encode_message(decoded)
591        self.assertEqual(self.encoded_invalid_enum, encoded)
592
593    def testDateTimeNoTimeZone(self):
594        """Test that DateTimeFields are encoded/decoded correctly."""
595
596        class MyMessage(messages.Message):
597            value = message_types.DateTimeField(1)
598
599        value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000)
600        message = MyMessage(value=value)
601        decoded = self.PROTOLIB.decode_message(
602            MyMessage, self.PROTOLIB.encode_message(message))
603        self.assertEquals(decoded.value, value)
604
605    def testDateTimeWithTimeZone(self):
606        """Test DateTimeFields with time zones."""
607
608        class MyMessage(messages.Message):
609            value = message_types.DateTimeField(1)
610
611        value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000,
612                                  util.TimeZoneOffset(8 * 60))
613        message = MyMessage(value=value)
614        decoded = self.PROTOLIB.decode_message(
615            MyMessage, self.PROTOLIB.encode_message(message))
616        self.assertEquals(decoded.value, value)
617
618
619def pick_unused_port():
620    """Find an unused port to use in tests.
621
622      Derived from Damon Kohlers example:
623
624        http://code.activestate.com/recipes/531822-pick-unused-port
625    """
626    temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
627    try:
628        temp.bind(('localhost', 0))
629        port = temp.getsockname()[1]
630    finally:
631        temp.close()
632    return port
633
634
635def get_module_name(module_attribute):
636    """Get the module name.
637
638    Args:
639      module_attribute: An attribute of the module.
640
641    Returns:
642      The fully qualified module name or simple module name where
643      'module_attribute' is defined if the module name is "__main__".
644    """
645    if module_attribute.__module__ == '__main__':
646        module_file = inspect.getfile(module_attribute)
647        default = os.path.basename(module_file).split('.')[0]
648        return default
649    return module_attribute.__module__
650