• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# regression test for SAX 2.0            -*- coding: utf-8 -*-
2# $Id$
3
4from xml.sax import make_parser, ContentHandler, \
5                    SAXException, SAXReaderNotAvailable, SAXParseException, \
6                    saxutils
7try:
8    make_parser()
9except SAXReaderNotAvailable:
10    # don't try to test this module if we cannot create a parser
11    raise ImportError("no XML parsers available")
12from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \
13                             XMLFilterBase, prepare_input_source
14from xml.sax.expatreader import create_parser
15from xml.sax.handler import feature_namespaces
16from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
17from cStringIO import StringIO
18import io
19import gc
20import os.path
21import shutil
22import test.test_support as support
23from test.test_support import findfile, run_unittest, TESTFN
24import unittest
25
26TEST_XMLFILE = findfile("test.xml", subdir="xmltestdata")
27TEST_XMLFILE_OUT = findfile("test.xml.out", subdir="xmltestdata")
28
29supports_unicode_filenames = True
30if not os.path.supports_unicode_filenames:
31    try:
32        support.TESTFN_UNICODE.encode(support.TESTFN_ENCODING)
33    except (AttributeError, UnicodeError, TypeError):
34        # Either the file system encoding is None, or the file name
35        # cannot be encoded in the file system encoding.
36        supports_unicode_filenames = False
37requires_unicode_filenames = unittest.skipUnless(
38        supports_unicode_filenames,
39        'Requires unicode filenames support')
40
41ns_uri = "http://www.python.org/xml-ns/saxtest/"
42
43class XmlTestBase(unittest.TestCase):
44    def verify_empty_attrs(self, attrs):
45        self.assertRaises(KeyError, attrs.getValue, "attr")
46        self.assertRaises(KeyError, attrs.getValueByQName, "attr")
47        self.assertRaises(KeyError, attrs.getNameByQName, "attr")
48        self.assertRaises(KeyError, attrs.getQNameByName, "attr")
49        self.assertRaises(KeyError, attrs.__getitem__, "attr")
50        self.assertEqual(attrs.getLength(), 0)
51        self.assertEqual(attrs.getNames(), [])
52        self.assertEqual(attrs.getQNames(), [])
53        self.assertEqual(len(attrs), 0)
54        self.assertFalse(attrs.has_key("attr"))
55        self.assertEqual(attrs.keys(), [])
56        self.assertEqual(attrs.get("attrs"), None)
57        self.assertEqual(attrs.get("attrs", 25), 25)
58        self.assertEqual(attrs.items(), [])
59        self.assertEqual(attrs.values(), [])
60
61    def verify_empty_nsattrs(self, attrs):
62        self.assertRaises(KeyError, attrs.getValue, (ns_uri, "attr"))
63        self.assertRaises(KeyError, attrs.getValueByQName, "ns:attr")
64        self.assertRaises(KeyError, attrs.getNameByQName, "ns:attr")
65        self.assertRaises(KeyError, attrs.getQNameByName, (ns_uri, "attr"))
66        self.assertRaises(KeyError, attrs.__getitem__, (ns_uri, "attr"))
67        self.assertEqual(attrs.getLength(), 0)
68        self.assertEqual(attrs.getNames(), [])
69        self.assertEqual(attrs.getQNames(), [])
70        self.assertEqual(len(attrs), 0)
71        self.assertFalse(attrs.has_key((ns_uri, "attr")))
72        self.assertEqual(attrs.keys(), [])
73        self.assertEqual(attrs.get((ns_uri, "attr")), None)
74        self.assertEqual(attrs.get((ns_uri, "attr"), 25), 25)
75        self.assertEqual(attrs.items(), [])
76        self.assertEqual(attrs.values(), [])
77
78    def verify_attrs_wattr(self, attrs):
79        self.assertEqual(attrs.getLength(), 1)
80        self.assertEqual(attrs.getNames(), ["attr"])
81        self.assertEqual(attrs.getQNames(), ["attr"])
82        self.assertEqual(len(attrs), 1)
83        self.assertTrue(attrs.has_key("attr"))
84        self.assertEqual(attrs.keys(), ["attr"])
85        self.assertEqual(attrs.get("attr"), "val")
86        self.assertEqual(attrs.get("attr", 25), "val")
87        self.assertEqual(attrs.items(), [("attr", "val")])
88        self.assertEqual(attrs.values(), ["val"])
89        self.assertEqual(attrs.getValue("attr"), "val")
90        self.assertEqual(attrs.getValueByQName("attr"), "val")
91        self.assertEqual(attrs.getNameByQName("attr"), "attr")
92        self.assertEqual(attrs["attr"], "val")
93        self.assertEqual(attrs.getQNameByName("attr"), "attr")
94
95
96def xml_unicode(doc, encoding=None):
97    if encoding is None:
98        return doc
99    return u'<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)
100
101def xml_bytes(doc, encoding, decl_encoding=Ellipsis):
102    if decl_encoding is Ellipsis:
103        decl_encoding = encoding
104    return xml_unicode(doc, decl_encoding).encode(encoding, 'xmlcharrefreplace')
105
106def make_xml_file(doc, encoding, decl_encoding=Ellipsis):
107    if decl_encoding is Ellipsis:
108        decl_encoding = encoding
109    with io.open(TESTFN, 'w', encoding=encoding, errors='xmlcharrefreplace') as f:
110        f.write(xml_unicode(doc, decl_encoding))
111
112
113class ParseTest(unittest.TestCase):
114    data = support.u(r'<money value="$\xa3\u20ac\U0001017b">'
115                     r'$\xa3\u20ac\U0001017b</money>')
116
117    def tearDown(self):
118        support.unlink(TESTFN)
119
120    def check_parse(self, f):
121        from xml.sax import parse
122        result = StringIO()
123        parse(f, XMLGenerator(result, 'utf-8'))
124        self.assertEqual(result.getvalue(), xml_bytes(self.data, 'utf-8'))
125
126    def test_parse_bytes(self):
127        # UTF-8 is default encoding, US-ASCII is compatible with UTF-8,
128        # UTF-16 is autodetected
129        encodings = ('us-ascii', 'utf-8', 'utf-16', 'utf-16le', 'utf-16be')
130        for encoding in encodings:
131            self.check_parse(io.BytesIO(xml_bytes(self.data, encoding)))
132            make_xml_file(self.data, encoding)
133            self.check_parse(TESTFN)
134            with io.open(TESTFN, 'rb') as f:
135                self.check_parse(f)
136            self.check_parse(io.BytesIO(xml_bytes(self.data, encoding, None)))
137            make_xml_file(self.data, encoding, None)
138            self.check_parse(TESTFN)
139            with io.open(TESTFN, 'rb') as f:
140                self.check_parse(f)
141        # accept UTF-8 with BOM
142        self.check_parse(io.BytesIO(xml_bytes(self.data, 'utf-8-sig', 'utf-8')))
143        make_xml_file(self.data, 'utf-8-sig', 'utf-8')
144        self.check_parse(TESTFN)
145        with io.open(TESTFN, 'rb') as f:
146            self.check_parse(f)
147        self.check_parse(io.BytesIO(xml_bytes(self.data, 'utf-8-sig', None)))
148        make_xml_file(self.data, 'utf-8-sig', None)
149        self.check_parse(TESTFN)
150        with io.open(TESTFN, 'rb') as f:
151            self.check_parse(f)
152        # accept data with declared encoding
153        self.check_parse(io.BytesIO(xml_bytes(self.data, 'iso-8859-1')))
154        make_xml_file(self.data, 'iso-8859-1')
155        self.check_parse(TESTFN)
156        with io.open(TESTFN, 'rb') as f:
157            self.check_parse(f)
158        # fail on non-UTF-8 incompatible data without declared encoding
159        with self.assertRaises(SAXException):
160            self.check_parse(io.BytesIO(xml_bytes(self.data, 'iso-8859-1', None)))
161        make_xml_file(self.data, 'iso-8859-1', None)
162        with self.assertRaises(SAXException):
163            self.check_parse(TESTFN)
164        with io.open(TESTFN, 'rb') as f:
165            with self.assertRaises(SAXException):
166                self.check_parse(f)
167
168    def test_parse_InputSource(self):
169        # accept data without declared but with explicitly specified encoding
170        make_xml_file(self.data, 'iso-8859-1', None)
171        with io.open(TESTFN, 'rb') as f:
172            input = InputSource()
173            input.setByteStream(f)
174            input.setEncoding('iso-8859-1')
175            self.check_parse(input)
176
177    def test_parse_close_source(self):
178        builtin_open = open
179        non_local = {'fileobj': None}
180
181        def mock_open(*args):
182            fileobj = builtin_open(*args)
183            non_local['fileobj'] = fileobj
184            return fileobj
185
186        with support.swap_attr(saxutils, 'open', mock_open):
187            make_xml_file(self.data, 'iso-8859-1', None)
188            with self.assertRaises(SAXException):
189                self.check_parse(TESTFN)
190            self.assertTrue(non_local['fileobj'].closed)
191
192    def check_parseString(self, s):
193        from xml.sax import parseString
194        result = StringIO()
195        parseString(s, XMLGenerator(result, 'utf-8'))
196        self.assertEqual(result.getvalue(), xml_bytes(self.data, 'utf-8'))
197
198    def test_parseString_bytes(self):
199        # UTF-8 is default encoding, US-ASCII is compatible with UTF-8,
200        # UTF-16 is autodetected
201        encodings = ('us-ascii', 'utf-8', 'utf-16', 'utf-16le', 'utf-16be')
202        for encoding in encodings:
203            self.check_parseString(xml_bytes(self.data, encoding))
204            self.check_parseString(xml_bytes(self.data, encoding, None))
205        # accept UTF-8 with BOM
206        self.check_parseString(xml_bytes(self.data, 'utf-8-sig', 'utf-8'))
207        self.check_parseString(xml_bytes(self.data, 'utf-8-sig', None))
208        # accept data with declared encoding
209        self.check_parseString(xml_bytes(self.data, 'iso-8859-1'))
210        # fail on non-UTF-8 incompatible data without declared encoding
211        with self.assertRaises(SAXException):
212            self.check_parseString(xml_bytes(self.data, 'iso-8859-1', None))
213
214
215class MakeParserTest(unittest.TestCase):
216    def test_make_parser2(self):
217        # Creating parsers several times in a row should succeed.
218        # Testing this because there have been failures of this kind
219        # before.
220        from xml.sax import make_parser
221        p = make_parser()
222        from xml.sax import make_parser
223        p = make_parser()
224        from xml.sax import make_parser
225        p = make_parser()
226        from xml.sax import make_parser
227        p = make_parser()
228        from xml.sax import make_parser
229        p = make_parser()
230        from xml.sax import make_parser
231        p = make_parser()
232
233
234# ===========================================================================
235#
236#   saxutils tests
237#
238# ===========================================================================
239
240class SaxutilsTest(unittest.TestCase):
241    # ===== escape
242    def test_escape_basic(self):
243        self.assertEqual(escape("Donald Duck & Co"), "Donald Duck &amp; Co")
244
245    def test_escape_all(self):
246        self.assertEqual(escape("<Donald Duck & Co>"),
247                         "&lt;Donald Duck &amp; Co&gt;")
248
249    def test_escape_extra(self):
250        self.assertEqual(escape("Hei på deg", {"å" : "&aring;"}),
251                         "Hei p&aring; deg")
252
253    # ===== unescape
254    def test_unescape_basic(self):
255        self.assertEqual(unescape("Donald Duck &amp; Co"), "Donald Duck & Co")
256
257    def test_unescape_all(self):
258        self.assertEqual(unescape("&lt;Donald Duck &amp; Co&gt;"),
259                         "<Donald Duck & Co>")
260
261    def test_unescape_extra(self):
262        self.assertEqual(unescape("Hei på deg", {"å" : "&aring;"}),
263                         "Hei p&aring; deg")
264
265    def test_unescape_amp_extra(self):
266        self.assertEqual(unescape("&amp;foo;", {"&foo;": "splat"}), "&foo;")
267
268    # ===== quoteattr
269    def test_quoteattr_basic(self):
270        self.assertEqual(quoteattr("Donald Duck & Co"),
271                         '"Donald Duck &amp; Co"')
272
273    def test_single_quoteattr(self):
274        self.assertEqual(quoteattr('Includes "double" quotes'),
275                         '\'Includes "double" quotes\'')
276
277    def test_double_quoteattr(self):
278        self.assertEqual(quoteattr("Includes 'single' quotes"),
279                         "\"Includes 'single' quotes\"")
280
281    def test_single_double_quoteattr(self):
282        self.assertEqual(quoteattr("Includes 'single' and \"double\" quotes"),
283                         "\"Includes 'single' and &quot;double&quot; quotes\"")
284
285    # ===== make_parser
286    def test_make_parser(self):
287        # Creating a parser should succeed - it should fall back
288        # to the expatreader
289        p = make_parser(['xml.parsers.no_such_parser'])
290
291
292class PrepareInputSourceTest(unittest.TestCase):
293
294    def setUp(self):
295        self.file = support.TESTFN
296        with open(self.file, "w") as tmp:
297            tmp.write("This was read from a file.")
298
299    def tearDown(self):
300        support.unlink(self.file)
301
302    def make_byte_stream(self):
303        return io.BytesIO(b"This is a byte stream.")
304
305    def checkContent(self, stream, content):
306        self.assertIsNotNone(stream)
307        self.assertEqual(stream.read(), content)
308        stream.close()
309
310
311    def test_byte_stream(self):
312        # If the source is an InputSource that does not have a character
313        # stream but does have a byte stream, use the byte stream.
314        src = InputSource(self.file)
315        src.setByteStream(self.make_byte_stream())
316        prep = prepare_input_source(src)
317        self.assertIsNone(prep.getCharacterStream())
318        self.checkContent(prep.getByteStream(),
319                          b"This is a byte stream.")
320
321    def test_system_id(self):
322        # If the source is an InputSource that has neither a character
323        # stream nor a byte stream, open the system ID.
324        src = InputSource(self.file)
325        prep = prepare_input_source(src)
326        self.assertIsNone(prep.getCharacterStream())
327        self.checkContent(prep.getByteStream(),
328                          b"This was read from a file.")
329
330    def test_string(self):
331        # If the source is a string, use it as a system ID and open it.
332        prep = prepare_input_source(self.file)
333        self.assertIsNone(prep.getCharacterStream())
334        self.checkContent(prep.getByteStream(),
335                          b"This was read from a file.")
336
337    def test_binary_file(self):
338        # If the source is a binary file-like object, use it as a byte
339        # stream.
340        prep = prepare_input_source(self.make_byte_stream())
341        self.assertIsNone(prep.getCharacterStream())
342        self.checkContent(prep.getByteStream(),
343                          b"This is a byte stream.")
344
345
346# ===== XMLGenerator
347
348start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'
349
350class XmlgenTest:
351    def test_xmlgen_basic(self):
352        result = self.ioclass()
353        gen = XMLGenerator(result)
354        gen.startDocument()
355        gen.startElement("doc", {})
356        gen.endElement("doc")
357        gen.endDocument()
358
359        self.assertEqual(result.getvalue(), start + "<doc></doc>")
360
361    def test_xmlgen_content(self):
362        result = self.ioclass()
363        gen = XMLGenerator(result)
364
365        gen.startDocument()
366        gen.startElement("doc", {})
367        gen.characters("huhei")
368        gen.endElement("doc")
369        gen.endDocument()
370
371        self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
372
373    def test_xmlgen_pi(self):
374        result = self.ioclass()
375        gen = XMLGenerator(result)
376
377        gen.startDocument()
378        gen.processingInstruction("test", "data")
379        gen.startElement("doc", {})
380        gen.endElement("doc")
381        gen.endDocument()
382
383        self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>")
384
385    def test_xmlgen_content_escape(self):
386        result = self.ioclass()
387        gen = XMLGenerator(result)
388
389        gen.startDocument()
390        gen.startElement("doc", {})
391        gen.characters("<huhei&")
392        gen.endElement("doc")
393        gen.endDocument()
394
395        self.assertEqual(result.getvalue(),
396            start + "<doc>&lt;huhei&amp;</doc>")
397
398    def test_xmlgen_attr_escape(self):
399        result = self.ioclass()
400        gen = XMLGenerator(result)
401
402        gen.startDocument()
403        gen.startElement("doc", {"a": '"'})
404        gen.startElement("e", {"a": "'"})
405        gen.endElement("e")
406        gen.startElement("e", {"a": "'\""})
407        gen.endElement("e")
408        gen.startElement("e", {"a": "\n\r\t"})
409        gen.endElement("e")
410        gen.endElement("doc")
411        gen.endDocument()
412
413        self.assertEqual(result.getvalue(), start +
414            ("<doc a='\"'><e a=\"'\"></e>"
415             "<e a=\"'&quot;\"></e>"
416             "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
417
418    def test_xmlgen_encoding(self):
419        encodings = ('iso-8859-15', 'utf-8',
420                     'utf-16be', 'utf-16le',
421                     'utf-32be', 'utf-32le')
422        for encoding in encodings:
423            result = self.ioclass()
424            gen = XMLGenerator(result, encoding=encoding)
425
426            gen.startDocument()
427            gen.startElement("doc", {"a": u'\u20ac'})
428            gen.characters(u"\u20ac")
429            gen.endElement("doc")
430            gen.endDocument()
431
432            self.assertEqual(result.getvalue(), (
433                u'<?xml version="1.0" encoding="%s"?>\n'
434                u'<doc a="\u20ac">\u20ac</doc>' % encoding
435                ).encode(encoding, 'xmlcharrefreplace'))
436
437    def test_xmlgen_unencodable(self):
438        result = self.ioclass()
439        gen = XMLGenerator(result, encoding='ascii')
440
441        gen.startDocument()
442        gen.startElement("doc", {"a": u'\u20ac'})
443        gen.characters(u"\u20ac")
444        gen.endElement("doc")
445        gen.endDocument()
446
447        self.assertEqual(result.getvalue(),
448                '<?xml version="1.0" encoding="ascii"?>\n'
449                '<doc a="&#8364;">&#8364;</doc>')
450
451    def test_xmlgen_ignorable(self):
452        result = self.ioclass()
453        gen = XMLGenerator(result)
454
455        gen.startDocument()
456        gen.startElement("doc", {})
457        gen.ignorableWhitespace(" ")
458        gen.endElement("doc")
459        gen.endDocument()
460
461        self.assertEqual(result.getvalue(), start + "<doc> </doc>")
462
463    def test_xmlgen_encoding_bytes(self):
464        encodings = ('iso-8859-15', 'utf-8',
465                     'utf-16be', 'utf-16le',
466                     'utf-32be', 'utf-32le')
467        for encoding in encodings:
468            result = self.ioclass()
469            gen = XMLGenerator(result, encoding=encoding)
470
471            gen.startDocument()
472            gen.startElement("doc", {"a": u'\u20ac'})
473            gen.characters(u"\u20ac".encode(encoding))
474            gen.ignorableWhitespace(" ".encode(encoding))
475            gen.endElement("doc")
476            gen.endDocument()
477
478            self.assertEqual(result.getvalue(), (
479                u'<?xml version="1.0" encoding="%s"?>\n'
480                u'<doc a="\u20ac">\u20ac </doc>' % encoding
481                ).encode(encoding, 'xmlcharrefreplace'))
482
483    def test_xmlgen_ns(self):
484        result = self.ioclass()
485        gen = XMLGenerator(result)
486
487        gen.startDocument()
488        gen.startPrefixMapping("ns1", ns_uri)
489        gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
490        # add an unqualified name
491        gen.startElementNS((None, "udoc"), None, {})
492        gen.endElementNS((None, "udoc"), None)
493        gen.endElementNS((ns_uri, "doc"), "ns1:doc")
494        gen.endPrefixMapping("ns1")
495        gen.endDocument()
496
497        self.assertEqual(result.getvalue(), start + \
498           ('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
499                                         ns_uri))
500
501    def test_1463026_1(self):
502        result = self.ioclass()
503        gen = XMLGenerator(result)
504
505        gen.startDocument()
506        gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
507        gen.endElementNS((None, 'a'), 'a')
508        gen.endDocument()
509
510        self.assertEqual(result.getvalue(), start+'<a b="c"></a>')
511
512    def test_1463026_2(self):
513        result = self.ioclass()
514        gen = XMLGenerator(result)
515
516        gen.startDocument()
517        gen.startPrefixMapping(None, 'qux')
518        gen.startElementNS(('qux', 'a'), 'a', {})
519        gen.endElementNS(('qux', 'a'), 'a')
520        gen.endPrefixMapping(None)
521        gen.endDocument()
522
523        self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>')
524
525    def test_1463026_3(self):
526        result = self.ioclass()
527        gen = XMLGenerator(result)
528
529        gen.startDocument()
530        gen.startPrefixMapping('my', 'qux')
531        gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
532        gen.endElementNS(('qux', 'a'), 'a')
533        gen.endPrefixMapping('my')
534        gen.endDocument()
535
536        self.assertEqual(result.getvalue(),
537            start+'<my:a xmlns:my="qux" b="c"></my:a>')
538
539    def test_5027_1(self):
540        # The xml prefix (as in xml:lang below) is reserved and bound by
541        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
542        # a bug whereby a KeyError is raised because this namespace is missing
543        # from a dictionary.
544        #
545        # This test demonstrates the bug by parsing a document.
546        test_xml = StringIO(
547            '<?xml version="1.0"?>'
548            '<a:g1 xmlns:a="http://example.com/ns">'
549             '<a:g2 xml:lang="en">Hello</a:g2>'
550            '</a:g1>')
551
552        parser = make_parser()
553        parser.setFeature(feature_namespaces, True)
554        result = self.ioclass()
555        gen = XMLGenerator(result)
556        parser.setContentHandler(gen)
557        parser.parse(test_xml)
558
559        self.assertEqual(result.getvalue(),
560                         start + (
561                         '<a:g1 xmlns:a="http://example.com/ns">'
562                          '<a:g2 xml:lang="en">Hello</a:g2>'
563                         '</a:g1>'))
564
565    def test_5027_2(self):
566        # The xml prefix (as in xml:lang below) is reserved and bound by
567        # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
568        # a bug whereby a KeyError is raised because this namespace is missing
569        # from a dictionary.
570        #
571        # This test demonstrates the bug by direct manipulation of the
572        # XMLGenerator.
573        result = self.ioclass()
574        gen = XMLGenerator(result)
575
576        gen.startDocument()
577        gen.startPrefixMapping('a', 'http://example.com/ns')
578        gen.startElementNS(('http://example.com/ns', 'g1'), 'g1', {})
579        lang_attr = {('http://www.w3.org/XML/1998/namespace', 'lang'): 'en'}
580        gen.startElementNS(('http://example.com/ns', 'g2'), 'g2', lang_attr)
581        gen.characters('Hello')
582        gen.endElementNS(('http://example.com/ns', 'g2'), 'g2')
583        gen.endElementNS(('http://example.com/ns', 'g1'), 'g1')
584        gen.endPrefixMapping('a')
585        gen.endDocument()
586
587        self.assertEqual(result.getvalue(),
588                         start + (
589                         '<a:g1 xmlns:a="http://example.com/ns">'
590                          '<a:g2 xml:lang="en">Hello</a:g2>'
591                         '</a:g1>'))
592
593    def test_no_close_file(self):
594        result = self.ioclass()
595        def func(out):
596            gen = XMLGenerator(out)
597            gen.startDocument()
598            gen.startElement("doc", {})
599        func(result)
600        self.assertFalse(result.closed)
601
602    def test_xmlgen_fragment(self):
603        result = self.ioclass()
604        gen = XMLGenerator(result)
605
606        # Don't call gen.startDocument()
607        gen.startElement("foo", {"a": "1.0"})
608        gen.characters("Hello")
609        gen.endElement("foo")
610        gen.startElement("bar", {"b": "2.0"})
611        gen.endElement("bar")
612        # Don't call gen.endDocument()
613
614        self.assertEqual(result.getvalue(),
615                         '<foo a="1.0">Hello</foo><bar b="2.0"></bar>')
616
617class StringXmlgenTest(XmlgenTest, unittest.TestCase):
618    ioclass = StringIO
619
620class BytesIOXmlgenTest(XmlgenTest, unittest.TestCase):
621    ioclass = io.BytesIO
622
623class WriterXmlgenTest(XmlgenTest, unittest.TestCase):
624    class ioclass(list):
625        write = list.append
626        closed = False
627
628        def getvalue(self):
629            return b''.join(self)
630
631
632class XMLFilterBaseTest(unittest.TestCase):
633    def test_filter_basic(self):
634        result = StringIO()
635        gen = XMLGenerator(result)
636        filter = XMLFilterBase()
637        filter.setContentHandler(gen)
638
639        filter.startDocument()
640        filter.startElement("doc", {})
641        filter.characters("content")
642        filter.ignorableWhitespace(" ")
643        filter.endElement("doc")
644        filter.endDocument()
645
646        self.assertEqual(result.getvalue(), start + "<doc>content </doc>")
647
648# ===========================================================================
649#
650#   expatreader tests
651#
652# ===========================================================================
653
654xml_test_out = open(TEST_XMLFILE_OUT).read()
655
656class ExpatReaderTest(XmlTestBase):
657
658    # ===== XMLReader support
659
660    def test_expat_binary_file(self):
661        parser = create_parser()
662        result = StringIO()
663        xmlgen = XMLGenerator(result)
664
665        parser.setContentHandler(xmlgen)
666        parser.parse(open(TEST_XMLFILE))
667
668        self.assertEqual(result.getvalue(), xml_test_out)
669
670    @requires_unicode_filenames
671    def test_expat_file_unicode(self):
672        fname = support.TESTFN_UNICODE
673        shutil.copyfile(TEST_XMLFILE, fname)
674        self.addCleanup(support.unlink, fname)
675
676        parser = create_parser()
677        result = StringIO()
678        xmlgen = XMLGenerator(result)
679
680        parser.setContentHandler(xmlgen)
681        parser.parse(open(fname))
682
683        self.assertEqual(result.getvalue(), xml_test_out)
684
685    # ===== DTDHandler support
686
687    class TestDTDHandler:
688
689        def __init__(self):
690            self._notations = []
691            self._entities  = []
692
693        def notationDecl(self, name, publicId, systemId):
694            self._notations.append((name, publicId, systemId))
695
696        def unparsedEntityDecl(self, name, publicId, systemId, ndata):
697            self._entities.append((name, publicId, systemId, ndata))
698
699    def test_expat_dtdhandler(self):
700        parser = create_parser()
701        handler = self.TestDTDHandler()
702        parser.setDTDHandler(handler)
703
704        parser.feed('<!DOCTYPE doc [\n')
705        parser.feed('  <!ENTITY img SYSTEM "expat.gif" NDATA GIF>\n')
706        parser.feed('  <!NOTATION GIF PUBLIC "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN">\n')
707        parser.feed(']>\n')
708        parser.feed('<doc></doc>')
709        parser.close()
710
711        self.assertEqual(handler._notations,
712            [("GIF", "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN", None)])
713        self.assertEqual(handler._entities, [("img", None, "expat.gif", "GIF")])
714
715    # ===== EntityResolver support
716
717    class TestEntityResolver:
718
719        def resolveEntity(self, publicId, systemId):
720            inpsrc = InputSource()
721            inpsrc.setByteStream(StringIO("<entity/>"))
722            return inpsrc
723
724    def test_expat_entityresolver(self):
725        parser = create_parser()
726        parser.setEntityResolver(self.TestEntityResolver())
727        result = StringIO()
728        parser.setContentHandler(XMLGenerator(result))
729
730        parser.feed('<!DOCTYPE doc [\n')
731        parser.feed('  <!ENTITY test SYSTEM "whatever">\n')
732        parser.feed(']>\n')
733        parser.feed('<doc>&test;</doc>')
734        parser.close()
735
736        self.assertEqual(result.getvalue(), start +
737                         "<doc><entity></entity></doc>")
738
739    # ===== Attributes support
740
741    class AttrGatherer(ContentHandler):
742
743        def startElement(self, name, attrs):
744            self._attrs = attrs
745
746        def startElementNS(self, name, qname, attrs):
747            self._attrs = attrs
748
749    def test_expat_attrs_empty(self):
750        parser = create_parser()
751        gather = self.AttrGatherer()
752        parser.setContentHandler(gather)
753
754        parser.feed("<doc/>")
755        parser.close()
756
757        self.verify_empty_attrs(gather._attrs)
758
759    def test_expat_attrs_wattr(self):
760        parser = create_parser()
761        gather = self.AttrGatherer()
762        parser.setContentHandler(gather)
763
764        parser.feed("<doc attr='val'/>")
765        parser.close()
766
767        self.verify_attrs_wattr(gather._attrs)
768
769    def test_expat_nsattrs_empty(self):
770        parser = create_parser(1)
771        gather = self.AttrGatherer()
772        parser.setContentHandler(gather)
773
774        parser.feed("<doc/>")
775        parser.close()
776
777        self.verify_empty_nsattrs(gather._attrs)
778
779    def test_expat_nsattrs_wattr(self):
780        parser = create_parser(1)
781        gather = self.AttrGatherer()
782        parser.setContentHandler(gather)
783
784        parser.feed("<doc xmlns:ns='%s' ns:attr='val'/>" % ns_uri)
785        parser.close()
786
787        attrs = gather._attrs
788
789        self.assertEqual(attrs.getLength(), 1)
790        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
791        self.assertTrue((attrs.getQNames() == [] or
792                         attrs.getQNames() == ["ns:attr"]))
793        self.assertEqual(len(attrs), 1)
794        self.assertTrue(attrs.has_key((ns_uri, "attr")))
795        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
796        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
797        self.assertEqual(attrs.items(), [((ns_uri, "attr"), "val")])
798        self.assertEqual(attrs.values(), ["val"])
799        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
800        self.assertEqual(attrs[(ns_uri, "attr")], "val")
801
802    # ===== InputSource support
803
804    def test_expat_inpsource_filename(self):
805        parser = create_parser()
806        result = StringIO()
807        xmlgen = XMLGenerator(result)
808
809        parser.setContentHandler(xmlgen)
810        parser.parse(TEST_XMLFILE)
811
812        self.assertEqual(result.getvalue(), xml_test_out)
813
814    def test_expat_inpsource_sysid(self):
815        parser = create_parser()
816        result = StringIO()
817        xmlgen = XMLGenerator(result)
818
819        parser.setContentHandler(xmlgen)
820        parser.parse(InputSource(TEST_XMLFILE))
821
822        self.assertEqual(result.getvalue(), xml_test_out)
823
824    @requires_unicode_filenames
825    def test_expat_inpsource_sysid_unicode(self):
826        fname = support.TESTFN_UNICODE
827        shutil.copyfile(TEST_XMLFILE, fname)
828        self.addCleanup(support.unlink, fname)
829
830        parser = create_parser()
831        result = StringIO()
832        xmlgen = XMLGenerator(result)
833
834        parser.setContentHandler(xmlgen)
835        parser.parse(InputSource(fname))
836
837        self.assertEqual(result.getvalue(), xml_test_out)
838
839    def test_expat_inpsource_byte_stream(self):
840        parser = create_parser()
841        result = StringIO()
842        xmlgen = XMLGenerator(result)
843
844        parser.setContentHandler(xmlgen)
845        inpsrc = InputSource()
846        inpsrc.setByteStream(open(TEST_XMLFILE))
847        parser.parse(inpsrc)
848
849        self.assertEqual(result.getvalue(), xml_test_out)
850
851    # ===== IncrementalParser support
852
853    def test_expat_incremental(self):
854        result = StringIO()
855        xmlgen = XMLGenerator(result)
856        parser = create_parser()
857        parser.setContentHandler(xmlgen)
858
859        parser.feed("<doc>")
860        parser.feed("</doc>")
861        parser.close()
862
863        self.assertEqual(result.getvalue(), start + "<doc></doc>")
864
865    def test_expat_incremental_reset(self):
866        result = StringIO()
867        xmlgen = XMLGenerator(result)
868        parser = create_parser()
869        parser.setContentHandler(xmlgen)
870
871        parser.feed("<doc>")
872        parser.feed("text")
873
874        result = StringIO()
875        xmlgen = XMLGenerator(result)
876        parser.setContentHandler(xmlgen)
877        parser.reset()
878
879        parser.feed("<doc>")
880        parser.feed("text")
881        parser.feed("</doc>")
882        parser.close()
883
884        self.assertEqual(result.getvalue(), start + "<doc>text</doc>")
885
886    # ===== Locator support
887
888    def test_expat_locator_noinfo(self):
889        result = StringIO()
890        xmlgen = XMLGenerator(result)
891        parser = create_parser()
892        parser.setContentHandler(xmlgen)
893
894        parser.feed("<doc>")
895        parser.feed("</doc>")
896        parser.close()
897
898        self.assertEqual(parser.getSystemId(), None)
899        self.assertEqual(parser.getPublicId(), None)
900        self.assertEqual(parser.getLineNumber(), 1)
901
902    def test_expat_locator_withinfo(self):
903        result = StringIO()
904        xmlgen = XMLGenerator(result)
905        parser = create_parser()
906        parser.setContentHandler(xmlgen)
907        parser.parse(TEST_XMLFILE)
908
909        self.assertEqual(parser.getSystemId(), TEST_XMLFILE)
910        self.assertEqual(parser.getPublicId(), None)
911
912    @requires_unicode_filenames
913    def test_expat_locator_withinfo_unicode(self):
914        fname = support.TESTFN_UNICODE
915        shutil.copyfile(TEST_XMLFILE, fname)
916        self.addCleanup(support.unlink, fname)
917
918        result = StringIO()
919        xmlgen = XMLGenerator(result)
920        parser = create_parser()
921        parser.setContentHandler(xmlgen)
922        parser.parse(fname)
923
924        self.assertEqual(parser.getSystemId(), fname)
925        self.assertEqual(parser.getPublicId(), None)
926
927
928# ===========================================================================
929#
930#   error reporting
931#
932# ===========================================================================
933
934class ErrorReportingTest(unittest.TestCase):
935    def test_expat_inpsource_location(self):
936        parser = create_parser()
937        parser.setContentHandler(ContentHandler()) # do nothing
938        source = InputSource()
939        source.setByteStream(StringIO("<foo bar foobar>"))   #ill-formed
940        name = "a file name"
941        source.setSystemId(name)
942        try:
943            parser.parse(source)
944            self.fail()
945        except SAXException, e:
946            self.assertEqual(e.getSystemId(), name)
947
948    def test_expat_incomplete(self):
949        parser = create_parser()
950        parser.setContentHandler(ContentHandler()) # do nothing
951        self.assertRaises(SAXParseException, parser.parse, StringIO("<foo>"))
952        self.assertEqual(parser.getColumnNumber(), 5)
953        self.assertEqual(parser.getLineNumber(), 1)
954
955    def test_sax_parse_exception_str(self):
956        # pass various values from a locator to the SAXParseException to
957        # make sure that the __str__() doesn't fall apart when None is
958        # passed instead of an integer line and column number
959        #
960        # use "normal" values for the locator:
961        str(SAXParseException("message", None,
962                              self.DummyLocator(1, 1)))
963        # use None for the line number:
964        str(SAXParseException("message", None,
965                              self.DummyLocator(None, 1)))
966        # use None for the column number:
967        str(SAXParseException("message", None,
968                              self.DummyLocator(1, None)))
969        # use None for both:
970        str(SAXParseException("message", None,
971                              self.DummyLocator(None, None)))
972
973    class DummyLocator:
974        def __init__(self, lineno, colno):
975            self._lineno = lineno
976            self._colno = colno
977
978        def getPublicId(self):
979            return "pubid"
980
981        def getSystemId(self):
982            return "sysid"
983
984        def getLineNumber(self):
985            return self._lineno
986
987        def getColumnNumber(self):
988            return self._colno
989
990# ===========================================================================
991#
992#   xmlreader tests
993#
994# ===========================================================================
995
996class XmlReaderTest(XmlTestBase):
997
998    # ===== AttributesImpl
999    def test_attrs_empty(self):
1000        self.verify_empty_attrs(AttributesImpl({}))
1001
1002    def test_attrs_wattr(self):
1003        self.verify_attrs_wattr(AttributesImpl({"attr" : "val"}))
1004
1005    def test_nsattrs_empty(self):
1006        self.verify_empty_nsattrs(AttributesNSImpl({}, {}))
1007
1008    def test_nsattrs_wattr(self):
1009        attrs = AttributesNSImpl({(ns_uri, "attr") : "val"},
1010                                 {(ns_uri, "attr") : "ns:attr"})
1011
1012        self.assertEqual(attrs.getLength(), 1)
1013        self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
1014        self.assertEqual(attrs.getQNames(), ["ns:attr"])
1015        self.assertEqual(len(attrs), 1)
1016        self.assertTrue(attrs.has_key((ns_uri, "attr")))
1017        self.assertEqual(attrs.keys(), [(ns_uri, "attr")])
1018        self.assertEqual(attrs.get((ns_uri, "attr")), "val")
1019        self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
1020        self.assertEqual(attrs.items(), [((ns_uri, "attr"), "val")])
1021        self.assertEqual(attrs.values(), ["val"])
1022        self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
1023        self.assertEqual(attrs.getValueByQName("ns:attr"), "val")
1024        self.assertEqual(attrs.getNameByQName("ns:attr"), (ns_uri, "attr"))
1025        self.assertEqual(attrs[(ns_uri, "attr")], "val")
1026        self.assertEqual(attrs.getQNameByName((ns_uri, "attr")), "ns:attr")
1027
1028
1029    # During the development of Python 2.5, an attempt to move the "xml"
1030    # package implementation to a new package ("xmlcore") proved painful.
1031    # The goal of this change was to allow applications to be able to
1032    # obtain and rely on behavior in the standard library implementation
1033    # of the XML support without needing to be concerned about the
1034    # availability of the PyXML implementation.
1035    #
1036    # While the existing import hackery in Lib/xml/__init__.py can cause
1037    # PyXML's _xmlpus package to supplant the "xml" package, that only
1038    # works because either implementation uses the "xml" package name for
1039    # imports.
1040    #
1041    # The move resulted in a number of problems related to the fact that
1042    # the import machinery's "package context" is based on the name that's
1043    # being imported rather than the __name__ of the actual package
1044    # containment; it wasn't possible for the "xml" package to be replaced
1045    # by a simple module that indirected imports to the "xmlcore" package.
1046    #
1047    # The following two tests exercised bugs that were introduced in that
1048    # attempt.  Keeping these tests around will help detect problems with
1049    # other attempts to provide reliable access to the standard library's
1050    # implementation of the XML support.
1051
1052    def test_sf_1511497(self):
1053        # Bug report: http://www.python.org/sf/1511497
1054        import sys
1055        old_modules = sys.modules.copy()
1056        for modname in sys.modules.keys():
1057            if modname.startswith("xml."):
1058                del sys.modules[modname]
1059        try:
1060            import xml.sax.expatreader
1061            module = xml.sax.expatreader
1062            self.assertEqual(module.__name__, "xml.sax.expatreader")
1063        finally:
1064            sys.modules.update(old_modules)
1065
1066    def test_sf_1513611(self):
1067        # Bug report: http://www.python.org/sf/1513611
1068        sio = StringIO("invalid")
1069        parser = make_parser()
1070        from xml.sax import SAXParseException
1071        self.assertRaises(SAXParseException, parser.parse, sio)
1072
1073
1074def test_main():
1075    run_unittest(MakeParserTest,
1076                 ParseTest,
1077                 SaxutilsTest,
1078                 PrepareInputSourceTest,
1079                 StringXmlgenTest,
1080                 BytesIOXmlgenTest,
1081                 WriterXmlgenTest,
1082                 ExpatReaderTest,
1083                 ErrorReportingTest,
1084                 XmlReaderTest)
1085
1086if __name__ == "__main__":
1087    test_main()
1088