## This file is part of Scapy ## See http://www.secdev.org/projects/scapy for more informations ## Copyright (C) Philippe Biondi ## Modified by Maxence Tury ## This program is published under a GPLv2 license """ ASN.1 (Abstract Syntax Notation One) """ from __future__ import absolute_import from __future__ import print_function import random from datetime import datetime from scapy.config import conf from scapy.error import Scapy_Exception, warning from scapy.volatile import RandField, RandIP, GeneralizedTime from scapy.utils import Enum_metaclass, EnumElement, binrepr from scapy.compat import plain_str, chb, raw, orb import scapy.modules.six as six from scapy.modules.six.moves import range class RandASN1Object(RandField): def __init__(self, objlist=None): self.objlist = [ x._asn1_obj for x in six.itervalues(ASN1_Class_UNIVERSAL.__rdict__) if hasattr(x, "_asn1_obj") ] if objlist is None else objlist self.chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" def _fix(self, n=0): o = random.choice(self.objlist) if issubclass(o, ASN1_INTEGER): return o(int(random.gauss(0,1000))) elif issubclass(o, ASN1_IPADDRESS): z = RandIP()._fix() return o(z) elif issubclass(o, ASN1_GENERALIZED_TIME) or issubclass(o, ASN1_UTC_TIME): z = GeneralizedTime()._fix() return o(z) elif issubclass(o, ASN1_STRING): z = int(random.expovariate(0.05)+1) return o("".join(random.choice(self.chars) for _ in range(z))) elif issubclass(o, ASN1_SEQUENCE) and (n < 10): z = int(random.expovariate(0.08)+1) return o([self.__class__(objlist=self.objlist)._fix(n + 1) for _ in range(z)]) return ASN1_INTEGER(int(random.gauss(0,1000))) ############## #### ASN1 #### ############## class ASN1_Error(Scapy_Exception): pass class ASN1_Encoding_Error(ASN1_Error): pass class ASN1_Decoding_Error(ASN1_Error): pass class ASN1_BadTag_Decoding_Error(ASN1_Decoding_Error): pass class ASN1Codec(EnumElement): def register_stem(cls, stem): cls._stem = stem def dec(cls, s, context=None): return cls._stem.dec(s, context=context) def safedec(cls, s, context=None): return cls._stem.safedec(s, context=context) def get_stem(cls): return cls.stem class ASN1_Codecs_metaclass(Enum_metaclass): element_class = ASN1Codec class ASN1_Codecs(six.with_metaclass(ASN1_Codecs_metaclass)): BER = 1 DER = 2 PER = 3 CER = 4 LWER = 5 BACnet = 6 OER = 7 SER = 8 XER = 9 class ASN1Tag(EnumElement): def __init__(self, key, value, context=None, codec=None): EnumElement.__init__(self, key, value) self._context = context if codec == None: codec = {} self._codec = codec def clone(self): # /!\ not a real deep copy. self.codec is shared return self.__class__(self._key, self._value, self._context, self._codec) def register_asn1_object(self, asn1obj): self._asn1_obj = asn1obj def asn1_object(self, val): if hasattr(self,"_asn1_obj"): return self._asn1_obj(val) raise ASN1_Error("%r does not have any assigned ASN1 object" % self) def register(self, codecnum, codec): self._codec[codecnum] = codec def get_codec(self, codec): try: c = self._codec[codec] except KeyError as msg: raise ASN1_Error("Codec %r not found for tag %r" % (codec, self)) return c class ASN1_Class_metaclass(Enum_metaclass): element_class = ASN1Tag def __new__(cls, name, bases, dct): # XXX factorise a bit with Enum_metaclass.__new__() for b in bases: for k,v in six.iteritems(b.__dict__): if k not in dct and isinstance(v,ASN1Tag): dct[k] = v.clone() rdict = {} for k,v in six.iteritems(dct): if isinstance(v, int): v = ASN1Tag(k,v) dct[k] = v rdict[v] = v elif isinstance(v, ASN1Tag): rdict[v] = v dct["__rdict__"] = rdict cls = type.__new__(cls, name, bases, dct) for v in cls.__dict__.values(): if isinstance(v, ASN1Tag): v.context = cls # overwrite ASN1Tag contexts, even cloned ones return cls class ASN1_Class(six.with_metaclass(ASN1_Class_metaclass)): pass class ASN1_Class_UNIVERSAL(ASN1_Class): name = "UNIVERSAL" ERROR = -3 RAW = -2 NONE = -1 ANY = 0 BOOLEAN = 1 INTEGER = 2 BIT_STRING = 3 STRING = 4 NULL = 5 OID = 6 OBJECT_DESCRIPTOR = 7 EXTERNAL = 8 REAL = 9 ENUMERATED = 10 EMBEDDED_PDF = 11 UTF8_STRING = 12 RELATIVE_OID = 13 SEQUENCE = 16|0x20 # constructed encoding SET = 17|0x20 # constructed encoding NUMERIC_STRING = 18 PRINTABLE_STRING = 19 T61_STRING = 20 # aka TELETEX_STRING VIDEOTEX_STRING = 21 IA5_STRING = 22 UTC_TIME = 23 GENERALIZED_TIME = 24 GRAPHIC_STRING = 25 ISO646_STRING = 26 # aka VISIBLE_STRING GENERAL_STRING = 27 UNIVERSAL_STRING = 28 CHAR_STRING = 29 BMP_STRING = 30 IPADDRESS = 0|0x40 # application-specific encoding COUNTER32 = 1|0x40 # application-specific encoding GAUGE32 = 2|0x40 # application-specific encoding TIME_TICKS = 3|0x40 # application-specific encoding class ASN1_Object_metaclass(type): def __new__(cls, name, bases, dct): c = super(ASN1_Object_metaclass, cls).__new__(cls, name, bases, dct) try: c.tag.register_asn1_object(c) except: warning("Error registering %r for %r" % (c.tag, c.codec)) return c class ASN1_Object(six.with_metaclass(ASN1_Object_metaclass)): tag = ASN1_Class_UNIVERSAL.ANY def __init__(self, val): self.val = val def enc(self, codec): return self.tag.get_codec(codec).enc(self.val) def __repr__(self): return "<%s[%r]>" % (self.__dict__.get("name", self.__class__.__name__), self.val) def __str__(self): return self.enc(conf.ASN1_default_codec) def __bytes__(self): return self.enc(conf.ASN1_default_codec) def strshow(self, lvl=0): return (" "*lvl)+repr(self)+"\n" def show(self, lvl=0): print(self.strshow(lvl)) def __eq__(self, other): return self.val == other def __lt__(self, other): return self.val < other def __le__(self, other): return self.val <= other def __gt__(self, other): return self.val > other def __ge__(self, other): return self.val >= other def __ne__(self, other): return self.val != other ####################### #### ASN1 objects #### ####################### # on the whole, we order the classes by ASN1_Class_UNIVERSAL tag value class ASN1_DECODING_ERROR(ASN1_Object): tag = ASN1_Class_UNIVERSAL.ERROR def __init__(self, val, exc=None): ASN1_Object.__init__(self, val) self.exc = exc def __repr__(self): return "<%s[%r]{{%r}}>" % (self.__dict__.get("name", self.__class__.__name__), self.val, self.exc.args[0]) def enc(self, codec): if isinstance(self.val, ASN1_Object): return self.val.enc(codec) return self.val class ASN1_force(ASN1_Object): tag = ASN1_Class_UNIVERSAL.RAW def enc(self, codec): if isinstance(self.val, ASN1_Object): return self.val.enc(codec) return self.val class ASN1_BADTAG(ASN1_force): pass class ASN1_INTEGER(ASN1_Object): tag = ASN1_Class_UNIVERSAL.INTEGER def __repr__(self): h = hex(self.val) if h[-1] == "L": h = h[:-1] # cut at 22 because with leading '0x', x509 serials should be < 23 if len(h) > 22: h = h[:12] + "..." + h[-10:] r = repr(self.val) if len(r) > 20: r = r[:10] + "..." + r[-10:] return h + " <%s[%s]>" % (self.__dict__.get("name", self.__class__.__name__), r) class ASN1_BOOLEAN(ASN1_INTEGER): tag = ASN1_Class_UNIVERSAL.BOOLEAN # BER: 0 means False, anything else means True def __repr__(self): return '%s %s' % (not (self.val==0), ASN1_Object.__repr__(self)) class ASN1_BIT_STRING(ASN1_Object): """ /!\ ASN1_BIT_STRING values are bit strings like "011101". /!\ A zero-bit padded readable string is provided nonetheless, /!\ which is also output when __str__ is called. """ tag = ASN1_Class_UNIVERSAL.BIT_STRING def __init__(self, val, readable=False): if not readable: self.val = val else: self.val_readable = val def __setattr__(self, name, value): str_value = None if isinstance(value, str): str_value = value value = raw(value) if name == "val_readable": if isinstance(value, bytes): val = b"".join(binrepr(orb(x)).zfill(8).encode("utf8") for x in value) else: val = "" super(ASN1_Object, self).__setattr__("val", val) super(ASN1_Object, self).__setattr__(name, value) super(ASN1_Object, self).__setattr__("unused_bits", 0) elif name == "val": if not str_value: str_value = plain_str(value) if isinstance(value, bytes): if any(c for c in str_value if c not in ["0", "1"]): print("Invalid operation: 'val' is not a valid bit string.") return else: if len(value) % 8 == 0: unused_bits = 0 else: unused_bits = 8 - (len(value) % 8) padded_value = str_value + ("0" * unused_bits) bytes_arr = zip(*[iter(padded_value)]*8) val_readable = b"".join(chb(int("".join(x),2)) for x in bytes_arr) else: val_readable = "" unused_bits = 0 super(ASN1_Object, self).__setattr__("val_readable", val_readable) super(ASN1_Object, self).__setattr__(name, value) super(ASN1_Object, self).__setattr__("unused_bits", unused_bits) elif name == "unused_bits": print("Invalid operation: unused_bits rewriting is not supported.") else: super(ASN1_Object, self).__setattr__(name, value) def __repr__(self): if len(self.val) <= 16: v = plain_str(self.val) return "<%s[%s] (%d unused bit%s)>" % (self.__dict__.get("name", self.__class__.__name__), v, self.unused_bits, "s" if self.unused_bits>1 else "") else: s = self.val_readable if len(s) > 20: s = s[:10] + b"..." + s[-10:] v = plain_str(self.val) return "<%s[%s] (%d unused bit%s)>" % (self.__dict__.get("name", self.__class__.__name__), v, self.unused_bits, "s" if self.unused_bits>1 else "") def __str__(self): return self.val_readable def __bytes__(self): return self.val_readable class ASN1_STRING(ASN1_Object): tag = ASN1_Class_UNIVERSAL.STRING class ASN1_NULL(ASN1_Object): tag = ASN1_Class_UNIVERSAL.NULL def __repr__(self): return ASN1_Object.__repr__(self) class ASN1_OID(ASN1_Object): tag = ASN1_Class_UNIVERSAL.OID def __init__(self, val): val = conf.mib._oid(plain_str(val)) ASN1_Object.__init__(self, val) self.oidname = conf.mib._oidname(val) def __repr__(self): return "<%s[%r]>" % (self.__dict__.get("name", self.__class__.__name__), self.oidname) class ASN1_ENUMERATED(ASN1_INTEGER): tag = ASN1_Class_UNIVERSAL.ENUMERATED class ASN1_UTF8_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.UTF8_STRING class ASN1_NUMERIC_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.NUMERIC_STRING class ASN1_PRINTABLE_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.PRINTABLE_STRING class ASN1_T61_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.T61_STRING class ASN1_VIDEOTEX_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.VIDEOTEX_STRING class ASN1_IA5_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.IA5_STRING class ASN1_UTC_TIME(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.UTC_TIME def __init__(self, val): super(ASN1_UTC_TIME, self).__init__(val) def __setattr__(self, name, value): if isinstance(value, bytes): value = plain_str(value) if name == "val": pretty_time = None if (isinstance(value, str) and len(value) == 13 and value[-1] == "Z"): dt = datetime.strptime(value[:-1], "%y%m%d%H%M%S") pretty_time = dt.strftime("%b %d %H:%M:%S %Y GMT") else: pretty_time = "%s [invalid utc_time]" % value super(ASN1_UTC_TIME, self).__setattr__("pretty_time", pretty_time) super(ASN1_UTC_TIME, self).__setattr__(name, value) elif name == "pretty_time": print("Invalid operation: pretty_time rewriting is not supported.") else: super(ASN1_UTC_TIME, self).__setattr__(name, value) def __repr__(self): return "%s %s" % (self.pretty_time, ASN1_STRING.__repr__(self)) class ASN1_GENERALIZED_TIME(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.GENERALIZED_TIME def __init__(self, val): super(ASN1_GENERALIZED_TIME, self).__init__(val) def __setattr__(self, name, value): if isinstance(value, bytes): value = plain_str(value) if name == "val": pretty_time = None if (isinstance(value, str) and len(value) == 15 and value[-1] == "Z"): dt = datetime.strptime(value[:-1], "%Y%m%d%H%M%S") pretty_time = dt.strftime("%b %d %H:%M:%S %Y GMT") else: pretty_time = "%s [invalid generalized_time]" % value super(ASN1_GENERALIZED_TIME, self).__setattr__("pretty_time", pretty_time) super(ASN1_GENERALIZED_TIME, self).__setattr__(name, value) elif name == "pretty_time": print("Invalid operation: pretty_time rewriting is not supported.") else: super(ASN1_GENERALIZED_TIME, self).__setattr__(name, value) def __repr__(self): return "%s %s" % (self.pretty_time, ASN1_STRING.__repr__(self)) class ASN1_ISO646_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.ISO646_STRING class ASN1_UNIVERSAL_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.UNIVERSAL_STRING class ASN1_BMP_STRING(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.BMP_STRING class ASN1_SEQUENCE(ASN1_Object): tag = ASN1_Class_UNIVERSAL.SEQUENCE def strshow(self, lvl=0): s = (" "*lvl)+("# %s:" % self.__class__.__name__)+"\n" for o in self.val: s += o.strshow(lvl=lvl+1) return s class ASN1_SET(ASN1_SEQUENCE): tag = ASN1_Class_UNIVERSAL.SET class ASN1_IPADDRESS(ASN1_STRING): tag = ASN1_Class_UNIVERSAL.IPADDRESS class ASN1_COUNTER32(ASN1_INTEGER): tag = ASN1_Class_UNIVERSAL.COUNTER32 class ASN1_GAUGE32(ASN1_INTEGER): tag = ASN1_Class_UNIVERSAL.GAUGE32 class ASN1_TIME_TICKS(ASN1_INTEGER): tag = ASN1_Class_UNIVERSAL.TIME_TICKS conf.ASN1_default_codec = ASN1_Codecs.BER