• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Shim module exporting the same ElementTree API for lxml and
2xml.etree backends.
3
4When lxml is installed, it is automatically preferred over the built-in
5xml.etree module.
6On Python 2.7, the cElementTree module is preferred over the pure-python
7ElementTree module.
8
9Besides exporting a unified interface, this also defines extra functions
10or subclasses built-in ElementTree classes to add features that are
11only availble in lxml, like OrderedDict for attributes, pretty_print and
12iterwalk.
13"""
14from fontTools.misc.textTools import tostr
15
16
17XML_DECLARATION = """<?xml version='1.0' encoding='%s'?>"""
18
19__all__ = [
20    # public symbols
21    "Comment",
22    "dump",
23    "Element",
24    "ElementTree",
25    "fromstring",
26    "fromstringlist",
27    "iselement",
28    "iterparse",
29    "parse",
30    "ParseError",
31    "PI",
32    "ProcessingInstruction",
33    "QName",
34    "SubElement",
35    "tostring",
36    "tostringlist",
37    "TreeBuilder",
38    "XML",
39    "XMLParser",
40    "register_namespace",
41]
42
43try:
44    from lxml.etree import *
45
46    _have_lxml = True
47except ImportError:
48    try:
49        from xml.etree.cElementTree import *
50
51        # the cElementTree version of XML function doesn't support
52        # the optional 'parser' keyword argument
53        from xml.etree.ElementTree import XML
54    except ImportError:  # pragma: no cover
55        from xml.etree.ElementTree import *
56    _have_lxml = False
57
58    import sys
59
60    # dict is always ordered in python >= 3.6 and on pypy
61    PY36 = sys.version_info >= (3, 6)
62    try:
63        import __pypy__
64    except ImportError:
65        __pypy__ = None
66    _dict_is_ordered = bool(PY36 or __pypy__)
67    del PY36, __pypy__
68
69    if _dict_is_ordered:
70        _Attrib = dict
71    else:
72        from collections import OrderedDict as _Attrib
73
74    if isinstance(Element, type):
75        _Element = Element
76    else:
77        # in py27, cElementTree.Element cannot be subclassed, so
78        # we need to import the pure-python class
79        from xml.etree.ElementTree import Element as _Element
80
81    class Element(_Element):
82        """Element subclass that keeps the order of attributes."""
83
84        def __init__(self, tag, attrib=_Attrib(), **extra):
85            super(Element, self).__init__(tag)
86            self.attrib = _Attrib()
87            if attrib:
88                self.attrib.update(attrib)
89            if extra:
90                self.attrib.update(extra)
91
92    def SubElement(parent, tag, attrib=_Attrib(), **extra):
93        """Must override SubElement as well otherwise _elementtree.SubElement
94        fails if 'parent' is a subclass of Element object.
95        """
96        element = parent.__class__(tag, attrib, **extra)
97        parent.append(element)
98        return element
99
100    def _iterwalk(element, events, tag):
101        include = tag is None or element.tag == tag
102        if include and "start" in events:
103            yield ("start", element)
104        for e in element:
105            for item in _iterwalk(e, events, tag):
106                yield item
107        if include:
108            yield ("end", element)
109
110    def iterwalk(element_or_tree, events=("end",), tag=None):
111        """A tree walker that generates events from an existing tree as
112        if it was parsing XML data with iterparse().
113        Drop-in replacement for lxml.etree.iterwalk.
114        """
115        if iselement(element_or_tree):
116            element = element_or_tree
117        else:
118            element = element_or_tree.getroot()
119        if tag == "*":
120            tag = None
121        for item in _iterwalk(element, events, tag):
122            yield item
123
124    _ElementTree = ElementTree
125
126    class ElementTree(_ElementTree):
127        """ElementTree subclass that adds 'pretty_print' and 'doctype'
128        arguments to the 'write' method.
129        Currently these are only supported for the default XML serialization
130        'method', and not also for "html" or "text", for these are delegated
131        to the base class.
132        """
133
134        def write(
135            self,
136            file_or_filename,
137            encoding=None,
138            xml_declaration=False,
139            method=None,
140            doctype=None,
141            pretty_print=False,
142        ):
143            if method and method != "xml":
144                # delegate to super-class
145                super(ElementTree, self).write(
146                    file_or_filename,
147                    encoding=encoding,
148                    xml_declaration=xml_declaration,
149                    method=method,
150                )
151                return
152
153            if encoding is not None and encoding.lower() == "unicode":
154                if xml_declaration:
155                    raise ValueError(
156                        "Serialisation to unicode must not request an XML declaration"
157                    )
158                write_declaration = False
159                encoding = "unicode"
160            elif xml_declaration is None:
161                # by default, write an XML declaration only for non-standard encodings
162                write_declaration = encoding is not None and encoding.upper() not in (
163                    "ASCII",
164                    "UTF-8",
165                    "UTF8",
166                    "US-ASCII",
167                )
168            else:
169                write_declaration = xml_declaration
170
171            if encoding is None:
172                encoding = "ASCII"
173
174            if pretty_print:
175                # NOTE this will modify the tree in-place
176                _indent(self._root)
177
178            with _get_writer(file_or_filename, encoding) as write:
179                if write_declaration:
180                    write(XML_DECLARATION % encoding.upper())
181                    if pretty_print:
182                        write("\n")
183                if doctype:
184                    write(_tounicode(doctype))
185                    if pretty_print:
186                        write("\n")
187
188                qnames, namespaces = _namespaces(self._root)
189                _serialize_xml(write, self._root, qnames, namespaces)
190
191    import io
192
193    def tostring(
194        element,
195        encoding=None,
196        xml_declaration=None,
197        method=None,
198        doctype=None,
199        pretty_print=False,
200    ):
201        """Custom 'tostring' function that uses our ElementTree subclass, with
202        pretty_print support.
203        """
204        stream = io.StringIO() if encoding == "unicode" else io.BytesIO()
205        ElementTree(element).write(
206            stream,
207            encoding=encoding,
208            xml_declaration=xml_declaration,
209            method=method,
210            doctype=doctype,
211            pretty_print=pretty_print,
212        )
213        return stream.getvalue()
214
215    # serialization support
216
217    import re
218
219    # Valid XML strings can include any Unicode character, excluding control
220    # characters, the surrogate blocks, FFFE, and FFFF:
221    #   Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
222    # Here we reversed the pattern to match only the invalid characters.
223    # For the 'narrow' python builds supporting only UCS-2, which represent
224    # characters beyond BMP as UTF-16 surrogate pairs, we need to pass through
225    # the surrogate block. I haven't found a more elegant solution...
226    UCS2 = sys.maxunicode < 0x10FFFF
227    if UCS2:
228        _invalid_xml_string = re.compile(
229            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uFFFE-\uFFFF]"
230        )
231    else:
232        _invalid_xml_string = re.compile(
233            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uD800-\uDFFF\uFFFE-\uFFFF]"
234        )
235
236    def _tounicode(s):
237        """Test if a string is valid user input and decode it to unicode string
238        using ASCII encoding if it's a bytes string.
239        Reject all bytes/unicode input that contains non-XML characters.
240        Reject all bytes input that contains non-ASCII characters.
241        """
242        try:
243            s = tostr(s, encoding="ascii", errors="strict")
244        except UnicodeDecodeError:
245            raise ValueError(
246                "Bytes strings can only contain ASCII characters. "
247                "Use unicode strings for non-ASCII characters.")
248        except AttributeError:
249            _raise_serialization_error(s)
250        if s and _invalid_xml_string.search(s):
251            raise ValueError(
252                "All strings must be XML compatible: Unicode or ASCII, "
253                "no NULL bytes or control characters"
254            )
255        return s
256
257    import contextlib
258
259    @contextlib.contextmanager
260    def _get_writer(file_or_filename, encoding):
261        # returns text write method and release all resources after using
262        try:
263            write = file_or_filename.write
264        except AttributeError:
265            # file_or_filename is a file name
266            f = open(
267                file_or_filename,
268                "w",
269                encoding="utf-8" if encoding == "unicode" else encoding,
270                errors="xmlcharrefreplace",
271            )
272            with f:
273                yield f.write
274        else:
275            # file_or_filename is a file-like object
276            # encoding determines if it is a text or binary writer
277            if encoding == "unicode":
278                # use a text writer as is
279                yield write
280            else:
281                # wrap a binary writer with TextIOWrapper
282                detach_buffer = False
283                if isinstance(file_or_filename, io.BufferedIOBase):
284                    buf = file_or_filename
285                elif isinstance(file_or_filename, io.RawIOBase):
286                    buf = io.BufferedWriter(file_or_filename)
287                    detach_buffer = True
288                else:
289                    # This is to handle passed objects that aren't in the
290                    # IOBase hierarchy, but just have a write method
291                    buf = io.BufferedIOBase()
292                    buf.writable = lambda: True
293                    buf.write = write
294                    try:
295                        # TextIOWrapper uses this methods to determine
296                        # if BOM (for UTF-16, etc) should be added
297                        buf.seekable = file_or_filename.seekable
298                        buf.tell = file_or_filename.tell
299                    except AttributeError:
300                        pass
301                wrapper = io.TextIOWrapper(
302                    buf,
303                    encoding=encoding,
304                    errors="xmlcharrefreplace",
305                    newline="\n",
306                )
307                try:
308                    yield wrapper.write
309                finally:
310                    # Keep the original file open when the TextIOWrapper and
311                    # the BufferedWriter are destroyed
312                    wrapper.detach()
313                    if detach_buffer:
314                        buf.detach()
315
316    from xml.etree.ElementTree import _namespace_map
317
318    def _namespaces(elem):
319        # identify namespaces used in this tree
320
321        # maps qnames to *encoded* prefix:local names
322        qnames = {None: None}
323
324        # maps uri:s to prefixes
325        namespaces = {}
326
327        def add_qname(qname):
328            # calculate serialized qname representation
329            try:
330                qname = _tounicode(qname)
331                if qname[:1] == "{":
332                    uri, tag = qname[1:].rsplit("}", 1)
333                    prefix = namespaces.get(uri)
334                    if prefix is None:
335                        prefix = _namespace_map.get(uri)
336                        if prefix is None:
337                            prefix = "ns%d" % len(namespaces)
338                        else:
339                            prefix = _tounicode(prefix)
340                        if prefix != "xml":
341                            namespaces[uri] = prefix
342                    if prefix:
343                        qnames[qname] = "%s:%s" % (prefix, tag)
344                    else:
345                        qnames[qname] = tag  # default element
346                else:
347                    qnames[qname] = qname
348            except TypeError:
349                _raise_serialization_error(qname)
350
351        # populate qname and namespaces table
352        for elem in elem.iter():
353            tag = elem.tag
354            if isinstance(tag, QName):
355                if tag.text not in qnames:
356                    add_qname(tag.text)
357            elif isinstance(tag, str):
358                if tag not in qnames:
359                    add_qname(tag)
360            elif tag is not None and tag is not Comment and tag is not PI:
361                _raise_serialization_error(tag)
362            for key, value in elem.items():
363                if isinstance(key, QName):
364                    key = key.text
365                if key not in qnames:
366                    add_qname(key)
367                if isinstance(value, QName) and value.text not in qnames:
368                    add_qname(value.text)
369            text = elem.text
370            if isinstance(text, QName) and text.text not in qnames:
371                add_qname(text.text)
372        return qnames, namespaces
373
374    def _serialize_xml(write, elem, qnames, namespaces, **kwargs):
375        tag = elem.tag
376        text = elem.text
377        if tag is Comment:
378            write("<!--%s-->" % _tounicode(text))
379        elif tag is ProcessingInstruction:
380            write("<?%s?>" % _tounicode(text))
381        else:
382            tag = qnames[_tounicode(tag) if tag is not None else None]
383            if tag is None:
384                if text:
385                    write(_escape_cdata(text))
386                for e in elem:
387                    _serialize_xml(write, e, qnames, None)
388            else:
389                write("<" + tag)
390                if namespaces:
391                    for uri, prefix in sorted(
392                        namespaces.items(), key=lambda x: x[1]
393                    ):  # sort on prefix
394                        if prefix:
395                            prefix = ":" + prefix
396                        write(' xmlns%s="%s"' % (prefix, _escape_attrib(uri)))
397                attrs = elem.attrib
398                if attrs:
399                    # try to keep existing attrib order
400                    if len(attrs) <= 1 or type(attrs) is _Attrib:
401                        items = attrs.items()
402                    else:
403                        # if plain dict, use lexical order
404                        items = sorted(attrs.items())
405                    for k, v in items:
406                        if isinstance(k, QName):
407                            k = _tounicode(k.text)
408                        else:
409                            k = _tounicode(k)
410                        if isinstance(v, QName):
411                            v = qnames[_tounicode(v.text)]
412                        else:
413                            v = _escape_attrib(v)
414                        write(' %s="%s"' % (qnames[k], v))
415                if text is not None or len(elem):
416                    write(">")
417                    if text:
418                        write(_escape_cdata(text))
419                    for e in elem:
420                        _serialize_xml(write, e, qnames, None)
421                    write("</" + tag + ">")
422                else:
423                    write("/>")
424        if elem.tail:
425            write(_escape_cdata(elem.tail))
426
427    def _raise_serialization_error(text):
428        raise TypeError(
429            "cannot serialize %r (type %s)" % (text, type(text).__name__)
430        )
431
432    def _escape_cdata(text):
433        # escape character data
434        try:
435            text = _tounicode(text)
436            # it's worth avoiding do-nothing calls for short strings
437            if "&" in text:
438                text = text.replace("&", "&amp;")
439            if "<" in text:
440                text = text.replace("<", "&lt;")
441            if ">" in text:
442                text = text.replace(">", "&gt;")
443            return text
444        except (TypeError, AttributeError):
445            _raise_serialization_error(text)
446
447    def _escape_attrib(text):
448        # escape attribute value
449        try:
450            text = _tounicode(text)
451            if "&" in text:
452                text = text.replace("&", "&amp;")
453            if "<" in text:
454                text = text.replace("<", "&lt;")
455            if ">" in text:
456                text = text.replace(">", "&gt;")
457            if '"' in text:
458                text = text.replace('"', "&quot;")
459            if "\n" in text:
460                text = text.replace("\n", "&#10;")
461            return text
462        except (TypeError, AttributeError):
463            _raise_serialization_error(text)
464
465    def _indent(elem, level=0):
466        # From http://effbot.org/zone/element-lib.htm#prettyprint
467        i = "\n" + level * "  "
468        if len(elem):
469            if not elem.text or not elem.text.strip():
470                elem.text = i + "  "
471            if not elem.tail or not elem.tail.strip():
472                elem.tail = i
473            for elem in elem:
474                _indent(elem, level + 1)
475            if not elem.tail or not elem.tail.strip():
476                elem.tail = i
477        else:
478            if level and (not elem.tail or not elem.tail.strip()):
479                elem.tail = i
480