1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Tests for protorpc.protojson.""" 19 20__author__ = 'rafek@google.com (Rafe Kaplan)' 21 22 23import datetime 24import imp 25import sys 26import unittest 27 28from protorpc import message_types 29from protorpc import messages 30from protorpc import protojson 31from protorpc import test_util 32 33import simplejson 34 35 36class CustomField(messages.MessageField): 37 """Custom MessageField class.""" 38 39 type = int 40 message_type = message_types.VoidMessage 41 42 def __init__(self, number, **kwargs): 43 super(CustomField, self).__init__(self.message_type, number, **kwargs) 44 45 def value_to_message(self, value): 46 return self.message_type() 47 48 49class MyMessage(messages.Message): 50 """Test message containing various types.""" 51 52 class Color(messages.Enum): 53 54 RED = 1 55 GREEN = 2 56 BLUE = 3 57 58 class Nested(messages.Message): 59 60 nested_value = messages.StringField(1) 61 62 a_string = messages.StringField(2) 63 an_integer = messages.IntegerField(3) 64 a_float = messages.FloatField(4) 65 a_boolean = messages.BooleanField(5) 66 an_enum = messages.EnumField(Color, 6) 67 a_nested = messages.MessageField(Nested, 7) 68 a_repeated = messages.IntegerField(8, repeated=True) 69 a_repeated_float = messages.FloatField(9, repeated=True) 70 a_datetime = message_types.DateTimeField(10) 71 a_repeated_datetime = message_types.DateTimeField(11, repeated=True) 72 a_custom = CustomField(12) 73 a_repeated_custom = CustomField(13, repeated=True) 74 75 76class ModuleInterfaceTest(test_util.ModuleInterfaceTest, 77 test_util.TestCase): 78 79 MODULE = protojson 80 81 82# TODO(rafek): Convert this test to the compliance test in test_util. 83class ProtojsonTest(test_util.TestCase, 84 test_util.ProtoConformanceTestBase): 85 """Test JSON encoding and decoding.""" 86 87 PROTOLIB = protojson 88 89 def CompareEncoded(self, expected_encoded, actual_encoded): 90 """JSON encoding will be laundered to remove string differences.""" 91 self.assertEquals(simplejson.loads(expected_encoded), 92 simplejson.loads(actual_encoded)) 93 94 encoded_empty_message = '{}' 95 96 encoded_partial = """{ 97 "double_value": 1.23, 98 "int64_value": -100000000000, 99 "int32_value": 1020, 100 "string_value": "a string", 101 "enum_value": "VAL2" 102 } 103 """ 104 105 encoded_full = """{ 106 "double_value": 1.23, 107 "float_value": -2.5, 108 "int64_value": -100000000000, 109 "uint64_value": 102020202020, 110 "int32_value": 1020, 111 "bool_value": true, 112 "string_value": "a string\u044f", 113 "bytes_value": "YSBieXRlc//+", 114 "enum_value": "VAL2" 115 } 116 """ 117 118 encoded_repeated = """{ 119 "double_value": [1.23, 2.3], 120 "float_value": [-2.5, 0.5], 121 "int64_value": [-100000000000, 20], 122 "uint64_value": [102020202020, 10], 123 "int32_value": [1020, 718], 124 "bool_value": [true, false], 125 "string_value": ["a string\u044f", "another string"], 126 "bytes_value": ["YSBieXRlc//+", "YW5vdGhlciBieXRlcw=="], 127 "enum_value": ["VAL2", "VAL1"] 128 } 129 """ 130 131 encoded_nested = """{ 132 "nested": { 133 "a_value": "a string" 134 } 135 } 136 """ 137 138 encoded_repeated_nested = """{ 139 "repeated_nested": [{"a_value": "a string"}, 140 {"a_value": "another string"}] 141 } 142 """ 143 144 unexpected_tag_message = '{"unknown": "value"}' 145 146 encoded_default_assigned = '{"a_value": "a default"}' 147 148 encoded_nested_empty = '{"nested": {}}' 149 150 encoded_repeated_nested_empty = '{"repeated_nested": [{}, {}]}' 151 152 encoded_extend_message = '{"int64_value": [400, 50, 6000]}' 153 154 encoded_string_types = '{"string_value": "Latin"}' 155 156 encoded_invalid_enum = '{"enum_value": "undefined"}' 157 158 def testConvertIntegerToFloat(self): 159 """Test that integers passed in to float fields are converted. 160 161 This is necessary because JSON outputs integers for numbers with 0 decimals. 162 """ 163 message = protojson.decode_message(MyMessage, '{"a_float": 10}') 164 165 self.assertTrue(isinstance(message.a_float, float)) 166 self.assertEquals(10.0, message.a_float) 167 168 def testConvertStringToNumbers(self): 169 """Test that strings passed to integer fields are converted.""" 170 message = protojson.decode_message(MyMessage, 171 """{"an_integer": "10", 172 "a_float": "3.5", 173 "a_repeated": ["1", "2"], 174 "a_repeated_float": ["1.5", "2", 10] 175 }""") 176 177 self.assertEquals(MyMessage(an_integer=10, 178 a_float=3.5, 179 a_repeated=[1, 2], 180 a_repeated_float=[1.5, 2.0, 10.0]), 181 message) 182 183 def testWrongTypeAssignment(self): 184 """Test when wrong type is assigned to a field.""" 185 self.assertRaises(messages.ValidationError, 186 protojson.decode_message, 187 MyMessage, '{"a_string": 10}') 188 self.assertRaises(messages.ValidationError, 189 protojson.decode_message, 190 MyMessage, '{"an_integer": 10.2}') 191 self.assertRaises(messages.ValidationError, 192 protojson.decode_message, 193 MyMessage, '{"an_integer": "10.2"}') 194 195 def testNumericEnumeration(self): 196 """Test that numbers work for enum values.""" 197 message = protojson.decode_message(MyMessage, '{"an_enum": 2}') 198 199 expected_message = MyMessage() 200 expected_message.an_enum = MyMessage.Color.GREEN 201 202 self.assertEquals(expected_message, message) 203 204 def testNumericEnumerationNegativeTest(self): 205 """Test with an invalid number for the enum value.""" 206 self.assertRaisesRegexp( 207 messages.DecodeError, 208 'Invalid enum value "89"', 209 protojson.decode_message, 210 MyMessage, 211 '{"an_enum": 89}') 212 213 def testAlphaEnumeration(self): 214 """Test that alpha enum values work.""" 215 message = protojson.decode_message(MyMessage, '{"an_enum": "RED"}') 216 217 expected_message = MyMessage() 218 expected_message.an_enum = MyMessage.Color.RED 219 220 self.assertEquals(expected_message, message) 221 222 def testAlphaEnumerationNegativeTest(self): 223 """The alpha enum value is invalid.""" 224 self.assertRaisesRegexp( 225 messages.DecodeError, 226 'Invalid enum value "IAMINVALID"', 227 protojson.decode_message, 228 MyMessage, 229 '{"an_enum": "IAMINVALID"}') 230 231 def testEnumerationNegativeTestWithEmptyString(self): 232 """The enum value is an empty string.""" 233 self.assertRaisesRegexp( 234 messages.DecodeError, 235 'Invalid enum value ""', 236 protojson.decode_message, 237 MyMessage, 238 '{"an_enum": ""}') 239 240 def testNullValues(self): 241 """Test that null values overwrite existing values.""" 242 self.assertEquals(MyMessage(), 243 protojson.decode_message(MyMessage, 244 ('{"an_integer": null,' 245 ' "a_nested": null,' 246 ' "an_enum": null' 247 '}'))) 248 249 def testEmptyList(self): 250 """Test that empty lists are ignored.""" 251 self.assertEquals(MyMessage(), 252 protojson.decode_message(MyMessage, 253 '{"a_repeated": []}')) 254 255 def testNotJSON(self): 256 """Test error when string is not valid JSON.""" 257 self.assertRaises(ValueError, 258 protojson.decode_message, MyMessage, '{this is not json}') 259 260 def testDoNotEncodeStrangeObjects(self): 261 """Test trying to encode a strange object. 262 263 The main purpose of this test is to complete coverage. It ensures that 264 the default behavior of the JSON encoder is preserved when someone tries to 265 serialized an unexpected type. 266 """ 267 class BogusObject(object): 268 269 def check_initialized(self): 270 pass 271 272 self.assertRaises(TypeError, 273 protojson.encode_message, 274 BogusObject()) 275 276 def testMergeEmptyString(self): 277 """Test merging the empty or space only string.""" 278 message = protojson.decode_message(test_util.OptionalMessage, '') 279 self.assertEquals(test_util.OptionalMessage(), message) 280 281 message = protojson.decode_message(test_util.OptionalMessage, ' ') 282 self.assertEquals(test_util.OptionalMessage(), message) 283 284 def testProtojsonUnrecognizedFieldName(self): 285 """Test that unrecognized fields are saved and can be accessed.""" 286 decoded = protojson.decode_message(MyMessage, 287 ('{"an_integer": 1, "unknown_val": 2}')) 288 self.assertEquals(decoded.an_integer, 1) 289 self.assertEquals(1, len(decoded.all_unrecognized_fields())) 290 self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) 291 self.assertEquals((2, messages.Variant.INT64), 292 decoded.get_unrecognized_field_info('unknown_val')) 293 294 def testProtojsonUnrecognizedFieldNumber(self): 295 """Test that unrecognized fields are saved and can be accessed.""" 296 decoded = protojson.decode_message( 297 MyMessage, 298 '{"an_integer": 1, "1001": "unknown", "-123": "negative", ' 299 '"456_mixed": 2}') 300 self.assertEquals(decoded.an_integer, 1) 301 self.assertEquals(3, len(decoded.all_unrecognized_fields())) 302 self.assertTrue(1001 in decoded.all_unrecognized_fields()) 303 self.assertEquals(('unknown', messages.Variant.STRING), 304 decoded.get_unrecognized_field_info(1001)) 305 self.assertTrue('-123' in decoded.all_unrecognized_fields()) 306 self.assertEquals(('negative', messages.Variant.STRING), 307 decoded.get_unrecognized_field_info('-123')) 308 self.assertTrue('456_mixed' in decoded.all_unrecognized_fields()) 309 self.assertEquals((2, messages.Variant.INT64), 310 decoded.get_unrecognized_field_info('456_mixed')) 311 312 def testProtojsonUnrecognizedNull(self): 313 """Test that unrecognized fields that are None are skipped.""" 314 decoded = protojson.decode_message( 315 MyMessage, 316 '{"an_integer": 1, "unrecognized_null": null}') 317 self.assertEquals(decoded.an_integer, 1) 318 self.assertEquals(decoded.all_unrecognized_fields(), []) 319 320 def testUnrecognizedFieldVariants(self): 321 """Test that unrecognized fields are mapped to the right variants.""" 322 for encoded, expected_variant in ( 323 ('{"an_integer": 1, "unknown_val": 2}', messages.Variant.INT64), 324 ('{"an_integer": 1, "unknown_val": 2.0}', messages.Variant.DOUBLE), 325 ('{"an_integer": 1, "unknown_val": "string value"}', 326 messages.Variant.STRING), 327 ('{"an_integer": 1, "unknown_val": [1, 2, 3]}', messages.Variant.INT64), 328 ('{"an_integer": 1, "unknown_val": [1, 2.0, 3]}', 329 messages.Variant.DOUBLE), 330 ('{"an_integer": 1, "unknown_val": [1, "foo", 3]}', 331 messages.Variant.STRING), 332 ('{"an_integer": 1, "unknown_val": true}', messages.Variant.BOOL)): 333 decoded = protojson.decode_message(MyMessage, encoded) 334 self.assertEquals(decoded.an_integer, 1) 335 self.assertEquals(1, len(decoded.all_unrecognized_fields())) 336 self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) 337 _, decoded_variant = decoded.get_unrecognized_field_info('unknown_val') 338 self.assertEquals(expected_variant, decoded_variant) 339 340 def testDecodeDateTime(self): 341 for datetime_string, datetime_vals in ( 342 ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)), 343 ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): 344 message = protojson.decode_message( 345 MyMessage, '{"a_datetime": "%s"}' % datetime_string) 346 expected_message = MyMessage( 347 a_datetime=datetime.datetime(*datetime_vals)) 348 349 self.assertEquals(expected_message, message) 350 351 def testDecodeInvalidDateTime(self): 352 self.assertRaises(messages.DecodeError, protojson.decode_message, 353 MyMessage, '{"a_datetime": "invalid"}') 354 355 def testEncodeDateTime(self): 356 for datetime_string, datetime_vals in ( 357 ('2012-09-30T15:31:50.262000', (2012, 9, 30, 15, 31, 50, 262000)), 358 ('2012-09-30T15:31:50.262123', (2012, 9, 30, 15, 31, 50, 262123)), 359 ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): 360 decoded_message = protojson.encode_message( 361 MyMessage(a_datetime=datetime.datetime(*datetime_vals))) 362 expected_decoding = '{"a_datetime": "%s"}' % datetime_string 363 self.CompareEncoded(expected_decoding, decoded_message) 364 365 def testDecodeRepeatedDateTime(self): 366 message = protojson.decode_message( 367 MyMessage, 368 '{"a_repeated_datetime": ["2012-09-30T15:31:50.262", ' 369 '"2010-01-21T09:52:00", "2000-01-01T01:00:59.999999"]}') 370 expected_message = MyMessage( 371 a_repeated_datetime=[ 372 datetime.datetime(2012, 9, 30, 15, 31, 50, 262000), 373 datetime.datetime(2010, 1, 21, 9, 52), 374 datetime.datetime(2000, 1, 1, 1, 0, 59, 999999)]) 375 376 self.assertEquals(expected_message, message) 377 378 def testDecodeCustom(self): 379 message = protojson.decode_message(MyMessage, '{"a_custom": 1}') 380 self.assertEquals(MyMessage(a_custom=1), message) 381 382 def testDecodeInvalidCustom(self): 383 self.assertRaises(messages.ValidationError, protojson.decode_message, 384 MyMessage, '{"a_custom": "invalid"}') 385 386 def testEncodeCustom(self): 387 decoded_message = protojson.encode_message(MyMessage(a_custom=1)) 388 self.CompareEncoded('{"a_custom": 1}', decoded_message) 389 390 def testDecodeRepeatedCustom(self): 391 message = protojson.decode_message( 392 MyMessage, '{"a_repeated_custom": [1, 2, 3]}') 393 self.assertEquals(MyMessage(a_repeated_custom=[1, 2, 3]), message) 394 395 def testDecodeBadBase64BytesField(self): 396 """Test decoding improperly encoded base64 bytes value.""" 397 self.assertRaisesWithRegexpMatch( 398 messages.DecodeError, 399 'Base64 decoding error: Incorrect padding', 400 protojson.decode_message, 401 test_util.OptionalMessage, 402 '{"bytes_value": "abcdefghijklmnopq"}') 403 404 405class CustomProtoJson(protojson.ProtoJson): 406 407 def encode_field(self, field, value): 408 return '{encoded}' + value 409 410 def decode_field(self, field, value): 411 return '{decoded}' + value 412 413 414class CustomProtoJsonTest(test_util.TestCase): 415 """Tests for serialization overriding functionality.""" 416 417 def setUp(self): 418 self.protojson = CustomProtoJson() 419 420 def testEncode(self): 421 self.assertEqual('{"a_string": "{encoded}xyz"}', 422 self.protojson.encode_message(MyMessage(a_string='xyz'))) 423 424 def testDecode(self): 425 self.assertEqual( 426 MyMessage(a_string='{decoded}xyz'), 427 self.protojson.decode_message(MyMessage, '{"a_string": "xyz"}')) 428 429 def testDecodeEmptyMessage(self): 430 self.assertEqual( 431 MyMessage(a_string='{decoded}'), 432 self.protojson.decode_message(MyMessage, '{"a_string": ""}')) 433 434 def testDefault(self): 435 self.assertTrue(protojson.ProtoJson.get_default(), 436 protojson.ProtoJson.get_default()) 437 438 instance = CustomProtoJson() 439 protojson.ProtoJson.set_default(instance) 440 self.assertTrue(instance is protojson.ProtoJson.get_default()) 441 442 443class InvalidJsonModule(object): 444 pass 445 446 447class ValidJsonModule(object): 448 class JSONEncoder(object): 449 pass 450 451 452class TestJsonDependencyLoading(test_util.TestCase): 453 """Test loading various implementations of json.""" 454 455 def get_import(self): 456 """Get __import__ method. 457 458 Returns: 459 The current __import__ method. 460 """ 461 if isinstance(__builtins__, dict): 462 return __builtins__['__import__'] 463 else: 464 return __builtins__.__import__ 465 466 def set_import(self, new_import): 467 """Set __import__ method. 468 469 Args: 470 new_import: Function to replace __import__. 471 """ 472 if isinstance(__builtins__, dict): 473 __builtins__['__import__'] = new_import 474 else: 475 __builtins__.__import__ = new_import 476 477 def setUp(self): 478 """Save original import function.""" 479 self.simplejson = sys.modules.pop('simplejson', None) 480 self.json = sys.modules.pop('json', None) 481 self.original_import = self.get_import() 482 def block_all_jsons(name, *args, **kwargs): 483 if 'json' in name: 484 if name in sys.modules: 485 module = sys.modules[name] 486 module.name = name 487 return module 488 raise ImportError('Unable to find %s' % name) 489 else: 490 return self.original_import(name, *args, **kwargs) 491 self.set_import(block_all_jsons) 492 493 def tearDown(self): 494 """Restore original import functions and any loaded modules.""" 495 496 def reset_module(name, module): 497 if module: 498 sys.modules[name] = module 499 else: 500 sys.modules.pop(name, None) 501 reset_module('simplejson', self.simplejson) 502 reset_module('json', self.json) 503 imp.reload(protojson) 504 505 def testLoadProtojsonWithValidJsonModule(self): 506 """Test loading protojson module with a valid json dependency.""" 507 sys.modules['json'] = ValidJsonModule 508 509 # This will cause protojson to reload with the default json module 510 # instead of simplejson. 511 imp.reload(protojson) 512 self.assertEquals('json', protojson.json.name) 513 514 def testLoadProtojsonWithSimplejsonModule(self): 515 """Test loading protojson module with simplejson dependency.""" 516 sys.modules['simplejson'] = ValidJsonModule 517 518 # This will cause protojson to reload with the default json module 519 # instead of simplejson. 520 imp.reload(protojson) 521 self.assertEquals('simplejson', protojson.json.name) 522 523 def testLoadProtojsonWithInvalidJsonModule(self): 524 """Loading protojson module with an invalid json defaults to simplejson.""" 525 sys.modules['json'] = InvalidJsonModule 526 sys.modules['simplejson'] = ValidJsonModule 527 528 # Ignore bad module and default back to simplejson. 529 imp.reload(protojson) 530 self.assertEquals('simplejson', protojson.json.name) 531 532 def testLoadProtojsonWithInvalidJsonModuleAndNoSimplejson(self): 533 """Loading protojson module with invalid json and no simplejson.""" 534 sys.modules['json'] = InvalidJsonModule 535 536 # Bad module without simplejson back raises errors. 537 self.assertRaisesWithRegexpMatch( 538 ImportError, 539 'json library "json" is not compatible with ProtoRPC', 540 imp.reload, 541 protojson) 542 543 def testLoadProtojsonWithNoJsonModules(self): 544 """Loading protojson module with invalid json and no simplejson.""" 545 # No json modules raise the first exception. 546 self.assertRaisesWithRegexpMatch( 547 ImportError, 548 'Unable to find json', 549 imp.reload, 550 protojson) 551 552 553if __name__ == '__main__': 554 unittest.main() 555