1# Copyright 2020 The ChromiumOS Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4"""Tests for proto_utils.""" 5 6import unittest 7 8from google.protobuf import field_mask_pb2 9from google.protobuf import timestamp_pb2 10from google.protobuf.descriptor import FieldDescriptor 11 12from chromiumos.build.api import system_image_pb2 13from chromiumos.config.payload.flat_config_pb2 import FlatConfig 14 15from chromiumos.config.public_replication.public_replication_pb2 import ( 16 PublicReplication) 17from chromiumos.config.public_replication.testdata.public_replication_testdata_pb2 import ( 18 PublicReplicationTestdata, 19 SimpleTestdata, 20 NestedTestdata, 21 OneofTestdata, 22 WrapperTestdata1, 23 WrapperTestdata2, 24 WrapperTestdata3, 25 RecursiveMessage, 26 PrivateMessage, 27 NestedPrivateMessage, 28 NestedRepeatedPrivateMessage, 29) 30 31from common import proto_utils 32 33 34class ProtoUtilsTest(unittest.TestCase): 35 """Tests for proto_utils.""" 36 37 def test_resolve_field_path(self): 38 """Tests resolving paths in a proto.""" 39 40 # Check fully valid path 41 infos = proto_utils.resolve_field_path( 42 FlatConfig, 43 'hw_components.soc.cores', 44 ) 45 self.assertTrue(all(infos)) 46 self.assertEqual(infos[0].typeid, FieldDescriptor.TYPE_MESSAGE) 47 self.assertEqual(infos[0].name, 'hw_components') 48 self.assertEqual(infos[0].typename, 'Component') 49 self.assertEqual(infos[0].repeated, True) 50 51 self.assertEqual(infos[1].typeid, FieldDescriptor.TYPE_MESSAGE) 52 self.assertEqual(infos[1].name, 'soc') 53 self.assertEqual(infos[1].typename, 'Soc') 54 self.assertEqual(infos[1].repeated, False) 55 56 self.assertEqual(infos[2].typeid, FieldDescriptor.TYPE_INT32) 57 self.assertEqual(infos[2].name, 'cores') 58 self.assertEqual(infos[2].typename, 'int32') 59 self.assertEqual(infos[2].repeated, False) 60 61 # Check parsing stops at first invalid field 62 infos = proto_utils.resolve_field_path( 63 FlatConfig, 64 'hw_components.not_exist.cores', 65 ) 66 self.assertEqual(len(infos), 3) 67 self.assertEqual(infos[0].typeid, FieldDescriptor.TYPE_MESSAGE) 68 self.assertEqual(infos[0].name, 'hw_components') 69 self.assertEqual(infos[0].typename, 'Component') 70 self.assertEqual(infos[0].repeated, True) 71 self.assertEqual(infos[1], None) 72 self.assertEqual(infos[2], None) 73 74 def test_get_all_fields(self): 75 """Tests getting all fields on a proto.""" 76 timestamp = timestamp_pb2.Timestamp(seconds=1, nanos=2) 77 self.assertSequenceEqual([1, 2], proto_utils.get_all_fields(timestamp)) 78 79 def test_get_dep_graph(self): 80 """Tests getting the depgraph of a proto.""" 81 self.assertDictEqual( 82 proto_utils.get_dep_graph(system_image_pb2.SystemImage.BuildTarget()), { 83 'chromiumos.build.api.Portage.BuildTarget': [], 84 'chromiumos.build.api.SystemImage.BuildTarget': 85 ['chromiumos.build.api.Portage.BuildTarget'] 86 }) 87 88 def test_get_dep_order(self): 89 """Tests getting the dependency order of a proto.""" 90 self.assertSequenceEqual( 91 proto_utils.get_dep_order(system_image_pb2.SystemImage.BuildTarget()), [ 92 'chromiumos.build.api.Portage.BuildTarget', 93 'chromiumos.build.api.SystemImage.BuildTarget' 94 ]) 95 96 def test_apply_public_replication(self): 97 """Tests applying the PublicReplication message in the case where it is a 98 field on the src argument. 99 100 In this case, since PublicReplication is a field on 101 PublicReplicationTestdata, no recursion is needed. 102 """ 103 src = PublicReplicationTestdata( 104 str1='abc', 105 str2='def', 106 public_replication=PublicReplication( 107 public_fields=field_mask_pb2.FieldMask(paths=['str1'])), 108 ) 109 src.simple1.SetInParent() 110 111 expected = PublicReplicationTestdata(str1='abc') 112 113 dst = PublicReplicationTestdata() 114 proto_utils.apply_public_replication(src, dst) 115 self.assertEqual(dst, expected) 116 117 def test_apply_public_replication_empty_source(self): 118 """Tests applying the PublicReplication message in the case where it is a 119 field on the src argument. 120 121 In this case, since PublicReplication is a field on 122 PublicReplicationTestdata, no recursion is needed. 123 """ 124 src = PublicReplicationTestdata( 125 str1='abc', 126 str2='def', 127 public_replication=PublicReplication( 128 public_fields=field_mask_pb2.FieldMask(paths=['str1', 'simple1'])), 129 ) 130 src.simple1.SetInParent() 131 132 expected = PublicReplicationTestdata(str1='abc') 133 expected.simple1.SetInParent() 134 135 dst = PublicReplicationTestdata() 136 proto_utils.apply_public_replication(src, dst) 137 self.assertEqual(dst, expected) 138 139 def test_apply_public_replication_oneof(self): 140 """Tests applying the PublicReplication message in the case where it is a 141 field on the src argument. 142 143 In this case, since PublicReplication is a field on 144 PublicReplicationTestdata, no recursion is needed. 145 """ 146 src = PublicReplicationTestdata( 147 str1='abc', 148 str2='def', 149 oneof1=OneofTestdata( 150 nested2=NestedTestdata(simple1=SimpleTestdata(str1='ghi'))), 151 public_replication=PublicReplication( 152 public_fields=field_mask_pb2.FieldMask(paths=['str1', 'oneof1'])), 153 ) 154 155 expected = PublicReplicationTestdata( 156 str1='abc', 157 oneof1=OneofTestdata( 158 nested2=NestedTestdata(simple1=SimpleTestdata(str1='ghi')))) 159 160 dst = PublicReplicationTestdata() 161 proto_utils.apply_public_replication(src, dst) 162 self.assertEqual(dst, expected) 163 164 def test_apply_public_replication_nested_field(self): 165 """Tests applying the PublicReplication message in the case where it is in 166 a nested field on the src argument. 167 168 In this case, recursion is required: src is a WrapperTestdata2, which has a 169 WrapperTestdata1, which has a PublicReplicationTestdata. 170 """ 171 src = WrapperTestdata2( 172 wrapper_testdata1=WrapperTestdata1( 173 n1=1, 174 pr_testdata=PublicReplicationTestdata( 175 str1='abc', 176 str2='def', 177 public_replication=PublicReplication( 178 public_fields=field_mask_pb2.FieldMask(paths=['str1'])), 179 ))) 180 181 expected = WrapperTestdata2( 182 wrapper_testdata1=WrapperTestdata1( 183 pr_testdata=PublicReplicationTestdata(str1='abc',))) 184 185 dst = WrapperTestdata2() 186 proto_utils.apply_public_replication(src, dst) 187 self.assertEqual(dst, expected) 188 189 def test_apply_public_replication_nested_repeated_field(self): 190 """Tests applying the PublicReplication message in the case where it is in 191 a nested repeated field on the src argument. 192 193 This case is similar to test_apply_public_replication_nested_field, but the 194 PublicReplicationTestdata field is repeated. Note that the two 195 PublicReplicationTestdata messages set different public_fields; this is 196 possible, but not necessarily expected. 197 """ 198 src = WrapperTestdata2( 199 wrapper_testdata1=WrapperTestdata1( 200 n1=1, 201 repeated_pr_testdata=[ 202 PublicReplicationTestdata( 203 str1='abc', 204 str2='def', 205 public_replication=PublicReplication( 206 public_fields=field_mask_pb2.FieldMask(paths=['str1'])), 207 ), 208 PublicReplicationTestdata( 209 str1='abc', 210 str2='def', 211 public_replication=PublicReplication( 212 public_fields=field_mask_pb2.FieldMask(paths=['str2'])), 213 ), 214 ])) 215 216 expected = WrapperTestdata2( 217 wrapper_testdata1=WrapperTestdata1(repeated_pr_testdata=[ 218 PublicReplicationTestdata(str1='abc'), 219 PublicReplicationTestdata(str2='def'), 220 ])) 221 222 dst = WrapperTestdata2() 223 proto_utils.apply_public_replication(src, dst) 224 self.assertEqual(dst, expected) 225 226 def test_apply_public_replication_unreplicated_repeated_field(self): 227 """Tests applying the PublicReplication message in the case where a message 228 in a repeated field is not replicated. 229 230 Messages that have no fields replicated should not appear in the final 231 output. 232 """ 233 src = WrapperTestdata2( 234 wrapper_testdata1=WrapperTestdata1( 235 n1=1, 236 repeated_pr_testdata=[ 237 PublicReplicationTestdata( 238 str1='abc', 239 str2='def', 240 ), 241 PublicReplicationTestdata( 242 str1='abc', 243 str2='def', 244 public_replication=PublicReplication( 245 public_fields=field_mask_pb2.FieldMask(paths=['str2'])), 246 ), 247 ])) 248 249 # Note that only the second PublicReplicationTestdata field appears in the 250 # output; the first has no fields replicated, so is deleted. 251 expected = WrapperTestdata2( 252 wrapper_testdata1=WrapperTestdata1(repeated_pr_testdata=[ 253 PublicReplicationTestdata(str2='def'), 254 ])) 255 256 dst = WrapperTestdata2() 257 proto_utils.apply_public_replication(src, dst) 258 self.assertEqual(dst, expected) 259 260 def test_apply_public_replication_stacked_messages(self): 261 """Tests applying PublicReplication messages that appear on top of each 262 other in the dependency tree. 263 """ 264 src = WrapperTestdata3( 265 public_replication=PublicReplication( 266 public_fields=field_mask_pb2.FieldMask(paths=['b1'])), 267 b1=True, 268 pr_testdata=PublicReplicationTestdata( 269 public_replication=PublicReplication( 270 public_fields=field_mask_pb2.FieldMask(paths=['str2'])), 271 str1='teststr1', 272 str2='teststr2', 273 ), 274 ) 275 276 expected = WrapperTestdata3( 277 b1=True, 278 pr_testdata=PublicReplicationTestdata(str2='teststr2',), 279 ) 280 281 dst = WrapperTestdata3() 282 proto_utils.apply_public_replication(src, dst) 283 self.assertEqual(dst, expected) 284 285 def test_apply_public_replication_stacked_messages_overlap(self): 286 """Tests applying PublicReplication messages that appear on top of each 287 other in the dependency tree, where there is overlap in the fields they 288 specify. 289 """ 290 # WrapperTestdata3 specifies that "pr_testdata" should be replicated. This 291 # means the entire PublicReplicationTestdata message is replicated, even 292 # though it specifies only "str2" should be replicated. 293 src = WrapperTestdata3( 294 public_replication=PublicReplication( 295 public_fields=field_mask_pb2.FieldMask(paths=['pr_testdata'])), 296 pr_testdata=PublicReplicationTestdata( 297 public_replication=PublicReplication( 298 public_fields=field_mask_pb2.FieldMask(paths=['str2'])), 299 str1='teststr1', 300 str2='teststr2', 301 ), 302 ) 303 304 expected = WrapperTestdata3( 305 pr_testdata=PublicReplicationTestdata( 306 public_replication=PublicReplication( 307 public_fields=field_mask_pb2.FieldMask(paths=['str2'])), 308 str1='teststr1', 309 str2='teststr2', 310 ),) 311 312 dst = WrapperTestdata3() 313 proto_utils.apply_public_replication(src, dst) 314 self.assertEqual(dst, expected) 315 316 def test_apply_public_replication_empty_paths(self): 317 """Tests applying the PublicReplication message with empty paths. 318 319 If the FieldMask has empty paths, no fields should be replicated. Also test 320 similar cases where public_replication and public_fields are not set. 321 """ 322 srcs = [ 323 PublicReplicationTestdata( 324 str1='abc', 325 str2='def', 326 public_replication=PublicReplication( 327 public_fields=field_mask_pb2.FieldMask(paths=[])), 328 ), 329 PublicReplicationTestdata( 330 str1='abc', 331 str2='def', 332 ), 333 PublicReplicationTestdata( 334 str1='abc', 335 str2='def', 336 public_replication=PublicReplication(), 337 ) 338 ] 339 340 expected = PublicReplicationTestdata() 341 342 for src in srcs: 343 dst = PublicReplicationTestdata() 344 proto_utils.apply_public_replication(src, dst) 345 self.assertEqual(dst, expected) 346 347 def test_apply_public_replication_different_messages(self): 348 """Tests that passing different messages to apply_public_replication raises 349 an Error. 350 """ 351 with self.assertRaisesRegex( 352 ValueError, 'src and dst must be the same message type. Got ' 353 'chromiumos.config.public_replication.testdata.PublicReplicationTestdata' 354 ' and chromiumos.config.public_replication.testdata.WrapperTestdata1'): 355 proto_utils.apply_public_replication(PublicReplicationTestdata(), 356 WrapperTestdata1()) 357 358 def test_apply_public_replication_recursive_message(self): 359 """Tests that a message that references itself as a field doesn't cause 360 infinite recursion. 361 """ 362 src = RecursiveMessage( 363 b1=True, 364 recursive_message=RecursiveMessage(b1=False), 365 ) 366 dst = RecursiveMessage() 367 proto_utils.apply_public_replication(src, dst) 368 self.assertEqual(dst, RecursiveMessage()) 369 370 def test_create_symbol_db(self): 371 """Test that we get a good symbol from the symbol database.""" 372 self.assertIsNot( 373 proto_utils.create_symbol_db().GetSymbol( 374 "chromiumos.config.payload.ConfigBundle"), 375 None, 376 ) 377 378 def test_remove_emptymessage(self): 379 """Test that a message that does not contain any public information 380 is removed. 381 """ 382 src = NestedPrivateMessage( 383 nested_messages=PrivateMessage( 384 config=PrivateMessage.Config(payload=[ 385 PrivateMessage.Config.Test(bools=True), 386 PrivateMessage.Config.Test(bools=True) 387 ]))) 388 dst = NestedPrivateMessage() 389 proto_utils.apply_public_replication(src, dst) 390 self.assertEqual(dst, NestedPrivateMessage()) 391 392 def test_remove_repeated_emptymessage(self): 393 """Test that a message that does not contain any public information 394 is removed. 395 """ 396 src = NestedRepeatedPrivateMessage(nested_messages=[ 397 PrivateMessage( 398 config=PrivateMessage.Config(payload=[ 399 PrivateMessage.Config.Test(bools=True), 400 PrivateMessage.Config.Test(bools=True) 401 ])) 402 ]) 403 dst = NestedRepeatedPrivateMessage() 404 proto_utils.apply_public_replication(src, dst) 405 self.assertEqual(dst, NestedRepeatedPrivateMessage()) 406