• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The ChromiumOS Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""Tests for proto_utils."""
5
6import unittest
7
8from google.protobuf import field_mask_pb2
9from google.protobuf import timestamp_pb2
10from google.protobuf.descriptor import FieldDescriptor
11
12from chromiumos.build.api import system_image_pb2
13from chromiumos.config.payload.flat_config_pb2 import FlatConfig
14
15from chromiumos.config.public_replication.public_replication_pb2 import (
16    PublicReplication)
17from chromiumos.config.public_replication.testdata.public_replication_testdata_pb2 import (
18    PublicReplicationTestdata,
19    SimpleTestdata,
20    NestedTestdata,
21    OneofTestdata,
22    WrapperTestdata1,
23    WrapperTestdata2,
24    WrapperTestdata3,
25    RecursiveMessage,
26    PrivateMessage,
27    NestedPrivateMessage,
28    NestedRepeatedPrivateMessage,
29)
30
31from common import proto_utils
32
33
34class ProtoUtilsTest(unittest.TestCase):
35  """Tests for proto_utils."""
36
37  def test_resolve_field_path(self):
38    """Tests resolving paths in a proto."""
39
40    # Check fully valid path
41    infos = proto_utils.resolve_field_path(
42        FlatConfig,
43        'hw_components.soc.cores',
44    )
45    self.assertTrue(all(infos))
46    self.assertEqual(infos[0].typeid, FieldDescriptor.TYPE_MESSAGE)
47    self.assertEqual(infos[0].name, 'hw_components')
48    self.assertEqual(infos[0].typename, 'Component')
49    self.assertEqual(infos[0].repeated, True)
50
51    self.assertEqual(infos[1].typeid, FieldDescriptor.TYPE_MESSAGE)
52    self.assertEqual(infos[1].name, 'soc')
53    self.assertEqual(infos[1].typename, 'Soc')
54    self.assertEqual(infos[1].repeated, False)
55
56    self.assertEqual(infos[2].typeid, FieldDescriptor.TYPE_INT32)
57    self.assertEqual(infos[2].name, 'cores')
58    self.assertEqual(infos[2].typename, 'int32')
59    self.assertEqual(infos[2].repeated, False)
60
61    # Check parsing stops at first invalid field
62    infos = proto_utils.resolve_field_path(
63        FlatConfig,
64        'hw_components.not_exist.cores',
65    )
66    self.assertEqual(len(infos), 3)
67    self.assertEqual(infos[0].typeid, FieldDescriptor.TYPE_MESSAGE)
68    self.assertEqual(infos[0].name, 'hw_components')
69    self.assertEqual(infos[0].typename, 'Component')
70    self.assertEqual(infos[0].repeated, True)
71    self.assertEqual(infos[1], None)
72    self.assertEqual(infos[2], None)
73
74  def test_get_all_fields(self):
75    """Tests getting all fields on a proto."""
76    timestamp = timestamp_pb2.Timestamp(seconds=1, nanos=2)
77    self.assertSequenceEqual([1, 2], proto_utils.get_all_fields(timestamp))
78
79  def test_get_dep_graph(self):
80    """Tests getting the depgraph of a proto."""
81    self.assertDictEqual(
82        proto_utils.get_dep_graph(system_image_pb2.SystemImage.BuildTarget()), {
83            'chromiumos.build.api.Portage.BuildTarget': [],
84            'chromiumos.build.api.SystemImage.BuildTarget':
85                ['chromiumos.build.api.Portage.BuildTarget']
86        })
87
88  def test_get_dep_order(self):
89    """Tests getting the dependency order of a proto."""
90    self.assertSequenceEqual(
91        proto_utils.get_dep_order(system_image_pb2.SystemImage.BuildTarget()), [
92            'chromiumos.build.api.Portage.BuildTarget',
93            'chromiumos.build.api.SystemImage.BuildTarget'
94        ])
95
96  def test_apply_public_replication(self):
97    """Tests applying the PublicReplication message in the case where it is a
98    field on the src argument.
99
100    In this case, since PublicReplication is a field on
101    PublicReplicationTestdata, no recursion is needed.
102    """
103    src = PublicReplicationTestdata(
104        str1='abc',
105        str2='def',
106        public_replication=PublicReplication(
107            public_fields=field_mask_pb2.FieldMask(paths=['str1'])),
108    )
109    src.simple1.SetInParent()
110
111    expected = PublicReplicationTestdata(str1='abc')
112
113    dst = PublicReplicationTestdata()
114    proto_utils.apply_public_replication(src, dst)
115    self.assertEqual(dst, expected)
116
117  def test_apply_public_replication_empty_source(self):
118    """Tests applying the PublicReplication message in the case where it is a
119    field on the src argument.
120
121    In this case, since PublicReplication is a field on
122    PublicReplicationTestdata, no recursion is needed.
123    """
124    src = PublicReplicationTestdata(
125        str1='abc',
126        str2='def',
127        public_replication=PublicReplication(
128            public_fields=field_mask_pb2.FieldMask(paths=['str1', 'simple1'])),
129    )
130    src.simple1.SetInParent()
131
132    expected = PublicReplicationTestdata(str1='abc')
133    expected.simple1.SetInParent()
134
135    dst = PublicReplicationTestdata()
136    proto_utils.apply_public_replication(src, dst)
137    self.assertEqual(dst, expected)
138
139  def test_apply_public_replication_oneof(self):
140    """Tests applying the PublicReplication message in the case where it is a
141    field on the src argument.
142
143    In this case, since PublicReplication is a field on
144    PublicReplicationTestdata, no recursion is needed.
145    """
146    src = PublicReplicationTestdata(
147        str1='abc',
148        str2='def',
149        oneof1=OneofTestdata(
150            nested2=NestedTestdata(simple1=SimpleTestdata(str1='ghi'))),
151        public_replication=PublicReplication(
152            public_fields=field_mask_pb2.FieldMask(paths=['str1', 'oneof1'])),
153    )
154
155    expected = PublicReplicationTestdata(
156        str1='abc',
157        oneof1=OneofTestdata(
158            nested2=NestedTestdata(simple1=SimpleTestdata(str1='ghi'))))
159
160    dst = PublicReplicationTestdata()
161    proto_utils.apply_public_replication(src, dst)
162    self.assertEqual(dst, expected)
163
164  def test_apply_public_replication_nested_field(self):
165    """Tests applying the PublicReplication message in the case where it is in
166    a nested field on the src argument.
167
168    In this case, recursion is required: src is a WrapperTestdata2, which has a
169    WrapperTestdata1, which has a PublicReplicationTestdata.
170    """
171    src = WrapperTestdata2(
172        wrapper_testdata1=WrapperTestdata1(
173            n1=1,
174            pr_testdata=PublicReplicationTestdata(
175                str1='abc',
176                str2='def',
177                public_replication=PublicReplication(
178                    public_fields=field_mask_pb2.FieldMask(paths=['str1'])),
179            )))
180
181    expected = WrapperTestdata2(
182        wrapper_testdata1=WrapperTestdata1(
183            pr_testdata=PublicReplicationTestdata(str1='abc',)))
184
185    dst = WrapperTestdata2()
186    proto_utils.apply_public_replication(src, dst)
187    self.assertEqual(dst, expected)
188
189  def test_apply_public_replication_nested_repeated_field(self):
190    """Tests applying the PublicReplication message in the case where it is in
191    a nested repeated field on the src argument.
192
193    This case is similar to test_apply_public_replication_nested_field, but the
194    PublicReplicationTestdata field is repeated. Note that the two
195    PublicReplicationTestdata messages set different public_fields; this is
196    possible, but not necessarily expected.
197    """
198    src = WrapperTestdata2(
199        wrapper_testdata1=WrapperTestdata1(
200            n1=1,
201            repeated_pr_testdata=[
202                PublicReplicationTestdata(
203                    str1='abc',
204                    str2='def',
205                    public_replication=PublicReplication(
206                        public_fields=field_mask_pb2.FieldMask(paths=['str1'])),
207                ),
208                PublicReplicationTestdata(
209                    str1='abc',
210                    str2='def',
211                    public_replication=PublicReplication(
212                        public_fields=field_mask_pb2.FieldMask(paths=['str2'])),
213                ),
214            ]))
215
216    expected = WrapperTestdata2(
217        wrapper_testdata1=WrapperTestdata1(repeated_pr_testdata=[
218            PublicReplicationTestdata(str1='abc'),
219            PublicReplicationTestdata(str2='def'),
220        ]))
221
222    dst = WrapperTestdata2()
223    proto_utils.apply_public_replication(src, dst)
224    self.assertEqual(dst, expected)
225
226  def test_apply_public_replication_unreplicated_repeated_field(self):
227    """Tests applying the PublicReplication message in the case where a message
228    in a repeated field is not replicated.
229
230    Messages that have no fields replicated should not appear in the final
231    output.
232    """
233    src = WrapperTestdata2(
234        wrapper_testdata1=WrapperTestdata1(
235            n1=1,
236            repeated_pr_testdata=[
237                PublicReplicationTestdata(
238                    str1='abc',
239                    str2='def',
240                ),
241                PublicReplicationTestdata(
242                    str1='abc',
243                    str2='def',
244                    public_replication=PublicReplication(
245                        public_fields=field_mask_pb2.FieldMask(paths=['str2'])),
246                ),
247            ]))
248
249    # Note that only the second PublicReplicationTestdata field appears in the
250    # output; the first has no fields replicated, so is deleted.
251    expected = WrapperTestdata2(
252        wrapper_testdata1=WrapperTestdata1(repeated_pr_testdata=[
253            PublicReplicationTestdata(str2='def'),
254        ]))
255
256    dst = WrapperTestdata2()
257    proto_utils.apply_public_replication(src, dst)
258    self.assertEqual(dst, expected)
259
260  def test_apply_public_replication_stacked_messages(self):
261    """Tests applying PublicReplication messages that appear on top of each
262    other in the dependency tree.
263    """
264    src = WrapperTestdata3(
265        public_replication=PublicReplication(
266            public_fields=field_mask_pb2.FieldMask(paths=['b1'])),
267        b1=True,
268        pr_testdata=PublicReplicationTestdata(
269            public_replication=PublicReplication(
270                public_fields=field_mask_pb2.FieldMask(paths=['str2'])),
271            str1='teststr1',
272            str2='teststr2',
273        ),
274    )
275
276    expected = WrapperTestdata3(
277        b1=True,
278        pr_testdata=PublicReplicationTestdata(str2='teststr2',),
279    )
280
281    dst = WrapperTestdata3()
282    proto_utils.apply_public_replication(src, dst)
283    self.assertEqual(dst, expected)
284
285  def test_apply_public_replication_stacked_messages_overlap(self):
286    """Tests applying PublicReplication messages that appear on top of each
287    other in the dependency tree, where there is overlap in the fields they
288    specify.
289    """
290    # WrapperTestdata3 specifies that "pr_testdata" should be replicated. This
291    # means the entire PublicReplicationTestdata message is replicated, even
292    # though it specifies only "str2" should be replicated.
293    src = WrapperTestdata3(
294        public_replication=PublicReplication(
295            public_fields=field_mask_pb2.FieldMask(paths=['pr_testdata'])),
296        pr_testdata=PublicReplicationTestdata(
297            public_replication=PublicReplication(
298                public_fields=field_mask_pb2.FieldMask(paths=['str2'])),
299            str1='teststr1',
300            str2='teststr2',
301        ),
302    )
303
304    expected = WrapperTestdata3(
305        pr_testdata=PublicReplicationTestdata(
306            public_replication=PublicReplication(
307                public_fields=field_mask_pb2.FieldMask(paths=['str2'])),
308            str1='teststr1',
309            str2='teststr2',
310        ),)
311
312    dst = WrapperTestdata3()
313    proto_utils.apply_public_replication(src, dst)
314    self.assertEqual(dst, expected)
315
316  def test_apply_public_replication_empty_paths(self):
317    """Tests applying the PublicReplication message with empty paths.
318
319    If the FieldMask has empty paths, no fields should be replicated. Also test
320    similar cases where public_replication and public_fields are not set.
321    """
322    srcs = [
323        PublicReplicationTestdata(
324            str1='abc',
325            str2='def',
326            public_replication=PublicReplication(
327                public_fields=field_mask_pb2.FieldMask(paths=[])),
328        ),
329        PublicReplicationTestdata(
330            str1='abc',
331            str2='def',
332        ),
333        PublicReplicationTestdata(
334            str1='abc',
335            str2='def',
336            public_replication=PublicReplication(),
337        )
338    ]
339
340    expected = PublicReplicationTestdata()
341
342    for src in srcs:
343      dst = PublicReplicationTestdata()
344      proto_utils.apply_public_replication(src, dst)
345      self.assertEqual(dst, expected)
346
347  def test_apply_public_replication_different_messages(self):
348    """Tests that passing different messages to apply_public_replication raises
349    an Error.
350    """
351    with self.assertRaisesRegex(
352        ValueError, 'src and dst must be the same message type. Got '
353        'chromiumos.config.public_replication.testdata.PublicReplicationTestdata'
354        ' and chromiumos.config.public_replication.testdata.WrapperTestdata1'):
355      proto_utils.apply_public_replication(PublicReplicationTestdata(),
356                                           WrapperTestdata1())
357
358  def test_apply_public_replication_recursive_message(self):
359    """Tests that a message that references itself as a field doesn't cause
360    infinite recursion.
361    """
362    src = RecursiveMessage(
363        b1=True,
364        recursive_message=RecursiveMessage(b1=False),
365    )
366    dst = RecursiveMessage()
367    proto_utils.apply_public_replication(src, dst)
368    self.assertEqual(dst, RecursiveMessage())
369
370  def test_create_symbol_db(self):
371    """Test that we get a good symbol from the symbol database."""
372    self.assertIsNot(
373        proto_utils.create_symbol_db().GetSymbol(
374            "chromiumos.config.payload.ConfigBundle"),
375        None,
376    )
377
378  def test_remove_emptymessage(self):
379    """Test that a message that does not contain any public information
380    is removed.
381    """
382    src = NestedPrivateMessage(
383        nested_messages=PrivateMessage(
384            config=PrivateMessage.Config(payload=[
385                PrivateMessage.Config.Test(bools=True),
386                PrivateMessage.Config.Test(bools=True)
387            ])))
388    dst = NestedPrivateMessage()
389    proto_utils.apply_public_replication(src, dst)
390    self.assertEqual(dst, NestedPrivateMessage())
391
392  def test_remove_repeated_emptymessage(self):
393    """Test that a message that does not contain any public information
394    is removed.
395    """
396    src = NestedRepeatedPrivateMessage(nested_messages=[
397        PrivateMessage(
398            config=PrivateMessage.Config(payload=[
399                PrivateMessage.Config.Test(bools=True),
400                PrivateMessage.Config.Test(bools=True)
401            ]))
402    ])
403    dst = NestedRepeatedPrivateMessage()
404    proto_utils.apply_public_replication(src, dst)
405    self.assertEqual(dst, NestedRepeatedPrivateMessage())
406