• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) Philippe Biondi <phil@secdev.org>
5# Acknowledgment: Maxence Tury <maxence.tury@ssi.gouv.fr>
6
7"""
8Classes that implement ASN.1 data structures.
9"""
10
11import copy
12
13from functools import reduce
14
15from scapy.asn1.asn1 import (
16    ASN1_BIT_STRING,
17    ASN1_BOOLEAN,
18    ASN1_Class,
19    ASN1_Class_UNIVERSAL,
20    ASN1_Error,
21    ASN1_INTEGER,
22    ASN1_NULL,
23    ASN1_OID,
24    ASN1_Object,
25    ASN1_STRING,
26)
27from scapy.asn1.ber import (
28    BER_Decoding_Error,
29    BER_id_dec,
30    BER_tagging_dec,
31    BER_tagging_enc,
32)
33from scapy.base_classes import BasePacket
34from scapy.compat import raw
35from scapy.volatile import (
36    GeneralizedTime,
37    RandChoice,
38    RandInt,
39    RandNum,
40    RandOID,
41    RandString,
42    RandField,
43)
44
45from scapy import packet
46
47from typing import (
48    Any,
49    AnyStr,
50    Callable,
51    Dict,
52    Generic,
53    List,
54    Optional,
55    Tuple,
56    Type,
57    TypeVar,
58    Union,
59    cast,
60    TYPE_CHECKING,
61)
62
63if TYPE_CHECKING:
64    from scapy.asn1packet import ASN1_Packet
65
66
67class ASN1F_badsequence(Exception):
68    pass
69
70
71class ASN1F_element(object):
72    pass
73
74
75##########################
76#    Basic ASN1 Field    #
77##########################
78
79_I = TypeVar('_I')  # Internal storage
80_A = TypeVar('_A')  # ASN.1 object
81
82
83class ASN1F_field(ASN1F_element, Generic[_I, _A]):
84    holds_packets = 0
85    islist = 0
86    ASN1_tag = ASN1_Class_UNIVERSAL.ANY
87    context = ASN1_Class_UNIVERSAL  # type: Type[ASN1_Class]
88
89    def __init__(self,
90                 name,  # type: str
91                 default,  # type: Optional[_A]
92                 context=None,  # type: Optional[Type[ASN1_Class]]
93                 implicit_tag=None,  # type: Optional[int]
94                 explicit_tag=None,  # type: Optional[int]
95                 flexible_tag=False,  # type: Optional[bool]
96                 size_len=None,  # type: Optional[int]
97                 ):
98        # type: (...) -> None
99        if context is not None:
100            self.context = context
101        self.name = name
102        if default is None:
103            self.default = default  # type: Optional[_A]
104        elif isinstance(default, ASN1_NULL):
105            self.default = default  # type: ignore
106        else:
107            self.default = self.ASN1_tag.asn1_object(default)  # type: ignore
108        self.size_len = size_len
109        self.flexible_tag = flexible_tag
110        if (implicit_tag is not None) and (explicit_tag is not None):
111            err_msg = "field cannot be both implicitly and explicitly tagged"
112            raise ASN1_Error(err_msg)
113        self.implicit_tag = implicit_tag and int(implicit_tag)
114        self.explicit_tag = explicit_tag and int(explicit_tag)
115        # network_tag gets useful for ASN1F_CHOICE
116        self.network_tag = int(implicit_tag or explicit_tag or self.ASN1_tag)
117        self.owners = []  # type: List[Type[ASN1_Packet]]
118
119    def register_owner(self, cls):
120        # type: (Type[ASN1_Packet]) -> None
121        self.owners.append(cls)
122
123    def i2repr(self, pkt, x):
124        # type: (ASN1_Packet, _I) -> str
125        return repr(x)
126
127    def i2h(self, pkt, x):
128        # type: (ASN1_Packet, _I) -> Any
129        return x
130
131    def m2i(self, pkt, s):
132        # type: (ASN1_Packet, bytes) -> Tuple[_A, bytes]
133        """
134        The good thing about safedec is that it may still decode ASN1
135        even if there is a mismatch between the expected tag (self.ASN1_tag)
136        and the actual tag; the decoded ASN1 object will simply be put
137        into an ASN1_BADTAG object. However, safedec prevents the raising of
138        exceptions needed for ASN1F_optional processing.
139        Thus we use 'flexible_tag', which should be False with ASN1F_optional.
140
141        Regarding other fields, we might need to know whether encoding went
142        as expected or not. Noticeably, input methods from cert.py expect
143        certain exceptions to be raised. Hence default flexible_tag is False.
144        """
145        diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
146                                      implicit_tag=self.implicit_tag,
147                                      explicit_tag=self.explicit_tag,
148                                      safe=self.flexible_tag,
149                                      _fname=self.name)
150        if diff_tag is not None:
151            # this implies that flexible_tag was True
152            if self.implicit_tag is not None:
153                self.implicit_tag = diff_tag
154            elif self.explicit_tag is not None:
155                self.explicit_tag = diff_tag
156        codec = self.ASN1_tag.get_codec(pkt.ASN1_codec)
157        if self.flexible_tag:
158            return codec.safedec(s, context=self.context)  # type: ignore
159        else:
160            return codec.dec(s, context=self.context)  # type: ignore
161
162    def i2m(self, pkt, x):
163        # type: (ASN1_Packet, Union[bytes, _I, _A]) -> bytes
164        if x is None:
165            return b""
166        if isinstance(x, ASN1_Object):
167            if (self.ASN1_tag == ASN1_Class_UNIVERSAL.ANY or
168                x.tag == ASN1_Class_UNIVERSAL.RAW or
169                x.tag == ASN1_Class_UNIVERSAL.ERROR or
170               self.ASN1_tag == x.tag):
171                s = x.enc(pkt.ASN1_codec)
172            else:
173                raise ASN1_Error("Encoding Error: got %r instead of an %r for field [%s]" % (x, self.ASN1_tag, self.name))  # noqa: E501
174        else:
175            s = self.ASN1_tag.get_codec(pkt.ASN1_codec).enc(x, size_len=self.size_len)
176        return BER_tagging_enc(s,
177                               implicit_tag=self.implicit_tag,
178                               explicit_tag=self.explicit_tag)
179
180    def any2i(self, pkt, x):
181        # type: (ASN1_Packet, Any) -> _I
182        return cast(_I, x)
183
184    def extract_packet(self,
185                       cls,  # type: Type[ASN1_Packet]
186                       s,  # type: bytes
187                       _underlayer=None  # type: Optional[ASN1_Packet]
188                       ):
189        # type: (...) -> Tuple[ASN1_Packet, bytes]
190        try:
191            c = cls(s, _underlayer=_underlayer)
192        except ASN1F_badsequence:
193            c = packet.Raw(s, _underlayer=_underlayer)  # type: ignore
194        cpad = c.getlayer(packet.Raw)
195        s = b""
196        if cpad is not None:
197            s = cpad.load
198            if cpad.underlayer:
199                del cpad.underlayer.payload
200        return c, s
201
202    def build(self, pkt):
203        # type: (ASN1_Packet) -> bytes
204        return self.i2m(pkt, getattr(pkt, self.name))
205
206    def dissect(self, pkt, s):
207        # type: (ASN1_Packet, bytes) -> bytes
208        v, s = self.m2i(pkt, s)
209        self.set_val(pkt, v)
210        return s
211
212    def do_copy(self, x):
213        # type: (Any) -> Any
214        if isinstance(x, list):
215            x = x[:]
216            for i in range(len(x)):
217                if isinstance(x[i], BasePacket):
218                    x[i] = x[i].copy()
219            return x
220        if hasattr(x, "copy"):
221            return x.copy()
222        return x
223
224    def set_val(self, pkt, val):
225        # type: (ASN1_Packet, Any) -> None
226        setattr(pkt, self.name, val)
227
228    def is_empty(self, pkt):
229        # type: (ASN1_Packet) -> bool
230        return getattr(pkt, self.name) is None
231
232    def get_fields_list(self):
233        # type: () -> List[ASN1F_field[Any, Any]]
234        return [self]
235
236    def __str__(self):
237        # type: () -> str
238        return repr(self)
239
240    def randval(self):
241        # type: () -> RandField[_I]
242        return cast(RandField[_I], RandInt())
243
244    def copy(self):
245        # type: () -> ASN1F_field[_I, _A]
246        return copy.copy(self)
247
248
249############################
250#    Simple ASN1 Fields    #
251############################
252
253class ASN1F_BOOLEAN(ASN1F_field[bool, ASN1_BOOLEAN]):
254    ASN1_tag = ASN1_Class_UNIVERSAL.BOOLEAN
255
256    def randval(self):
257        # type: () -> RandChoice
258        return RandChoice(True, False)
259
260
261class ASN1F_INTEGER(ASN1F_field[int, ASN1_INTEGER]):
262    ASN1_tag = ASN1_Class_UNIVERSAL.INTEGER
263
264    def randval(self):
265        # type: () -> RandNum
266        return RandNum(-2**64, 2**64 - 1)
267
268
269class ASN1F_enum_INTEGER(ASN1F_INTEGER):
270    def __init__(self,
271                 name,  # type: str
272                 default,  # type: ASN1_INTEGER
273                 enum,  # type: Dict[int, str]
274                 context=None,  # type: Optional[Any]
275                 implicit_tag=None,  # type: Optional[Any]
276                 explicit_tag=None,  # type: Optional[Any]
277                 ):
278        # type: (...) -> None
279        super(ASN1F_enum_INTEGER, self).__init__(
280            name, default, context=context,
281            implicit_tag=implicit_tag,
282            explicit_tag=explicit_tag
283        )
284        i2s = self.i2s = {}  # type: Dict[int, str]
285        s2i = self.s2i = {}  # type: Dict[str, int]
286        if isinstance(enum, list):
287            keys = range(len(enum))
288        else:
289            keys = list(enum)
290        if any(isinstance(x, str) for x in keys):
291            i2s, s2i = s2i, i2s  # type: ignore
292        for k in keys:
293            i2s[k] = enum[k]
294            s2i[enum[k]] = k
295
296    def i2m(self,
297            pkt,  # type: ASN1_Packet
298            s,  # type: Union[bytes, str, int, ASN1_INTEGER]
299            ):
300        # type: (...) -> bytes
301        if not isinstance(s, str):
302            vs = s
303        else:
304            vs = self.s2i[s]
305        return super(ASN1F_enum_INTEGER, self).i2m(pkt, vs)
306
307    def i2repr(self,
308               pkt,  # type: ASN1_Packet
309               x,  # type: Union[str, int]
310               ):
311        # type: (...) -> str
312        if x is not None and isinstance(x, ASN1_INTEGER):
313            r = self.i2s.get(x.val)
314            if r:
315                return "'%s' %s" % (r, repr(x))
316        return repr(x)
317
318
319class ASN1F_BIT_STRING(ASN1F_field[str, ASN1_BIT_STRING]):
320    ASN1_tag = ASN1_Class_UNIVERSAL.BIT_STRING
321
322    def __init__(self,
323                 name,  # type: str
324                 default,  # type: Optional[Union[ASN1_BIT_STRING, AnyStr]]
325                 default_readable=True,  # type: bool
326                 context=None,  # type: Optional[Any]
327                 implicit_tag=None,  # type: Optional[int]
328                 explicit_tag=None,  # type: Optional[int]
329                 ):
330        # type: (...) -> None
331        super(ASN1F_BIT_STRING, self).__init__(
332            name, None, context=context,
333            implicit_tag=implicit_tag,
334            explicit_tag=explicit_tag
335        )
336        if isinstance(default, (bytes, str)):
337            self.default = ASN1_BIT_STRING(default,
338                                           readable=default_readable)
339        else:
340            self.default = default
341
342    def randval(self):
343        # type: () -> RandString
344        return RandString(RandNum(0, 1000))
345
346
347class ASN1F_STRING(ASN1F_field[str, ASN1_STRING]):
348    ASN1_tag = ASN1_Class_UNIVERSAL.STRING
349
350    def randval(self):
351        # type: () -> RandString
352        return RandString(RandNum(0, 1000))
353
354
355class ASN1F_NULL(ASN1F_INTEGER):
356    ASN1_tag = ASN1_Class_UNIVERSAL.NULL
357
358
359class ASN1F_OID(ASN1F_field[str, ASN1_OID]):
360    ASN1_tag = ASN1_Class_UNIVERSAL.OID
361
362    def randval(self):
363        # type: () -> RandOID
364        return RandOID()
365
366
367class ASN1F_ENUMERATED(ASN1F_enum_INTEGER):
368    ASN1_tag = ASN1_Class_UNIVERSAL.ENUMERATED
369
370
371class ASN1F_UTF8_STRING(ASN1F_STRING):
372    ASN1_tag = ASN1_Class_UNIVERSAL.UTF8_STRING
373
374
375class ASN1F_NUMERIC_STRING(ASN1F_STRING):
376    ASN1_tag = ASN1_Class_UNIVERSAL.NUMERIC_STRING
377
378
379class ASN1F_PRINTABLE_STRING(ASN1F_STRING):
380    ASN1_tag = ASN1_Class_UNIVERSAL.PRINTABLE_STRING
381
382
383class ASN1F_T61_STRING(ASN1F_STRING):
384    ASN1_tag = ASN1_Class_UNIVERSAL.T61_STRING
385
386
387class ASN1F_VIDEOTEX_STRING(ASN1F_STRING):
388    ASN1_tag = ASN1_Class_UNIVERSAL.VIDEOTEX_STRING
389
390
391class ASN1F_IA5_STRING(ASN1F_STRING):
392    ASN1_tag = ASN1_Class_UNIVERSAL.IA5_STRING
393
394
395class ASN1F_GENERAL_STRING(ASN1F_STRING):
396    ASN1_tag = ASN1_Class_UNIVERSAL.GENERAL_STRING
397
398
399class ASN1F_UTC_TIME(ASN1F_STRING):
400    ASN1_tag = ASN1_Class_UNIVERSAL.UTC_TIME
401
402    def randval(self):  # type: ignore
403        # type: () -> GeneralizedTime
404        return GeneralizedTime()
405
406
407class ASN1F_GENERALIZED_TIME(ASN1F_STRING):
408    ASN1_tag = ASN1_Class_UNIVERSAL.GENERALIZED_TIME
409
410    def randval(self):  # type: ignore
411        # type: () -> GeneralizedTime
412        return GeneralizedTime()
413
414
415class ASN1F_ISO646_STRING(ASN1F_STRING):
416    ASN1_tag = ASN1_Class_UNIVERSAL.ISO646_STRING
417
418
419class ASN1F_UNIVERSAL_STRING(ASN1F_STRING):
420    ASN1_tag = ASN1_Class_UNIVERSAL.UNIVERSAL_STRING
421
422
423class ASN1F_BMP_STRING(ASN1F_STRING):
424    ASN1_tag = ASN1_Class_UNIVERSAL.BMP_STRING
425
426
427class ASN1F_SEQUENCE(ASN1F_field[List[Any], List[Any]]):
428    # Here is how you could decode a SEQUENCE
429    # with an unknown, private high-tag prefix :
430    # class PrivSeq(ASN1_Packet):
431    #     ASN1_codec = ASN1_Codecs.BER
432    #     ASN1_root = ASN1F_SEQUENCE(
433    #                       <asn1 field #0>,
434    #                       ...
435    #                       <asn1 field #N>,
436    #                       explicit_tag=0,
437    #                       flexible_tag=True)
438    # Because we use flexible_tag, the value of the explicit_tag does not matter.  # noqa: E501
439    ASN1_tag = ASN1_Class_UNIVERSAL.SEQUENCE
440    holds_packets = 1
441
442    def __init__(self, *seq, **kwargs):
443        # type: (*Any, **Any) -> None
444        name = "dummy_seq_name"
445        default = [field.default for field in seq]
446        super(ASN1F_SEQUENCE, self).__init__(
447            name, default, **kwargs
448        )
449        self.seq = seq
450        self.islist = len(seq) > 1
451
452    def __repr__(self):
453        # type: () -> str
454        return "<%s%r>" % (self.__class__.__name__, self.seq)
455
456    def is_empty(self, pkt):
457        # type: (ASN1_Packet) -> bool
458        return all(f.is_empty(pkt) for f in self.seq)
459
460    def get_fields_list(self):
461        # type: () -> List[ASN1F_field[Any, Any]]
462        return reduce(lambda x, y: x + y.get_fields_list(),
463                      self.seq, [])
464
465    def m2i(self, pkt, s):
466        # type: (Any, bytes) -> Tuple[Any, bytes]
467        """
468        ASN1F_SEQUENCE behaves transparently, with nested ASN1_objects being
469        dissected one by one. Because we use obj.dissect (see loop below)
470        instead of obj.m2i (as we trust dissect to do the appropriate set_vals)
471        we do not directly retrieve the list of nested objects.
472        Thus m2i returns an empty list (along with the proper remainder).
473        It is discarded by dissect() and should not be missed elsewhere.
474        """
475        diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
476                                      implicit_tag=self.implicit_tag,
477                                      explicit_tag=self.explicit_tag,
478                                      safe=self.flexible_tag,
479                                      _fname=pkt.name)
480        if diff_tag is not None:
481            if self.implicit_tag is not None:
482                self.implicit_tag = diff_tag
483            elif self.explicit_tag is not None:
484                self.explicit_tag = diff_tag
485        codec = self.ASN1_tag.get_codec(pkt.ASN1_codec)
486        i, s, remain = codec.check_type_check_len(s)
487        if len(s) == 0:
488            for obj in self.seq:
489                obj.set_val(pkt, None)
490        else:
491            for obj in self.seq:
492                try:
493                    s = obj.dissect(pkt, s)
494                except ASN1F_badsequence:
495                    break
496            if len(s) > 0:
497                raise BER_Decoding_Error("unexpected remainder", remaining=s)
498        return [], remain
499
500    def dissect(self, pkt, s):
501        # type: (Any, bytes) -> bytes
502        _, x = self.m2i(pkt, s)
503        return x
504
505    def build(self, pkt):
506        # type: (ASN1_Packet) -> bytes
507        s = reduce(lambda x, y: x + y.build(pkt),
508                   self.seq, b"")
509        return super(ASN1F_SEQUENCE, self).i2m(pkt, s)
510
511
512class ASN1F_SET(ASN1F_SEQUENCE):
513    ASN1_tag = ASN1_Class_UNIVERSAL.SET
514
515
516_SEQ_T = Union[
517    'ASN1_Packet',
518    Type[ASN1F_field[Any, Any]],
519    'ASN1F_PACKET',
520    ASN1F_field[Any, Any],
521]
522
523
524class ASN1F_SEQUENCE_OF(ASN1F_field[List[_SEQ_T],
525                                    List[ASN1_Object[Any]]]):
526    """
527    Two types are allowed as cls: ASN1_Packet, ASN1F_field
528    """
529    ASN1_tag = ASN1_Class_UNIVERSAL.SEQUENCE
530    islist = 1
531
532    def __init__(self,
533                 name,  # type: str
534                 default,  # type: Any
535                 cls,  # type: _SEQ_T
536                 context=None,  # type: Optional[Any]
537                 implicit_tag=None,  # type: Optional[Any]
538                 explicit_tag=None,  # type: Optional[Any]
539                 ):
540        # type: (...) -> None
541        if isinstance(cls, type) and issubclass(cls, ASN1F_field) or \
542                isinstance(cls, ASN1F_field):
543            if isinstance(cls, type):
544                self.fld = cls(name, b"")
545            else:
546                self.fld = cls
547            self._extract_packet = lambda s, pkt: self.fld.m2i(pkt, s)
548            self.holds_packets = 0
549        elif hasattr(cls, "ASN1_root") or callable(cls):
550            self.cls = cast("Type[ASN1_Packet]", cls)
551            self._extract_packet = lambda s, pkt: self.extract_packet(
552                self.cls, s, _underlayer=pkt)
553            self.holds_packets = 1
554        else:
555            raise ValueError("cls should be an ASN1_Packet or ASN1_field")
556        super(ASN1F_SEQUENCE_OF, self).__init__(
557            name, None, context=context,
558            implicit_tag=implicit_tag, explicit_tag=explicit_tag
559        )
560        self.default = default
561
562    def is_empty(self,
563                 pkt,  # type: ASN1_Packet
564                 ):
565        # type: (...) -> bool
566        return ASN1F_field.is_empty(self, pkt)
567
568    def m2i(self,
569            pkt,  # type: ASN1_Packet
570            s,  # type: bytes
571            ):
572        # type: (...) -> Tuple[List[Any], bytes]
573        diff_tag, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
574                                      implicit_tag=self.implicit_tag,
575                                      explicit_tag=self.explicit_tag,
576                                      safe=self.flexible_tag)
577        if diff_tag is not None:
578            if self.implicit_tag is not None:
579                self.implicit_tag = diff_tag
580            elif self.explicit_tag is not None:
581                self.explicit_tag = diff_tag
582        codec = self.ASN1_tag.get_codec(pkt.ASN1_codec)
583        i, s, remain = codec.check_type_check_len(s)
584        lst = []
585        while s:
586            c, s = self._extract_packet(s, pkt)  # type: ignore
587            if c:
588                lst.append(c)
589        if len(s) > 0:
590            raise BER_Decoding_Error("unexpected remainder", remaining=s)
591        return lst, remain
592
593    def build(self, pkt):
594        # type: (ASN1_Packet) -> bytes
595        val = getattr(pkt, self.name)
596        if isinstance(val, ASN1_Object) and \
597                val.tag == ASN1_Class_UNIVERSAL.RAW:
598            s = cast(Union[List[_SEQ_T], bytes], val)
599        elif val is None:
600            s = b""
601        else:
602            s = b"".join(raw(i) for i in val)
603        return self.i2m(pkt, s)
604
605    def i2repr(self, pkt, x):
606        # type: (ASN1_Packet, _I) -> str
607        if self.holds_packets:
608            return super(ASN1F_SEQUENCE_OF, self).i2repr(pkt, x)  # type: ignore
609        else:
610            return "[%s]" % ", ".join(
611                self.fld.i2repr(pkt, x) for x in x  # type: ignore
612            )
613
614    def randval(self):
615        # type: () -> Any
616        if self.holds_packets:
617            return packet.fuzz(self.cls())
618        else:
619            return self.fld.randval()
620
621    def __repr__(self):
622        # type: () -> str
623        return "<%s %s>" % (self.__class__.__name__, self.name)
624
625
626class ASN1F_SET_OF(ASN1F_SEQUENCE_OF):
627    ASN1_tag = ASN1_Class_UNIVERSAL.SET
628
629
630class ASN1F_IPADDRESS(ASN1F_STRING):
631    ASN1_tag = ASN1_Class_UNIVERSAL.IPADDRESS
632
633
634class ASN1F_TIME_TICKS(ASN1F_INTEGER):
635    ASN1_tag = ASN1_Class_UNIVERSAL.TIME_TICKS
636
637
638#############################
639#    Complex ASN1 Fields    #
640#############################
641
642class ASN1F_optional(ASN1F_element):
643    def __init__(self, field):
644        # type: (ASN1F_field[Any, Any]) -> None
645        field.flexible_tag = False
646        self._field = field
647
648    def __getattr__(self, attr):
649        # type: (str) -> Optional[Any]
650        return getattr(self._field, attr)
651
652    def m2i(self, pkt, s):
653        # type: (ASN1_Packet, bytes) -> Tuple[Any, bytes]
654        try:
655            return self._field.m2i(pkt, s)
656        except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error):
657            # ASN1_Error may be raised by ASN1F_CHOICE
658            return None, s
659
660    def dissect(self, pkt, s):
661        # type: (ASN1_Packet, bytes) -> bytes
662        try:
663            return self._field.dissect(pkt, s)
664        except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error):
665            self._field.set_val(pkt, None)
666            return s
667
668    def build(self, pkt):
669        # type: (ASN1_Packet) -> bytes
670        if self._field.is_empty(pkt):
671            return b""
672        return self._field.build(pkt)
673
674    def any2i(self, pkt, x):
675        # type: (ASN1_Packet, Any) -> Any
676        return self._field.any2i(pkt, x)
677
678    def i2repr(self, pkt, x):
679        # type: (ASN1_Packet, Any) -> str
680        return self._field.i2repr(pkt, x)
681
682
683_CHOICE_T = Union['ASN1_Packet', Type[ASN1F_field[Any, Any]], 'ASN1F_PACKET']
684
685
686class ASN1F_CHOICE(ASN1F_field[_CHOICE_T, ASN1_Object[Any]]):
687    """
688    Multiple types are allowed: ASN1_Packet, ASN1F_field and ASN1F_PACKET(),
689    See layers/x509.py for examples.
690    Other ASN1F_field instances than ASN1F_PACKET instances must not be used.
691    """
692    holds_packets = 1
693    ASN1_tag = ASN1_Class_UNIVERSAL.ANY
694
695    def __init__(self, name, default, *args, **kwargs):
696        # type: (str, Any, *_CHOICE_T, **Any) -> None
697        if "implicit_tag" in kwargs:
698            err_msg = "ASN1F_CHOICE has been called with an implicit_tag"
699            raise ASN1_Error(err_msg)
700        self.implicit_tag = None
701        for kwarg in ["context", "explicit_tag"]:
702            setattr(self, kwarg, kwargs.get(kwarg))
703        super(ASN1F_CHOICE, self).__init__(
704            name, None, context=self.context,
705            explicit_tag=self.explicit_tag
706        )
707        self.default = default
708        self.current_choice = None
709        self.choices = {}  # type: Dict[int, _CHOICE_T]
710        self.pktchoices = {}
711        for p in args:
712            if hasattr(p, "ASN1_root"):
713                p = cast('ASN1_Packet', p)
714                # should be ASN1_Packet
715                if hasattr(p.ASN1_root, "choices"):
716                    root = cast(ASN1F_CHOICE, p.ASN1_root)
717                    for k, v in root.choices.items():
718                        # ASN1F_CHOICE recursion
719                        self.choices[k] = v
720                else:
721                    self.choices[p.ASN1_root.network_tag] = p
722            elif hasattr(p, "ASN1_tag"):
723                if isinstance(p, type):
724                    # should be ASN1F_field class
725                    self.choices[int(p.ASN1_tag)] = p
726                else:
727                    # should be ASN1F_field instance
728                    self.choices[p.network_tag] = p
729                    self.pktchoices[hash(p.cls)] = (p.implicit_tag, p.explicit_tag)  # noqa: E501
730            else:
731                raise ASN1_Error("ASN1F_CHOICE: no tag found for one field")
732
733    def m2i(self, pkt, s):
734        # type: (ASN1_Packet, bytes) -> Tuple[ASN1_Object[Any], bytes]
735        """
736        First we have to retrieve the appropriate choice.
737        Then we extract the field/packet, according to this choice.
738        """
739        if len(s) == 0:
740            raise ASN1_Error("ASN1F_CHOICE: got empty string")
741        _, s = BER_tagging_dec(s, hidden_tag=self.ASN1_tag,
742                               explicit_tag=self.explicit_tag)
743        tag, _ = BER_id_dec(s)
744        if tag in self.choices:
745            choice = self.choices[tag]
746        else:
747            if self.flexible_tag:
748                choice = ASN1F_field
749            else:
750                raise ASN1_Error(
751                    "ASN1F_CHOICE: unexpected field in '%s' "
752                    "(tag %s not in possible tags %s)" % (
753                        self.name, tag, list(self.choices.keys())
754                    )
755                )
756        if hasattr(choice, "ASN1_root"):
757            # we don't want to import ASN1_Packet in this module...
758            return self.extract_packet(choice, s, _underlayer=pkt)  # type: ignore
759        elif isinstance(choice, type):
760            return choice(self.name, b"").m2i(pkt, s)
761        else:
762            # XXX check properly if this is an ASN1F_PACKET
763            return choice.m2i(pkt, s)
764
765    def i2m(self, pkt, x):
766        # type: (ASN1_Packet, Any) -> bytes
767        if x is None:
768            s = b""
769        else:
770            s = raw(x)
771            if hash(type(x)) in self.pktchoices:
772                imp, exp = self.pktchoices[hash(type(x))]
773                s = BER_tagging_enc(s,
774                                    implicit_tag=imp,
775                                    explicit_tag=exp)
776        return BER_tagging_enc(s, explicit_tag=self.explicit_tag)
777
778    def randval(self):
779        # type: () -> RandChoice
780        randchoices = []
781        for p in self.choices.values():
782            if hasattr(p, "ASN1_root"):
783                # should be ASN1_Packet class
784                randchoices.append(packet.fuzz(p()))  # type: ignore
785            elif hasattr(p, "ASN1_tag"):
786                if isinstance(p, type):
787                    # should be (basic) ASN1F_field class
788                    randchoices.append(p("dummy", None).randval())
789                else:
790                    # should be ASN1F_PACKET instance
791                    randchoices.append(p.randval())
792        return RandChoice(*randchoices)
793
794
795class ASN1F_PACKET(ASN1F_field['ASN1_Packet', Optional['ASN1_Packet']]):
796    holds_packets = 1
797
798    def __init__(self,
799                 name,  # type: str
800                 default,  # type: Optional[ASN1_Packet]
801                 cls,  # type: Type[ASN1_Packet]
802                 context=None,  # type: Optional[Any]
803                 implicit_tag=None,  # type: Optional[int]
804                 explicit_tag=None,  # type: Optional[int]
805                 next_cls_cb=None,  # type: Optional[Callable[[ASN1_Packet], Type[ASN1_Packet]]]  # noqa: E501
806                 ):
807        # type: (...) -> None
808        self.cls = cls
809        self.next_cls_cb = next_cls_cb
810        super(ASN1F_PACKET, self).__init__(
811            name, None, context=context,
812            implicit_tag=implicit_tag, explicit_tag=explicit_tag
813        )
814        if implicit_tag is None and explicit_tag is None and cls is not None:
815            if cls.ASN1_root.ASN1_tag == ASN1_Class_UNIVERSAL.SEQUENCE:
816                self.network_tag = 16 | 0x20  # 16 + CONSTRUCTED
817        self.default = default
818
819    def m2i(self, pkt, s):
820        # type: (ASN1_Packet, bytes) -> Tuple[Any, bytes]
821        if self.next_cls_cb:
822            cls = self.next_cls_cb(pkt) or self.cls
823        else:
824            cls = self.cls
825        if not hasattr(cls, "ASN1_root"):
826            # A normal Packet (!= ASN1)
827            return self.extract_packet(cls, s, _underlayer=pkt)
828        diff_tag, s = BER_tagging_dec(s, hidden_tag=cls.ASN1_root.ASN1_tag,  # noqa: E501
829                                      implicit_tag=self.implicit_tag,
830                                      explicit_tag=self.explicit_tag,
831                                      safe=self.flexible_tag,
832                                      _fname=self.name)
833        if diff_tag is not None:
834            if self.implicit_tag is not None:
835                self.implicit_tag = diff_tag
836            elif self.explicit_tag is not None:
837                self.explicit_tag = diff_tag
838        if not s:
839            return None, s
840        return self.extract_packet(cls, s, _underlayer=pkt)
841
842    def i2m(self,
843            pkt,  # type: ASN1_Packet
844            x  # type: Union[bytes, ASN1_Packet, None, ASN1_Object[Optional[ASN1_Packet]]]  # noqa: E501
845            ):
846        # type: (...) -> bytes
847        if x is None:
848            s = b""
849        elif isinstance(x, bytes):
850            s = x
851        elif isinstance(x, ASN1_Object):
852            if x.val:
853                s = raw(x.val)
854            else:
855                s = b""
856        else:
857            s = raw(x)
858            if not hasattr(x, "ASN1_root"):
859                # A normal Packet (!= ASN1)
860                return s
861        return BER_tagging_enc(s,
862                               implicit_tag=self.implicit_tag,
863                               explicit_tag=self.explicit_tag)
864
865    def any2i(self,
866              pkt,  # type: ASN1_Packet
867              x  # type: Union[bytes, ASN1_Packet, None, ASN1_Object[Optional[ASN1_Packet]]]  # noqa: E501
868              ):
869        # type: (...) -> 'ASN1_Packet'
870        if hasattr(x, "add_underlayer"):
871            x.add_underlayer(pkt)  # type: ignore
872        return super(ASN1F_PACKET, self).any2i(pkt, x)
873
874    def randval(self):  # type: ignore
875        # type: () -> ASN1_Packet
876        return packet.fuzz(self.cls())
877
878
879class ASN1F_BIT_STRING_ENCAPS(ASN1F_BIT_STRING):
880    """
881    We may emulate simple string encapsulation with explicit_tag=0x04,
882    but we need a specific class for bit strings because of unused bits, etc.
883    """
884    ASN1_tag = ASN1_Class_UNIVERSAL.BIT_STRING
885
886    def __init__(self,
887                 name,  # type: str
888                 default,  # type: Optional[ASN1_Packet]
889                 cls,  # type: Type[ASN1_Packet]
890                 context=None,  # type: Optional[Any]
891                 implicit_tag=None,  # type: Optional[int]
892                 explicit_tag=None,  # type: Optional[int]
893                 ):
894        # type: (...) -> None
895        self.cls = cls
896        super(ASN1F_BIT_STRING_ENCAPS, self).__init__(  # type: ignore
897            name,
898            default and raw(default),
899            context=context,
900            implicit_tag=implicit_tag,
901            explicit_tag=explicit_tag
902        )
903
904    def m2i(self, pkt, s):  # type: ignore
905        # type: (ASN1_Packet, bytes) -> Tuple[Optional[ASN1_Packet], bytes]
906        bit_string, remain = super(ASN1F_BIT_STRING_ENCAPS, self).m2i(pkt, s)
907        if len(bit_string.val) % 8 != 0:
908            raise BER_Decoding_Error("wrong bit string", remaining=s)
909        if bit_string.val_readable:
910            p, s = self.extract_packet(self.cls, bit_string.val_readable,
911                                       _underlayer=pkt)
912        else:
913            return None, bit_string.val_readable
914        if len(s) > 0:
915            raise BER_Decoding_Error("unexpected remainder", remaining=s)
916        return p, remain
917
918    def i2m(self, pkt, x):  # type: ignore
919        # type: (ASN1_Packet, Optional[ASN1_BIT_STRING]) -> bytes
920        if not isinstance(x, ASN1_BIT_STRING):
921            x = ASN1_BIT_STRING(
922                b"" if x is None else bytes(x),  # type: ignore
923                readable=True,
924            )
925        return super(ASN1F_BIT_STRING_ENCAPS, self).i2m(pkt, x)
926
927
928class ASN1F_FLAGS(ASN1F_BIT_STRING):
929    def __init__(self,
930                 name,  # type: str
931                 default,  # type: Optional[str]
932                 mapping,  # type: List[str]
933                 context=None,  # type: Optional[Any]
934                 implicit_tag=None,  # type: Optional[int]
935                 explicit_tag=None,  # type: Optional[Any]
936                 ):
937        # type: (...) -> None
938        self.mapping = mapping
939        super(ASN1F_FLAGS, self).__init__(
940            name, default,
941            default_readable=False,
942            context=context,
943            implicit_tag=implicit_tag,
944            explicit_tag=explicit_tag
945        )
946
947    def any2i(self, pkt, x):
948        # type: (ASN1_Packet, Any) -> str
949        if isinstance(x, str):
950            if any(y not in ["0", "1"] for y in x):
951                # resolve the flags
952                value = ["0"] * len(self.mapping)
953                for i in x.split("+"):
954                    value[self.mapping.index(i)] = "1"
955                x = "".join(value)
956            x = ASN1_BIT_STRING(x)
957        return super(ASN1F_FLAGS, self).any2i(pkt, x)
958
959    def get_flags(self, pkt):
960        # type: (ASN1_Packet) -> List[str]
961        fbytes = getattr(pkt, self.name).val
962        return [self.mapping[i] for i, positional in enumerate(fbytes)
963                if positional == '1' and i < len(self.mapping)]
964
965    def i2repr(self, pkt, x):
966        # type: (ASN1_Packet, Any) -> str
967        if x is not None:
968            pretty_s = ", ".join(self.get_flags(pkt))
969            return pretty_s + " " + repr(x)
970        return repr(x)
971
972
973class ASN1F_STRING_PacketField(ASN1F_STRING):
974    """
975    ASN1F_STRING that holds packets.
976    """
977    holds_packets = 1
978
979    def i2m(self, pkt, val):
980        # type: (ASN1_Packet, Any) -> bytes
981        if hasattr(val, "ASN1_root"):
982            val = ASN1_STRING(bytes(val))  # type: ignore
983        return super(ASN1F_STRING_PacketField, self).i2m(pkt, val)
984
985    def any2i(self, pkt, x):
986        # type: (ASN1_Packet, Any) -> Any
987        if hasattr(x, "add_underlayer"):
988            x.add_underlayer(pkt)
989        return super(ASN1F_STRING_PacketField, self).any2i(pkt, x)
990