• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for protorpc.messages."""
19import six
20
21__author__ = 'rafek@google.com (Rafe Kaplan)'
22
23
24import pickle
25import re
26import sys
27import types
28import unittest
29
30from protorpc import descriptor
31from protorpc import message_types
32from protorpc import messages
33from protorpc import test_util
34
35
36class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
37                          test_util.TestCase):
38
39  MODULE = messages
40
41
42class ValidationErrorTest(test_util.TestCase):
43
44  def testStr_NoFieldName(self):
45    """Test string version of ValidationError when no name provided."""
46    self.assertEquals('Validation error',
47                      str(messages.ValidationError('Validation error')))
48
49  def testStr_FieldName(self):
50    """Test string version of ValidationError when no name provided."""
51    validation_error = messages.ValidationError('Validation error')
52    validation_error.field_name = 'a_field'
53    self.assertEquals('Validation error', str(validation_error))
54
55
56class EnumTest(test_util.TestCase):
57
58  def setUp(self):
59    """Set up tests."""
60    # Redefine Color class in case so that changes to it (an error) in one test
61    # does not affect other tests.
62    global Color
63    class Color(messages.Enum):
64      RED = 20
65      ORANGE = 2
66      YELLOW = 40
67      GREEN = 4
68      BLUE = 50
69      INDIGO = 5
70      VIOLET = 80
71
72  def testNames(self):
73    """Test that names iterates over enum names."""
74    self.assertEquals(
75        set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']),
76        set(Color.names()))
77
78  def testNumbers(self):
79    """Tests that numbers iterates of enum numbers."""
80    self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers()))
81
82  def testIterate(self):
83    """Test that __iter__ iterates over all enum values."""
84    self.assertEquals(set(Color),
85                      set([Color.RED,
86                           Color.ORANGE,
87                           Color.YELLOW,
88                           Color.GREEN,
89                           Color.BLUE,
90                           Color.INDIGO,
91                           Color.VIOLET]))
92
93  def testNaturalOrder(self):
94    """Test that natural order enumeration is in numeric order."""
95    self.assertEquals([Color.ORANGE,
96                       Color.GREEN,
97                       Color.INDIGO,
98                       Color.RED,
99                       Color.YELLOW,
100                       Color.BLUE,
101                       Color.VIOLET],
102                      sorted(Color))
103
104  def testByName(self):
105    """Test look-up by name."""
106    self.assertEquals(Color.RED, Color.lookup_by_name('RED'))
107    self.assertRaises(KeyError, Color.lookup_by_name, 20)
108    self.assertRaises(KeyError, Color.lookup_by_name, Color.RED)
109
110  def testByNumber(self):
111    """Test look-up by number."""
112    self.assertRaises(KeyError, Color.lookup_by_number, 'RED')
113    self.assertEquals(Color.RED, Color.lookup_by_number(20))
114    self.assertRaises(KeyError, Color.lookup_by_number, Color.RED)
115
116  def testConstructor(self):
117    """Test that constructor look-up by name or number."""
118    self.assertEquals(Color.RED, Color('RED'))
119    self.assertEquals(Color.RED, Color(u'RED'))
120    self.assertEquals(Color.RED, Color(20))
121    if six.PY2:
122        self.assertEquals(Color.RED, Color(long(20)))
123    self.assertEquals(Color.RED, Color(Color.RED))
124    self.assertRaises(TypeError, Color, 'Not exists')
125    self.assertRaises(TypeError, Color, 'Red')
126    self.assertRaises(TypeError, Color, 100)
127    self.assertRaises(TypeError, Color, 10.0)
128
129  def testLen(self):
130    """Test that len function works to count enums."""
131    self.assertEquals(7, len(Color))
132
133  def testNoSubclasses(self):
134    """Test that it is not possible to sub-class enum classes."""
135    def declare_subclass():
136      class MoreColor(Color):
137        pass
138    self.assertRaises(messages.EnumDefinitionError,
139                      declare_subclass)
140
141  def testClassNotMutable(self):
142    """Test that enum classes themselves are not mutable."""
143    self.assertRaises(AttributeError,
144                      setattr,
145                      Color,
146                      'something_new',
147                      10)
148
149  def testInstancesMutable(self):
150    """Test that enum instances are not mutable."""
151    self.assertRaises(TypeError,
152                      setattr,
153                      Color.RED,
154                      'something_new',
155                      10)
156
157  def testDefEnum(self):
158    """Test def_enum works by building enum class from dict."""
159    WeekDay = messages.Enum.def_enum({'Monday': 1,
160                                      'Tuesday': 2,
161                                      'Wednesday': 3,
162                                      'Thursday': 4,
163                                      'Friday': 6,
164                                      'Saturday': 7,
165                                      'Sunday': 8},
166                                     'WeekDay')
167    self.assertEquals('Wednesday', WeekDay(3).name)
168    self.assertEquals(6, WeekDay('Friday').number)
169    self.assertEquals(WeekDay.Sunday, WeekDay('Sunday'))
170
171  def testNonInt(self):
172    """Test that non-integer values rejection by enum def."""
173    self.assertRaises(messages.EnumDefinitionError,
174                      messages.Enum.def_enum,
175                      {'Bad': '1'},
176                      'BadEnum')
177
178  def testNegativeInt(self):
179    """Test that negative numbers rejection by enum def."""
180    self.assertRaises(messages.EnumDefinitionError,
181                      messages.Enum.def_enum,
182                      {'Bad': -1},
183                      'BadEnum')
184
185  def testLowerBound(self):
186    """Test that zero is accepted by enum def."""
187    class NotImportant(messages.Enum):
188      """Testing for value zero"""
189      VALUE = 0
190
191    self.assertEquals(0, int(NotImportant.VALUE))
192
193  def testTooLargeInt(self):
194    """Test that numbers too large are rejected."""
195    self.assertRaises(messages.EnumDefinitionError,
196                      messages.Enum.def_enum,
197                      {'Bad': (2 ** 29)},
198                      'BadEnum')
199
200  def testRepeatedInt(self):
201    """Test duplicated numbers are forbidden."""
202    self.assertRaises(messages.EnumDefinitionError,
203                      messages.Enum.def_enum,
204                      {'Ok': 1, 'Repeated': 1},
205                      'BadEnum')
206
207  def testStr(self):
208    """Test converting to string."""
209    self.assertEquals('RED', str(Color.RED))
210    self.assertEquals('ORANGE', str(Color.ORANGE))
211
212  def testInt(self):
213    """Test converting to int."""
214    self.assertEquals(20, int(Color.RED))
215    self.assertEquals(2, int(Color.ORANGE))
216
217  def testRepr(self):
218    """Test enum representation."""
219    self.assertEquals('Color(RED, 20)', repr(Color.RED))
220    self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW))
221
222  def testDocstring(self):
223    """Test that docstring is supported ok."""
224    class NotImportant(messages.Enum):
225      """I have a docstring."""
226
227      VALUE1 = 1
228
229    self.assertEquals('I have a docstring.', NotImportant.__doc__)
230
231  def testDeleteEnumValue(self):
232    """Test that enum values cannot be deleted."""
233    self.assertRaises(TypeError, delattr, Color, 'RED')
234
235  def testEnumName(self):
236    """Test enum name."""
237    module_name = test_util.get_module_name(EnumTest)
238    self.assertEquals('%s.Color' % module_name, Color.definition_name())
239    self.assertEquals(module_name, Color.outer_definition_name())
240    self.assertEquals(module_name, Color.definition_package())
241
242  def testDefinitionName_OverrideModule(self):
243    """Test enum module is overriden by module package name."""
244    global package
245    try:
246      package = 'my.package'
247      self.assertEquals('my.package.Color', Color.definition_name())
248      self.assertEquals('my.package', Color.outer_definition_name())
249      self.assertEquals('my.package', Color.definition_package())
250    finally:
251      del package
252
253  def testDefinitionName_NoModule(self):
254    """Test what happens when there is no module for enum."""
255    class Enum1(messages.Enum):
256      pass
257
258    original_modules = sys.modules
259    sys.modules = dict(sys.modules)
260    try:
261      del sys.modules[__name__]
262      self.assertEquals('Enum1', Enum1.definition_name())
263      self.assertEquals(None, Enum1.outer_definition_name())
264      self.assertEquals(None, Enum1.definition_package())
265      self.assertEquals(six.text_type, type(Enum1.definition_name()))
266    finally:
267      sys.modules = original_modules
268
269  def testDefinitionName_Nested(self):
270    """Test nested Enum names."""
271    class MyMessage(messages.Message):
272
273      class NestedEnum(messages.Enum):
274
275        pass
276
277      class NestedMessage(messages.Message):
278
279        class NestedEnum(messages.Enum):
280
281          pass
282
283    module_name = test_util.get_module_name(EnumTest)
284    self.assertEquals('%s.MyMessage.NestedEnum' % module_name,
285                      MyMessage.NestedEnum.definition_name())
286    self.assertEquals('%s.MyMessage' % module_name,
287                      MyMessage.NestedEnum.outer_definition_name())
288    self.assertEquals(module_name,
289                      MyMessage.NestedEnum.definition_package())
290
291    self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name,
292                      MyMessage.NestedMessage.NestedEnum.definition_name())
293    self.assertEquals(
294      '%s.MyMessage.NestedMessage' % module_name,
295      MyMessage.NestedMessage.NestedEnum.outer_definition_name())
296    self.assertEquals(module_name,
297                      MyMessage.NestedMessage.NestedEnum.definition_package())
298
299  def testMessageDefinition(self):
300    """Test that enumeration knows its enclosing message definition."""
301    class OuterEnum(messages.Enum):
302      pass
303
304    self.assertEquals(None, OuterEnum.message_definition())
305
306    class OuterMessage(messages.Message):
307
308      class InnerEnum(messages.Enum):
309        pass
310
311    self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition())
312
313  def testComparison(self):
314    """Test comparing various enums to different types."""
315    class Enum1(messages.Enum):
316      VAL1 = 1
317      VAL2 = 2
318
319    class Enum2(messages.Enum):
320      VAL1 = 1
321
322    self.assertEquals(Enum1.VAL1, Enum1.VAL1)
323    self.assertNotEquals(Enum1.VAL1, Enum1.VAL2)
324    self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
325    self.assertNotEquals(Enum1.VAL1, 'VAL1')
326    self.assertNotEquals(Enum1.VAL1, 1)
327    self.assertNotEquals(Enum1.VAL1, 2)
328    self.assertNotEquals(Enum1.VAL1, None)
329    self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
330
331    self.assertTrue(Enum1.VAL1 < Enum1.VAL2)
332    self.assertTrue(Enum1.VAL2 > Enum1.VAL1)
333
334    self.assertNotEquals(1, Enum2.VAL1)
335
336  def testPickle(self):
337    """Testing pickling and unpickling of Enum instances."""
338    colors = list(Color)
339    unpickled = pickle.loads(pickle.dumps(colors))
340    self.assertEquals(colors, unpickled)
341    # Unpickling shouldn't create new enum instances.
342    for i, color in enumerate(colors):
343      self.assertTrue(color is unpickled[i])
344
345
346class FieldListTest(test_util.TestCase):
347
348  def setUp(self):
349    self.integer_field = messages.IntegerField(1, repeated=True)
350
351  def testConstructor(self):
352    self.assertEquals([1, 2, 3],
353                      messages.FieldList(self.integer_field, [1, 2, 3]))
354    self.assertEquals([1, 2, 3],
355                      messages.FieldList(self.integer_field, (1, 2, 3)))
356    self.assertEquals([], messages.FieldList(self.integer_field, []))
357
358  def testNone(self):
359    self.assertRaises(TypeError, messages.FieldList, self.integer_field, None)
360
361  def testDoNotAutoConvertString(self):
362    string_field = messages.StringField(1, repeated=True)
363    self.assertRaises(messages.ValidationError,
364                      messages.FieldList, string_field, 'abc')
365
366  def testConstructorCopies(self):
367    a_list = [1, 3, 6]
368    field_list = messages.FieldList(self.integer_field, a_list)
369    self.assertFalse(a_list is field_list)
370    self.assertFalse(field_list is
371                     messages.FieldList(self.integer_field, field_list))
372
373  def testNonRepeatedField(self):
374    self.assertRaisesWithRegexpMatch(
375      messages.FieldDefinitionError,
376      'FieldList may only accept repeated fields',
377      messages.FieldList,
378      messages.IntegerField(1),
379      [])
380
381  def testConstructor_InvalidValues(self):
382    self.assertRaisesWithRegexpMatch(
383      messages.ValidationError,
384      re.escape("Expected type %r "
385                "for IntegerField, found 1 (type %r)"
386               % (six.integer_types, str)),
387      messages.FieldList, self.integer_field, ["1", "2", "3"])
388
389  def testConstructor_Scalars(self):
390    self.assertRaisesWithRegexpMatch(
391      messages.ValidationError,
392      "IntegerField is repeated. Found: 3",
393      messages.FieldList, self.integer_field, 3)
394
395    self.assertRaisesWithRegexpMatch(
396      messages.ValidationError,
397      "IntegerField is repeated. Found: <(list[_]?|sequence)iterator object",
398      messages.FieldList, self.integer_field, iter([1, 2, 3]))
399
400  def testSetSlice(self):
401    field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
402    field_list[1:3] = [10, 20]
403    self.assertEquals([1, 10, 20, 4, 5], field_list)
404
405  def testSetSlice_InvalidValues(self):
406    field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
407
408    def setslice():
409      field_list[1:3] = ['10', '20']
410
411    msg_re = re.escape("Expected type %r "
412                       "for IntegerField, found 10 (type %r)"
413                             % (six.integer_types, str))
414    self.assertRaisesWithRegexpMatch(
415      messages.ValidationError,
416      msg_re,
417      setslice)
418
419  def testSetItem(self):
420    field_list = messages.FieldList(self.integer_field, [2])
421    field_list[0] = 10
422    self.assertEquals([10], field_list)
423
424  def testSetItem_InvalidValues(self):
425    field_list = messages.FieldList(self.integer_field, [2])
426
427    def setitem():
428      field_list[0] = '10'
429    self.assertRaisesWithRegexpMatch(
430      messages.ValidationError,
431      re.escape("Expected type %r "
432                "for IntegerField, found 10 (type %r)"
433               % (six.integer_types, str)),
434      setitem)
435
436  def testAppend(self):
437    field_list = messages.FieldList(self.integer_field, [2])
438    field_list.append(10)
439    self.assertEquals([2, 10], field_list)
440
441  def testAppend_InvalidValues(self):
442    field_list = messages.FieldList(self.integer_field, [2])
443    field_list.name = 'a_field'
444
445    def append():
446      field_list.append('10')
447    self.assertRaisesWithRegexpMatch(
448      messages.ValidationError,
449      re.escape("Expected type %r "
450                "for IntegerField, found 10 (type %r)"
451               % (six.integer_types, str)),
452      append)
453
454  def testExtend(self):
455    field_list = messages.FieldList(self.integer_field, [2])
456    field_list.extend([10])
457    self.assertEquals([2, 10], field_list)
458
459  def testExtend_InvalidValues(self):
460    field_list = messages.FieldList(self.integer_field, [2])
461
462    def extend():
463      field_list.extend(['10'])
464    self.assertRaisesWithRegexpMatch(
465      messages.ValidationError,
466      re.escape("Expected type %r "
467                "for IntegerField, found 10 (type %r)"
468               % (six.integer_types, str)),
469      extend)
470
471  def testInsert(self):
472    field_list = messages.FieldList(self.integer_field, [2, 3])
473    field_list.insert(1, 10)
474    self.assertEquals([2, 10, 3], field_list)
475
476  def testInsert_InvalidValues(self):
477    field_list = messages.FieldList(self.integer_field, [2, 3])
478
479    def insert():
480      field_list.insert(1, '10')
481    self.assertRaisesWithRegexpMatch(
482      messages.ValidationError,
483      re.escape("Expected type %r "
484                "for IntegerField, found 10 (type %r)"
485               % (six.integer_types, str)),
486      insert)
487
488  def testPickle(self):
489    """Testing pickling and unpickling of disconnected FieldList instances."""
490    field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
491    unpickled = pickle.loads(pickle.dumps(field_list))
492    self.assertEquals(field_list, unpickled)
493    self.assertIsInstance(unpickled.field, messages.IntegerField)
494    self.assertEquals(1, unpickled.field.number)
495    self.assertTrue(unpickled.field.repeated)
496
497
498class FieldTest(test_util.TestCase):
499
500  def ActionOnAllFieldClasses(self, action):
501    """Test all field classes except Message and Enum.
502
503    Message and Enum require separate tests.
504
505    Args:
506      action: Callable that takes the field class as a parameter.
507    """
508    for field_class in (messages.IntegerField,
509                        messages.FloatField,
510                        messages.BooleanField,
511                        messages.BytesField,
512                        messages.StringField,
513                       ):
514      action(field_class)
515
516  def testNumberAttribute(self):
517    """Test setting the number attribute."""
518    def action(field_class):
519      # Check range.
520      self.assertRaises(messages.InvalidNumberError,
521                        field_class,
522                        0)
523      self.assertRaises(messages.InvalidNumberError,
524                        field_class,
525                        -1)
526      self.assertRaises(messages.InvalidNumberError,
527                        field_class,
528                        messages.MAX_FIELD_NUMBER + 1)
529
530      # Check reserved.
531      self.assertRaises(messages.InvalidNumberError,
532                        field_class,
533                        messages.FIRST_RESERVED_FIELD_NUMBER)
534      self.assertRaises(messages.InvalidNumberError,
535                        field_class,
536                        messages.LAST_RESERVED_FIELD_NUMBER)
537      self.assertRaises(messages.InvalidNumberError,
538                        field_class,
539                        '1')
540
541      # This one should work.
542      field_class(number=1)
543    self.ActionOnAllFieldClasses(action)
544
545  def testRequiredAndRepeated(self):
546    """Test setting the required and repeated fields."""
547    def action(field_class):
548      field_class(1, required=True)
549      field_class(1, repeated=True)
550      self.assertRaises(messages.FieldDefinitionError,
551                        field_class,
552                        1,
553                        required=True,
554                        repeated=True)
555    self.ActionOnAllFieldClasses(action)
556
557  def testInvalidVariant(self):
558    """Test field with invalid variants."""
559    def action(field_class):
560      if field_class is not message_types.DateTimeField:
561        self.assertRaises(messages.InvalidVariantError,
562                          field_class,
563                          1,
564                          variant=messages.Variant.ENUM)
565    self.ActionOnAllFieldClasses(action)
566
567  def testDefaultVariant(self):
568    """Test that default variant is used when not set."""
569    def action(field_class):
570      field = field_class(1)
571      self.assertEquals(field_class.DEFAULT_VARIANT, field.variant)
572
573    self.ActionOnAllFieldClasses(action)
574
575  def testAlternateVariant(self):
576    """Test that default variant is used when not set."""
577    field = messages.IntegerField(1, variant=messages.Variant.UINT32)
578    self.assertEquals(messages.Variant.UINT32, field.variant)
579
580  def testDefaultFields_Single(self):
581    """Test default field is correct type (single)."""
582    defaults = {messages.IntegerField: 10,
583                messages.FloatField: 1.5,
584                messages.BooleanField: False,
585                messages.BytesField: b'abc',
586                messages.StringField: u'abc',
587               }
588
589    def action(field_class):
590      field_class(1, default=defaults[field_class])
591    self.ActionOnAllFieldClasses(action)
592
593    # Run defaults test again checking for str/unicode compatiblity.
594    defaults[messages.StringField] = 'abc'
595    self.ActionOnAllFieldClasses(action)
596
597  def testStringField_BadUnicodeInDefault(self):
598    """Test binary values in string field."""
599    self.assertRaisesWithRegexpMatch(
600      messages.InvalidDefaultError,
601      r"Invalid default value for StringField:.*: "
602      r"Field encountered non-ASCII string .*: "
603      r"'ascii' codec can't decode byte 0x89 in position 0: "
604      r"ordinal not in range",
605      messages.StringField, 1, default=b'\x89')
606
607  def testDefaultFields_InvalidSingle(self):
608    """Test default field is correct type (invalid single)."""
609    def action(field_class):
610      self.assertRaises(messages.InvalidDefaultError,
611                        field_class,
612                        1,
613                        default=object())
614    self.ActionOnAllFieldClasses(action)
615
616  def testDefaultFields_InvalidRepeated(self):
617    """Test default field does not accept defaults."""
618    self.assertRaisesWithRegexpMatch(
619      messages.FieldDefinitionError,
620      'Repeated fields may not have defaults',
621      messages.StringField, 1, repeated=True, default=[1, 2, 3])
622
623  def testDefaultFields_None(self):
624    """Test none is always acceptable."""
625    def action(field_class):
626      field_class(1, default=None)
627      field_class(1, required=True, default=None)
628      field_class(1, repeated=True, default=None)
629    self.ActionOnAllFieldClasses(action)
630
631  def testDefaultFields_Enum(self):
632    """Test the default for enum fields."""
633    class Symbol(messages.Enum):
634
635      ALPHA = 1
636      BETA = 2
637      GAMMA = 3
638
639    field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA)
640
641    self.assertEquals(Symbol.ALPHA, field.default)
642
643  def testDefaultFields_EnumStringDelayedResolution(self):
644    """Test that enum fields resolve default strings."""
645    field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
646                               1,
647                               default='OPTIONAL')
648
649    self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default)
650
651  def testDefaultFields_EnumIntDelayedResolution(self):
652    """Test that enum fields resolve default integers."""
653    field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
654                               1,
655                               default=2)
656
657    self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default)
658
659  def testDefaultFields_EnumOkIfTypeKnown(self):
660    """Test that enum fields accept valid default values when type is known."""
661    field = messages.EnumField(descriptor.FieldDescriptor.Label,
662                               1,
663                               default='REPEATED')
664
665    self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default)
666
667  def testDefaultFields_EnumForceCheckIfTypeKnown(self):
668    """Test that enum fields validate default values if type is known."""
669    self.assertRaisesWithRegexpMatch(TypeError,
670                                     'No such value for NOT_A_LABEL in '
671                                     'Enum Label',
672                                     messages.EnumField,
673                                     descriptor.FieldDescriptor.Label,
674                                     1,
675                                     default='NOT_A_LABEL')
676
677  def testDefaultFields_EnumInvalidDelayedResolution(self):
678    """Test that enum fields raise errors upon delayed resolution error."""
679    field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
680                               1,
681                               default=200)
682
683    self.assertRaisesWithRegexpMatch(TypeError,
684                                     'No such value for 200 in Enum Label',
685                                     getattr,
686                                     field,
687                                     'default')
688
689  def testValidate_Valid(self):
690    """Test validation of valid values."""
691    values = {messages.IntegerField: 10,
692              messages.FloatField: 1.5,
693              messages.BooleanField: False,
694              messages.BytesField: b'abc',
695              messages.StringField: u'abc',
696             }
697    def action(field_class):
698      # Optional.
699      field = field_class(1)
700      field.validate(values[field_class])
701
702      # Required.
703      field = field_class(1, required=True)
704      field.validate(values[field_class])
705
706      # Repeated.
707      field = field_class(1, repeated=True)
708      field.validate([])
709      field.validate(())
710      field.validate([values[field_class]])
711      field.validate((values[field_class],))
712
713      # Right value, but not repeated.
714      self.assertRaises(messages.ValidationError,
715                        field.validate,
716                        values[field_class])
717      self.assertRaises(messages.ValidationError,
718                        field.validate,
719                        values[field_class])
720
721    self.ActionOnAllFieldClasses(action)
722
723  def testValidate_Invalid(self):
724    """Test validation of valid values."""
725    values = {messages.IntegerField: "10",
726              messages.FloatField: 1,
727              messages.BooleanField: 0,
728              messages.BytesField: 10.20,
729              messages.StringField: 42,
730             }
731    def action(field_class):
732      # Optional.
733      field = field_class(1)
734      self.assertRaises(messages.ValidationError,
735                        field.validate,
736                        values[field_class])
737
738      # Required.
739      field = field_class(1, required=True)
740      self.assertRaises(messages.ValidationError,
741                        field.validate,
742                        values[field_class])
743
744      # Repeated.
745      field = field_class(1, repeated=True)
746      self.assertRaises(messages.ValidationError,
747                        field.validate,
748                        [values[field_class]])
749      self.assertRaises(messages.ValidationError,
750                        field.validate,
751                        (values[field_class],))
752    self.ActionOnAllFieldClasses(action)
753
754  def testValidate_None(self):
755    """Test that None is valid for non-required fields."""
756    def action(field_class):
757      # Optional.
758      field = field_class(1)
759      field.validate(None)
760
761      # Required.
762      field = field_class(1, required=True)
763      self.assertRaisesWithRegexpMatch(messages.ValidationError,
764                                       'Required field is missing',
765                                       field.validate,
766                                       None)
767
768      # Repeated.
769      field = field_class(1, repeated=True)
770      field.validate(None)
771      self.assertRaisesWithRegexpMatch(messages.ValidationError,
772                                       'Repeated values for %s may '
773                                       'not be None' % field_class.__name__,
774                                       field.validate,
775                                       [None])
776      self.assertRaises(messages.ValidationError,
777                        field.validate,
778                        (None,))
779    self.ActionOnAllFieldClasses(action)
780
781  def testValidateElement(self):
782    """Test validation of valid values."""
783    values = {messages.IntegerField: 10,
784              messages.FloatField: 1.5,
785              messages.BooleanField: False,
786              messages.BytesField: 'abc',
787              messages.StringField: u'abc',
788             }
789    def action(field_class):
790      # Optional.
791      field = field_class(1)
792      field.validate_element(values[field_class])
793
794      # Required.
795      field = field_class(1, required=True)
796      field.validate_element(values[field_class])
797
798      # Repeated.
799      field = field_class(1, repeated=True)
800      self.assertRaises(message.VAlidationError,
801                        field.validate_element,
802                        [])
803      self.assertRaises(message.VAlidationError,
804                        field.validate_element,
805                        ())
806      field.validate_element(values[field_class])
807      field.validate_element(values[field_class])
808
809      # Right value, but repeated.
810      self.assertRaises(messages.ValidationError,
811                        field.validate_element,
812                        [values[field_class]])
813      self.assertRaises(messages.ValidationError,
814                        field.validate_element,
815                        (values[field_class],))
816
817  def testReadOnly(self):
818    """Test that objects are all read-only."""
819    def action(field_class):
820      field = field_class(10)
821      self.assertRaises(AttributeError,
822                        setattr,
823                        field,
824                        'number',
825                        20)
826      self.assertRaises(AttributeError,
827                        setattr,
828                        field,
829                        'anything_else',
830                        'whatever')
831    self.ActionOnAllFieldClasses(action)
832
833  def testMessageField(self):
834    """Test the construction of message fields."""
835    self.assertRaises(messages.FieldDefinitionError,
836                      messages.MessageField,
837                      str,
838                      10)
839
840    self.assertRaises(messages.FieldDefinitionError,
841                      messages.MessageField,
842                      messages.Message,
843                      10)
844
845    class MyMessage(messages.Message):
846      pass
847
848    field = messages.MessageField(MyMessage, 10)
849    self.assertEquals(MyMessage, field.type)
850
851  def testMessageField_ForwardReference(self):
852    """Test the construction of forward reference message fields."""
853    global MyMessage
854    global ForwardMessage
855    try:
856      class MyMessage(messages.Message):
857
858        self_reference = messages.MessageField('MyMessage', 1)
859        forward = messages.MessageField('ForwardMessage', 2)
860        nested = messages.MessageField('ForwardMessage.NestedMessage', 3)
861        inner = messages.MessageField('Inner', 4)
862
863        class Inner(messages.Message):
864
865          sibling = messages.MessageField('Sibling', 1)
866
867        class Sibling(messages.Message):
868
869          pass
870
871      class ForwardMessage(messages.Message):
872
873        class NestedMessage(messages.Message):
874
875          pass
876
877      self.assertEquals(MyMessage,
878                        MyMessage.field_by_name('self_reference').type)
879
880      self.assertEquals(ForwardMessage,
881                        MyMessage.field_by_name('forward').type)
882
883      self.assertEquals(ForwardMessage.NestedMessage,
884                        MyMessage.field_by_name('nested').type)
885
886      self.assertEquals(MyMessage.Inner,
887                        MyMessage.field_by_name('inner').type)
888
889      self.assertEquals(MyMessage.Sibling,
890                        MyMessage.Inner.field_by_name('sibling').type)
891    finally:
892      try:
893        del MyMessage
894        del ForwardMessage
895      except:
896        pass
897
898  def testMessageField_WrongType(self):
899    """Test that forward referencing the wrong type raises an error."""
900    global AnEnum
901    try:
902      class AnEnum(messages.Enum):
903        pass
904
905      class AnotherMessage(messages.Message):
906
907        a_field = messages.MessageField('AnEnum', 1)
908
909      self.assertRaises(messages.FieldDefinitionError,
910                        getattr,
911                        AnotherMessage.field_by_name('a_field'),
912                        'type')
913    finally:
914      del AnEnum
915
916  def testMessageFieldValidate(self):
917    """Test validation on message field."""
918    class MyMessage(messages.Message):
919      pass
920
921    class AnotherMessage(messages.Message):
922      pass
923
924    field = messages.MessageField(MyMessage, 10)
925    field.validate(MyMessage())
926
927    self.assertRaises(messages.ValidationError,
928                      field.validate,
929                      AnotherMessage())
930
931  def testMessageFieldMessageType(self):
932    """Test message_type property."""
933    class MyMessage(messages.Message):
934      pass
935
936    class HasMessage(messages.Message):
937      field = messages.MessageField(MyMessage, 1)
938
939    self.assertEqual(HasMessage.field.type, HasMessage.field.message_type)
940
941  def testMessageFieldValueFromMessage(self):
942    class MyMessage(messages.Message):
943      pass
944
945    class HasMessage(messages.Message):
946      field = messages.MessageField(MyMessage, 1)
947
948    instance = MyMessage()
949
950    self.assertTrue(instance is HasMessage.field.value_from_message(instance))
951
952  def testMessageFieldValueFromMessageWrongType(self):
953    class MyMessage(messages.Message):
954      pass
955
956    class HasMessage(messages.Message):
957      field = messages.MessageField(MyMessage, 1)
958
959    self.assertRaisesWithRegexpMatch(
960        messages.DecodeError,
961        'Expected type MyMessage, got int: 10',
962        HasMessage.field.value_from_message, 10)
963
964  def testMessageFieldValueToMessage(self):
965    class MyMessage(messages.Message):
966      pass
967
968    class HasMessage(messages.Message):
969      field = messages.MessageField(MyMessage, 1)
970
971    instance = MyMessage()
972
973    self.assertTrue(instance is HasMessage.field.value_to_message(instance))
974
975  def testMessageFieldValueToMessageWrongType(self):
976    class MyMessage(messages.Message):
977      pass
978
979    class MyOtherMessage(messages.Message):
980      pass
981
982    class HasMessage(messages.Message):
983      field = messages.MessageField(MyMessage, 1)
984
985    instance = MyOtherMessage()
986
987    self.assertRaisesWithRegexpMatch(
988        messages.EncodeError,
989        'Expected type MyMessage, got MyOtherMessage: <MyOtherMessage>',
990        HasMessage.field.value_to_message, instance)
991
992  def testIntegerField_AllowLong(self):
993    """Test that the integer field allows for longs."""
994    if six.PY2:
995        messages.IntegerField(10, default=long(10))
996
997  def testMessageFieldValidate_Initialized(self):
998    """Test validation on message field."""
999    class MyMessage(messages.Message):
1000      field1 = messages.IntegerField(1, required=True)
1001
1002    field = messages.MessageField(MyMessage, 10)
1003
1004    # Will validate messages where is_initialized() is False.
1005    message = MyMessage()
1006    field.validate(message)
1007    message.field1 = 20
1008    field.validate(message)
1009
1010  def testEnumField(self):
1011    """Test the construction of enum fields."""
1012    self.assertRaises(messages.FieldDefinitionError,
1013                      messages.EnumField,
1014                      str,
1015                      10)
1016
1017    self.assertRaises(messages.FieldDefinitionError,
1018                      messages.EnumField,
1019                      messages.Enum,
1020                      10)
1021
1022    class Color(messages.Enum):
1023      RED = 1
1024      GREEN = 2
1025      BLUE = 3
1026
1027    field = messages.EnumField(Color, 10)
1028    self.assertEquals(Color, field.type)
1029
1030    class Another(messages.Enum):
1031      VALUE = 1
1032
1033    self.assertRaises(messages.InvalidDefaultError,
1034                      messages.EnumField,
1035                      Color,
1036                      10,
1037                      default=Another.VALUE)
1038
1039  def testEnumField_ForwardReference(self):
1040    """Test the construction of forward reference enum fields."""
1041    global MyMessage
1042    global ForwardEnum
1043    global ForwardMessage
1044    try:
1045      class MyMessage(messages.Message):
1046
1047        forward = messages.EnumField('ForwardEnum', 1)
1048        nested = messages.EnumField('ForwardMessage.NestedEnum', 2)
1049        inner = messages.EnumField('Inner', 3)
1050
1051        class Inner(messages.Enum):
1052          pass
1053
1054      class ForwardEnum(messages.Enum):
1055        pass
1056
1057      class ForwardMessage(messages.Message):
1058
1059        class NestedEnum(messages.Enum):
1060          pass
1061
1062      self.assertEquals(ForwardEnum,
1063                        MyMessage.field_by_name('forward').type)
1064
1065      self.assertEquals(ForwardMessage.NestedEnum,
1066                        MyMessage.field_by_name('nested').type)
1067
1068      self.assertEquals(MyMessage.Inner,
1069                        MyMessage.field_by_name('inner').type)
1070    finally:
1071      try:
1072        del MyMessage
1073        del ForwardEnum
1074        del ForwardMessage
1075      except:
1076        pass
1077
1078  def testEnumField_WrongType(self):
1079    """Test that forward referencing the wrong type raises an error."""
1080    global AMessage
1081    try:
1082      class AMessage(messages.Message):
1083        pass
1084
1085      class AnotherMessage(messages.Message):
1086
1087        a_field = messages.EnumField('AMessage', 1)
1088
1089      self.assertRaises(messages.FieldDefinitionError,
1090                        getattr,
1091                        AnotherMessage.field_by_name('a_field'),
1092                        'type')
1093    finally:
1094      del AMessage
1095
1096  def testMessageDefinition(self):
1097    """Test that message definition is set on fields."""
1098    class MyMessage(messages.Message):
1099
1100      my_field = messages.StringField(1)
1101
1102    self.assertEquals(MyMessage,
1103                      MyMessage.field_by_name('my_field').message_definition())
1104
1105  def testNoneAssignment(self):
1106    """Test that assigning None does not change comparison."""
1107    class MyMessage(messages.Message):
1108
1109      my_field = messages.StringField(1)
1110
1111    m1 = MyMessage()
1112    m2 = MyMessage()
1113    m2.my_field = None
1114    self.assertEquals(m1, m2)
1115
1116  def testNonAsciiStr(self):
1117    """Test validation fails for non-ascii StringField values."""
1118    class Thing(messages.Message):
1119      string_field = messages.StringField(2)
1120
1121    thing = Thing()
1122    self.assertRaisesWithRegexpMatch(
1123      messages.ValidationError,
1124      'Field string_field encountered non-ASCII string',
1125      setattr, thing, 'string_field', test_util.BINARY)
1126
1127
1128class MessageTest(test_util.TestCase):
1129  """Tests for message class."""
1130
1131  def CreateMessageClass(self):
1132    """Creates a simple message class with 3 fields.
1133
1134    Fields are defined in alphabetical order but with conflicting numeric
1135    order.
1136    """
1137    class ComplexMessage(messages.Message):
1138      a3 = messages.IntegerField(3)
1139      b1 = messages.StringField(1)
1140      c2 = messages.StringField(2)
1141
1142    return ComplexMessage
1143
1144  def testSameNumbers(self):
1145    """Test that cannot assign two fields with same numbers."""
1146
1147    def action():
1148      class BadMessage(messages.Message):
1149        f1 = messages.IntegerField(1)
1150        f2 = messages.IntegerField(1)
1151    self.assertRaises(messages.DuplicateNumberError,
1152                      action)
1153
1154  def testStrictAssignment(self):
1155    """Tests that cannot assign to unknown or non-reserved attributes."""
1156    class SimpleMessage(messages.Message):
1157      field = messages.IntegerField(1)
1158
1159    simple_message = SimpleMessage()
1160    self.assertRaises(AttributeError,
1161                      setattr,
1162                      simple_message,
1163                      'does_not_exist',
1164                      10)
1165
1166  def testListAssignmentDoesNotCopy(self):
1167    class SimpleMessage(messages.Message):
1168      repeated = messages.IntegerField(1, repeated=True)
1169
1170    message = SimpleMessage()
1171    original = message.repeated
1172    message.repeated = []
1173    self.assertFalse(original is message.repeated)
1174
1175  def testValidate_Optional(self):
1176    """Tests validation of optional fields."""
1177    class SimpleMessage(messages.Message):
1178      non_required = messages.IntegerField(1)
1179
1180    simple_message = SimpleMessage()
1181    simple_message.check_initialized()
1182    simple_message.non_required = 10
1183    simple_message.check_initialized()
1184
1185  def testValidate_Required(self):
1186    """Tests validation of required fields."""
1187    class SimpleMessage(messages.Message):
1188      required = messages.IntegerField(1, required=True)
1189
1190    simple_message = SimpleMessage()
1191    self.assertRaises(messages.ValidationError,
1192                      simple_message.check_initialized)
1193    simple_message.required = 10
1194    simple_message.check_initialized()
1195
1196  def testValidate_Repeated(self):
1197    """Tests validation of repeated fields."""
1198    class SimpleMessage(messages.Message):
1199      repeated = messages.IntegerField(1, repeated=True)
1200
1201    simple_message = SimpleMessage()
1202
1203    # Check valid values.
1204    for valid_value in [], [10], [10, 20], (), (10,), (10, 20):
1205      simple_message.repeated = valid_value
1206      simple_message.check_initialized()
1207
1208    # Check cleared.
1209    simple_message.repeated = []
1210    simple_message.check_initialized()
1211
1212    # Check invalid values.
1213    for invalid_value in 10, ['10', '20'], [None], (None,):
1214      self.assertRaises(messages.ValidationError,
1215                        setattr, simple_message, 'repeated', invalid_value)
1216
1217  def testIsInitialized(self):
1218    """Tests is_initialized."""
1219    class SimpleMessage(messages.Message):
1220      required = messages.IntegerField(1, required=True)
1221
1222    simple_message = SimpleMessage()
1223    self.assertFalse(simple_message.is_initialized())
1224
1225    simple_message.required = 10
1226
1227    self.assertTrue(simple_message.is_initialized())
1228
1229  def testIsInitializedNestedField(self):
1230    """Tests is_initialized for nested fields."""
1231    class SimpleMessage(messages.Message):
1232      required = messages.IntegerField(1, required=True)
1233
1234    class NestedMessage(messages.Message):
1235      simple = messages.MessageField(SimpleMessage, 1)
1236
1237    simple_message = SimpleMessage()
1238    self.assertFalse(simple_message.is_initialized())
1239    nested_message = NestedMessage(simple=simple_message)
1240    self.assertFalse(nested_message.is_initialized())
1241
1242    simple_message.required = 10
1243
1244    self.assertTrue(simple_message.is_initialized())
1245    self.assertTrue(nested_message.is_initialized())
1246
1247  def testInitializeNestedFieldFromDict(self):
1248    """Tests initializing nested fields from dict."""
1249    class SimpleMessage(messages.Message):
1250      required = messages.IntegerField(1, required=True)
1251
1252    class NestedMessage(messages.Message):
1253      simple = messages.MessageField(SimpleMessage, 1)
1254
1255    class RepeatedMessage(messages.Message):
1256      simple = messages.MessageField(SimpleMessage, 1, repeated=True)
1257
1258    nested_message1 = NestedMessage(simple={'required': 10})
1259    self.assertTrue(nested_message1.is_initialized())
1260    self.assertTrue(nested_message1.simple.is_initialized())
1261
1262    nested_message2 = NestedMessage()
1263    nested_message2.simple = {'required': 10}
1264    self.assertTrue(nested_message2.is_initialized())
1265    self.assertTrue(nested_message2.simple.is_initialized())
1266
1267    repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)]
1268
1269    repeated_message1 = RepeatedMessage(simple=repeated_values)
1270    self.assertEquals(3, len(repeated_message1.simple))
1271    self.assertFalse(repeated_message1.is_initialized())
1272
1273    repeated_message1.simple[0].required = 0
1274    self.assertTrue(repeated_message1.is_initialized())
1275
1276    repeated_message2 = RepeatedMessage()
1277    repeated_message2.simple = repeated_values
1278    self.assertEquals(3, len(repeated_message2.simple))
1279    self.assertFalse(repeated_message2.is_initialized())
1280
1281    repeated_message2.simple[0].required = 0
1282    self.assertTrue(repeated_message2.is_initialized())
1283
1284  def testNestedMethodsNotAllowed(self):
1285    """Test that method definitions on Message classes are not allowed."""
1286    def action():
1287      class WithMethods(messages.Message):
1288        def not_allowed(self):
1289          pass
1290
1291    self.assertRaises(messages.MessageDefinitionError,
1292                      action)
1293
1294  def testNestedAttributesNotAllowed(self):
1295    """Test that attribute assignment on Message classes are not allowed."""
1296    def int_attribute():
1297      class WithMethods(messages.Message):
1298        not_allowed = 1
1299
1300    def string_attribute():
1301      class WithMethods(messages.Message):
1302        not_allowed = 'not allowed'
1303
1304    def enum_attribute():
1305      class WithMethods(messages.Message):
1306        not_allowed = Color.RED
1307
1308    for action in (int_attribute, string_attribute, enum_attribute):
1309      self.assertRaises(messages.MessageDefinitionError,
1310                        action)
1311
1312  def testNameIsSetOnFields(self):
1313    """Make sure name is set on fields after Message class init."""
1314    class HasNamedFields(messages.Message):
1315      field = messages.StringField(1)
1316
1317    self.assertEquals('field', HasNamedFields.field_by_number(1).name)
1318
1319  def testSubclassingMessageDisallowed(self):
1320    """Not permitted to create sub-classes of message classes."""
1321    class SuperClass(messages.Message):
1322      pass
1323
1324    def action():
1325      class SubClass(SuperClass):
1326        pass
1327
1328    self.assertRaises(messages.MessageDefinitionError,
1329                      action)
1330
1331  def testAllFields(self):
1332    """Test all_fields method."""
1333    ComplexMessage = self.CreateMessageClass()
1334    fields = list(ComplexMessage.all_fields())
1335
1336    # Order does not matter, so sort now.
1337    fields = sorted(fields, key=lambda f: f.name)
1338
1339    self.assertEquals(3, len(fields))
1340    self.assertEquals('a3', fields[0].name)
1341    self.assertEquals('b1', fields[1].name)
1342    self.assertEquals('c2', fields[2].name)
1343
1344  def testFieldByName(self):
1345    """Test getting field by name."""
1346    ComplexMessage = self.CreateMessageClass()
1347
1348    self.assertEquals(3, ComplexMessage.field_by_name('a3').number)
1349    self.assertEquals(1, ComplexMessage.field_by_name('b1').number)
1350    self.assertEquals(2, ComplexMessage.field_by_name('c2').number)
1351
1352    self.assertRaises(KeyError,
1353                      ComplexMessage.field_by_name,
1354                      'unknown')
1355
1356  def testFieldByNumber(self):
1357    """Test getting field by number."""
1358    ComplexMessage = self.CreateMessageClass()
1359
1360    self.assertEquals('a3', ComplexMessage.field_by_number(3).name)
1361    self.assertEquals('b1', ComplexMessage.field_by_number(1).name)
1362    self.assertEquals('c2', ComplexMessage.field_by_number(2).name)
1363
1364    self.assertRaises(KeyError,
1365                      ComplexMessage.field_by_number,
1366                      4)
1367
1368  def testGetAssignedValue(self):
1369    """Test getting the assigned value of a field."""
1370    class SomeMessage(messages.Message):
1371      a_value = messages.StringField(1, default=u'a default')
1372
1373    message = SomeMessage()
1374    self.assertEquals(None, message.get_assigned_value('a_value'))
1375
1376    message.a_value = u'a string'
1377    self.assertEquals(u'a string', message.get_assigned_value('a_value'))
1378
1379    message.a_value = u'a default'
1380    self.assertEquals(u'a default', message.get_assigned_value('a_value'))
1381
1382    self.assertRaisesWithRegexpMatch(
1383        AttributeError,
1384        'Message SomeMessage has no field no_such_field',
1385        message.get_assigned_value,
1386        'no_such_field')
1387
1388  def testReset(self):
1389    """Test resetting a field value."""
1390    class SomeMessage(messages.Message):
1391      a_value = messages.StringField(1, default=u'a default')
1392      repeated = messages.IntegerField(2, repeated=True)
1393
1394    message = SomeMessage()
1395
1396    self.assertRaises(AttributeError, message.reset, 'unknown')
1397
1398    self.assertEquals(u'a default', message.a_value)
1399    message.reset('a_value')
1400    self.assertEquals(u'a default', message.a_value)
1401
1402    message.a_value = u'a new value'
1403    self.assertEquals(u'a new value', message.a_value)
1404    message.reset('a_value')
1405    self.assertEquals(u'a default', message.a_value)
1406
1407    message.repeated = [1, 2, 3]
1408    self.assertEquals([1, 2, 3], message.repeated)
1409    saved = message.repeated
1410    message.reset('repeated')
1411    self.assertEquals([], message.repeated)
1412    self.assertIsInstance(message.repeated, messages.FieldList)
1413    self.assertEquals([1, 2, 3], saved)
1414
1415  def testAllowNestedEnums(self):
1416    """Test allowing nested enums in a message definition."""
1417    class Trade(messages.Message):
1418      class Duration(messages.Enum):
1419        GTC = 1
1420        DAY = 2
1421
1422      class Currency(messages.Enum):
1423        USD = 1
1424        GBP = 2
1425        INR = 3
1426
1427    # Sorted by name order seems to be the only feasible option.
1428    self.assertEquals(['Currency', 'Duration'], Trade.__enums__)
1429
1430    # Message definition will now be set on Enumerated objects.
1431    self.assertEquals(Trade, Trade.Duration.message_definition())
1432
1433  def testAllowNestedMessages(self):
1434    """Test allowing nested messages in a message definition."""
1435    class Trade(messages.Message):
1436      class Lot(messages.Message):
1437        pass
1438
1439      class Agent(messages.Message):
1440        pass
1441
1442    # Sorted by name order seems to be the only feasible option.
1443    self.assertEquals(['Agent', 'Lot'], Trade.__messages__)
1444    self.assertEquals(Trade, Trade.Agent.message_definition())
1445    self.assertEquals(Trade, Trade.Lot.message_definition())
1446
1447    # But not Message itself.
1448    def action():
1449      class Trade(messages.Message):
1450        NiceTry = messages.Message
1451    self.assertRaises(messages.MessageDefinitionError, action)
1452
1453  def testDisallowClassAssignments(self):
1454    """Test setting class attributes may not happen."""
1455    class MyMessage(messages.Message):
1456      pass
1457
1458    self.assertRaises(AttributeError,
1459                      setattr,
1460                      MyMessage,
1461                      'x',
1462                      'do not assign')
1463
1464  def testEquality(self):
1465    """Test message class equality."""
1466    # Comparison against enums must work.
1467    class MyEnum(messages.Enum):
1468      val1 = 1
1469      val2 = 2
1470
1471    # Comparisons against nested messages must work.
1472    class AnotherMessage(messages.Message):
1473      string = messages.StringField(1)
1474
1475    class MyMessage(messages.Message):
1476      field1 = messages.IntegerField(1)
1477      field2 = messages.EnumField(MyEnum, 2)
1478      field3 = messages.MessageField(AnotherMessage, 3)
1479
1480    message1 = MyMessage()
1481
1482    self.assertNotEquals('hi', message1)
1483    self.assertNotEquals(AnotherMessage(), message1)
1484    self.assertEquals(message1, message1)
1485
1486    message2 = MyMessage()
1487
1488    self.assertEquals(message1, message2)
1489
1490    message1.field1 = 10
1491    self.assertNotEquals(message1, message2)
1492
1493    message2.field1 = 20
1494    self.assertNotEquals(message1, message2)
1495
1496    message2.field1 = 10
1497    self.assertEquals(message1, message2)
1498
1499    message1.field2 = MyEnum.val1
1500    self.assertNotEquals(message1, message2)
1501
1502    message2.field2 = MyEnum.val2
1503    self.assertNotEquals(message1, message2)
1504
1505    message2.field2 = MyEnum.val1
1506    self.assertEquals(message1, message2)
1507
1508    message1.field3 = AnotherMessage()
1509    message1.field3.string = 'value1'
1510    self.assertNotEquals(message1, message2)
1511
1512    message2.field3 = AnotherMessage()
1513    message2.field3.string = 'value2'
1514    self.assertNotEquals(message1, message2)
1515
1516    message2.field3.string = 'value1'
1517    self.assertEquals(message1, message2)
1518
1519  def testEqualityWithUnknowns(self):
1520    """Test message class equality with unknown fields."""
1521
1522    class MyMessage(messages.Message):
1523      field1 = messages.IntegerField(1)
1524
1525    message1 = MyMessage()
1526    message2 = MyMessage()
1527    self.assertEquals(message1, message2)
1528    message1.set_unrecognized_field('unknown1', 'value1',
1529                                    messages.Variant.STRING)
1530    self.assertEquals(message1, message2)
1531
1532    message1.set_unrecognized_field('unknown2', ['asdf', 3],
1533                                    messages.Variant.STRING)
1534    message1.set_unrecognized_field('unknown3', 4.7,
1535                                    messages.Variant.DOUBLE)
1536    self.assertEquals(message1, message2)
1537
1538  def testUnrecognizedFieldInvalidVariant(self):
1539    class MyMessage(messages.Message):
1540      field1 = messages.IntegerField(1)
1541
1542    message1 = MyMessage()
1543    self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
1544                      {'unhandled': 'type'}, None)
1545    self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
1546                      {'unhandled': 'type'}, 123)
1547
1548  def testRepr(self):
1549    """Test represtation of Message object."""
1550    class MyMessage(messages.Message):
1551      integer_value = messages.IntegerField(1)
1552      string_value = messages.StringField(2)
1553      unassigned = messages.StringField(3)
1554      unassigned_with_default = messages.StringField(4, default=u'a default')
1555
1556    my_message = MyMessage()
1557    my_message.integer_value = 42
1558    my_message.string_value = u'A string'
1559
1560    pat = re.compile(r"<MyMessage\n integer_value: 42\n"
1561                      " string_value: [u]?'A string'>")
1562    self.assertTrue(pat.match(repr(my_message)) is not None)
1563
1564  def testValidation(self):
1565    """Test validation of message values."""
1566    # Test optional.
1567    class SubMessage(messages.Message):
1568      pass
1569
1570    class Message(messages.Message):
1571      val = messages.MessageField(SubMessage, 1)
1572
1573    message = Message()
1574
1575    message_field = messages.MessageField(Message, 1)
1576    message_field.validate(message)
1577    message.val = SubMessage()
1578    message_field.validate(message)
1579    self.assertRaises(messages.ValidationError,
1580                      setattr, message, 'val', [SubMessage()])
1581
1582    # Test required.
1583    class Message(messages.Message):
1584      val = messages.MessageField(SubMessage, 1, required=True)
1585
1586    message = Message()
1587
1588    message_field = messages.MessageField(Message, 1)
1589    message_field.validate(message)
1590    message.val = SubMessage()
1591    message_field.validate(message)
1592    self.assertRaises(messages.ValidationError,
1593                      setattr, message, 'val', [SubMessage()])
1594
1595    # Test repeated.
1596    class Message(messages.Message):
1597      val = messages.MessageField(SubMessage, 1, repeated=True)
1598
1599    message = Message()
1600
1601    message_field = messages.MessageField(Message, 1)
1602    message_field.validate(message)
1603    self.assertRaisesWithRegexpMatch(
1604      messages.ValidationError,
1605      "Field val is repeated. Found: <SubMessage>",
1606      setattr, message, 'val', SubMessage())
1607    message.val = [SubMessage()]
1608    message_field.validate(message)
1609
1610  def testDefinitionName(self):
1611    """Test message name."""
1612    class MyMessage(messages.Message):
1613      pass
1614
1615    module_name = test_util.get_module_name(FieldTest)
1616    self.assertEquals('%s.MyMessage' % module_name,
1617                      MyMessage.definition_name())
1618    self.assertEquals(module_name, MyMessage.outer_definition_name())
1619    self.assertEquals(module_name, MyMessage.definition_package())
1620
1621    self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1622    self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
1623    self.assertEquals(six.text_type, type(MyMessage.definition_package()))
1624
1625  def testDefinitionName_OverrideModule(self):
1626    """Test message module is overriden by module package name."""
1627    class MyMessage(messages.Message):
1628      pass
1629
1630    global package
1631    package = 'my.package'
1632
1633    try:
1634      self.assertEquals('my.package.MyMessage', MyMessage.definition_name())
1635      self.assertEquals('my.package', MyMessage.outer_definition_name())
1636      self.assertEquals('my.package', MyMessage.definition_package())
1637
1638      self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1639      self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
1640      self.assertEquals(six.text_type, type(MyMessage.definition_package()))
1641    finally:
1642      del package
1643
1644  def testDefinitionName_NoModule(self):
1645    """Test what happens when there is no module for message."""
1646    class MyMessage(messages.Message):
1647      pass
1648
1649    original_modules = sys.modules
1650    sys.modules = dict(sys.modules)
1651    try:
1652      del sys.modules[__name__]
1653      self.assertEquals('MyMessage', MyMessage.definition_name())
1654      self.assertEquals(None, MyMessage.outer_definition_name())
1655      self.assertEquals(None, MyMessage.definition_package())
1656
1657      self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1658    finally:
1659      sys.modules = original_modules
1660
1661  def testDefinitionName_Nested(self):
1662    """Test nested message names."""
1663    class MyMessage(messages.Message):
1664
1665      class NestedMessage(messages.Message):
1666
1667        class NestedMessage(messages.Message):
1668
1669          pass
1670
1671    module_name = test_util.get_module_name(MessageTest)
1672    self.assertEquals('%s.MyMessage.NestedMessage' % module_name,
1673                      MyMessage.NestedMessage.definition_name())
1674    self.assertEquals('%s.MyMessage' % module_name,
1675                      MyMessage.NestedMessage.outer_definition_name())
1676    self.assertEquals(module_name,
1677                      MyMessage.NestedMessage.definition_package())
1678
1679    self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name,
1680                      MyMessage.NestedMessage.NestedMessage.definition_name())
1681    self.assertEquals(
1682      '%s.MyMessage.NestedMessage' % module_name,
1683      MyMessage.NestedMessage.NestedMessage.outer_definition_name())
1684    self.assertEquals(
1685      module_name,
1686      MyMessage.NestedMessage.NestedMessage.definition_package())
1687
1688
1689  def testMessageDefinition(self):
1690    """Test that enumeration knows its enclosing message definition."""
1691    class OuterMessage(messages.Message):
1692
1693      class InnerMessage(messages.Message):
1694        pass
1695
1696    self.assertEquals(None, OuterMessage.message_definition())
1697    self.assertEquals(OuterMessage,
1698                      OuterMessage.InnerMessage.message_definition())
1699
1700  def testConstructorKwargs(self):
1701    """Test kwargs via constructor."""
1702    class SomeMessage(messages.Message):
1703      name = messages.StringField(1)
1704      number = messages.IntegerField(2)
1705
1706    expected = SomeMessage()
1707    expected.name = 'my name'
1708    expected.number = 200
1709    self.assertEquals(expected, SomeMessage(name='my name', number=200))
1710
1711  def testConstructorNotAField(self):
1712    """Test kwargs via constructor with wrong names."""
1713    class SomeMessage(messages.Message):
1714      pass
1715
1716    self.assertRaisesWithRegexpMatch(
1717      AttributeError,
1718      'May not assign arbitrary value does_not_exist to message SomeMessage',
1719      SomeMessage,
1720      does_not_exist=10)
1721
1722  def testGetUnsetRepeatedValue(self):
1723    class SomeMessage(messages.Message):
1724      repeated = messages.IntegerField(1, repeated=True)
1725
1726    instance = SomeMessage()
1727    self.assertEquals([], instance.repeated)
1728    self.assertTrue(isinstance(instance.repeated, messages.FieldList))
1729
1730  def testCompareAutoInitializedRepeatedFields(self):
1731    class SomeMessage(messages.Message):
1732      repeated = messages.IntegerField(1, repeated=True)
1733
1734    message1 = SomeMessage(repeated=[])
1735    message2 = SomeMessage()
1736    self.assertEquals(message1, message2)
1737
1738  def testUnknownValues(self):
1739    """Test message class equality with unknown fields."""
1740    class MyMessage(messages.Message):
1741      field1 = messages.IntegerField(1)
1742
1743    message = MyMessage()
1744    self.assertEquals([], message.all_unrecognized_fields())
1745    self.assertEquals((None, None),
1746                      message.get_unrecognized_field_info('doesntexist'))
1747    self.assertEquals((None, None),
1748                      message.get_unrecognized_field_info(
1749                          'doesntexist', None, None))
1750    self.assertEquals(('defaultvalue', 'defaultwire'),
1751                      message.get_unrecognized_field_info(
1752                          'doesntexist', 'defaultvalue', 'defaultwire'))
1753    self.assertEquals((3, None),
1754                      message.get_unrecognized_field_info(
1755                          'doesntexist', value_default=3))
1756
1757    message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE)
1758    self.assertEquals(1, len(message.all_unrecognized_fields()))
1759    self.assertTrue('exists' in message.all_unrecognized_fields())
1760    self.assertEquals((9.5, messages.Variant.DOUBLE),
1761                      message.get_unrecognized_field_info('exists'))
1762    self.assertEquals((9.5, messages.Variant.DOUBLE),
1763                      message.get_unrecognized_field_info('exists', 'type',
1764                                                          1234))
1765    self.assertEquals((1234, None),
1766                      message.get_unrecognized_field_info('doesntexist', 1234))
1767
1768    message.set_unrecognized_field('another', 'value', messages.Variant.STRING)
1769    self.assertEquals(2, len(message.all_unrecognized_fields()))
1770    self.assertTrue('exists' in message.all_unrecognized_fields())
1771    self.assertTrue('another' in message.all_unrecognized_fields())
1772    self.assertEquals((9.5, messages.Variant.DOUBLE),
1773                      message.get_unrecognized_field_info('exists'))
1774    self.assertEquals(('value', messages.Variant.STRING),
1775                      message.get_unrecognized_field_info('another'))
1776
1777    message.set_unrecognized_field('typetest1', ['list', 0, ('test',)],
1778                                   messages.Variant.STRING)
1779    self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
1780                      message.get_unrecognized_field_info('typetest1'))
1781    message.set_unrecognized_field('typetest2', '', messages.Variant.STRING)
1782    self.assertEquals(('', messages.Variant.STRING),
1783                      message.get_unrecognized_field_info('typetest2'))
1784
1785  def testPickle(self):
1786    """Testing pickling and unpickling of Message instances."""
1787    global MyEnum
1788    global AnotherMessage
1789    global MyMessage
1790
1791    class MyEnum(messages.Enum):
1792      val1 = 1
1793      val2 = 2
1794
1795    class AnotherMessage(messages.Message):
1796      string = messages.StringField(1, repeated=True)
1797
1798    class MyMessage(messages.Message):
1799      field1 = messages.IntegerField(1)
1800      field2 = messages.EnumField(MyEnum, 2)
1801      field3 = messages.MessageField(AnotherMessage, 3)
1802
1803    message = MyMessage(field1=1, field2=MyEnum.val2,
1804                        field3=AnotherMessage(string=['a', 'b', 'c']))
1805    message.set_unrecognized_field('exists', 'value', messages.Variant.STRING)
1806    message.set_unrecognized_field('repeated', ['list', 0, ('test',)],
1807                                   messages.Variant.STRING)
1808    unpickled = pickle.loads(pickle.dumps(message))
1809    self.assertEquals(message, unpickled)
1810    self.assertTrue(AnotherMessage.string is unpickled.field3.string.field)
1811    self.assertTrue('exists' in message.all_unrecognized_fields())
1812    self.assertEquals(('value', messages.Variant.STRING),
1813                      message.get_unrecognized_field_info('exists'))
1814    self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
1815                      message.get_unrecognized_field_info('repeated'))
1816
1817
1818class FindDefinitionTest(test_util.TestCase):
1819  """Test finding definitions relative to various definitions and modules."""
1820
1821  def setUp(self):
1822    """Set up module-space.  Starts off empty."""
1823    self.modules = {}
1824
1825  def DefineModule(self, name):
1826    """Define a module and its parents in module space.
1827
1828    Modules that are already defined in self.modules are not re-created.
1829
1830    Args:
1831      name: Fully qualified name of modules to create.
1832
1833    Returns:
1834      Deepest nested module.  For example:
1835
1836        DefineModule('a.b.c')  # Returns c.
1837    """
1838    name_path = name.split('.')
1839    full_path = []
1840    for node in name_path:
1841      full_path.append(node)
1842      full_name = '.'.join(full_path)
1843      self.modules.setdefault(full_name, types.ModuleType(full_name))
1844    return self.modules[name]
1845
1846  def DefineMessage(self, module, name, children={}, add_to_module=True):
1847    """Define a new Message class in the context of a module.
1848
1849    Used for easily describing complex Message hierarchy.  Message is defined
1850    including all child definitions.
1851
1852    Args:
1853      module: Fully qualified name of module to place Message class in.
1854      name: Name of Message to define within module.
1855      children: Define any level of nesting of children definitions.  To define
1856        a message, map the name to another dictionary.  The dictionary can
1857        itself contain additional definitions, and so on.  To map to an Enum,
1858        define the Enum class separately and map it by name.
1859      add_to_module: If True, new Message class is added to module.  If False,
1860        new Message is not added.
1861    """
1862    # Make sure module exists.
1863    module_instance = self.DefineModule(module)
1864
1865    # Recursively define all child messages.
1866    for attribute, value in children.items():
1867      if isinstance(value, dict):
1868        children[attribute] = self.DefineMessage(
1869            module, attribute, value, False)
1870
1871    # Override default __module__ variable.
1872    children['__module__'] = module
1873
1874    # Instantiate and possibly add to module.
1875    message_class = type(name, (messages.Message,), dict(children))
1876    if add_to_module:
1877      setattr(module_instance, name, message_class)
1878    return message_class
1879
1880  def Importer(self, module, globals='', locals='', fromlist=None):
1881    """Importer function.
1882
1883    Acts like __import__.  Only loads modules from self.modules.  Does not
1884    try to load real modules defined elsewhere.  Does not try to handle relative
1885    imports.
1886
1887    Args:
1888      module: Fully qualified name of module to load from self.modules.
1889    """
1890    if fromlist is None:
1891      module = module.split('.')[0]
1892    try:
1893      return self.modules[module]
1894    except KeyError:
1895      raise ImportError()
1896
1897  def testNoSuchModule(self):
1898    """Test searching for definitions that do no exist."""
1899    self.assertRaises(messages.DefinitionNotFoundError,
1900                      messages.find_definition,
1901                      'does.not.exist',
1902                      importer=self.Importer)
1903
1904  def testRefersToModule(self):
1905    """Test that referring to a module does not return that module."""
1906    self.DefineModule('i.am.a.module')
1907    self.assertRaises(messages.DefinitionNotFoundError,
1908                      messages.find_definition,
1909                      'i.am.a.module',
1910                      importer=self.Importer)
1911
1912  def testNoDefinition(self):
1913    """Test not finding a definition in an existing module."""
1914    self.DefineModule('i.am.a.module')
1915    self.assertRaises(messages.DefinitionNotFoundError,
1916                      messages.find_definition,
1917                      'i.am.a.module.MyMessage',
1918                      importer=self.Importer)
1919
1920  def testNotADefinition(self):
1921    """Test trying to fetch something that is not a definition."""
1922    module = self.DefineModule('i.am.a.module')
1923    setattr(module, 'A', 'a string')
1924    self.assertRaises(messages.DefinitionNotFoundError,
1925                      messages.find_definition,
1926                      'i.am.a.module.A',
1927                      importer=self.Importer)
1928
1929  def testGlobalFind(self):
1930    """Test finding definitions from fully qualified module names."""
1931    A = self.DefineMessage('a.b.c', 'A', {})
1932    self.assertEquals(A, messages.find_definition('a.b.c.A',
1933                                                  importer=self.Importer))
1934    B = self.DefineMessage('a.b.c', 'B', {'C':{}})
1935    self.assertEquals(B.C, messages.find_definition('a.b.c.B.C',
1936                                                    importer=self.Importer))
1937
1938  def testRelativeToModule(self):
1939    """Test finding definitions relative to modules."""
1940    # Define modules.
1941    a = self.DefineModule('a')
1942    b = self.DefineModule('a.b')
1943    c = self.DefineModule('a.b.c')
1944
1945    # Define messages.
1946    A = self.DefineMessage('a', 'A')
1947    B = self.DefineMessage('a.b', 'B')
1948    C = self.DefineMessage('a.b.c', 'C')
1949    D = self.DefineMessage('a.b.d', 'D')
1950
1951    # Find A, B, C and D relative to a.
1952    self.assertEquals(A, messages.find_definition(
1953        'A', a, importer=self.Importer))
1954    self.assertEquals(B, messages.find_definition(
1955        'b.B', a, importer=self.Importer))
1956    self.assertEquals(C, messages.find_definition(
1957        'b.c.C', a, importer=self.Importer))
1958    self.assertEquals(D, messages.find_definition(
1959        'b.d.D', a, importer=self.Importer))
1960
1961    # Find A, B, C and D relative to b.
1962    self.assertEquals(A, messages.find_definition(
1963        'A', b, importer=self.Importer))
1964    self.assertEquals(B, messages.find_definition(
1965        'B', b, importer=self.Importer))
1966    self.assertEquals(C, messages.find_definition(
1967        'c.C', b, importer=self.Importer))
1968    self.assertEquals(D, messages.find_definition(
1969        'd.D', b, importer=self.Importer))
1970
1971    # Find A, B, C and D relative to c.  Module d is the same case as c.
1972    self.assertEquals(A, messages.find_definition(
1973        'A', c, importer=self.Importer))
1974    self.assertEquals(B, messages.find_definition(
1975        'B', c, importer=self.Importer))
1976    self.assertEquals(C, messages.find_definition(
1977        'C', c, importer=self.Importer))
1978    self.assertEquals(D, messages.find_definition(
1979        'd.D', c, importer=self.Importer))
1980
1981  def testRelativeToMessages(self):
1982    """Test finding definitions relative to Message definitions."""
1983    A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}})
1984    B = A.B
1985    C = A.B.C
1986    D = A.B.D
1987
1988    # Find relative to A.
1989    self.assertEquals(A, messages.find_definition(
1990        'A', A, importer=self.Importer))
1991    self.assertEquals(B, messages.find_definition(
1992        'B', A, importer=self.Importer))
1993    self.assertEquals(C, messages.find_definition(
1994        'B.C', A, importer=self.Importer))
1995    self.assertEquals(D, messages.find_definition(
1996        'B.D', A, importer=self.Importer))
1997
1998    # Find relative to B.
1999    self.assertEquals(A, messages.find_definition(
2000        'A', B, importer=self.Importer))
2001    self.assertEquals(B, messages.find_definition(
2002        'B', B, importer=self.Importer))
2003    self.assertEquals(C, messages.find_definition(
2004        'C', B, importer=self.Importer))
2005    self.assertEquals(D, messages.find_definition(
2006        'D', B, importer=self.Importer))
2007
2008    # Find relative to C.
2009    self.assertEquals(A, messages.find_definition(
2010        'A', C, importer=self.Importer))
2011    self.assertEquals(B, messages.find_definition(
2012        'B', C, importer=self.Importer))
2013    self.assertEquals(C, messages.find_definition(
2014        'C', C, importer=self.Importer))
2015    self.assertEquals(D, messages.find_definition(
2016        'D', C, importer=self.Importer))
2017
2018    # Find relative to C searching from c.
2019    self.assertEquals(A, messages.find_definition(
2020        'b.A', C, importer=self.Importer))
2021    self.assertEquals(B, messages.find_definition(
2022        'b.A.B', C, importer=self.Importer))
2023    self.assertEquals(C, messages.find_definition(
2024        'b.A.B.C', C, importer=self.Importer))
2025    self.assertEquals(D, messages.find_definition(
2026        'b.A.B.D', C, importer=self.Importer))
2027
2028  def testAbsoluteReference(self):
2029    """Test finding absolute definition names."""
2030    # Define modules.
2031    a = self.DefineModule('a')
2032    b = self.DefineModule('a.a')
2033
2034    # Define messages.
2035    aA = self.DefineMessage('a', 'A')
2036    aaA = self.DefineMessage('a.a', 'A')
2037
2038    # Always find a.A.
2039    self.assertEquals(aA, messages.find_definition('.a.A', None,
2040                                                   importer=self.Importer))
2041    self.assertEquals(aA, messages.find_definition('.a.A', a,
2042                                                   importer=self.Importer))
2043    self.assertEquals(aA, messages.find_definition('.a.A', aA,
2044                                                   importer=self.Importer))
2045    self.assertEquals(aA, messages.find_definition('.a.A', aaA,
2046                                                   importer=self.Importer))
2047
2048  def testFindEnum(self):
2049    """Test that Enums are found."""
2050    class Color(messages.Enum):
2051      pass
2052    A = self.DefineMessage('a', 'A', {'Color': Color})
2053
2054    self.assertEquals(
2055        Color,
2056        messages.find_definition('Color', A, importer=self.Importer))
2057
2058  def testFalseScope(self):
2059    """Test that Message definitions nested in strange objects are hidden."""
2060    global X
2061    class X(object):
2062      class A(messages.Message):
2063        pass
2064
2065    self.assertRaises(TypeError, messages.find_definition, 'A', X)
2066    self.assertRaises(messages.DefinitionNotFoundError,
2067                      messages.find_definition,
2068                      'X.A', sys.modules[__name__])
2069
2070  def testSearchAttributeFirst(self):
2071    """Make sure not faked out by module, but continues searching."""
2072    A = self.DefineMessage('a', 'A')
2073    module_A = self.DefineModule('a.A')
2074
2075    self.assertEquals(A, messages.find_definition(
2076        'a.A', None, importer=self.Importer))
2077
2078
2079class FindDefinitionUnicodeTests(test_util.TestCase):
2080
2081  # TODO(craigcitro): Fix this test and re-enable it.
2082  def notatestUnicodeString(self):
2083    """Test using unicode names."""
2084    from protorpc import registry
2085    self.assertEquals('ServiceMapping',
2086                      messages.find_definition(
2087                        u'protorpc.registry.ServiceMapping',
2088                        None).__name__)
2089
2090
2091def main():
2092  unittest.main()
2093
2094
2095if __name__ == '__main__':
2096  main()
2097