• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for apitools.base.protorpclite.messages."""
19import pickle
20import re
21import sys
22import types
23import unittest
24
25import six
26
27from apitools.base.protorpclite import descriptor
28from apitools.base.protorpclite import message_types
29from apitools.base.protorpclite import messages
30from apitools.base.protorpclite import test_util
31
32# This package plays lots of games with modifying global variables inside
33# test cases. Hence:
34# pylint:disable=function-redefined
35# pylint:disable=global-variable-not-assigned
36# pylint:disable=global-variable-undefined
37# pylint:disable=redefined-outer-name
38# pylint:disable=undefined-variable
39# pylint:disable=unused-variable
40# pylint:disable=too-many-lines
41
42
43class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
44                          test_util.TestCase):
45
46    MODULE = messages
47
48
49class ValidationErrorTest(test_util.TestCase):
50
51    def testStr_NoFieldName(self):
52        """Test string version of ValidationError when no name provided."""
53        self.assertEquals('Validation error',
54                          str(messages.ValidationError('Validation error')))
55
56    def testStr_FieldName(self):
57        """Test string version of ValidationError when no name provided."""
58        validation_error = messages.ValidationError('Validation error')
59        validation_error.field_name = 'a_field'
60        self.assertEquals('Validation error', str(validation_error))
61
62
63class EnumTest(test_util.TestCase):
64
65    def setUp(self):
66        """Set up tests."""
67        # Redefine Color class in case so that changes to it (an
68        # error) in one test does not affect other tests.
69        global Color  # pylint:disable=global-variable-not-assigned
70
71        # pylint:disable=unused-variable
72        class Color(messages.Enum):
73            RED = 20
74            ORANGE = 2
75            YELLOW = 40
76            GREEN = 4
77            BLUE = 50
78            INDIGO = 5
79            VIOLET = 80
80
81    def testNames(self):
82        """Test that names iterates over enum names."""
83        self.assertEquals(
84            set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED',
85                 'VIOLET', 'YELLOW']),
86            set(Color.names()))
87
88    def testNumbers(self):
89        """Tests that numbers iterates of enum numbers."""
90        self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers()))
91
92    def testIterate(self):
93        """Test that __iter__ iterates over all enum values."""
94        self.assertEquals(set(Color),
95                          set([Color.RED,
96                               Color.ORANGE,
97                               Color.YELLOW,
98                               Color.GREEN,
99                               Color.BLUE,
100                               Color.INDIGO,
101                               Color.VIOLET]))
102
103    def testNaturalOrder(self):
104        """Test that natural order enumeration is in numeric order."""
105        self.assertEquals([Color.ORANGE,
106                           Color.GREEN,
107                           Color.INDIGO,
108                           Color.RED,
109                           Color.YELLOW,
110                           Color.BLUE,
111                           Color.VIOLET],
112                          sorted(Color))
113
114    def testByName(self):
115        """Test look-up by name."""
116        self.assertEquals(Color.RED, Color.lookup_by_name('RED'))
117        self.assertRaises(KeyError, Color.lookup_by_name, 20)
118        self.assertRaises(KeyError, Color.lookup_by_name, Color.RED)
119
120    def testByNumber(self):
121        """Test look-up by number."""
122        self.assertRaises(KeyError, Color.lookup_by_number, 'RED')
123        self.assertEquals(Color.RED, Color.lookup_by_number(20))
124        self.assertRaises(KeyError, Color.lookup_by_number, Color.RED)
125
126    def testConstructor(self):
127        """Test that constructor look-up by name or number."""
128        self.assertEquals(Color.RED, Color('RED'))
129        self.assertEquals(Color.RED, Color(u'RED'))
130        self.assertEquals(Color.RED, Color(20))
131        if six.PY2:
132            self.assertEquals(Color.RED, Color(long(20)))
133        self.assertEquals(Color.RED, Color(Color.RED))
134        self.assertRaises(TypeError, Color, 'Not exists')
135        self.assertRaises(TypeError, Color, 'Red')
136        self.assertRaises(TypeError, Color, 100)
137        self.assertRaises(TypeError, Color, 10.0)
138
139    def testLen(self):
140        """Test that len function works to count enums."""
141        self.assertEquals(7, len(Color))
142
143    def testNoSubclasses(self):
144        """Test that it is not possible to sub-class enum classes."""
145        def declare_subclass():
146            class MoreColor(Color):
147                pass
148        self.assertRaises(messages.EnumDefinitionError,
149                          declare_subclass)
150
151    def testClassNotMutable(self):
152        """Test that enum classes themselves are not mutable."""
153        self.assertRaises(AttributeError,
154                          setattr,
155                          Color,
156                          'something_new',
157                          10)
158
159    def testInstancesMutable(self):
160        """Test that enum instances are not mutable."""
161        self.assertRaises(TypeError,
162                          setattr,
163                          Color.RED,
164                          'something_new',
165                          10)
166
167    def testDefEnum(self):
168        """Test def_enum works by building enum class from dict."""
169        WeekDay = messages.Enum.def_enum({'Monday': 1,
170                                          'Tuesday': 2,
171                                          'Wednesday': 3,
172                                          'Thursday': 4,
173                                          'Friday': 6,
174                                          'Saturday': 7,
175                                          'Sunday': 8},
176                                         'WeekDay')
177        self.assertEquals('Wednesday', WeekDay(3).name)
178        self.assertEquals(6, WeekDay('Friday').number)
179        self.assertEquals(WeekDay.Sunday, WeekDay('Sunday'))
180
181    def testNonInt(self):
182        """Test that non-integer values rejection by enum def."""
183        self.assertRaises(messages.EnumDefinitionError,
184                          messages.Enum.def_enum,
185                          {'Bad': '1'},
186                          'BadEnum')
187
188    def testNegativeInt(self):
189        """Test that negative numbers rejection by enum def."""
190        self.assertRaises(messages.EnumDefinitionError,
191                          messages.Enum.def_enum,
192                          {'Bad': -1},
193                          'BadEnum')
194
195    def testLowerBound(self):
196        """Test that zero is accepted by enum def."""
197        class NotImportant(messages.Enum):
198            """Testing for value zero"""
199            VALUE = 0
200
201        self.assertEquals(0, int(NotImportant.VALUE))
202
203    def testTooLargeInt(self):
204        """Test that numbers too large are rejected."""
205        self.assertRaises(messages.EnumDefinitionError,
206                          messages.Enum.def_enum,
207                          {'Bad': (2 ** 29)},
208                          'BadEnum')
209
210    def testRepeatedInt(self):
211        """Test duplicated numbers are forbidden."""
212        self.assertRaises(messages.EnumDefinitionError,
213                          messages.Enum.def_enum,
214                          {'Ok': 1, 'Repeated': 1},
215                          'BadEnum')
216
217    def testStr(self):
218        """Test converting to string."""
219        self.assertEquals('RED', str(Color.RED))
220        self.assertEquals('ORANGE', str(Color.ORANGE))
221
222    def testInt(self):
223        """Test converting to int."""
224        self.assertEquals(20, int(Color.RED))
225        self.assertEquals(2, int(Color.ORANGE))
226
227    def testRepr(self):
228        """Test enum representation."""
229        self.assertEquals('Color(RED, 20)', repr(Color.RED))
230        self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW))
231
232    def testDocstring(self):
233        """Test that docstring is supported ok."""
234        class NotImportant(messages.Enum):
235            """I have a docstring."""
236
237            VALUE1 = 1
238
239        self.assertEquals('I have a docstring.', NotImportant.__doc__)
240
241    def testDeleteEnumValue(self):
242        """Test that enum values cannot be deleted."""
243        self.assertRaises(TypeError, delattr, Color, 'RED')
244
245    def testEnumName(self):
246        """Test enum name."""
247        module_name = test_util.get_module_name(EnumTest)
248        self.assertEquals('%s.Color' % module_name, Color.definition_name())
249        self.assertEquals(module_name, Color.outer_definition_name())
250        self.assertEquals(module_name, Color.definition_package())
251
252    def testDefinitionName_OverrideModule(self):
253        """Test enum module is overriden by module package name."""
254        global package
255        try:
256            package = 'my.package'
257            self.assertEquals('my.package.Color', Color.definition_name())
258            self.assertEquals('my.package', Color.outer_definition_name())
259            self.assertEquals('my.package', Color.definition_package())
260        finally:
261            del package
262
263    def testDefinitionName_NoModule(self):
264        """Test what happens when there is no module for enum."""
265        class Enum1(messages.Enum):
266            pass
267
268        original_modules = sys.modules
269        sys.modules = dict(sys.modules)
270        try:
271            del sys.modules[__name__]
272            self.assertEquals('Enum1', Enum1.definition_name())
273            self.assertEquals(None, Enum1.outer_definition_name())
274            self.assertEquals(None, Enum1.definition_package())
275            self.assertEquals(six.text_type, type(Enum1.definition_name()))
276        finally:
277            sys.modules = original_modules
278
279    def testDefinitionName_Nested(self):
280        """Test nested Enum names."""
281        class MyMessage(messages.Message):
282
283            class NestedEnum(messages.Enum):
284
285                pass
286
287            class NestedMessage(messages.Message):
288
289                class NestedEnum(messages.Enum):
290
291                    pass
292
293        module_name = test_util.get_module_name(EnumTest)
294        self.assertEquals('%s.MyMessage.NestedEnum' % module_name,
295                          MyMessage.NestedEnum.definition_name())
296        self.assertEquals('%s.MyMessage' % module_name,
297                          MyMessage.NestedEnum.outer_definition_name())
298        self.assertEquals(module_name,
299                          MyMessage.NestedEnum.definition_package())
300
301        self.assertEquals(
302            '%s.MyMessage.NestedMessage.NestedEnum' % module_name,
303            MyMessage.NestedMessage.NestedEnum.definition_name())
304        self.assertEquals(
305            '%s.MyMessage.NestedMessage' % module_name,
306            MyMessage.NestedMessage.NestedEnum.outer_definition_name())
307        self.assertEquals(
308            module_name,
309            MyMessage.NestedMessage.NestedEnum.definition_package())
310
311    def testMessageDefinition(self):
312        """Test that enumeration knows its enclosing message definition."""
313        class OuterEnum(messages.Enum):
314            pass
315
316        self.assertEquals(None, OuterEnum.message_definition())
317
318        class OuterMessage(messages.Message):
319
320            class InnerEnum(messages.Enum):
321                pass
322
323        self.assertEquals(
324            OuterMessage, OuterMessage.InnerEnum.message_definition())
325
326    def testComparison(self):
327        """Test comparing various enums to different types."""
328        class Enum1(messages.Enum):
329            VAL1 = 1
330            VAL2 = 2
331
332        class Enum2(messages.Enum):
333            VAL1 = 1
334
335        self.assertEquals(Enum1.VAL1, Enum1.VAL1)
336        self.assertNotEquals(Enum1.VAL1, Enum1.VAL2)
337        self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
338        self.assertNotEquals(Enum1.VAL1, 'VAL1')
339        self.assertNotEquals(Enum1.VAL1, 1)
340        self.assertNotEquals(Enum1.VAL1, 2)
341        self.assertNotEquals(Enum1.VAL1, None)
342        self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
343
344        self.assertTrue(Enum1.VAL1 < Enum1.VAL2)
345        self.assertTrue(Enum1.VAL2 > Enum1.VAL1)
346
347        self.assertNotEquals(1, Enum2.VAL1)
348
349    def testPickle(self):
350        """Testing pickling and unpickling of Enum instances."""
351        colors = list(Color)
352        unpickled = pickle.loads(pickle.dumps(colors))
353        self.assertEquals(colors, unpickled)
354        # Unpickling shouldn't create new enum instances.
355        for i, color in enumerate(colors):
356            self.assertTrue(color is unpickled[i])
357
358
359class FieldListTest(test_util.TestCase):
360
361    def setUp(self):
362        self.integer_field = messages.IntegerField(1, repeated=True)
363
364    def testConstructor(self):
365        self.assertEquals([1, 2, 3],
366                          messages.FieldList(self.integer_field, [1, 2, 3]))
367        self.assertEquals([1, 2, 3],
368                          messages.FieldList(self.integer_field, (1, 2, 3)))
369        self.assertEquals([], messages.FieldList(self.integer_field, []))
370
371    def testNone(self):
372        self.assertRaises(TypeError, messages.FieldList,
373                          self.integer_field, None)
374
375    def testDoNotAutoConvertString(self):
376        string_field = messages.StringField(1, repeated=True)
377        self.assertRaises(messages.ValidationError,
378                          messages.FieldList, string_field, 'abc')
379
380    def testConstructorCopies(self):
381        a_list = [1, 3, 6]
382        field_list = messages.FieldList(self.integer_field, a_list)
383        self.assertFalse(a_list is field_list)
384        self.assertFalse(field_list is
385                         messages.FieldList(self.integer_field, field_list))
386
387    def testNonRepeatedField(self):
388        self.assertRaisesWithRegexpMatch(
389            messages.FieldDefinitionError,
390            'FieldList may only accept repeated fields',
391            messages.FieldList,
392            messages.IntegerField(1),
393            [])
394
395    def testConstructor_InvalidValues(self):
396        self.assertRaisesWithRegexpMatch(
397            messages.ValidationError,
398            re.escape("Expected type %r "
399                      "for IntegerField, found 1 (type %r)"
400                      % (six.integer_types, str)),
401            messages.FieldList, self.integer_field, ["1", "2", "3"])
402
403    def testConstructor_Scalars(self):
404        self.assertRaisesWithRegexpMatch(
405            messages.ValidationError,
406            "IntegerField is repeated. Found: 3",
407            messages.FieldList, self.integer_field, 3)
408
409        self.assertRaisesWithRegexpMatch(
410            messages.ValidationError,
411            ("IntegerField is repeated. Found: "
412             "<(list[_]?|sequence)iterator object"),
413            messages.FieldList, self.integer_field, iter([1, 2, 3]))
414
415    def testSetSlice(self):
416        field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
417        field_list[1:3] = [10, 20]
418        self.assertEquals([1, 10, 20, 4, 5], field_list)
419
420    def testSetSlice_InvalidValues(self):
421        field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
422
423        def setslice():
424            field_list[1:3] = ['10', '20']
425
426        msg_re = re.escape("Expected type %r "
427                           "for IntegerField, found 10 (type %r)"
428                           % (six.integer_types, str))
429        self.assertRaisesWithRegexpMatch(
430            messages.ValidationError,
431            msg_re,
432            setslice)
433
434    def testSetItem(self):
435        field_list = messages.FieldList(self.integer_field, [2])
436        field_list[0] = 10
437        self.assertEquals([10], field_list)
438
439    def testSetItem_InvalidValues(self):
440        field_list = messages.FieldList(self.integer_field, [2])
441
442        def setitem():
443            field_list[0] = '10'
444        self.assertRaisesWithRegexpMatch(
445            messages.ValidationError,
446            re.escape("Expected type %r "
447                      "for IntegerField, found 10 (type %r)"
448                      % (six.integer_types, str)),
449            setitem)
450
451    def testAppend(self):
452        field_list = messages.FieldList(self.integer_field, [2])
453        field_list.append(10)
454        self.assertEquals([2, 10], field_list)
455
456    def testAppend_InvalidValues(self):
457        field_list = messages.FieldList(self.integer_field, [2])
458        field_list.name = 'a_field'
459
460        def append():
461            field_list.append('10')
462        self.assertRaisesWithRegexpMatch(
463            messages.ValidationError,
464            re.escape("Expected type %r "
465                      "for IntegerField, found 10 (type %r)"
466                      % (six.integer_types, str)),
467            append)
468
469    def testExtend(self):
470        field_list = messages.FieldList(self.integer_field, [2])
471        field_list.extend([10])
472        self.assertEquals([2, 10], field_list)
473
474    def testExtend_InvalidValues(self):
475        field_list = messages.FieldList(self.integer_field, [2])
476
477        def extend():
478            field_list.extend(['10'])
479        self.assertRaisesWithRegexpMatch(
480            messages.ValidationError,
481            re.escape("Expected type %r "
482                      "for IntegerField, found 10 (type %r)"
483                      % (six.integer_types, str)),
484            extend)
485
486    def testInsert(self):
487        field_list = messages.FieldList(self.integer_field, [2, 3])
488        field_list.insert(1, 10)
489        self.assertEquals([2, 10, 3], field_list)
490
491    def testInsert_InvalidValues(self):
492        field_list = messages.FieldList(self.integer_field, [2, 3])
493
494        def insert():
495            field_list.insert(1, '10')
496        self.assertRaisesWithRegexpMatch(
497            messages.ValidationError,
498            re.escape("Expected type %r "
499                      "for IntegerField, found 10 (type %r)"
500                      % (six.integer_types, str)),
501            insert)
502
503    def testPickle(self):
504        """Testing pickling and unpickling of FieldList instances."""
505        field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
506        unpickled = pickle.loads(pickle.dumps(field_list))
507        self.assertEquals(field_list, unpickled)
508        self.assertIsInstance(unpickled.field, messages.IntegerField)
509        self.assertEquals(1, unpickled.field.number)
510        self.assertTrue(unpickled.field.repeated)
511
512
513class FieldTest(test_util.TestCase):
514
515    def ActionOnAllFieldClasses(self, action):
516        """Test all field classes except Message and Enum.
517
518        Message and Enum require separate tests.
519
520        Args:
521          action: Callable that takes the field class as a parameter.
522        """
523        classes = (messages.IntegerField,
524                   messages.FloatField,
525                   messages.BooleanField,
526                   messages.BytesField,
527                   messages.StringField)
528        for field_class in classes:
529            action(field_class)
530
531    def testNumberAttribute(self):
532        """Test setting the number attribute."""
533        def action(field_class):
534            # Check range.
535            self.assertRaises(messages.InvalidNumberError,
536                              field_class,
537                              0)
538            self.assertRaises(messages.InvalidNumberError,
539                              field_class,
540                              -1)
541            self.assertRaises(messages.InvalidNumberError,
542                              field_class,
543                              messages.MAX_FIELD_NUMBER + 1)
544
545            # Check reserved.
546            self.assertRaises(messages.InvalidNumberError,
547                              field_class,
548                              messages.FIRST_RESERVED_FIELD_NUMBER)
549            self.assertRaises(messages.InvalidNumberError,
550                              field_class,
551                              messages.LAST_RESERVED_FIELD_NUMBER)
552            self.assertRaises(messages.InvalidNumberError,
553                              field_class,
554                              '1')
555
556            # This one should work.
557            field_class(number=1)
558        self.ActionOnAllFieldClasses(action)
559
560    def testRequiredAndRepeated(self):
561        """Test setting the required and repeated fields."""
562        def action(field_class):
563            field_class(1, required=True)
564            field_class(1, repeated=True)
565            self.assertRaises(messages.FieldDefinitionError,
566                              field_class,
567                              1,
568                              required=True,
569                              repeated=True)
570        self.ActionOnAllFieldClasses(action)
571
572    def testInvalidVariant(self):
573        """Test field with invalid variants."""
574        def action(field_class):
575            if field_class is not message_types.DateTimeField:
576                self.assertRaises(messages.InvalidVariantError,
577                                  field_class,
578                                  1,
579                                  variant=messages.Variant.ENUM)
580        self.ActionOnAllFieldClasses(action)
581
582    def testDefaultVariant(self):
583        """Test that default variant is used when not set."""
584        def action(field_class):
585            field = field_class(1)
586            self.assertEquals(field_class.DEFAULT_VARIANT, field.variant)
587
588        self.ActionOnAllFieldClasses(action)
589
590    def testAlternateVariant(self):
591        """Test that default variant is used when not set."""
592        field = messages.IntegerField(1, variant=messages.Variant.UINT32)
593        self.assertEquals(messages.Variant.UINT32, field.variant)
594
595    def testDefaultFields_Single(self):
596        """Test default field is correct type (single)."""
597        defaults = {
598            messages.IntegerField: 10,
599            messages.FloatField: 1.5,
600            messages.BooleanField: False,
601            messages.BytesField: b'abc',
602            messages.StringField: u'abc',
603        }
604
605        def action(field_class):
606            field_class(1, default=defaults[field_class])
607        self.ActionOnAllFieldClasses(action)
608
609        # Run defaults test again checking for str/unicode compatiblity.
610        defaults[messages.StringField] = 'abc'
611        self.ActionOnAllFieldClasses(action)
612
613    def testStringField_BadUnicodeInDefault(self):
614        """Test binary values in string field."""
615        self.assertRaisesWithRegexpMatch(
616            messages.InvalidDefaultError,
617            r"Invalid default value for StringField:.*: "
618            r"Field encountered non-ASCII string .*: "
619            r"'ascii' codec can't decode byte 0x89 in position 0: "
620            r"ordinal not in range",
621            messages.StringField, 1, default=b'\x89')
622
623    def testDefaultFields_InvalidSingle(self):
624        """Test default field is correct type (invalid single)."""
625        def action(field_class):
626            self.assertRaises(messages.InvalidDefaultError,
627                              field_class,
628                              1,
629                              default=object())
630        self.ActionOnAllFieldClasses(action)
631
632    def testDefaultFields_InvalidRepeated(self):
633        """Test default field does not accept defaults."""
634        self.assertRaisesWithRegexpMatch(
635            messages.FieldDefinitionError,
636            'Repeated fields may not have defaults',
637            messages.StringField, 1, repeated=True, default=[1, 2, 3])
638
639    def testDefaultFields_None(self):
640        """Test none is always acceptable."""
641        def action(field_class):
642            field_class(1, default=None)
643            field_class(1, required=True, default=None)
644            field_class(1, repeated=True, default=None)
645        self.ActionOnAllFieldClasses(action)
646
647    def testDefaultFields_Enum(self):
648        """Test the default for enum fields."""
649        class Symbol(messages.Enum):
650
651            ALPHA = 1
652            BETA = 2
653            GAMMA = 3
654
655        field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA)
656
657        self.assertEquals(Symbol.ALPHA, field.default)
658
659    def testDefaultFields_EnumStringDelayedResolution(self):
660        """Test that enum fields resolve default strings."""
661        field = messages.EnumField(
662            'apitools.base.protorpclite.descriptor.FieldDescriptor.Label',
663            1,
664            default='OPTIONAL')
665
666        self.assertEquals(
667            descriptor.FieldDescriptor.Label.OPTIONAL, field.default)
668
669    def testDefaultFields_EnumIntDelayedResolution(self):
670        """Test that enum fields resolve default integers."""
671        field = messages.EnumField(
672            'apitools.base.protorpclite.descriptor.FieldDescriptor.Label',
673            1,
674            default=2)
675
676        self.assertEquals(
677            descriptor.FieldDescriptor.Label.REQUIRED, field.default)
678
679    def testDefaultFields_EnumOkIfTypeKnown(self):
680        """Test enum fields accept valid default values when type is known."""
681        field = messages.EnumField(descriptor.FieldDescriptor.Label,
682                                   1,
683                                   default='REPEATED')
684
685        self.assertEquals(
686            descriptor.FieldDescriptor.Label.REPEATED, field.default)
687
688    def testDefaultFields_EnumForceCheckIfTypeKnown(self):
689        """Test that enum fields validate default values if type is known."""
690        self.assertRaisesWithRegexpMatch(TypeError,
691                                         'No such value for NOT_A_LABEL in '
692                                         'Enum Label',
693                                         messages.EnumField,
694                                         descriptor.FieldDescriptor.Label,
695                                         1,
696                                         default='NOT_A_LABEL')
697
698    def testDefaultFields_EnumInvalidDelayedResolution(self):
699        """Test that enum fields raise errors upon delayed resolution error."""
700        field = messages.EnumField(
701            'apitools.base.protorpclite.descriptor.FieldDescriptor.Label',
702            1,
703            default=200)
704
705        self.assertRaisesWithRegexpMatch(TypeError,
706                                         'No such value for 200 in Enum Label',
707                                         getattr,
708                                         field,
709                                         'default')
710
711    def testValidate_Valid(self):
712        """Test validation of valid values."""
713        values = {
714            messages.IntegerField: 10,
715            messages.FloatField: 1.5,
716            messages.BooleanField: False,
717            messages.BytesField: b'abc',
718            messages.StringField: u'abc',
719        }
720
721        def action(field_class):
722            # Optional.
723            field = field_class(1)
724            field.validate(values[field_class])
725
726            # Required.
727            field = field_class(1, required=True)
728            field.validate(values[field_class])
729
730            # Repeated.
731            field = field_class(1, repeated=True)
732            field.validate([])
733            field.validate(())
734            field.validate([values[field_class]])
735            field.validate((values[field_class],))
736
737            # Right value, but not repeated.
738            self.assertRaises(messages.ValidationError,
739                              field.validate,
740                              values[field_class])
741            self.assertRaises(messages.ValidationError,
742                              field.validate,
743                              values[field_class])
744
745        self.ActionOnAllFieldClasses(action)
746
747    def testValidate_Invalid(self):
748        """Test validation of valid values."""
749        values = {
750            messages.IntegerField: "10",
751            messages.FloatField: "blah",
752            messages.BooleanField: 0,
753            messages.BytesField: 10.20,
754            messages.StringField: 42,
755        }
756
757        def action(field_class):
758            # Optional.
759            field = field_class(1)
760            self.assertRaises(messages.ValidationError,
761                              field.validate,
762                              values[field_class])
763
764            # Required.
765            field = field_class(1, required=True)
766            self.assertRaises(messages.ValidationError,
767                              field.validate,
768                              values[field_class])
769
770            # Repeated.
771            field = field_class(1, repeated=True)
772            self.assertRaises(messages.ValidationError,
773                              field.validate,
774                              [values[field_class]])
775            self.assertRaises(messages.ValidationError,
776                              field.validate,
777                              (values[field_class],))
778        self.ActionOnAllFieldClasses(action)
779
780    def testValidate_None(self):
781        """Test that None is valid for non-required fields."""
782        def action(field_class):
783            # Optional.
784            field = field_class(1)
785            field.validate(None)
786
787            # Required.
788            field = field_class(1, required=True)
789            self.assertRaisesWithRegexpMatch(messages.ValidationError,
790                                             'Required field is missing',
791                                             field.validate,
792                                             None)
793
794            # Repeated.
795            field = field_class(1, repeated=True)
796            field.validate(None)
797            self.assertRaisesWithRegexpMatch(
798                messages.ValidationError,
799                'Repeated values for %s may '
800                'not be None' % field_class.__name__,
801                field.validate,
802                [None])
803            self.assertRaises(messages.ValidationError,
804                              field.validate,
805                              (None,))
806        self.ActionOnAllFieldClasses(action)
807
808    def testValidateElement(self):
809        """Test validation of valid values."""
810        values = {
811            messages.IntegerField: (10, -1, 0),
812            messages.FloatField: (1.5, -1.5, 3),  # for json it is all a number
813            messages.BooleanField: (True, False),
814            messages.BytesField: (b'abc',),
815            messages.StringField: (u'abc',),
816        }
817
818        def action(field_class):
819            # Optional.
820            field = field_class(1)
821            for value in values[field_class]:
822                field.validate_element(value)
823
824            # Required.
825            field = field_class(1, required=True)
826            for value in values[field_class]:
827                field.validate_element(value)
828
829            # Repeated.
830            field = field_class(1, repeated=True)
831            self.assertRaises(messages.ValidationError,
832                              field.validate_element,
833                              [])
834            self.assertRaises(messages.ValidationError,
835                              field.validate_element,
836                              ())
837            for value in values[field_class]:
838                field.validate_element(value)
839
840            # Right value, but repeated.
841            self.assertRaises(messages.ValidationError,
842                              field.validate_element,
843                              list(values[field_class]))  # testing list
844            self.assertRaises(messages.ValidationError,
845                              field.validate_element,
846                              values[field_class])  # testing tuple
847
848        self.ActionOnAllFieldClasses(action)
849
850    def testValidateCastingElement(self):
851        field = messages.FloatField(1)
852        self.assertEquals(type(field.validate_element(12)), float)
853        self.assertEquals(type(field.validate_element(12.0)), float)
854        field = messages.IntegerField(1)
855        self.assertEquals(type(field.validate_element(12)), int)
856        self.assertRaises(messages.ValidationError,
857                          field.validate_element,
858                          12.0)  # should fails from float to int
859
860    def testReadOnly(self):
861        """Test that objects are all read-only."""
862        def action(field_class):
863            field = field_class(10)
864            self.assertRaises(AttributeError,
865                              setattr,
866                              field,
867                              'number',
868                              20)
869            self.assertRaises(AttributeError,
870                              setattr,
871                              field,
872                              'anything_else',
873                              'whatever')
874        self.ActionOnAllFieldClasses(action)
875
876    def testMessageField(self):
877        """Test the construction of message fields."""
878        self.assertRaises(messages.FieldDefinitionError,
879                          messages.MessageField,
880                          str,
881                          10)
882
883        self.assertRaises(messages.FieldDefinitionError,
884                          messages.MessageField,
885                          messages.Message,
886                          10)
887
888        class MyMessage(messages.Message):
889            pass
890
891        field = messages.MessageField(MyMessage, 10)
892        self.assertEquals(MyMessage, field.type)
893
894    def testMessageField_ForwardReference(self):
895        """Test the construction of forward reference message fields."""
896        global MyMessage
897        global ForwardMessage
898        try:
899            class MyMessage(messages.Message):
900
901                self_reference = messages.MessageField('MyMessage', 1)
902                forward = messages.MessageField('ForwardMessage', 2)
903                nested = messages.MessageField(
904                    'ForwardMessage.NestedMessage', 3)
905                inner = messages.MessageField('Inner', 4)
906
907                class Inner(messages.Message):
908
909                    sibling = messages.MessageField('Sibling', 1)
910
911                class Sibling(messages.Message):
912
913                    pass
914
915            class ForwardMessage(messages.Message):
916
917                class NestedMessage(messages.Message):
918
919                    pass
920
921            self.assertEquals(MyMessage,
922                              MyMessage.field_by_name('self_reference').type)
923
924            self.assertEquals(ForwardMessage,
925                              MyMessage.field_by_name('forward').type)
926
927            self.assertEquals(ForwardMessage.NestedMessage,
928                              MyMessage.field_by_name('nested').type)
929
930            self.assertEquals(MyMessage.Inner,
931                              MyMessage.field_by_name('inner').type)
932
933            self.assertEquals(MyMessage.Sibling,
934                              MyMessage.Inner.field_by_name('sibling').type)
935        finally:
936            try:
937                del MyMessage
938                del ForwardMessage
939            except:  # pylint:disable=bare-except
940                pass
941
942    def testMessageField_WrongType(self):
943        """Test that forward referencing the wrong type raises an error."""
944        global AnEnum
945        try:
946            class AnEnum(messages.Enum):
947                pass
948
949            class AnotherMessage(messages.Message):
950
951                a_field = messages.MessageField('AnEnum', 1)
952
953            self.assertRaises(messages.FieldDefinitionError,
954                              getattr,
955                              AnotherMessage.field_by_name('a_field'),
956                              'type')
957        finally:
958            del AnEnum
959
960    def testMessageFieldValidate(self):
961        """Test validation on message field."""
962        class MyMessage(messages.Message):
963            pass
964
965        class AnotherMessage(messages.Message):
966            pass
967
968        field = messages.MessageField(MyMessage, 10)
969        field.validate(MyMessage())
970
971        self.assertRaises(messages.ValidationError,
972                          field.validate,
973                          AnotherMessage())
974
975    def testMessageFieldMessageType(self):
976        """Test message_type property."""
977        class MyMessage(messages.Message):
978            pass
979
980        class HasMessage(messages.Message):
981            field = messages.MessageField(MyMessage, 1)
982
983        self.assertEqual(HasMessage.field.type, HasMessage.field.message_type)
984
985    def testMessageFieldValueFromMessage(self):
986        class MyMessage(messages.Message):
987            pass
988
989        class HasMessage(messages.Message):
990            field = messages.MessageField(MyMessage, 1)
991
992        instance = MyMessage()
993
994        self.assertTrue(
995            instance is HasMessage.field.value_from_message(instance))
996
997    def testMessageFieldValueFromMessageWrongType(self):
998        class MyMessage(messages.Message):
999            pass
1000
1001        class HasMessage(messages.Message):
1002            field = messages.MessageField(MyMessage, 1)
1003
1004        self.assertRaisesWithRegexpMatch(
1005            messages.DecodeError,
1006            'Expected type MyMessage, got int: 10',
1007            HasMessage.field.value_from_message, 10)
1008
1009    def testMessageFieldValueToMessage(self):
1010        class MyMessage(messages.Message):
1011            pass
1012
1013        class HasMessage(messages.Message):
1014            field = messages.MessageField(MyMessage, 1)
1015
1016        instance = MyMessage()
1017
1018        self.assertTrue(
1019            instance is HasMessage.field.value_to_message(instance))
1020
1021    def testMessageFieldValueToMessageWrongType(self):
1022        class MyMessage(messages.Message):
1023            pass
1024
1025        class MyOtherMessage(messages.Message):
1026            pass
1027
1028        class HasMessage(messages.Message):
1029            field = messages.MessageField(MyMessage, 1)
1030
1031        instance = MyOtherMessage()
1032
1033        self.assertRaisesWithRegexpMatch(
1034            messages.EncodeError,
1035            'Expected type MyMessage, got MyOtherMessage: <MyOtherMessage>',
1036            HasMessage.field.value_to_message, instance)
1037
1038    def testIntegerField_AllowLong(self):
1039        """Test that the integer field allows for longs."""
1040        if six.PY2:
1041            messages.IntegerField(10, default=long(10))
1042
1043    def testMessageFieldValidate_Initialized(self):
1044        """Test validation on message field."""
1045        class MyMessage(messages.Message):
1046            field1 = messages.IntegerField(1, required=True)
1047
1048        field = messages.MessageField(MyMessage, 10)
1049
1050        # Will validate messages where is_initialized() is False.
1051        message = MyMessage()
1052        field.validate(message)
1053        message.field1 = 20
1054        field.validate(message)
1055
1056    def testEnumField(self):
1057        """Test the construction of enum fields."""
1058        self.assertRaises(messages.FieldDefinitionError,
1059                          messages.EnumField,
1060                          str,
1061                          10)
1062
1063        self.assertRaises(messages.FieldDefinitionError,
1064                          messages.EnumField,
1065                          messages.Enum,
1066                          10)
1067
1068        class Color(messages.Enum):
1069            RED = 1
1070            GREEN = 2
1071            BLUE = 3
1072
1073        field = messages.EnumField(Color, 10)
1074        self.assertEquals(Color, field.type)
1075
1076        class Another(messages.Enum):
1077            VALUE = 1
1078
1079        self.assertRaises(messages.InvalidDefaultError,
1080                          messages.EnumField,
1081                          Color,
1082                          10,
1083                          default=Another.VALUE)
1084
1085    def testEnumField_ForwardReference(self):
1086        """Test the construction of forward reference enum fields."""
1087        global MyMessage
1088        global ForwardEnum
1089        global ForwardMessage
1090        try:
1091            class MyMessage(messages.Message):
1092
1093                forward = messages.EnumField('ForwardEnum', 1)
1094                nested = messages.EnumField('ForwardMessage.NestedEnum', 2)
1095                inner = messages.EnumField('Inner', 3)
1096
1097                class Inner(messages.Enum):
1098                    pass
1099
1100            class ForwardEnum(messages.Enum):
1101                pass
1102
1103            class ForwardMessage(messages.Message):
1104
1105                class NestedEnum(messages.Enum):
1106                    pass
1107
1108            self.assertEquals(ForwardEnum,
1109                              MyMessage.field_by_name('forward').type)
1110
1111            self.assertEquals(ForwardMessage.NestedEnum,
1112                              MyMessage.field_by_name('nested').type)
1113
1114            self.assertEquals(MyMessage.Inner,
1115                              MyMessage.field_by_name('inner').type)
1116        finally:
1117            try:
1118                del MyMessage
1119                del ForwardEnum
1120                del ForwardMessage
1121            except:  # pylint:disable=bare-except
1122                pass
1123
1124    def testEnumField_WrongType(self):
1125        """Test that forward referencing the wrong type raises an error."""
1126        global AMessage
1127        try:
1128            class AMessage(messages.Message):
1129                pass
1130
1131            class AnotherMessage(messages.Message):
1132
1133                a_field = messages.EnumField('AMessage', 1)
1134
1135            self.assertRaises(messages.FieldDefinitionError,
1136                              getattr,
1137                              AnotherMessage.field_by_name('a_field'),
1138                              'type')
1139        finally:
1140            del AMessage
1141
1142    def testMessageDefinition(self):
1143        """Test that message definition is set on fields."""
1144        class MyMessage(messages.Message):
1145
1146            my_field = messages.StringField(1)
1147
1148        self.assertEquals(
1149            MyMessage,
1150            MyMessage.field_by_name('my_field').message_definition())
1151
1152    def testNoneAssignment(self):
1153        """Test that assigning None does not change comparison."""
1154        class MyMessage(messages.Message):
1155
1156            my_field = messages.StringField(1)
1157
1158        m1 = MyMessage()
1159        m2 = MyMessage()
1160        m2.my_field = None
1161        self.assertEquals(m1, m2)
1162
1163    def testNonAsciiStr(self):
1164        """Test validation fails for non-ascii StringField values."""
1165        class Thing(messages.Message):
1166            string_field = messages.StringField(2)
1167
1168        thing = Thing()
1169        self.assertRaisesWithRegexpMatch(
1170            messages.ValidationError,
1171            'Field string_field encountered non-ASCII string',
1172            setattr, thing, 'string_field', test_util.BINARY)
1173
1174
1175class MessageTest(test_util.TestCase):
1176    """Tests for message class."""
1177
1178    def CreateMessageClass(self):
1179        """Creates a simple message class with 3 fields.
1180
1181        Fields are defined in alphabetical order but with conflicting numeric
1182        order.
1183        """
1184        class ComplexMessage(messages.Message):
1185            a3 = messages.IntegerField(3)
1186            b1 = messages.StringField(1)
1187            c2 = messages.StringField(2)
1188
1189        return ComplexMessage
1190
1191    def testSameNumbers(self):
1192        """Test that cannot assign two fields with same numbers."""
1193
1194        def action():
1195            class BadMessage(messages.Message):
1196                f1 = messages.IntegerField(1)
1197                f2 = messages.IntegerField(1)
1198        self.assertRaises(messages.DuplicateNumberError,
1199                          action)
1200
1201    def testStrictAssignment(self):
1202        """Tests that cannot assign to unknown or non-reserved attributes."""
1203        class SimpleMessage(messages.Message):
1204            field = messages.IntegerField(1)
1205
1206        simple_message = SimpleMessage()
1207        self.assertRaises(AttributeError,
1208                          setattr,
1209                          simple_message,
1210                          'does_not_exist',
1211                          10)
1212
1213    def testListAssignmentDoesNotCopy(self):
1214        class SimpleMessage(messages.Message):
1215            repeated = messages.IntegerField(1, repeated=True)
1216
1217        message = SimpleMessage()
1218        original = message.repeated
1219        message.repeated = []
1220        self.assertFalse(original is message.repeated)
1221
1222    def testValidate_Optional(self):
1223        """Tests validation of optional fields."""
1224        class SimpleMessage(messages.Message):
1225            non_required = messages.IntegerField(1)
1226
1227        simple_message = SimpleMessage()
1228        simple_message.check_initialized()
1229        simple_message.non_required = 10
1230        simple_message.check_initialized()
1231
1232    def testValidate_Required(self):
1233        """Tests validation of required fields."""
1234        class SimpleMessage(messages.Message):
1235            required = messages.IntegerField(1, required=True)
1236
1237        simple_message = SimpleMessage()
1238        self.assertRaises(messages.ValidationError,
1239                          simple_message.check_initialized)
1240        simple_message.required = 10
1241        simple_message.check_initialized()
1242
1243    def testValidate_Repeated(self):
1244        """Tests validation of repeated fields."""
1245        class SimpleMessage(messages.Message):
1246            repeated = messages.IntegerField(1, repeated=True)
1247
1248        simple_message = SimpleMessage()
1249
1250        # Check valid values.
1251        for valid_value in [], [10], [10, 20], (), (10,), (10, 20):
1252            simple_message.repeated = valid_value
1253            simple_message.check_initialized()
1254
1255        # Check cleared.
1256        simple_message.repeated = []
1257        simple_message.check_initialized()
1258
1259        # Check invalid values.
1260        for invalid_value in 10, ['10', '20'], [None], (None,):
1261            self.assertRaises(
1262                messages.ValidationError,
1263                setattr, simple_message, 'repeated', invalid_value)
1264
1265    def testIsInitialized(self):
1266        """Tests is_initialized."""
1267        class SimpleMessage(messages.Message):
1268            required = messages.IntegerField(1, required=True)
1269
1270        simple_message = SimpleMessage()
1271        self.assertFalse(simple_message.is_initialized())
1272
1273        simple_message.required = 10
1274
1275        self.assertTrue(simple_message.is_initialized())
1276
1277    def testIsInitializedNestedField(self):
1278        """Tests is_initialized for nested fields."""
1279        class SimpleMessage(messages.Message):
1280            required = messages.IntegerField(1, required=True)
1281
1282        class NestedMessage(messages.Message):
1283            simple = messages.MessageField(SimpleMessage, 1)
1284
1285        simple_message = SimpleMessage()
1286        self.assertFalse(simple_message.is_initialized())
1287        nested_message = NestedMessage(simple=simple_message)
1288        self.assertFalse(nested_message.is_initialized())
1289
1290        simple_message.required = 10
1291
1292        self.assertTrue(simple_message.is_initialized())
1293        self.assertTrue(nested_message.is_initialized())
1294
1295    def testInitializeNestedFieldFromDict(self):
1296        """Tests initializing nested fields from dict."""
1297        class SimpleMessage(messages.Message):
1298            required = messages.IntegerField(1, required=True)
1299
1300        class NestedMessage(messages.Message):
1301            simple = messages.MessageField(SimpleMessage, 1)
1302
1303        class RepeatedMessage(messages.Message):
1304            simple = messages.MessageField(SimpleMessage, 1, repeated=True)
1305
1306        nested_message1 = NestedMessage(simple={'required': 10})
1307        self.assertTrue(nested_message1.is_initialized())
1308        self.assertTrue(nested_message1.simple.is_initialized())
1309
1310        nested_message2 = NestedMessage()
1311        nested_message2.simple = {'required': 10}
1312        self.assertTrue(nested_message2.is_initialized())
1313        self.assertTrue(nested_message2.simple.is_initialized())
1314
1315        repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)]
1316
1317        repeated_message1 = RepeatedMessage(simple=repeated_values)
1318        self.assertEquals(3, len(repeated_message1.simple))
1319        self.assertFalse(repeated_message1.is_initialized())
1320
1321        repeated_message1.simple[0].required = 0
1322        self.assertTrue(repeated_message1.is_initialized())
1323
1324        repeated_message2 = RepeatedMessage()
1325        repeated_message2.simple = repeated_values
1326        self.assertEquals(3, len(repeated_message2.simple))
1327        self.assertFalse(repeated_message2.is_initialized())
1328
1329        repeated_message2.simple[0].required = 0
1330        self.assertTrue(repeated_message2.is_initialized())
1331
1332    def testNestedMethodsNotAllowed(self):
1333        """Test that method definitions on Message classes are not allowed."""
1334        def action():
1335            class WithMethods(messages.Message):
1336
1337                def not_allowed(self):
1338                    pass
1339
1340        self.assertRaises(messages.MessageDefinitionError,
1341                          action)
1342
1343    def testNestedAttributesNotAllowed(self):
1344        """Test attribute assignment on Message classes is not allowed."""
1345        def int_attribute():
1346            class WithMethods(messages.Message):
1347                not_allowed = 1
1348
1349        def string_attribute():
1350            class WithMethods(messages.Message):
1351                not_allowed = 'not allowed'
1352
1353        def enum_attribute():
1354            class WithMethods(messages.Message):
1355                not_allowed = Color.RED
1356
1357        for action in (int_attribute, string_attribute, enum_attribute):
1358            self.assertRaises(messages.MessageDefinitionError,
1359                              action)
1360
1361    def testNameIsSetOnFields(self):
1362        """Make sure name is set on fields after Message class init."""
1363        class HasNamedFields(messages.Message):
1364            field = messages.StringField(1)
1365
1366        self.assertEquals('field', HasNamedFields.field_by_number(1).name)
1367
1368    def testSubclassingMessageDisallowed(self):
1369        """Not permitted to create sub-classes of message classes."""
1370        class SuperClass(messages.Message):
1371            pass
1372
1373        def action():
1374            class SubClass(SuperClass):
1375                pass
1376
1377        self.assertRaises(messages.MessageDefinitionError,
1378                          action)
1379
1380    def testAllFields(self):
1381        """Test all_fields method."""
1382        ComplexMessage = self.CreateMessageClass()
1383        fields = list(ComplexMessage.all_fields())
1384
1385        # Order does not matter, so sort now.
1386        fields = sorted(fields, key=lambda f: f.name)
1387
1388        self.assertEquals(3, len(fields))
1389        self.assertEquals('a3', fields[0].name)
1390        self.assertEquals('b1', fields[1].name)
1391        self.assertEquals('c2', fields[2].name)
1392
1393    def testFieldByName(self):
1394        """Test getting field by name."""
1395        ComplexMessage = self.CreateMessageClass()
1396
1397        self.assertEquals(3, ComplexMessage.field_by_name('a3').number)
1398        self.assertEquals(1, ComplexMessage.field_by_name('b1').number)
1399        self.assertEquals(2, ComplexMessage.field_by_name('c2').number)
1400
1401        self.assertRaises(KeyError,
1402                          ComplexMessage.field_by_name,
1403                          'unknown')
1404
1405    def testFieldByNumber(self):
1406        """Test getting field by number."""
1407        ComplexMessage = self.CreateMessageClass()
1408
1409        self.assertEquals('a3', ComplexMessage.field_by_number(3).name)
1410        self.assertEquals('b1', ComplexMessage.field_by_number(1).name)
1411        self.assertEquals('c2', ComplexMessage.field_by_number(2).name)
1412
1413        self.assertRaises(KeyError,
1414                          ComplexMessage.field_by_number,
1415                          4)
1416
1417    def testGetAssignedValue(self):
1418        """Test getting the assigned value of a field."""
1419        class SomeMessage(messages.Message):
1420            a_value = messages.StringField(1, default=u'a default')
1421
1422        message = SomeMessage()
1423        self.assertEquals(None, message.get_assigned_value('a_value'))
1424
1425        message.a_value = u'a string'
1426        self.assertEquals(u'a string', message.get_assigned_value('a_value'))
1427
1428        message.a_value = u'a default'
1429        self.assertEquals(u'a default', message.get_assigned_value('a_value'))
1430
1431        self.assertRaisesWithRegexpMatch(
1432            AttributeError,
1433            'Message SomeMessage has no field no_such_field',
1434            message.get_assigned_value,
1435            'no_such_field')
1436
1437    def testReset(self):
1438        """Test resetting a field value."""
1439        class SomeMessage(messages.Message):
1440            a_value = messages.StringField(1, default=u'a default')
1441            repeated = messages.IntegerField(2, repeated=True)
1442
1443        message = SomeMessage()
1444
1445        self.assertRaises(AttributeError, message.reset, 'unknown')
1446
1447        self.assertEquals(u'a default', message.a_value)
1448        message.reset('a_value')
1449        self.assertEquals(u'a default', message.a_value)
1450
1451        message.a_value = u'a new value'
1452        self.assertEquals(u'a new value', message.a_value)
1453        message.reset('a_value')
1454        self.assertEquals(u'a default', message.a_value)
1455
1456        message.repeated = [1, 2, 3]
1457        self.assertEquals([1, 2, 3], message.repeated)
1458        saved = message.repeated
1459        message.reset('repeated')
1460        self.assertEquals([], message.repeated)
1461        self.assertIsInstance(message.repeated, messages.FieldList)
1462        self.assertEquals([1, 2, 3], saved)
1463
1464    def testAllowNestedEnums(self):
1465        """Test allowing nested enums in a message definition."""
1466        class Trade(messages.Message):
1467
1468            class Duration(messages.Enum):
1469                GTC = 1
1470                DAY = 2
1471
1472            class Currency(messages.Enum):
1473                USD = 1
1474                GBP = 2
1475                INR = 3
1476
1477        # Sorted by name order seems to be the only feasible option.
1478        self.assertEquals(['Currency', 'Duration'], Trade.__enums__)
1479
1480        # Message definition will now be set on Enumerated objects.
1481        self.assertEquals(Trade, Trade.Duration.message_definition())
1482
1483    def testAllowNestedMessages(self):
1484        """Test allowing nested messages in a message definition."""
1485        class Trade(messages.Message):
1486
1487            class Lot(messages.Message):
1488                pass
1489
1490            class Agent(messages.Message):
1491                pass
1492
1493        # Sorted by name order seems to be the only feasible option.
1494        self.assertEquals(['Agent', 'Lot'], Trade.__messages__)
1495        self.assertEquals(Trade, Trade.Agent.message_definition())
1496        self.assertEquals(Trade, Trade.Lot.message_definition())
1497
1498        # But not Message itself.
1499        def action():
1500            class Trade(messages.Message):
1501                NiceTry = messages.Message
1502        self.assertRaises(messages.MessageDefinitionError, action)
1503
1504    def testDisallowClassAssignments(self):
1505        """Test setting class attributes may not happen."""
1506        class MyMessage(messages.Message):
1507            pass
1508
1509        self.assertRaises(AttributeError,
1510                          setattr,
1511                          MyMessage,
1512                          'x',
1513                          'do not assign')
1514
1515    def testEquality(self):
1516        """Test message class equality."""
1517        # Comparison against enums must work.
1518        class MyEnum(messages.Enum):
1519            val1 = 1
1520            val2 = 2
1521
1522        # Comparisons against nested messages must work.
1523        class AnotherMessage(messages.Message):
1524            string = messages.StringField(1)
1525
1526        class MyMessage(messages.Message):
1527            field1 = messages.IntegerField(1)
1528            field2 = messages.EnumField(MyEnum, 2)
1529            field3 = messages.MessageField(AnotherMessage, 3)
1530
1531        message1 = MyMessage()
1532
1533        self.assertNotEquals('hi', message1)
1534        self.assertNotEquals(AnotherMessage(), message1)
1535        self.assertEquals(message1, message1)
1536
1537        message2 = MyMessage()
1538
1539        self.assertEquals(message1, message2)
1540
1541        message1.field1 = 10
1542        self.assertNotEquals(message1, message2)
1543
1544        message2.field1 = 20
1545        self.assertNotEquals(message1, message2)
1546
1547        message2.field1 = 10
1548        self.assertEquals(message1, message2)
1549
1550        message1.field2 = MyEnum.val1
1551        self.assertNotEquals(message1, message2)
1552
1553        message2.field2 = MyEnum.val2
1554        self.assertNotEquals(message1, message2)
1555
1556        message2.field2 = MyEnum.val1
1557        self.assertEquals(message1, message2)
1558
1559        message1.field3 = AnotherMessage()
1560        message1.field3.string = 'value1'
1561        self.assertNotEquals(message1, message2)
1562
1563        message2.field3 = AnotherMessage()
1564        message2.field3.string = 'value2'
1565        self.assertNotEquals(message1, message2)
1566
1567        message2.field3.string = 'value1'
1568        self.assertEquals(message1, message2)
1569
1570    def testEqualityWithUnknowns(self):
1571        """Test message class equality with unknown fields."""
1572
1573        class MyMessage(messages.Message):
1574            field1 = messages.IntegerField(1)
1575
1576        message1 = MyMessage()
1577        message2 = MyMessage()
1578        self.assertEquals(message1, message2)
1579        message1.set_unrecognized_field('unknown1', 'value1',
1580                                        messages.Variant.STRING)
1581        self.assertEquals(message1, message2)
1582
1583        message1.set_unrecognized_field('unknown2', ['asdf', 3],
1584                                        messages.Variant.STRING)
1585        message1.set_unrecognized_field('unknown3', 4.7,
1586                                        messages.Variant.DOUBLE)
1587        self.assertEquals(message1, message2)
1588
1589    def testUnrecognizedFieldInvalidVariant(self):
1590        class MyMessage(messages.Message):
1591            field1 = messages.IntegerField(1)
1592
1593        message1 = MyMessage()
1594        self.assertRaises(
1595            TypeError, message1.set_unrecognized_field, 'unknown4',
1596            {'unhandled': 'type'}, None)
1597        self.assertRaises(
1598            TypeError, message1.set_unrecognized_field, 'unknown4',
1599            {'unhandled': 'type'}, 123)
1600
1601    def testRepr(self):
1602        """Test represtation of Message object."""
1603        class MyMessage(messages.Message):
1604            integer_value = messages.IntegerField(1)
1605            string_value = messages.StringField(2)
1606            unassigned = messages.StringField(3)
1607            unassigned_with_default = messages.StringField(
1608                4, default=u'a default')
1609
1610        my_message = MyMessage()
1611        my_message.integer_value = 42
1612        my_message.string_value = u'A string'
1613
1614        pat = re.compile(r"<MyMessage\n integer_value: 42\n"
1615                         " string_value: [u]?'A string'>")
1616        self.assertTrue(pat.match(repr(my_message)) is not None)
1617
1618    def testValidation(self):
1619        """Test validation of message values."""
1620        # Test optional.
1621        class SubMessage(messages.Message):
1622            pass
1623
1624        class Message(messages.Message):
1625            val = messages.MessageField(SubMessage, 1)
1626
1627        message = Message()
1628
1629        message_field = messages.MessageField(Message, 1)
1630        message_field.validate(message)
1631        message.val = SubMessage()
1632        message_field.validate(message)
1633        self.assertRaises(messages.ValidationError,
1634                          setattr, message, 'val', [SubMessage()])
1635
1636        # Test required.
1637        class Message(messages.Message):
1638            val = messages.MessageField(SubMessage, 1, required=True)
1639
1640        message = Message()
1641
1642        message_field = messages.MessageField(Message, 1)
1643        message_field.validate(message)
1644        message.val = SubMessage()
1645        message_field.validate(message)
1646        self.assertRaises(messages.ValidationError,
1647                          setattr, message, 'val', [SubMessage()])
1648
1649        # Test repeated.
1650        class Message(messages.Message):
1651            val = messages.MessageField(SubMessage, 1, repeated=True)
1652
1653        message = Message()
1654
1655        message_field = messages.MessageField(Message, 1)
1656        message_field.validate(message)
1657        self.assertRaisesWithRegexpMatch(
1658            messages.ValidationError,
1659            "Field val is repeated. Found: <SubMessage>",
1660            setattr, message, 'val', SubMessage())
1661        message.val = [SubMessage()]
1662        message_field.validate(message)
1663
1664    def testDefinitionName(self):
1665        """Test message name."""
1666        class MyMessage(messages.Message):
1667            pass
1668
1669        module_name = test_util.get_module_name(FieldTest)
1670        self.assertEquals('%s.MyMessage' % module_name,
1671                          MyMessage.definition_name())
1672        self.assertEquals(module_name, MyMessage.outer_definition_name())
1673        self.assertEquals(module_name, MyMessage.definition_package())
1674
1675        self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1676        self.assertEquals(six.text_type, type(
1677            MyMessage.outer_definition_name()))
1678        self.assertEquals(six.text_type, type(MyMessage.definition_package()))
1679
1680    def testDefinitionName_OverrideModule(self):
1681        """Test message module is overriden by module package name."""
1682        class MyMessage(messages.Message):
1683            pass
1684
1685        global package
1686        package = 'my.package'
1687
1688        try:
1689            self.assertEquals('my.package.MyMessage',
1690                              MyMessage.definition_name())
1691            self.assertEquals('my.package', MyMessage.outer_definition_name())
1692            self.assertEquals('my.package', MyMessage.definition_package())
1693
1694            self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1695            self.assertEquals(six.text_type, type(
1696                MyMessage.outer_definition_name()))
1697            self.assertEquals(six.text_type, type(
1698                MyMessage.definition_package()))
1699        finally:
1700            del package
1701
1702    def testDefinitionName_NoModule(self):
1703        """Test what happens when there is no module for message."""
1704        class MyMessage(messages.Message):
1705            pass
1706
1707        original_modules = sys.modules
1708        sys.modules = dict(sys.modules)
1709        try:
1710            del sys.modules[__name__]
1711            self.assertEquals('MyMessage', MyMessage.definition_name())
1712            self.assertEquals(None, MyMessage.outer_definition_name())
1713            self.assertEquals(None, MyMessage.definition_package())
1714
1715            self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1716        finally:
1717            sys.modules = original_modules
1718
1719    def testDefinitionName_Nested(self):
1720        """Test nested message names."""
1721        class MyMessage(messages.Message):
1722
1723            class NestedMessage(messages.Message):
1724
1725                class NestedMessage(messages.Message):
1726
1727                    pass
1728
1729        module_name = test_util.get_module_name(MessageTest)
1730        self.assertEquals('%s.MyMessage.NestedMessage' % module_name,
1731                          MyMessage.NestedMessage.definition_name())
1732        self.assertEquals('%s.MyMessage' % module_name,
1733                          MyMessage.NestedMessage.outer_definition_name())
1734        self.assertEquals(module_name,
1735                          MyMessage.NestedMessage.definition_package())
1736
1737        self.assertEquals(
1738            '%s.MyMessage.NestedMessage.NestedMessage' % module_name,
1739            MyMessage.NestedMessage.NestedMessage.definition_name())
1740        self.assertEquals(
1741            '%s.MyMessage.NestedMessage' % module_name,
1742            MyMessage.NestedMessage.NestedMessage.outer_definition_name())
1743        self.assertEquals(
1744            module_name,
1745            MyMessage.NestedMessage.NestedMessage.definition_package())
1746
1747    def testMessageDefinition(self):
1748        """Test that enumeration knows its enclosing message definition."""
1749        class OuterMessage(messages.Message):
1750
1751            class InnerMessage(messages.Message):
1752                pass
1753
1754        self.assertEquals(None, OuterMessage.message_definition())
1755        self.assertEquals(OuterMessage,
1756                          OuterMessage.InnerMessage.message_definition())
1757
1758    def testConstructorKwargs(self):
1759        """Test kwargs via constructor."""
1760        class SomeMessage(messages.Message):
1761            name = messages.StringField(1)
1762            number = messages.IntegerField(2)
1763
1764        expected = SomeMessage()
1765        expected.name = 'my name'
1766        expected.number = 200
1767        self.assertEquals(expected, SomeMessage(name='my name', number=200))
1768
1769    def testConstructorNotAField(self):
1770        """Test kwargs via constructor with wrong names."""
1771        class SomeMessage(messages.Message):
1772            pass
1773
1774        self.assertRaisesWithRegexpMatch(
1775            AttributeError,
1776            ('May not assign arbitrary value does_not_exist to message '
1777             'SomeMessage'),
1778            SomeMessage,
1779            does_not_exist=10)
1780
1781    def testGetUnsetRepeatedValue(self):
1782        class SomeMessage(messages.Message):
1783            repeated = messages.IntegerField(1, repeated=True)
1784
1785        instance = SomeMessage()
1786        self.assertEquals([], instance.repeated)
1787        self.assertTrue(isinstance(instance.repeated, messages.FieldList))
1788
1789    def testCompareAutoInitializedRepeatedFields(self):
1790        class SomeMessage(messages.Message):
1791            repeated = messages.IntegerField(1, repeated=True)
1792
1793        message1 = SomeMessage(repeated=[])
1794        message2 = SomeMessage()
1795        self.assertEquals(message1, message2)
1796
1797    def testUnknownValues(self):
1798        """Test message class equality with unknown fields."""
1799        class MyMessage(messages.Message):
1800            field1 = messages.IntegerField(1)
1801
1802        message = MyMessage()
1803        self.assertEquals([], message.all_unrecognized_fields())
1804        self.assertEquals((None, None),
1805                          message.get_unrecognized_field_info('doesntexist'))
1806        self.assertEquals((None, None),
1807                          message.get_unrecognized_field_info(
1808                              'doesntexist', None, None))
1809        self.assertEquals(('defaultvalue', 'defaultwire'),
1810                          message.get_unrecognized_field_info(
1811                              'doesntexist', 'defaultvalue', 'defaultwire'))
1812        self.assertEquals((3, None),
1813                          message.get_unrecognized_field_info(
1814                              'doesntexist', value_default=3))
1815
1816        message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE)
1817        self.assertEquals(1, len(message.all_unrecognized_fields()))
1818        self.assertTrue('exists' in message.all_unrecognized_fields())
1819        self.assertEquals((9.5, messages.Variant.DOUBLE),
1820                          message.get_unrecognized_field_info('exists'))
1821        self.assertEquals((9.5, messages.Variant.DOUBLE),
1822                          message.get_unrecognized_field_info('exists', 'type',
1823                                                              1234))
1824        self.assertEquals(
1825            (1234, None),
1826            message.get_unrecognized_field_info('doesntexist', 1234))
1827
1828        message.set_unrecognized_field(
1829            'another', 'value', messages.Variant.STRING)
1830        self.assertEquals(2, len(message.all_unrecognized_fields()))
1831        self.assertTrue('exists' in message.all_unrecognized_fields())
1832        self.assertTrue('another' in message.all_unrecognized_fields())
1833        self.assertEquals((9.5, messages.Variant.DOUBLE),
1834                          message.get_unrecognized_field_info('exists'))
1835        self.assertEquals(('value', messages.Variant.STRING),
1836                          message.get_unrecognized_field_info('another'))
1837
1838        message.set_unrecognized_field('typetest1', ['list', 0, ('test',)],
1839                                       messages.Variant.STRING)
1840        self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
1841                          message.get_unrecognized_field_info('typetest1'))
1842        message.set_unrecognized_field(
1843            'typetest2', '', messages.Variant.STRING)
1844        self.assertEquals(('', messages.Variant.STRING),
1845                          message.get_unrecognized_field_info('typetest2'))
1846
1847    def testPickle(self):
1848        """Testing pickling and unpickling of Message instances."""
1849        global MyEnum
1850        global AnotherMessage
1851        global MyMessage
1852
1853        class MyEnum(messages.Enum):
1854            val1 = 1
1855            val2 = 2
1856
1857        class AnotherMessage(messages.Message):
1858            string = messages.StringField(1, repeated=True)
1859
1860        class MyMessage(messages.Message):
1861            field1 = messages.IntegerField(1)
1862            field2 = messages.EnumField(MyEnum, 2)
1863            field3 = messages.MessageField(AnotherMessage, 3)
1864
1865        message = MyMessage(field1=1, field2=MyEnum.val2,
1866                            field3=AnotherMessage(string=['a', 'b', 'c']))
1867        message.set_unrecognized_field(
1868            'exists', 'value', messages.Variant.STRING)
1869        message.set_unrecognized_field('repeated', ['list', 0, ('test',)],
1870                                       messages.Variant.STRING)
1871        unpickled = pickle.loads(pickle.dumps(message))
1872        self.assertEquals(message, unpickled)
1873        self.assertTrue(AnotherMessage.string is unpickled.field3.string.field)
1874        self.assertTrue('exists' in message.all_unrecognized_fields())
1875        self.assertEquals(('value', messages.Variant.STRING),
1876                          message.get_unrecognized_field_info('exists'))
1877        self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
1878                          message.get_unrecognized_field_info('repeated'))
1879
1880
1881class FindDefinitionTest(test_util.TestCase):
1882    """Test finding definitions relative to various definitions and modules."""
1883
1884    def setUp(self):
1885        """Set up module-space.  Starts off empty."""
1886        self.modules = {}
1887
1888    def DefineModule(self, name):
1889        """Define a module and its parents in module space.
1890
1891        Modules that are already defined in self.modules are not re-created.
1892
1893        Args:
1894          name: Fully qualified name of modules to create.
1895
1896        Returns:
1897          Deepest nested module.  For example:
1898
1899            DefineModule('a.b.c')  # Returns c.
1900        """
1901        name_path = name.split('.')
1902        full_path = []
1903        for node in name_path:
1904            full_path.append(node)
1905            full_name = '.'.join(full_path)
1906            self.modules.setdefault(full_name, types.ModuleType(full_name))
1907        return self.modules[name]
1908
1909    def DefineMessage(self, module, name, children=None, add_to_module=True):
1910        """Define a new Message class in the context of a module.
1911
1912        Used for easily describing complex Message hierarchy. Message
1913        is defined including all child definitions.
1914
1915        Args:
1916          module: Fully qualified name of module to place Message class in.
1917          name: Name of Message to define within module.
1918          children: Define any level of nesting of children
1919            definitions. To define a message, map the name to another
1920            dictionary. The dictionary can itself contain additional
1921            definitions, and so on. To map to an Enum, define the Enum
1922            class separately and map it by name.
1923          add_to_module: If True, new Message class is added to
1924            module. If False, new Message is not added.
1925
1926        """
1927        children = children or {}
1928        # Make sure module exists.
1929        module_instance = self.DefineModule(module)
1930
1931        # Recursively define all child messages.
1932        for attribute, value in children.items():
1933            if isinstance(value, dict):
1934                children[attribute] = self.DefineMessage(
1935                    module, attribute, value, False)
1936
1937        # Override default __module__ variable.
1938        children['__module__'] = module
1939
1940        # Instantiate and possibly add to module.
1941        message_class = type(name, (messages.Message,), dict(children))
1942        if add_to_module:
1943            setattr(module_instance, name, message_class)
1944        return message_class
1945
1946    # pylint:disable=unused-argument
1947    # pylint:disable=redefined-builtin
1948    def Importer(self, module, globals='', locals='', fromlist=None):
1949        """Importer function.
1950
1951        Acts like __import__. Only loads modules from self.modules.
1952        Does not try to load real modules defined elsewhere. Does not
1953        try to handle relative imports.
1954
1955        Args:
1956          module: Fully qualified name of module to load from self.modules.
1957
1958        """
1959        if fromlist is None:
1960            module = module.split('.')[0]
1961        try:
1962            return self.modules[module]
1963        except KeyError:
1964            raise ImportError()
1965    # pylint:disable=unused-argument
1966
1967    def testNoSuchModule(self):
1968        """Test searching for definitions that do no exist."""
1969        self.assertRaises(messages.DefinitionNotFoundError,
1970                          messages.find_definition,
1971                          'does.not.exist',
1972                          importer=self.Importer)
1973
1974    def testRefersToModule(self):
1975        """Test that referring to a module does not return that module."""
1976        self.DefineModule('i.am.a.module')
1977        self.assertRaises(messages.DefinitionNotFoundError,
1978                          messages.find_definition,
1979                          'i.am.a.module',
1980                          importer=self.Importer)
1981
1982    def testNoDefinition(self):
1983        """Test not finding a definition in an existing module."""
1984        self.DefineModule('i.am.a.module')
1985        self.assertRaises(messages.DefinitionNotFoundError,
1986                          messages.find_definition,
1987                          'i.am.a.module.MyMessage',
1988                          importer=self.Importer)
1989
1990    def testNotADefinition(self):
1991        """Test trying to fetch something that is not a definition."""
1992        module = self.DefineModule('i.am.a.module')
1993        setattr(module, 'A', 'a string')
1994        self.assertRaises(messages.DefinitionNotFoundError,
1995                          messages.find_definition,
1996                          'i.am.a.module.A',
1997                          importer=self.Importer)
1998
1999    def testGlobalFind(self):
2000        """Test finding definitions from fully qualified module names."""
2001        A = self.DefineMessage('a.b.c', 'A', {})
2002        self.assertEquals(A, messages.find_definition('a.b.c.A',
2003                                                      importer=self.Importer))
2004        B = self.DefineMessage('a.b.c', 'B', {'C': {}})
2005        self.assertEquals(
2006            B.C,
2007            messages.find_definition('a.b.c.B.C', importer=self.Importer))
2008
2009    def testRelativeToModule(self):
2010        """Test finding definitions relative to modules."""
2011        # Define modules.
2012        a = self.DefineModule('a')
2013        b = self.DefineModule('a.b')
2014        c = self.DefineModule('a.b.c')
2015
2016        # Define messages.
2017        A = self.DefineMessage('a', 'A')
2018        B = self.DefineMessage('a.b', 'B')
2019        C = self.DefineMessage('a.b.c', 'C')
2020        D = self.DefineMessage('a.b.d', 'D')
2021
2022        # Find A, B, C and D relative to a.
2023        self.assertEquals(A, messages.find_definition(
2024            'A', a, importer=self.Importer))
2025        self.assertEquals(B, messages.find_definition(
2026            'b.B', a, importer=self.Importer))
2027        self.assertEquals(C, messages.find_definition(
2028            'b.c.C', a, importer=self.Importer))
2029        self.assertEquals(D, messages.find_definition(
2030            'b.d.D', a, importer=self.Importer))
2031
2032        # Find A, B, C and D relative to b.
2033        self.assertEquals(A, messages.find_definition(
2034            'A', b, importer=self.Importer))
2035        self.assertEquals(B, messages.find_definition(
2036            'B', b, importer=self.Importer))
2037        self.assertEquals(C, messages.find_definition(
2038            'c.C', b, importer=self.Importer))
2039        self.assertEquals(D, messages.find_definition(
2040            'd.D', b, importer=self.Importer))
2041
2042        # Find A, B, C and D relative to c.  Module d is the same case as c.
2043        self.assertEquals(A, messages.find_definition(
2044            'A', c, importer=self.Importer))
2045        self.assertEquals(B, messages.find_definition(
2046            'B', c, importer=self.Importer))
2047        self.assertEquals(C, messages.find_definition(
2048            'C', c, importer=self.Importer))
2049        self.assertEquals(D, messages.find_definition(
2050            'd.D', c, importer=self.Importer))
2051
2052    def testRelativeToMessages(self):
2053        """Test finding definitions relative to Message definitions."""
2054        A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}})
2055        B = A.B
2056        C = A.B.C
2057        D = A.B.D
2058
2059        # Find relative to A.
2060        self.assertEquals(A, messages.find_definition(
2061            'A', A, importer=self.Importer))
2062        self.assertEquals(B, messages.find_definition(
2063            'B', A, importer=self.Importer))
2064        self.assertEquals(C, messages.find_definition(
2065            'B.C', A, importer=self.Importer))
2066        self.assertEquals(D, messages.find_definition(
2067            'B.D', A, importer=self.Importer))
2068
2069        # Find relative to B.
2070        self.assertEquals(A, messages.find_definition(
2071            'A', B, importer=self.Importer))
2072        self.assertEquals(B, messages.find_definition(
2073            'B', B, importer=self.Importer))
2074        self.assertEquals(C, messages.find_definition(
2075            'C', B, importer=self.Importer))
2076        self.assertEquals(D, messages.find_definition(
2077            'D', B, importer=self.Importer))
2078
2079        # Find relative to C.
2080        self.assertEquals(A, messages.find_definition(
2081            'A', C, importer=self.Importer))
2082        self.assertEquals(B, messages.find_definition(
2083            'B', C, importer=self.Importer))
2084        self.assertEquals(C, messages.find_definition(
2085            'C', C, importer=self.Importer))
2086        self.assertEquals(D, messages.find_definition(
2087            'D', C, importer=self.Importer))
2088
2089        # Find relative to C searching from c.
2090        self.assertEquals(A, messages.find_definition(
2091            'b.A', C, importer=self.Importer))
2092        self.assertEquals(B, messages.find_definition(
2093            'b.A.B', C, importer=self.Importer))
2094        self.assertEquals(C, messages.find_definition(
2095            'b.A.B.C', C, importer=self.Importer))
2096        self.assertEquals(D, messages.find_definition(
2097            'b.A.B.D', C, importer=self.Importer))
2098
2099    def testAbsoluteReference(self):
2100        """Test finding absolute definition names."""
2101        # Define modules.
2102        a = self.DefineModule('a')
2103        b = self.DefineModule('a.a')
2104
2105        # Define messages.
2106        aA = self.DefineMessage('a', 'A')
2107        aaA = self.DefineMessage('a.a', 'A')
2108
2109        # Always find a.A.
2110        self.assertEquals(aA, messages.find_definition('.a.A', None,
2111                                                       importer=self.Importer))
2112        self.assertEquals(aA, messages.find_definition('.a.A', a,
2113                                                       importer=self.Importer))
2114        self.assertEquals(aA, messages.find_definition('.a.A', aA,
2115                                                       importer=self.Importer))
2116        self.assertEquals(aA, messages.find_definition('.a.A', aaA,
2117                                                       importer=self.Importer))
2118
2119    def testFindEnum(self):
2120        """Test that Enums are found."""
2121        class Color(messages.Enum):
2122            pass
2123        A = self.DefineMessage('a', 'A', {'Color': Color})
2124
2125        self.assertEquals(
2126            Color,
2127            messages.find_definition('Color', A, importer=self.Importer))
2128
2129    def testFalseScope(self):
2130        """Test Message definitions nested in strange objects are hidden."""
2131        global X
2132
2133        class X(object):
2134
2135            class A(messages.Message):
2136                pass
2137
2138        self.assertRaises(TypeError, messages.find_definition, 'A', X)
2139        self.assertRaises(messages.DefinitionNotFoundError,
2140                          messages.find_definition,
2141                          'X.A', sys.modules[__name__])
2142
2143    def testSearchAttributeFirst(self):
2144        """Make sure not faked out by module, but continues searching."""
2145        A = self.DefineMessage('a', 'A')
2146        module_A = self.DefineModule('a.A')
2147
2148        self.assertEquals(A, messages.find_definition(
2149            'a.A', None, importer=self.Importer))
2150
2151
2152def main():
2153    unittest.main()
2154
2155
2156if __name__ == '__main__':
2157    main()
2158