• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""\
2A library of useful helper classes to the SAX classes, for the
3convenience of application and driver writers.
4"""
5
6import os, urlparse, urllib, types
7import io
8import sys
9import handler
10import xmlreader
11
12try:
13    _StringTypes = [types.StringType, types.UnicodeType]
14except AttributeError:
15    _StringTypes = [types.StringType]
16
17def __dict_replace(s, d):
18    """Replace substrings of a string using a dictionary."""
19    for key, value in d.items():
20        s = s.replace(key, value)
21    return s
22
23def escape(data, entities={}):
24    """Escape &, <, and > in a string of data.
25
26    You can escape other strings of data by passing a dictionary as
27    the optional entities parameter.  The keys and values must all be
28    strings; each key will be replaced with its corresponding value.
29    """
30
31    # must do ampersand first
32    data = data.replace("&", "&amp;")
33    data = data.replace(">", "&gt;")
34    data = data.replace("<", "&lt;")
35    if entities:
36        data = __dict_replace(data, entities)
37    return data
38
39def unescape(data, entities={}):
40    """Unescape &amp;, &lt;, and &gt; in a string of data.
41
42    You can unescape other strings of data by passing a dictionary as
43    the optional entities parameter.  The keys and values must all be
44    strings; each key will be replaced with its corresponding value.
45    """
46    data = data.replace("&lt;", "<")
47    data = data.replace("&gt;", ">")
48    if entities:
49        data = __dict_replace(data, entities)
50    # must do ampersand last
51    return data.replace("&amp;", "&")
52
53def quoteattr(data, entities={}):
54    """Escape and quote an attribute value.
55
56    Escape &, <, and > in a string of data, then quote it for use as
57    an attribute value.  The \" character will be escaped as well, if
58    necessary.
59
60    You can escape other strings of data by passing a dictionary as
61    the optional entities parameter.  The keys and values must all be
62    strings; each key will be replaced with its corresponding value.
63    """
64    entities = entities.copy()
65    entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
66    data = escape(data, entities)
67    if '"' in data:
68        if "'" in data:
69            data = '"%s"' % data.replace('"', "&quot;")
70        else:
71            data = "'%s'" % data
72    else:
73        data = '"%s"' % data
74    return data
75
76
77def _gettextwriter(out, encoding):
78    if out is None:
79        import sys
80        out = sys.stdout
81
82    if isinstance(out, io.RawIOBase):
83        buffer = io.BufferedIOBase(out)
84        # Keep the original file open when the TextIOWrapper is
85        # destroyed
86        buffer.close = lambda: None
87    else:
88        # This is to handle passed objects that aren't in the
89        # IOBase hierarchy, but just have a write method
90        buffer = io.BufferedIOBase()
91        buffer.writable = lambda: True
92        buffer.write = out.write
93        try:
94            # TextIOWrapper uses this methods to determine
95            # if BOM (for UTF-16, etc) should be added
96            buffer.seekable = out.seekable
97            buffer.tell = out.tell
98        except AttributeError:
99            pass
100    # wrap a binary writer with TextIOWrapper
101    return _UnbufferedTextIOWrapper(buffer, encoding=encoding,
102                                   errors='xmlcharrefreplace',
103                                   newline='\n')
104
105
106class _UnbufferedTextIOWrapper(io.TextIOWrapper):
107    def write(self, s):
108        super(_UnbufferedTextIOWrapper, self).write(s)
109        self.flush()
110
111
112class XMLGenerator(handler.ContentHandler):
113
114    def __init__(self, out=None, encoding="iso-8859-1"):
115        handler.ContentHandler.__init__(self)
116        out = _gettextwriter(out, encoding)
117        self._write = out.write
118        self._flush = out.flush
119        self._ns_contexts = [{}] # contains uri -> prefix dicts
120        self._current_context = self._ns_contexts[-1]
121        self._undeclared_ns_maps = []
122        self._encoding = encoding
123
124    def _qname(self, name):
125        """Builds a qualified name from a (ns_url, localname) pair"""
126        if name[0]:
127            # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is
128            # bound by definition to http://www.w3.org/XML/1998/namespace.  It
129            # does not need to be declared and will not usually be found in
130            # self._current_context.
131            if 'http://www.w3.org/XML/1998/namespace' == name[0]:
132                return 'xml:' + name[1]
133            # The name is in a non-empty namespace
134            prefix = self._current_context[name[0]]
135            if prefix:
136                # If it is not the default namespace, prepend the prefix
137                return prefix + ":" + name[1]
138        # Return the unqualified name
139        return name[1]
140
141    # ContentHandler methods
142
143    def startDocument(self):
144        self._write(u'<?xml version="1.0" encoding="%s"?>\n' %
145                        self._encoding)
146
147    def endDocument(self):
148        self._flush()
149
150    def startPrefixMapping(self, prefix, uri):
151        self._ns_contexts.append(self._current_context.copy())
152        self._current_context[uri] = prefix
153        self._undeclared_ns_maps.append((prefix, uri))
154
155    def endPrefixMapping(self, prefix):
156        self._current_context = self._ns_contexts[-1]
157        del self._ns_contexts[-1]
158
159    def startElement(self, name, attrs):
160        self._write(u'<' + name)
161        for (name, value) in attrs.items():
162            self._write(u' %s=%s' % (name, quoteattr(value)))
163        self._write(u'>')
164
165    def endElement(self, name):
166        self._write(u'</%s>' % name)
167
168    def startElementNS(self, name, qname, attrs):
169        self._write(u'<' + self._qname(name))
170
171        for prefix, uri in self._undeclared_ns_maps:
172            if prefix:
173                self._write(u' xmlns:%s="%s"' % (prefix, uri))
174            else:
175                self._write(u' xmlns="%s"' % uri)
176        self._undeclared_ns_maps = []
177
178        for (name, value) in attrs.items():
179            self._write(u' %s=%s' % (self._qname(name), quoteattr(value)))
180        self._write(u'>')
181
182    def endElementNS(self, name, qname):
183        self._write(u'</%s>' % self._qname(name))
184
185    def characters(self, content):
186        if not isinstance(content, unicode):
187            content = unicode(content, self._encoding)
188        self._write(escape(content))
189
190    def ignorableWhitespace(self, content):
191        if not isinstance(content, unicode):
192            content = unicode(content, self._encoding)
193        self._write(content)
194
195    def processingInstruction(self, target, data):
196        self._write(u'<?%s %s?>' % (target, data))
197
198
199class XMLFilterBase(xmlreader.XMLReader):
200    """This class is designed to sit between an XMLReader and the
201    client application's event handlers.  By default, it does nothing
202    but pass requests up to the reader and events on to the handlers
203    unmodified, but subclasses can override specific methods to modify
204    the event stream or the configuration requests as they pass
205    through."""
206
207    def __init__(self, parent = None):
208        xmlreader.XMLReader.__init__(self)
209        self._parent = parent
210
211    # ErrorHandler methods
212
213    def error(self, exception):
214        self._err_handler.error(exception)
215
216    def fatalError(self, exception):
217        self._err_handler.fatalError(exception)
218
219    def warning(self, exception):
220        self._err_handler.warning(exception)
221
222    # ContentHandler methods
223
224    def setDocumentLocator(self, locator):
225        self._cont_handler.setDocumentLocator(locator)
226
227    def startDocument(self):
228        self._cont_handler.startDocument()
229
230    def endDocument(self):
231        self._cont_handler.endDocument()
232
233    def startPrefixMapping(self, prefix, uri):
234        self._cont_handler.startPrefixMapping(prefix, uri)
235
236    def endPrefixMapping(self, prefix):
237        self._cont_handler.endPrefixMapping(prefix)
238
239    def startElement(self, name, attrs):
240        self._cont_handler.startElement(name, attrs)
241
242    def endElement(self, name):
243        self._cont_handler.endElement(name)
244
245    def startElementNS(self, name, qname, attrs):
246        self._cont_handler.startElementNS(name, qname, attrs)
247
248    def endElementNS(self, name, qname):
249        self._cont_handler.endElementNS(name, qname)
250
251    def characters(self, content):
252        self._cont_handler.characters(content)
253
254    def ignorableWhitespace(self, chars):
255        self._cont_handler.ignorableWhitespace(chars)
256
257    def processingInstruction(self, target, data):
258        self._cont_handler.processingInstruction(target, data)
259
260    def skippedEntity(self, name):
261        self._cont_handler.skippedEntity(name)
262
263    # DTDHandler methods
264
265    def notationDecl(self, name, publicId, systemId):
266        self._dtd_handler.notationDecl(name, publicId, systemId)
267
268    def unparsedEntityDecl(self, name, publicId, systemId, ndata):
269        self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
270
271    # EntityResolver methods
272
273    def resolveEntity(self, publicId, systemId):
274        return self._ent_handler.resolveEntity(publicId, systemId)
275
276    # XMLReader methods
277
278    def parse(self, source):
279        self._parent.setContentHandler(self)
280        self._parent.setErrorHandler(self)
281        self._parent.setEntityResolver(self)
282        self._parent.setDTDHandler(self)
283        self._parent.parse(source)
284
285    def setLocale(self, locale):
286        self._parent.setLocale(locale)
287
288    def getFeature(self, name):
289        return self._parent.getFeature(name)
290
291    def setFeature(self, name, state):
292        self._parent.setFeature(name, state)
293
294    def getProperty(self, name):
295        return self._parent.getProperty(name)
296
297    def setProperty(self, name, value):
298        self._parent.setProperty(name, value)
299
300    # XMLFilter methods
301
302    def getParent(self):
303        return self._parent
304
305    def setParent(self, parent):
306        self._parent = parent
307
308# --- Utility functions
309
310def prepare_input_source(source, base = ""):
311    """This function takes an InputSource and an optional base URL and
312    returns a fully resolved InputSource object ready for reading."""
313
314    if type(source) in _StringTypes:
315        source = xmlreader.InputSource(source)
316    elif hasattr(source, "read"):
317        f = source
318        source = xmlreader.InputSource()
319        source.setByteStream(f)
320        if hasattr(f, "name"):
321            source.setSystemId(f.name)
322
323    if source.getByteStream() is None:
324        try:
325            sysid = source.getSystemId()
326            basehead = os.path.dirname(os.path.normpath(base))
327            encoding = sys.getfilesystemencoding()
328            if isinstance(sysid, unicode):
329                if not isinstance(basehead, unicode):
330                    try:
331                        basehead = basehead.decode(encoding)
332                    except UnicodeDecodeError:
333                        sysid = sysid.encode(encoding)
334            else:
335                if isinstance(basehead, unicode):
336                    try:
337                        sysid = sysid.decode(encoding)
338                    except UnicodeDecodeError:
339                        basehead = basehead.encode(encoding)
340            sysidfilename = os.path.join(basehead, sysid)
341            isfile = os.path.isfile(sysidfilename)
342        except UnicodeError:
343            isfile = False
344        if isfile:
345            source.setSystemId(sysidfilename)
346            f = open(sysidfilename, "rb")
347        else:
348            source.setSystemId(urlparse.urljoin(base, source.getSystemId()))
349            f = urllib.urlopen(source.getSystemId())
350
351        source.setByteStream(f)
352
353    return source
354