• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Protocol Buffers - Google's data interchange format
3# Copyright 2008 Google Inc.  All rights reserved.
4#
5# Use of this source code is governed by a BSD-style
6# license that can be found in the LICENSE file or at
7# https://developers.google.com/open-source/licenses/bsd
8
9"""Tests Nextgen Pythonic protobuf APIs."""
10
11import io
12import unittest
13
14from google.protobuf import proto
15from google.protobuf.internal import encoder
16from google.protobuf.internal import test_util
17from google.protobuf.internal import testing_refleaks
18
19from google.protobuf.internal import _parameterized
20from google.protobuf import unittest_pb2
21from google.protobuf import unittest_proto3_arena_pb2
22
23
24@_parameterized.named_parameters(('_proto2', unittest_pb2),
25                                ('_proto3', unittest_proto3_arena_pb2))
26@testing_refleaks.TestCase
27class ProtoTest(unittest.TestCase):
28
29  def test_simple_serialize_parse(self, message_module):
30    msg = message_module.TestAllTypes()
31    test_util.SetAllFields(msg)
32    serialized_data = proto.serialize(msg)
33    parsed_msg = proto.parse(message_module.TestAllTypes, serialized_data)
34    self.assertEqual(msg, parsed_msg)
35
36  def test_serialize_parse_length_prefixed_empty(self, message_module):
37    empty_alltypes = message_module.TestAllTypes()
38    out = io.BytesIO()
39    proto.serialize_length_prefixed(empty_alltypes, out)
40
41    input_bytes = io.BytesIO(out.getvalue())
42    msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)
43
44    self.assertEqual(msg, empty_alltypes)
45
46  def test_parse_length_prefixed_truncated(self, message_module):
47    out = io.BytesIO()
48    encoder._VarintEncoder()(out.write, 9999)
49    msg = message_module.TestAllTypes(optional_int32=1)
50    out.write(proto.serialize(msg))
51
52    input_bytes = io.BytesIO(out.getvalue())
53    with self.assertRaises(ValueError) as context:
54      proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)
55      self.assertEqual(
56          str(context.exception),
57          'Truncated message or non-buffered input_bytes: '
58          'Expected 9999 bytes but only 2 bytes parsed for '
59          'TestAllTypes.',
60      )
61
62  def test_serialize_length_prefixed_fake_io(self, message_module):
63    class FakeBytesIO(io.BytesIO):
64
65      def write(self, b: bytes) -> int:
66        return 0
67
68    msg = message_module.TestAllTypes(optional_int32=123)
69    out = FakeBytesIO()
70    with self.assertRaises(TypeError) as context:
71      proto.serialize_length_prefixed(msg, out)
72    self.assertIn(
73        'Failed to write complete message (wrote: 0, expected: 2)',
74        str(context.exception),
75    )
76
77
78_EXPECTED_PROTO3 = b'\x04r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'
79_EXPECTED_PROTO2 = b'\x06\x08\x00r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'
80
81
82@_parameterized.named_parameters(
83    ('_proto2', unittest_pb2, _EXPECTED_PROTO2),
84    ('_proto3', unittest_proto3_arena_pb2, _EXPECTED_PROTO3),
85)
86@testing_refleaks.TestCase
87class LengthPrefixedWithGolden(unittest.TestCase):
88
89  def test_serialize_length_prefixed(self, message_module, expected):
90    number_of_messages = 3
91
92    out = io.BytesIO()
93    for index in range(0, number_of_messages):
94      msg = message_module.TestAllTypes(
95          optional_int32=index, optional_string='hi'
96      )
97      proto.serialize_length_prefixed(msg, out)
98
99    self.assertEqual(out.getvalue(), expected)
100
101  def test_parse_length_prefixed(self, message_module, input_bytes):
102    expected_number_of_messages = 3
103
104    input_io = io.BytesIO(input_bytes)
105    index = 0
106    while True:
107      msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_io)
108      if msg is None:
109        break
110      self.assertEqual(msg.optional_int32, index)
111      self.assertEqual(msg.optional_string, 'hi')
112      index += 1
113
114    self.assertEqual(index, expected_number_of_messages)
115
116
117if __name__ == '__main__':
118  unittest.main()
119