1# SPDX-License-Identifier: GPL-2.0-only 2# This file is part of Scapy 3# See https://scapy.net/ for more information 4# Copyright (C) Philippe Biondi <phil@secdev.org> 5# Acknowledgment: Maxence Tury <maxence.tury@ssi.gouv.fr> 6 7""" 8Classes that implement ASN.1 data structures. 9""" 10 11import copy 12 13from functools import reduce 14 15from scapy.asn1.asn1 import ( 16 ASN1_BIT_STRING, 17 ASN1_BOOLEAN, 18 ASN1_Class, 19 ASN1_Class_UNIVERSAL, 20 ASN1_Error, 21 ASN1_INTEGER, 22 ASN1_NULL, 23 ASN1_OID, 24 ASN1_Object, 25 ASN1_STRING, 26) 27from scapy.asn1.ber import ( 28 BER_Decoding_Error, 29 BER_id_dec, 30 BER_tagging_dec, 31 BER_tagging_enc, 32) 33from scapy.base_classes import BasePacket 34from scapy.compat import raw 35from scapy.volatile import ( 36 GeneralizedTime, 37 RandChoice, 38 RandInt, 39 RandNum, 40 RandOID, 41 RandString, 42 RandField, 43) 44 45from scapy import packet 46 47from typing import ( 48 Any, 49 AnyStr, 50 Callable, 51 Dict, 52 Generic, 53 List, 54 Optional, 55 Tuple, 56 Type, 57 TypeVar, 58 Union, 59 cast, 60 TYPE_CHECKING, 61) 62 63if TYPE_CHECKING: 64 from scapy.asn1packet import ASN1_Packet 65 66 67class ASN1F_badsequence(Exception): 68 pass 69 70 71class ASN1F_element(object): 72 pass 73 74 75########################## 76# Basic ASN1 Field # 77########################## 78 79_I = TypeVar('_I') # Internal storage 80_A = TypeVar('_A') # ASN.1 object 81 82 83class ASN1F_field(ASN1F_element, Generic[_I, _A]): 84 holds_packets = 0 85 islist = 0 86 ASN1_tag = ASN1_Class_UNIVERSAL.ANY 87 context = ASN1_Class_UNIVERSAL # type: Type[ASN1_Class] 88 89 def __init__(self, 90 name, # type: str 91 default, # type: Optional[_A] 92 context=None, # type: Optional[Type[ASN1_Class]] 93 implicit_tag=None, # type: Optional[int] 94 explicit_tag=None, # type: Optional[int] 95 flexible_tag=False, # type: Optional[bool] 96 size_len=None, # type: Optional[int] 97 ): 98 # type: (...) -> None 99 if context is not None: 100 self.context = context 101 self.name = name 102 if default is None: 103 self.default = default # type: Optional[_A] 104 elif isinstance(default, ASN1_NULL): 105 self.default = default # type: ignore 106 else: 107 self.default = self.ASN1_tag.asn1_object(default) # type: ignore 108 self.size_len = size_len 109 self.flexible_tag = flexible_tag 110 if (implicit_tag is not None) and (explicit_tag is not None): 111 err_msg = "field cannot be both implicitly and explicitly tagged" 112 raise ASN1_Error(err_msg) 113 self.implicit_tag = implicit_tag and int(implicit_tag) 114 self.explicit_tag = explicit_tag and int(explicit_tag) 115 # network_tag gets useful for ASN1F_CHOICE 116 self.network_tag = int(implicit_tag or explicit_tag or self.ASN1_tag) 117 self.owners = [] # type: List[Type[ASN1_Packet]] 118 119 def register_owner(self, cls): 120 # type: (Type[ASN1_Packet]) -> None 121 self.owners.append(cls) 122 123 def i2repr(self, pkt, x): 124 # type: (ASN1_Packet, _I) -> str 125 return repr(x) 126 127 def i2h(self, pkt, x): 128 # type: (ASN1_Packet, _I) -> Any 129 return x 130 131 def m2i(self, pkt, s): 132 # type: (ASN1_Packet, bytes) -> Tuple[_A, bytes] 133 """ 134 The good thing about safedec is that it may still decode ASN1 135 even if there is a mismatch between the expected tag (self.ASN1_tag) 136 and the actual tag; the decoded ASN1 object will simply be put 137 into an ASN1_BADTAG object. However, safedec prevents the raising of 138 exceptions needed for ASN1F_optional processing. 139 Thus we use 'flexible_tag', which should be False with ASN1F_optional. 140 141 Regarding other fields, we might need to know whether encoding went 142 as expected or not. Noticeably, input methods from cert.py expect 143 certain exceptions to be raised. Hence default flexible_tag is False. 144 """ 145 diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag, 146 implicit_tag=self.implicit_tag, 147 explicit_tag=self.explicit_tag, 148 safe=self.flexible_tag, 149 _fname=self.name) 150 if diff_tag is not None: 151 # this implies that flexible_tag was True 152 if self.implicit_tag is not None: 153 self.implicit_tag = diff_tag 154 elif self.explicit_tag is not None: 155 self.explicit_tag = diff_tag 156 codec = self.ASN1_tag.get_codec(pkt.ASN1_codec) 157 if self.flexible_tag: 158 return codec.safedec(s, context=self.context) # type: ignore 159 else: 160 return codec.dec(s, context=self.context) # type: ignore 161 162 def i2m(self, pkt, x): 163 # type: (ASN1_Packet, Union[bytes, _I, _A]) -> bytes 164 if x is None: 165 return b"" 166 if isinstance(x, ASN1_Object): 167 if (self.ASN1_tag == ASN1_Class_UNIVERSAL.ANY or 168 x.tag == ASN1_Class_UNIVERSAL.RAW or 169 x.tag == ASN1_Class_UNIVERSAL.ERROR or 170 self.ASN1_tag == x.tag): 171 s = x.enc(pkt.ASN1_codec) 172 else: 173 raise ASN1_Error("Encoding Error: got %r instead of an %r for field [%s]" % (x, self.ASN1_tag, self.name)) # noqa: E501 174 else: 175 s = self.ASN1_tag.get_codec(pkt.ASN1_codec).enc(x, size_len=self.size_len) 176 return BER_tagging_enc(s, 177 implicit_tag=self.implicit_tag, 178 explicit_tag=self.explicit_tag) 179 180 def any2i(self, pkt, x): 181 # type: (ASN1_Packet, Any) -> _I 182 return cast(_I, x) 183 184 def extract_packet(self, 185 cls, # type: Type[ASN1_Packet] 186 s, # type: bytes 187 _underlayer=None # type: Optional[ASN1_Packet] 188 ): 189 # type: (...) -> Tuple[ASN1_Packet, bytes] 190 try: 191 c = cls(s, _underlayer=_underlayer) 192 except ASN1F_badsequence: 193 c = packet.Raw(s, _underlayer=_underlayer) # type: ignore 194 cpad = c.getlayer(packet.Raw) 195 s = b"" 196 if cpad is not None: 197 s = cpad.load 198 if cpad.underlayer: 199 del cpad.underlayer.payload 200 return c, s 201 202 def build(self, pkt): 203 # type: (ASN1_Packet) -> bytes 204 return self.i2m(pkt, getattr(pkt, self.name)) 205 206 def dissect(self, pkt, s): 207 # type: (ASN1_Packet, bytes) -> bytes 208 v, s = self.m2i(pkt, s) 209 self.set_val(pkt, v) 210 return s 211 212 def do_copy(self, x): 213 # type: (Any) -> Any 214 if isinstance(x, list): 215 x = x[:] 216 for i in range(len(x)): 217 if isinstance(x[i], BasePacket): 218 x[i] = x[i].copy() 219 return x 220 if hasattr(x, "copy"): 221 return x.copy() 222 return x 223 224 def set_val(self, pkt, val): 225 # type: (ASN1_Packet, Any) -> None 226 setattr(pkt, self.name, val) 227 228 def is_empty(self, pkt): 229 # type: (ASN1_Packet) -> bool 230 return getattr(pkt, self.name) is None 231 232 def get_fields_list(self): 233 # type: () -> List[ASN1F_field[Any, Any]] 234 return [self] 235 236 def __str__(self): 237 # type: () -> str 238 return repr(self) 239 240 def randval(self): 241 # type: () -> RandField[_I] 242 return cast(RandField[_I], RandInt()) 243 244 def copy(self): 245 # type: () -> ASN1F_field[_I, _A] 246 return copy.copy(self) 247 248 249############################ 250# Simple ASN1 Fields # 251############################ 252 253class ASN1F_BOOLEAN(ASN1F_field[bool, ASN1_BOOLEAN]): 254 ASN1_tag = ASN1_Class_UNIVERSAL.BOOLEAN 255 256 def randval(self): 257 # type: () -> RandChoice 258 return RandChoice(True, False) 259 260 261class ASN1F_INTEGER(ASN1F_field[int, ASN1_INTEGER]): 262 ASN1_tag = ASN1_Class_UNIVERSAL.INTEGER 263 264 def randval(self): 265 # type: () -> RandNum 266 return RandNum(-2**64, 2**64 - 1) 267 268 269class ASN1F_enum_INTEGER(ASN1F_INTEGER): 270 def __init__(self, 271 name, # type: str 272 default, # type: ASN1_INTEGER 273 enum, # type: Dict[int, str] 274 context=None, # type: Optional[Any] 275 implicit_tag=None, # type: Optional[Any] 276 explicit_tag=None, # type: Optional[Any] 277 ): 278 # type: (...) -> None 279 super(ASN1F_enum_INTEGER, self).__init__( 280 name, default, context=context, 281 implicit_tag=implicit_tag, 282 explicit_tag=explicit_tag 283 ) 284 i2s = self.i2s = {} # type: Dict[int, str] 285 s2i = self.s2i = {} # type: Dict[str, int] 286 if isinstance(enum, list): 287 keys = range(len(enum)) 288 else: 289 keys = list(enum) 290 if any(isinstance(x, str) for x in keys): 291 i2s, s2i = s2i, i2s # type: ignore 292 for k in keys: 293 i2s[k] = enum[k] 294 s2i[enum[k]] = k 295 296 def i2m(self, 297 pkt, # type: ASN1_Packet 298 s, # type: Union[bytes, str, int, ASN1_INTEGER] 299 ): 300 # type: (...) -> bytes 301 if not isinstance(s, str): 302 vs = s 303 else: 304 vs = self.s2i[s] 305 return super(ASN1F_enum_INTEGER, self).i2m(pkt, vs) 306 307 def i2repr(self, 308 pkt, # type: ASN1_Packet 309 x, # type: Union[str, int] 310 ): 311 # type: (...) -> str 312 if x is not None and isinstance(x, ASN1_INTEGER): 313 r = self.i2s.get(x.val) 314 if r: 315 return "'%s' %s" % (r, repr(x)) 316 return repr(x) 317 318 319class ASN1F_BIT_STRING(ASN1F_field[str, ASN1_BIT_STRING]): 320 ASN1_tag = ASN1_Class_UNIVERSAL.BIT_STRING 321 322 def __init__(self, 323 name, # type: str 324 default, # type: Optional[Union[ASN1_BIT_STRING, AnyStr]] 325 default_readable=True, # type: bool 326 context=None, # type: Optional[Any] 327 implicit_tag=None, # type: Optional[int] 328 explicit_tag=None, # type: Optional[int] 329 ): 330 # type: (...) -> None 331 super(ASN1F_BIT_STRING, self).__init__( 332 name, None, context=context, 333 implicit_tag=implicit_tag, 334 explicit_tag=explicit_tag 335 ) 336 if isinstance(default, (bytes, str)): 337 self.default = ASN1_BIT_STRING(default, 338 readable=default_readable) 339 else: 340 self.default = default 341 342 def randval(self): 343 # type: () -> RandString 344 return RandString(RandNum(0, 1000)) 345 346 347class ASN1F_STRING(ASN1F_field[str, ASN1_STRING]): 348 ASN1_tag = ASN1_Class_UNIVERSAL.STRING 349 350 def randval(self): 351 # type: () -> RandString 352 return RandString(RandNum(0, 1000)) 353 354 355class ASN1F_NULL(ASN1F_INTEGER): 356 ASN1_tag = ASN1_Class_UNIVERSAL.NULL 357 358 359class ASN1F_OID(ASN1F_field[str, ASN1_OID]): 360 ASN1_tag = ASN1_Class_UNIVERSAL.OID 361 362 def randval(self): 363 # type: () -> RandOID 364 return RandOID() 365 366 367class ASN1F_ENUMERATED(ASN1F_enum_INTEGER): 368 ASN1_tag = ASN1_Class_UNIVERSAL.ENUMERATED 369 370 371class ASN1F_UTF8_STRING(ASN1F_STRING): 372 ASN1_tag = ASN1_Class_UNIVERSAL.UTF8_STRING 373 374 375class ASN1F_NUMERIC_STRING(ASN1F_STRING): 376 ASN1_tag = ASN1_Class_UNIVERSAL.NUMERIC_STRING 377 378 379class ASN1F_PRINTABLE_STRING(ASN1F_STRING): 380 ASN1_tag = ASN1_Class_UNIVERSAL.PRINTABLE_STRING 381 382 383class ASN1F_T61_STRING(ASN1F_STRING): 384 ASN1_tag = ASN1_Class_UNIVERSAL.T61_STRING 385 386 387class ASN1F_VIDEOTEX_STRING(ASN1F_STRING): 388 ASN1_tag = ASN1_Class_UNIVERSAL.VIDEOTEX_STRING 389 390 391class ASN1F_IA5_STRING(ASN1F_STRING): 392 ASN1_tag = ASN1_Class_UNIVERSAL.IA5_STRING 393 394 395class ASN1F_GENERAL_STRING(ASN1F_STRING): 396 ASN1_tag = ASN1_Class_UNIVERSAL.GENERAL_STRING 397 398 399class ASN1F_UTC_TIME(ASN1F_STRING): 400 ASN1_tag = ASN1_Class_UNIVERSAL.UTC_TIME 401 402 def randval(self): # type: ignore 403 # type: () -> GeneralizedTime 404 return GeneralizedTime() 405 406 407class ASN1F_GENERALIZED_TIME(ASN1F_STRING): 408 ASN1_tag = ASN1_Class_UNIVERSAL.GENERALIZED_TIME 409 410 def randval(self): # type: ignore 411 # type: () -> GeneralizedTime 412 return GeneralizedTime() 413 414 415class ASN1F_ISO646_STRING(ASN1F_STRING): 416 ASN1_tag = ASN1_Class_UNIVERSAL.ISO646_STRING 417 418 419class ASN1F_UNIVERSAL_STRING(ASN1F_STRING): 420 ASN1_tag = ASN1_Class_UNIVERSAL.UNIVERSAL_STRING 421 422 423class ASN1F_BMP_STRING(ASN1F_STRING): 424 ASN1_tag = ASN1_Class_UNIVERSAL.BMP_STRING 425 426 427class ASN1F_SEQUENCE(ASN1F_field[List[Any], List[Any]]): 428 # Here is how you could decode a SEQUENCE 429 # with an unknown, private high-tag prefix : 430 # class PrivSeq(ASN1_Packet): 431 # ASN1_codec = ASN1_Codecs.BER 432 # ASN1_root = ASN1F_SEQUENCE( 433 # <asn1 field #0>, 434 # ... 435 # <asn1 field #N>, 436 # explicit_tag=0, 437 # flexible_tag=True) 438 # Because we use flexible_tag, the value of the explicit_tag does not matter. # noqa: E501 439 ASN1_tag = ASN1_Class_UNIVERSAL.SEQUENCE 440 holds_packets = 1 441 442 def __init__(self, *seq, **kwargs): 443 # type: (*Any, **Any) -> None 444 name = "dummy_seq_name" 445 default = [field.default for field in seq] 446 super(ASN1F_SEQUENCE, self).__init__( 447 name, default, **kwargs 448 ) 449 self.seq = seq 450 self.islist = len(seq) > 1 451 452 def __repr__(self): 453 # type: () -> str 454 return "<%s%r>" % (self.__class__.__name__, self.seq) 455 456 def is_empty(self, pkt): 457 # type: (ASN1_Packet) -> bool 458 return all(f.is_empty(pkt) for f in self.seq) 459 460 def get_fields_list(self): 461 # type: () -> List[ASN1F_field[Any, Any]] 462 return reduce(lambda x, y: x + y.get_fields_list(), 463 self.seq, []) 464 465 def m2i(self, pkt, s): 466 # type: (Any, bytes) -> Tuple[Any, bytes] 467 """ 468 ASN1F_SEQUENCE behaves transparently, with nested ASN1_objects being 469 dissected one by one. Because we use obj.dissect (see loop below) 470 instead of obj.m2i (as we trust dissect to do the appropriate set_vals) 471 we do not directly retrieve the list of nested objects. 472 Thus m2i returns an empty list (along with the proper remainder). 473 It is discarded by dissect() and should not be missed elsewhere. 474 """ 475 diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag, 476 implicit_tag=self.implicit_tag, 477 explicit_tag=self.explicit_tag, 478 safe=self.flexible_tag, 479 _fname=pkt.name) 480 if diff_tag is not None: 481 if self.implicit_tag is not None: 482 self.implicit_tag = diff_tag 483 elif self.explicit_tag is not None: 484 self.explicit_tag = diff_tag 485 codec = self.ASN1_tag.get_codec(pkt.ASN1_codec) 486 i, s, remain = codec.check_type_check_len(s) 487 if len(s) == 0: 488 for obj in self.seq: 489 obj.set_val(pkt, None) 490 else: 491 for obj in self.seq: 492 try: 493 s = obj.dissect(pkt, s) 494 except ASN1F_badsequence: 495 break 496 if len(s) > 0: 497 raise BER_Decoding_Error("unexpected remainder", remaining=s) 498 return [], remain 499 500 def dissect(self, pkt, s): 501 # type: (Any, bytes) -> bytes 502 _, x = self.m2i(pkt, s) 503 return x 504 505 def build(self, pkt): 506 # type: (ASN1_Packet) -> bytes 507 s = reduce(lambda x, y: x + y.build(pkt), 508 self.seq, b"") 509 return super(ASN1F_SEQUENCE, self).i2m(pkt, s) 510 511 512class ASN1F_SET(ASN1F_SEQUENCE): 513 ASN1_tag = ASN1_Class_UNIVERSAL.SET 514 515 516_SEQ_T = Union[ 517 'ASN1_Packet', 518 Type[ASN1F_field[Any, Any]], 519 'ASN1F_PACKET', 520 ASN1F_field[Any, Any], 521] 522 523 524class ASN1F_SEQUENCE_OF(ASN1F_field[List[_SEQ_T], 525 List[ASN1_Object[Any]]]): 526 """ 527 Two types are allowed as cls: ASN1_Packet, ASN1F_field 528 """ 529 ASN1_tag = ASN1_Class_UNIVERSAL.SEQUENCE 530 islist = 1 531 532 def __init__(self, 533 name, # type: str 534 default, # type: Any 535 cls, # type: _SEQ_T 536 context=None, # type: Optional[Any] 537 implicit_tag=None, # type: Optional[Any] 538 explicit_tag=None, # type: Optional[Any] 539 ): 540 # type: (...) -> None 541 if isinstance(cls, type) and issubclass(cls, ASN1F_field) or \ 542 isinstance(cls, ASN1F_field): 543 if isinstance(cls, type): 544 self.fld = cls(name, b"") 545 else: 546 self.fld = cls 547 self._extract_packet = lambda s, pkt: self.fld.m2i(pkt, s) 548 self.holds_packets = 0 549 elif hasattr(cls, "ASN1_root") or callable(cls): 550 self.cls = cast("Type[ASN1_Packet]", cls) 551 self._extract_packet = lambda s, pkt: self.extract_packet( 552 self.cls, s, _underlayer=pkt) 553 self.holds_packets = 1 554 else: 555 raise ValueError("cls should be an ASN1_Packet or ASN1_field") 556 super(ASN1F_SEQUENCE_OF, self).__init__( 557 name, None, context=context, 558 implicit_tag=implicit_tag, explicit_tag=explicit_tag 559 ) 560 self.default = default 561 562 def is_empty(self, 563 pkt, # type: ASN1_Packet 564 ): 565 # type: (...) -> bool 566 return ASN1F_field.is_empty(self, pkt) 567 568 def m2i(self, 569 pkt, # type: ASN1_Packet 570 s, # type: bytes 571 ): 572 # type: (...) -> Tuple[List[Any], bytes] 573 diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag, 574 implicit_tag=self.implicit_tag, 575 explicit_tag=self.explicit_tag, 576 safe=self.flexible_tag) 577 if diff_tag is not None: 578 if self.implicit_tag is not None: 579 self.implicit_tag = diff_tag 580 elif self.explicit_tag is not None: 581 self.explicit_tag = diff_tag 582 codec = self.ASN1_tag.get_codec(pkt.ASN1_codec) 583 i, s, remain = codec.check_type_check_len(s) 584 lst = [] 585 while s: 586 c, s = self._extract_packet(s, pkt) # type: ignore 587 if c: 588 lst.append(c) 589 if len(s) > 0: 590 raise BER_Decoding_Error("unexpected remainder", remaining=s) 591 return lst, remain 592 593 def build(self, pkt): 594 # type: (ASN1_Packet) -> bytes 595 val = getattr(pkt, self.name) 596 if isinstance(val, ASN1_Object) and \ 597 val.tag == ASN1_Class_UNIVERSAL.RAW: 598 s = cast(Union[List[_SEQ_T], bytes], val) 599 elif val is None: 600 s = b"" 601 else: 602 s = b"".join(raw(i) for i in val) 603 return self.i2m(pkt, s) 604 605 def i2repr(self, pkt, x): 606 # type: (ASN1_Packet, _I) -> str 607 if self.holds_packets: 608 return super(ASN1F_SEQUENCE_OF, self).i2repr(pkt, x) # type: ignore 609 else: 610 return "[%s]" % ", ".join( 611 self.fld.i2repr(pkt, x) for x in x # type: ignore 612 ) 613 614 def randval(self): 615 # type: () -> Any 616 if self.holds_packets: 617 return packet.fuzz(self.cls()) 618 else: 619 return self.fld.randval() 620 621 def __repr__(self): 622 # type: () -> str 623 return "<%s %s>" % (self.__class__.__name__, self.name) 624 625 626class ASN1F_SET_OF(ASN1F_SEQUENCE_OF): 627 ASN1_tag = ASN1_Class_UNIVERSAL.SET 628 629 630class ASN1F_IPADDRESS(ASN1F_STRING): 631 ASN1_tag = ASN1_Class_UNIVERSAL.IPADDRESS 632 633 634class ASN1F_TIME_TICKS(ASN1F_INTEGER): 635 ASN1_tag = ASN1_Class_UNIVERSAL.TIME_TICKS 636 637 638############################# 639# Complex ASN1 Fields # 640############################# 641 642class ASN1F_optional(ASN1F_element): 643 def __init__(self, field): 644 # type: (ASN1F_field[Any, Any]) -> None 645 field.flexible_tag = False 646 self._field = field 647 648 def __getattr__(self, attr): 649 # type: (str) -> Optional[Any] 650 return getattr(self._field, attr) 651 652 def m2i(self, pkt, s): 653 # type: (ASN1_Packet, bytes) -> Tuple[Any, bytes] 654 try: 655 return self._field.m2i(pkt, s) 656 except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error): 657 # ASN1_Error may be raised by ASN1F_CHOICE 658 return None, s 659 660 def dissect(self, pkt, s): 661 # type: (ASN1_Packet, bytes) -> bytes 662 try: 663 return self._field.dissect(pkt, s) 664 except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error): 665 self._field.set_val(pkt, None) 666 return s 667 668 def build(self, pkt): 669 # type: (ASN1_Packet) -> bytes 670 if self._field.is_empty(pkt): 671 return b"" 672 return self._field.build(pkt) 673 674 def any2i(self, pkt, x): 675 # type: (ASN1_Packet, Any) -> Any 676 return self._field.any2i(pkt, x) 677 678 def i2repr(self, pkt, x): 679 # type: (ASN1_Packet, Any) -> str 680 return self._field.i2repr(pkt, x) 681 682 683_CHOICE_T = Union['ASN1_Packet', Type[ASN1F_field[Any, Any]], 'ASN1F_PACKET'] 684 685 686class ASN1F_CHOICE(ASN1F_field[_CHOICE_T, ASN1_Object[Any]]): 687 """ 688 Multiple types are allowed: ASN1_Packet, ASN1F_field and ASN1F_PACKET(), 689 See layers/x509.py for examples. 690 Other ASN1F_field instances than ASN1F_PACKET instances must not be used. 691 """ 692 holds_packets = 1 693 ASN1_tag = ASN1_Class_UNIVERSAL.ANY 694 695 def __init__(self, name, default, *args, **kwargs): 696 # type: (str, Any, *_CHOICE_T, **Any) -> None 697 if "implicit_tag" in kwargs: 698 err_msg = "ASN1F_CHOICE has been called with an implicit_tag" 699 raise ASN1_Error(err_msg) 700 self.implicit_tag = None 701 for kwarg in ["context", "explicit_tag"]: 702 setattr(self, kwarg, kwargs.get(kwarg)) 703 super(ASN1F_CHOICE, self).__init__( 704 name, None, context=self.context, 705 explicit_tag=self.explicit_tag 706 ) 707 self.default = default 708 self.current_choice = None 709 self.choices = {} # type: Dict[int, _CHOICE_T] 710 self.pktchoices = {} 711 for p in args: 712 if hasattr(p, "ASN1_root"): 713 p = cast('ASN1_Packet', p) 714 # should be ASN1_Packet 715 if hasattr(p.ASN1_root, "choices"): 716 root = cast(ASN1F_CHOICE, p.ASN1_root) 717 for k, v in root.choices.items(): 718 # ASN1F_CHOICE recursion 719 self.choices[k] = v 720 else: 721 self.choices[p.ASN1_root.network_tag] = p 722 elif hasattr(p, "ASN1_tag"): 723 if isinstance(p, type): 724 # should be ASN1F_field class 725 self.choices[int(p.ASN1_tag)] = p 726 else: 727 # should be ASN1F_field instance 728 self.choices[p.network_tag] = p 729 self.pktchoices[hash(p.cls)] = (p.implicit_tag, p.explicit_tag) # noqa: E501 730 else: 731 raise ASN1_Error("ASN1F_CHOICE: no tag found for one field") 732 733 def m2i(self, pkt, s): 734 # type: (ASN1_Packet, bytes) -> Tuple[ASN1_Object[Any], bytes] 735 """ 736 First we have to retrieve the appropriate choice. 737 Then we extract the field/packet, according to this choice. 738 """ 739 if len(s) == 0: 740 raise ASN1_Error("ASN1F_CHOICE: got empty string") 741 _, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag, 742 explicit_tag=self.explicit_tag) 743 tag, _ = BER_id_dec(s) 744 if tag in self.choices: 745 choice = self.choices[tag] 746 else: 747 if self.flexible_tag: 748 choice = ASN1F_field 749 else: 750 raise ASN1_Error( 751 "ASN1F_CHOICE: unexpected field in '%s' " 752 "(tag %s not in possible tags %s)" % ( 753 self.name, tag, list(self.choices.keys()) 754 ) 755 ) 756 if hasattr(choice, "ASN1_root"): 757 # we don't want to import ASN1_Packet in this module... 758 return self.extract_packet(choice, s, _underlayer=pkt) # type: ignore 759 elif isinstance(choice, type): 760 return choice(self.name, b"").m2i(pkt, s) 761 else: 762 # XXX check properly if this is an ASN1F_PACKET 763 return choice.m2i(pkt, s) 764 765 def i2m(self, pkt, x): 766 # type: (ASN1_Packet, Any) -> bytes 767 if x is None: 768 s = b"" 769 else: 770 s = raw(x) 771 if hash(type(x)) in self.pktchoices: 772 imp, exp = self.pktchoices[hash(type(x))] 773 s = BER_tagging_enc(s, 774 implicit_tag=imp, 775 explicit_tag=exp) 776 return BER_tagging_enc(s, explicit_tag=self.explicit_tag) 777 778 def randval(self): 779 # type: () -> RandChoice 780 randchoices = [] 781 for p in self.choices.values(): 782 if hasattr(p, "ASN1_root"): 783 # should be ASN1_Packet class 784 randchoices.append(packet.fuzz(p())) # type: ignore 785 elif hasattr(p, "ASN1_tag"): 786 if isinstance(p, type): 787 # should be (basic) ASN1F_field class 788 randchoices.append(p("dummy", None).randval()) 789 else: 790 # should be ASN1F_PACKET instance 791 randchoices.append(p.randval()) 792 return RandChoice(*randchoices) 793 794 795class ASN1F_PACKET(ASN1F_field['ASN1_Packet', Optional['ASN1_Packet']]): 796 holds_packets = 1 797 798 def __init__(self, 799 name, # type: str 800 default, # type: Optional[ASN1_Packet] 801 cls, # type: Type[ASN1_Packet] 802 context=None, # type: Optional[Any] 803 implicit_tag=None, # type: Optional[int] 804 explicit_tag=None, # type: Optional[int] 805 next_cls_cb=None, # type: Optional[Callable[[ASN1_Packet], Type[ASN1_Packet]]] # noqa: E501 806 ): 807 # type: (...) -> None 808 self.cls = cls 809 self.next_cls_cb = next_cls_cb 810 super(ASN1F_PACKET, self).__init__( 811 name, None, context=context, 812 implicit_tag=implicit_tag, explicit_tag=explicit_tag 813 ) 814 if implicit_tag is None and explicit_tag is None and cls is not None: 815 if cls.ASN1_root.ASN1_tag == ASN1_Class_UNIVERSAL.SEQUENCE: 816 self.network_tag = 16 | 0x20 # 16 + CONSTRUCTED 817 self.default = default 818 819 def m2i(self, pkt, s): 820 # type: (ASN1_Packet, bytes) -> Tuple[Any, bytes] 821 if self.next_cls_cb: 822 cls = self.next_cls_cb(pkt) or self.cls 823 else: 824 cls = self.cls 825 if not hasattr(cls, "ASN1_root"): 826 # A normal Packet (!= ASN1) 827 return self.extract_packet(cls, s, _underlayer=pkt) 828 diff_tag, s = BER_tagging_dec(s, hidden_tag=cls.ASN1_root.ASN1_tag, # noqa: E501 829 implicit_tag=self.implicit_tag, 830 explicit_tag=self.explicit_tag, 831 safe=self.flexible_tag, 832 _fname=self.name) 833 if diff_tag is not None: 834 if self.implicit_tag is not None: 835 self.implicit_tag = diff_tag 836 elif self.explicit_tag is not None: 837 self.explicit_tag = diff_tag 838 if not s: 839 return None, s 840 return self.extract_packet(cls, s, _underlayer=pkt) 841 842 def i2m(self, 843 pkt, # type: ASN1_Packet 844 x # type: Union[bytes, ASN1_Packet, None, ASN1_Object[Optional[ASN1_Packet]]] # noqa: E501 845 ): 846 # type: (...) -> bytes 847 if x is None: 848 s = b"" 849 elif isinstance(x, bytes): 850 s = x 851 elif isinstance(x, ASN1_Object): 852 if x.val: 853 s = raw(x.val) 854 else: 855 s = b"" 856 else: 857 s = raw(x) 858 if not hasattr(x, "ASN1_root"): 859 # A normal Packet (!= ASN1) 860 return s 861 return BER_tagging_enc(s, 862 implicit_tag=self.implicit_tag, 863 explicit_tag=self.explicit_tag) 864 865 def any2i(self, 866 pkt, # type: ASN1_Packet 867 x # type: Union[bytes, ASN1_Packet, None, ASN1_Object[Optional[ASN1_Packet]]] # noqa: E501 868 ): 869 # type: (...) -> 'ASN1_Packet' 870 if hasattr(x, "add_underlayer"): 871 x.add_underlayer(pkt) # type: ignore 872 return super(ASN1F_PACKET, self).any2i(pkt, x) 873 874 def randval(self): # type: ignore 875 # type: () -> ASN1_Packet 876 return packet.fuzz(self.cls()) 877 878 879class ASN1F_BIT_STRING_ENCAPS(ASN1F_BIT_STRING): 880 """ 881 We may emulate simple string encapsulation with explicit_tag=0x04, 882 but we need a specific class for bit strings because of unused bits, etc. 883 """ 884 ASN1_tag = ASN1_Class_UNIVERSAL.BIT_STRING 885 886 def __init__(self, 887 name, # type: str 888 default, # type: Optional[ASN1_Packet] 889 cls, # type: Type[ASN1_Packet] 890 context=None, # type: Optional[Any] 891 implicit_tag=None, # type: Optional[int] 892 explicit_tag=None, # type: Optional[int] 893 ): 894 # type: (...) -> None 895 self.cls = cls 896 super(ASN1F_BIT_STRING_ENCAPS, self).__init__( # type: ignore 897 name, 898 default and raw(default), 899 context=context, 900 implicit_tag=implicit_tag, 901 explicit_tag=explicit_tag 902 ) 903 904 def m2i(self, pkt, s): # type: ignore 905 # type: (ASN1_Packet, bytes) -> Tuple[Optional[ASN1_Packet], bytes] 906 bit_string, remain = super(ASN1F_BIT_STRING_ENCAPS, self).m2i(pkt, s) 907 if len(bit_string.val) % 8 != 0: 908 raise BER_Decoding_Error("wrong bit string", remaining=s) 909 if bit_string.val_readable: 910 p, s = self.extract_packet(self.cls, bit_string.val_readable, 911 _underlayer=pkt) 912 else: 913 return None, bit_string.val_readable 914 if len(s) > 0: 915 raise BER_Decoding_Error("unexpected remainder", remaining=s) 916 return p, remain 917 918 def i2m(self, pkt, x): # type: ignore 919 # type: (ASN1_Packet, Optional[ASN1_BIT_STRING]) -> bytes 920 if not isinstance(x, ASN1_BIT_STRING): 921 x = ASN1_BIT_STRING( 922 b"" if x is None else bytes(x), # type: ignore 923 readable=True, 924 ) 925 return super(ASN1F_BIT_STRING_ENCAPS, self).i2m(pkt, x) 926 927 928class ASN1F_FLAGS(ASN1F_BIT_STRING): 929 def __init__(self, 930 name, # type: str 931 default, # type: Optional[str] 932 mapping, # type: List[str] 933 context=None, # type: Optional[Any] 934 implicit_tag=None, # type: Optional[int] 935 explicit_tag=None, # type: Optional[Any] 936 ): 937 # type: (...) -> None 938 self.mapping = mapping 939 super(ASN1F_FLAGS, self).__init__( 940 name, default, 941 default_readable=False, 942 context=context, 943 implicit_tag=implicit_tag, 944 explicit_tag=explicit_tag 945 ) 946 947 def any2i(self, pkt, x): 948 # type: (ASN1_Packet, Any) -> str 949 if isinstance(x, str): 950 if any(y not in ["0", "1"] for y in x): 951 # resolve the flags 952 value = ["0"] * len(self.mapping) 953 for i in x.split("+"): 954 value[self.mapping.index(i)] = "1" 955 x = "".join(value) 956 x = ASN1_BIT_STRING(x) 957 return super(ASN1F_FLAGS, self).any2i(pkt, x) 958 959 def get_flags(self, pkt): 960 # type: (ASN1_Packet) -> List[str] 961 fbytes = getattr(pkt, self.name).val 962 return [self.mapping[i] for i, positional in enumerate(fbytes) 963 if positional == '1' and i < len(self.mapping)] 964 965 def i2repr(self, pkt, x): 966 # type: (ASN1_Packet, Any) -> str 967 if x is not None: 968 pretty_s = ", ".join(self.get_flags(pkt)) 969 return pretty_s + " " + repr(x) 970 return repr(x) 971 972 973class ASN1F_STRING_PacketField(ASN1F_STRING): 974 """ 975 ASN1F_STRING that holds packets. 976 """ 977 holds_packets = 1 978 979 def i2m(self, pkt, val): 980 # type: (ASN1_Packet, Any) -> bytes 981 if hasattr(val, "ASN1_root"): 982 val = ASN1_STRING(bytes(val)) # type: ignore 983 return super(ASN1F_STRING_PacketField, self).i2m(pkt, val) 984 985 def any2i(self, pkt, x): 986 # type: (ASN1_Packet, Any) -> Any 987 if hasattr(x, "add_underlayer"): 988 x.add_underlayer(pkt) 989 return super(ASN1F_STRING_PacketField, self).any2i(pkt, x) 990