1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# 5# Use of this source code is governed by a BSD-style 6# license that can be found in the LICENSE file or at 7# https://developers.google.com/open-source/licenses/bsd 8 9"""Tests Nextgen Pythonic protobuf APIs.""" 10 11import io 12import unittest 13 14from google.protobuf import proto 15from google.protobuf.internal import encoder 16from google.protobuf.internal import test_util 17from google.protobuf.internal import testing_refleaks 18 19from google.protobuf.internal import _parameterized 20from google.protobuf import unittest_pb2 21from google.protobuf import unittest_proto3_arena_pb2 22 23 24@_parameterized.named_parameters(('_proto2', unittest_pb2), 25 ('_proto3', unittest_proto3_arena_pb2)) 26@testing_refleaks.TestCase 27class ProtoTest(unittest.TestCase): 28 29 def test_simple_serialize_parse(self, message_module): 30 msg = message_module.TestAllTypes() 31 test_util.SetAllFields(msg) 32 serialized_data = proto.serialize(msg) 33 parsed_msg = proto.parse(message_module.TestAllTypes, serialized_data) 34 self.assertEqual(msg, parsed_msg) 35 36 def test_serialize_parse_length_prefixed_empty(self, message_module): 37 empty_alltypes = message_module.TestAllTypes() 38 out = io.BytesIO() 39 proto.serialize_length_prefixed(empty_alltypes, out) 40 41 input_bytes = io.BytesIO(out.getvalue()) 42 msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes) 43 44 self.assertEqual(msg, empty_alltypes) 45 46 def test_parse_length_prefixed_truncated(self, message_module): 47 out = io.BytesIO() 48 encoder._VarintEncoder()(out.write, 9999) 49 msg = message_module.TestAllTypes(optional_int32=1) 50 out.write(proto.serialize(msg)) 51 52 input_bytes = io.BytesIO(out.getvalue()) 53 with self.assertRaises(ValueError) as context: 54 proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes) 55 self.assertEqual( 56 str(context.exception), 57 'Truncated message or non-buffered input_bytes: ' 58 'Expected 9999 bytes but only 2 bytes parsed for ' 59 'TestAllTypes.', 60 ) 61 62 def test_serialize_length_prefixed_fake_io(self, message_module): 63 class FakeBytesIO(io.BytesIO): 64 65 def write(self, b: bytes) -> int: 66 return 0 67 68 msg = message_module.TestAllTypes(optional_int32=123) 69 out = FakeBytesIO() 70 with self.assertRaises(TypeError) as context: 71 proto.serialize_length_prefixed(msg, out) 72 self.assertIn( 73 'Failed to write complete message (wrote: 0, expected: 2)', 74 str(context.exception), 75 ) 76 77 78_EXPECTED_PROTO3 = b'\x04r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi' 79_EXPECTED_PROTO2 = b'\x06\x08\x00r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi' 80 81 82@_parameterized.named_parameters( 83 ('_proto2', unittest_pb2, _EXPECTED_PROTO2), 84 ('_proto3', unittest_proto3_arena_pb2, _EXPECTED_PROTO3), 85) 86@testing_refleaks.TestCase 87class LengthPrefixedWithGolden(unittest.TestCase): 88 89 def test_serialize_length_prefixed(self, message_module, expected): 90 number_of_messages = 3 91 92 out = io.BytesIO() 93 for index in range(0, number_of_messages): 94 msg = message_module.TestAllTypes( 95 optional_int32=index, optional_string='hi' 96 ) 97 proto.serialize_length_prefixed(msg, out) 98 99 self.assertEqual(out.getvalue(), expected) 100 101 def test_parse_length_prefixed(self, message_module, input_bytes): 102 expected_number_of_messages = 3 103 104 input_io = io.BytesIO(input_bytes) 105 index = 0 106 while True: 107 msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_io) 108 if msg is None: 109 break 110 self.assertEqual(msg.optional_int32, index) 111 self.assertEqual(msg.optional_string, 'hi') 112 index += 1 113 114 self.assertEqual(index, expected_number_of_messages) 115 116 117if __name__ == '__main__': 118 unittest.main() 119