• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for protorpc.protojson."""
19
20__author__ = 'rafek@google.com (Rafe Kaplan)'
21
22
23import datetime
24import imp
25import sys
26import unittest
27
28from protorpc import message_types
29from protorpc import messages
30from protorpc import protojson
31from protorpc import test_util
32
33import simplejson
34
35
36class CustomField(messages.MessageField):
37  """Custom MessageField class."""
38
39  type = int
40  message_type = message_types.VoidMessage
41
42  def __init__(self, number, **kwargs):
43    super(CustomField, self).__init__(self.message_type, number, **kwargs)
44
45  def value_to_message(self, value):
46    return self.message_type()
47
48
49class MyMessage(messages.Message):
50  """Test message containing various types."""
51
52  class Color(messages.Enum):
53
54    RED = 1
55    GREEN = 2
56    BLUE = 3
57
58  class Nested(messages.Message):
59
60    nested_value = messages.StringField(1)
61
62  a_string = messages.StringField(2)
63  an_integer = messages.IntegerField(3)
64  a_float = messages.FloatField(4)
65  a_boolean = messages.BooleanField(5)
66  an_enum = messages.EnumField(Color, 6)
67  a_nested = messages.MessageField(Nested, 7)
68  a_repeated = messages.IntegerField(8, repeated=True)
69  a_repeated_float = messages.FloatField(9, repeated=True)
70  a_datetime = message_types.DateTimeField(10)
71  a_repeated_datetime = message_types.DateTimeField(11, repeated=True)
72  a_custom = CustomField(12)
73  a_repeated_custom = CustomField(13, repeated=True)
74
75
76class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
77                          test_util.TestCase):
78
79  MODULE = protojson
80
81
82# TODO(rafek): Convert this test to the compliance test in test_util.
83class ProtojsonTest(test_util.TestCase,
84                    test_util.ProtoConformanceTestBase):
85  """Test JSON encoding and decoding."""
86
87  PROTOLIB = protojson
88
89  def CompareEncoded(self, expected_encoded, actual_encoded):
90    """JSON encoding will be laundered to remove string differences."""
91    self.assertEquals(simplejson.loads(expected_encoded),
92                      simplejson.loads(actual_encoded))
93
94  encoded_empty_message = '{}'
95
96  encoded_partial = """{
97    "double_value": 1.23,
98    "int64_value": -100000000000,
99    "int32_value": 1020,
100    "string_value": "a string",
101    "enum_value": "VAL2"
102  }
103  """
104
105  encoded_full = """{
106    "double_value": 1.23,
107    "float_value": -2.5,
108    "int64_value": -100000000000,
109    "uint64_value": 102020202020,
110    "int32_value": 1020,
111    "bool_value": true,
112    "string_value": "a string\u044f",
113    "bytes_value": "YSBieXRlc//+",
114    "enum_value": "VAL2"
115  }
116  """
117
118  encoded_repeated = """{
119    "double_value": [1.23, 2.3],
120    "float_value": [-2.5, 0.5],
121    "int64_value": [-100000000000, 20],
122    "uint64_value": [102020202020, 10],
123    "int32_value": [1020, 718],
124    "bool_value": [true, false],
125    "string_value": ["a string\u044f", "another string"],
126    "bytes_value": ["YSBieXRlc//+", "YW5vdGhlciBieXRlcw=="],
127    "enum_value": ["VAL2", "VAL1"]
128  }
129  """
130
131  encoded_nested = """{
132    "nested": {
133      "a_value": "a string"
134    }
135  }
136  """
137
138  encoded_repeated_nested = """{
139    "repeated_nested": [{"a_value": "a string"},
140                        {"a_value": "another string"}]
141  }
142  """
143
144  unexpected_tag_message = '{"unknown": "value"}'
145
146  encoded_default_assigned = '{"a_value": "a default"}'
147
148  encoded_nested_empty = '{"nested": {}}'
149
150  encoded_repeated_nested_empty = '{"repeated_nested": [{}, {}]}'
151
152  encoded_extend_message = '{"int64_value": [400, 50, 6000]}'
153
154  encoded_string_types = '{"string_value": "Latin"}'
155
156  encoded_invalid_enum = '{"enum_value": "undefined"}'
157
158  def testConvertIntegerToFloat(self):
159    """Test that integers passed in to float fields are converted.
160
161    This is necessary because JSON outputs integers for numbers with 0 decimals.
162    """
163    message = protojson.decode_message(MyMessage, '{"a_float": 10}')
164
165    self.assertTrue(isinstance(message.a_float, float))
166    self.assertEquals(10.0, message.a_float)
167
168  def testConvertStringToNumbers(self):
169    """Test that strings passed to integer fields are converted."""
170    message = protojson.decode_message(MyMessage,
171                                       """{"an_integer": "10",
172                                           "a_float": "3.5",
173                                           "a_repeated": ["1", "2"],
174                                           "a_repeated_float": ["1.5", "2", 10]
175                                           }""")
176
177    self.assertEquals(MyMessage(an_integer=10,
178                                a_float=3.5,
179                                a_repeated=[1, 2],
180                                a_repeated_float=[1.5, 2.0, 10.0]),
181                      message)
182
183  def testWrongTypeAssignment(self):
184    """Test when wrong type is assigned to a field."""
185    self.assertRaises(messages.ValidationError,
186                      protojson.decode_message,
187                      MyMessage, '{"a_string": 10}')
188    self.assertRaises(messages.ValidationError,
189                      protojson.decode_message,
190                      MyMessage, '{"an_integer": 10.2}')
191    self.assertRaises(messages.ValidationError,
192                      protojson.decode_message,
193                      MyMessage, '{"an_integer": "10.2"}')
194
195  def testNumericEnumeration(self):
196    """Test that numbers work for enum values."""
197    message = protojson.decode_message(MyMessage, '{"an_enum": 2}')
198
199    expected_message = MyMessage()
200    expected_message.an_enum = MyMessage.Color.GREEN
201
202    self.assertEquals(expected_message, message)
203
204  def testNumericEnumerationNegativeTest(self):
205    """Test with an invalid number for the enum value."""
206    self.assertRaisesRegexp(
207        messages.DecodeError,
208        'Invalid enum value "89"',
209        protojson.decode_message,
210        MyMessage,
211        '{"an_enum": 89}')
212
213  def testAlphaEnumeration(self):
214    """Test that alpha enum values work."""
215    message = protojson.decode_message(MyMessage, '{"an_enum": "RED"}')
216
217    expected_message = MyMessage()
218    expected_message.an_enum = MyMessage.Color.RED
219
220    self.assertEquals(expected_message, message)
221
222  def testAlphaEnumerationNegativeTest(self):
223    """The alpha enum value is invalid."""
224    self.assertRaisesRegexp(
225        messages.DecodeError,
226        'Invalid enum value "IAMINVALID"',
227        protojson.decode_message,
228        MyMessage,
229        '{"an_enum": "IAMINVALID"}')
230
231  def testEnumerationNegativeTestWithEmptyString(self):
232    """The enum value is an empty string."""
233    self.assertRaisesRegexp(
234        messages.DecodeError,
235        'Invalid enum value ""',
236        protojson.decode_message,
237        MyMessage,
238        '{"an_enum": ""}')
239
240  def testNullValues(self):
241    """Test that null values overwrite existing values."""
242    self.assertEquals(MyMessage(),
243                      protojson.decode_message(MyMessage,
244                                               ('{"an_integer": null,'
245                                                ' "a_nested": null,'
246                                                ' "an_enum": null'
247                                                '}')))
248
249  def testEmptyList(self):
250    """Test that empty lists are ignored."""
251    self.assertEquals(MyMessage(),
252                      protojson.decode_message(MyMessage,
253                                               '{"a_repeated": []}'))
254
255  def testNotJSON(self):
256    """Test error when string is not valid JSON."""
257    self.assertRaises(ValueError,
258                      protojson.decode_message, MyMessage, '{this is not json}')
259
260  def testDoNotEncodeStrangeObjects(self):
261    """Test trying to encode a strange object.
262
263    The main purpose of this test is to complete coverage.  It ensures that
264    the default behavior of the JSON encoder is preserved when someone tries to
265    serialized an unexpected type.
266    """
267    class BogusObject(object):
268
269      def check_initialized(self):
270        pass
271
272    self.assertRaises(TypeError,
273                      protojson.encode_message,
274                      BogusObject())
275
276  def testMergeEmptyString(self):
277    """Test merging the empty or space only string."""
278    message = protojson.decode_message(test_util.OptionalMessage, '')
279    self.assertEquals(test_util.OptionalMessage(), message)
280
281    message = protojson.decode_message(test_util.OptionalMessage, ' ')
282    self.assertEquals(test_util.OptionalMessage(), message)
283
284  def testProtojsonUnrecognizedFieldName(self):
285    """Test that unrecognized fields are saved and can be accessed."""
286    decoded = protojson.decode_message(MyMessage,
287                                       ('{"an_integer": 1, "unknown_val": 2}'))
288    self.assertEquals(decoded.an_integer, 1)
289    self.assertEquals(1, len(decoded.all_unrecognized_fields()))
290    self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0])
291    self.assertEquals((2, messages.Variant.INT64),
292                      decoded.get_unrecognized_field_info('unknown_val'))
293
294  def testProtojsonUnrecognizedFieldNumber(self):
295    """Test that unrecognized fields are saved and can be accessed."""
296    decoded = protojson.decode_message(
297        MyMessage,
298        '{"an_integer": 1, "1001": "unknown", "-123": "negative", '
299        '"456_mixed": 2}')
300    self.assertEquals(decoded.an_integer, 1)
301    self.assertEquals(3, len(decoded.all_unrecognized_fields()))
302    self.assertTrue(1001 in decoded.all_unrecognized_fields())
303    self.assertEquals(('unknown', messages.Variant.STRING),
304                      decoded.get_unrecognized_field_info(1001))
305    self.assertTrue('-123' in decoded.all_unrecognized_fields())
306    self.assertEquals(('negative', messages.Variant.STRING),
307                      decoded.get_unrecognized_field_info('-123'))
308    self.assertTrue('456_mixed' in decoded.all_unrecognized_fields())
309    self.assertEquals((2, messages.Variant.INT64),
310                      decoded.get_unrecognized_field_info('456_mixed'))
311
312  def testProtojsonUnrecognizedNull(self):
313    """Test that unrecognized fields that are None are skipped."""
314    decoded = protojson.decode_message(
315        MyMessage,
316        '{"an_integer": 1, "unrecognized_null": null}')
317    self.assertEquals(decoded.an_integer, 1)
318    self.assertEquals(decoded.all_unrecognized_fields(), [])
319
320  def testUnrecognizedFieldVariants(self):
321    """Test that unrecognized fields are mapped to the right variants."""
322    for encoded, expected_variant in (
323        ('{"an_integer": 1, "unknown_val": 2}', messages.Variant.INT64),
324        ('{"an_integer": 1, "unknown_val": 2.0}', messages.Variant.DOUBLE),
325        ('{"an_integer": 1, "unknown_val": "string value"}',
326         messages.Variant.STRING),
327        ('{"an_integer": 1, "unknown_val": [1, 2, 3]}', messages.Variant.INT64),
328        ('{"an_integer": 1, "unknown_val": [1, 2.0, 3]}',
329         messages.Variant.DOUBLE),
330        ('{"an_integer": 1, "unknown_val": [1, "foo", 3]}',
331         messages.Variant.STRING),
332        ('{"an_integer": 1, "unknown_val": true}', messages.Variant.BOOL)):
333      decoded = protojson.decode_message(MyMessage, encoded)
334      self.assertEquals(decoded.an_integer, 1)
335      self.assertEquals(1, len(decoded.all_unrecognized_fields()))
336      self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0])
337      _, decoded_variant = decoded.get_unrecognized_field_info('unknown_val')
338      self.assertEquals(expected_variant, decoded_variant)
339
340  def testDecodeDateTime(self):
341    for datetime_string, datetime_vals in (
342        ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)),
343        ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))):
344      message = protojson.decode_message(
345          MyMessage, '{"a_datetime": "%s"}' % datetime_string)
346      expected_message = MyMessage(
347          a_datetime=datetime.datetime(*datetime_vals))
348
349      self.assertEquals(expected_message, message)
350
351  def testDecodeInvalidDateTime(self):
352    self.assertRaises(messages.DecodeError, protojson.decode_message,
353                      MyMessage, '{"a_datetime": "invalid"}')
354
355  def testEncodeDateTime(self):
356    for datetime_string, datetime_vals in (
357        ('2012-09-30T15:31:50.262000', (2012, 9, 30, 15, 31, 50, 262000)),
358        ('2012-09-30T15:31:50.262123', (2012, 9, 30, 15, 31, 50, 262123)),
359        ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))):
360      decoded_message = protojson.encode_message(
361          MyMessage(a_datetime=datetime.datetime(*datetime_vals)))
362      expected_decoding = '{"a_datetime": "%s"}' % datetime_string
363      self.CompareEncoded(expected_decoding, decoded_message)
364
365  def testDecodeRepeatedDateTime(self):
366    message = protojson.decode_message(
367        MyMessage,
368        '{"a_repeated_datetime": ["2012-09-30T15:31:50.262", '
369        '"2010-01-21T09:52:00", "2000-01-01T01:00:59.999999"]}')
370    expected_message = MyMessage(
371        a_repeated_datetime=[
372            datetime.datetime(2012, 9, 30, 15, 31, 50, 262000),
373            datetime.datetime(2010, 1, 21, 9, 52),
374            datetime.datetime(2000, 1, 1, 1, 0, 59, 999999)])
375
376    self.assertEquals(expected_message, message)
377
378  def testDecodeCustom(self):
379    message = protojson.decode_message(MyMessage, '{"a_custom": 1}')
380    self.assertEquals(MyMessage(a_custom=1), message)
381
382  def testDecodeInvalidCustom(self):
383    self.assertRaises(messages.ValidationError, protojson.decode_message,
384                      MyMessage, '{"a_custom": "invalid"}')
385
386  def testEncodeCustom(self):
387    decoded_message = protojson.encode_message(MyMessage(a_custom=1))
388    self.CompareEncoded('{"a_custom": 1}', decoded_message)
389
390  def testDecodeRepeatedCustom(self):
391    message = protojson.decode_message(
392        MyMessage, '{"a_repeated_custom": [1, 2, 3]}')
393    self.assertEquals(MyMessage(a_repeated_custom=[1, 2, 3]), message)
394
395  def testDecodeBadBase64BytesField(self):
396    """Test decoding improperly encoded base64 bytes value."""
397    self.assertRaisesWithRegexpMatch(
398        messages.DecodeError,
399        'Base64 decoding error: Incorrect padding',
400        protojson.decode_message,
401        test_util.OptionalMessage,
402        '{"bytes_value": "abcdefghijklmnopq"}')
403
404
405class CustomProtoJson(protojson.ProtoJson):
406
407  def encode_field(self, field, value):
408    return '{encoded}' + value
409
410  def decode_field(self, field, value):
411    return '{decoded}' + value
412
413
414class CustomProtoJsonTest(test_util.TestCase):
415  """Tests for serialization overriding functionality."""
416
417  def setUp(self):
418    self.protojson = CustomProtoJson()
419
420  def testEncode(self):
421    self.assertEqual('{"a_string": "{encoded}xyz"}',
422                     self.protojson.encode_message(MyMessage(a_string='xyz')))
423
424  def testDecode(self):
425    self.assertEqual(
426        MyMessage(a_string='{decoded}xyz'),
427        self.protojson.decode_message(MyMessage, '{"a_string": "xyz"}'))
428
429  def testDecodeEmptyMessage(self):
430    self.assertEqual(
431        MyMessage(a_string='{decoded}'),
432        self.protojson.decode_message(MyMessage, '{"a_string": ""}'))
433
434  def testDefault(self):
435    self.assertTrue(protojson.ProtoJson.get_default(),
436                    protojson.ProtoJson.get_default())
437
438    instance = CustomProtoJson()
439    protojson.ProtoJson.set_default(instance)
440    self.assertTrue(instance is protojson.ProtoJson.get_default())
441
442
443class InvalidJsonModule(object):
444  pass
445
446
447class ValidJsonModule(object):
448  class JSONEncoder(object):
449    pass
450
451
452class TestJsonDependencyLoading(test_util.TestCase):
453  """Test loading various implementations of json."""
454
455  def get_import(self):
456    """Get __import__ method.
457
458    Returns:
459      The current __import__ method.
460    """
461    if isinstance(__builtins__, dict):
462      return __builtins__['__import__']
463    else:
464      return __builtins__.__import__
465
466  def set_import(self, new_import):
467    """Set __import__ method.
468
469    Args:
470      new_import: Function to replace __import__.
471    """
472    if isinstance(__builtins__, dict):
473      __builtins__['__import__'] = new_import
474    else:
475      __builtins__.__import__ = new_import
476
477  def setUp(self):
478    """Save original import function."""
479    self.simplejson = sys.modules.pop('simplejson', None)
480    self.json = sys.modules.pop('json', None)
481    self.original_import = self.get_import()
482    def block_all_jsons(name, *args, **kwargs):
483      if 'json' in name:
484        if name in sys.modules:
485          module = sys.modules[name]
486          module.name = name
487          return module
488        raise ImportError('Unable to find %s' % name)
489      else:
490        return self.original_import(name, *args, **kwargs)
491    self.set_import(block_all_jsons)
492
493  def tearDown(self):
494    """Restore original import functions and any loaded modules."""
495
496    def reset_module(name, module):
497      if module:
498        sys.modules[name] = module
499      else:
500        sys.modules.pop(name, None)
501    reset_module('simplejson', self.simplejson)
502    reset_module('json', self.json)
503    imp.reload(protojson)
504
505  def testLoadProtojsonWithValidJsonModule(self):
506    """Test loading protojson module with a valid json dependency."""
507    sys.modules['json'] = ValidJsonModule
508
509    # This will cause protojson to reload with the default json module
510    # instead of simplejson.
511    imp.reload(protojson)
512    self.assertEquals('json', protojson.json.name)
513
514  def testLoadProtojsonWithSimplejsonModule(self):
515    """Test loading protojson module with simplejson dependency."""
516    sys.modules['simplejson'] = ValidJsonModule
517
518    # This will cause protojson to reload with the default json module
519    # instead of simplejson.
520    imp.reload(protojson)
521    self.assertEquals('simplejson', protojson.json.name)
522
523  def testLoadProtojsonWithInvalidJsonModule(self):
524    """Loading protojson module with an invalid json defaults to simplejson."""
525    sys.modules['json'] = InvalidJsonModule
526    sys.modules['simplejson'] = ValidJsonModule
527
528    # Ignore bad module and default back to simplejson.
529    imp.reload(protojson)
530    self.assertEquals('simplejson', protojson.json.name)
531
532  def testLoadProtojsonWithInvalidJsonModuleAndNoSimplejson(self):
533    """Loading protojson module with invalid json and no simplejson."""
534    sys.modules['json'] = InvalidJsonModule
535
536    # Bad module without simplejson back raises errors.
537    self.assertRaisesWithRegexpMatch(
538        ImportError,
539        'json library "json" is not compatible with ProtoRPC',
540        imp.reload,
541        protojson)
542
543  def testLoadProtojsonWithNoJsonModules(self):
544    """Loading protojson module with invalid json and no simplejson."""
545    # No json modules raise the first exception.
546    self.assertRaisesWithRegexpMatch(
547        ImportError,
548        'Unable to find json',
549        imp.reload,
550        protojson)
551
552
553if __name__ == '__main__':
554  unittest.main()
555