• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# XXX TypeErrors on calling handlers, or on bad return values from a
2# handler, are obscure and unhelpful.
3
4import os
5import sys
6import sysconfig
7import unittest
8import traceback
9from io import BytesIO
10from test import support
11from test.support import os_helper
12
13from xml.parsers import expat
14from xml.parsers.expat import errors
15
16from test.support import sortdict
17
18
19class SetAttributeTest(unittest.TestCase):
20    def setUp(self):
21        self.parser = expat.ParserCreate(namespace_separator='!')
22
23    def test_buffer_text(self):
24        self.assertIs(self.parser.buffer_text, False)
25        for x in 0, 1, 2, 0:
26            self.parser.buffer_text = x
27            self.assertIs(self.parser.buffer_text, bool(x))
28
29    def test_namespace_prefixes(self):
30        self.assertIs(self.parser.namespace_prefixes, False)
31        for x in 0, 1, 2, 0:
32            self.parser.namespace_prefixes = x
33            self.assertIs(self.parser.namespace_prefixes, bool(x))
34
35    def test_ordered_attributes(self):
36        self.assertIs(self.parser.ordered_attributes, False)
37        for x in 0, 1, 2, 0:
38            self.parser.ordered_attributes = x
39            self.assertIs(self.parser.ordered_attributes, bool(x))
40
41    def test_specified_attributes(self):
42        self.assertIs(self.parser.specified_attributes, False)
43        for x in 0, 1, 2, 0:
44            self.parser.specified_attributes = x
45            self.assertIs(self.parser.specified_attributes, bool(x))
46
47    def test_invalid_attributes(self):
48        with self.assertRaises(AttributeError):
49            self.parser.returns_unicode = 1
50        with self.assertRaises(AttributeError):
51            self.parser.returns_unicode
52
53        # Issue #25019
54        self.assertRaises(TypeError, setattr, self.parser, range(0xF), 0)
55        self.assertRaises(TypeError, self.parser.__setattr__, range(0xF), 0)
56        self.assertRaises(TypeError, getattr, self.parser, range(0xF))
57
58
59data = b'''\
60<?xml version="1.0" encoding="iso-8859-1" standalone="no"?>
61<?xml-stylesheet href="stylesheet.css"?>
62<!-- comment data -->
63<!DOCTYPE quotations SYSTEM "quotations.dtd" [
64<!ELEMENT root ANY>
65<!ATTLIST root attr1 CDATA #REQUIRED attr2 CDATA #IMPLIED>
66<!NOTATION notation SYSTEM "notation.jpeg">
67<!ENTITY acirc "&#226;">
68<!ENTITY external_entity SYSTEM "entity.file">
69<!ENTITY unparsed_entity SYSTEM "entity.file" NDATA notation>
70%unparsed_entity;
71]>
72
73<root attr1="value1" attr2="value2&#8000;">
74<myns:subelement xmlns:myns="http://www.python.org/namespace">
75     Contents of subelements
76</myns:subelement>
77<sub2><![CDATA[contents of CDATA section]]></sub2>
78&external_entity;
79&skipped_entity;
80\xb5
81</root>
82'''
83
84
85# Produce UTF-8 output
86class ParseTest(unittest.TestCase):
87    class Outputter:
88        def __init__(self):
89            self.out = []
90
91        def StartElementHandler(self, name, attrs):
92            self.out.append('Start element: ' + repr(name) + ' ' +
93                            sortdict(attrs))
94
95        def EndElementHandler(self, name):
96            self.out.append('End element: ' + repr(name))
97
98        def CharacterDataHandler(self, data):
99            data = data.strip()
100            if data:
101                self.out.append('Character data: ' + repr(data))
102
103        def ProcessingInstructionHandler(self, target, data):
104            self.out.append('PI: ' + repr(target) + ' ' + repr(data))
105
106        def StartNamespaceDeclHandler(self, prefix, uri):
107            self.out.append('NS decl: ' + repr(prefix) + ' ' + repr(uri))
108
109        def EndNamespaceDeclHandler(self, prefix):
110            self.out.append('End of NS decl: ' + repr(prefix))
111
112        def StartCdataSectionHandler(self):
113            self.out.append('Start of CDATA section')
114
115        def EndCdataSectionHandler(self):
116            self.out.append('End of CDATA section')
117
118        def CommentHandler(self, text):
119            self.out.append('Comment: ' + repr(text))
120
121        def NotationDeclHandler(self, *args):
122            name, base, sysid, pubid = args
123            self.out.append('Notation declared: %s' %(args,))
124
125        def UnparsedEntityDeclHandler(self, *args):
126            entityName, base, systemId, publicId, notationName = args
127            self.out.append('Unparsed entity decl: %s' %(args,))
128
129        def NotStandaloneHandler(self):
130            self.out.append('Not standalone')
131            return 1
132
133        def ExternalEntityRefHandler(self, *args):
134            context, base, sysId, pubId = args
135            self.out.append('External entity ref: %s' %(args[1:],))
136            return 1
137
138        def StartDoctypeDeclHandler(self, *args):
139            self.out.append(('Start doctype', args))
140            return 1
141
142        def EndDoctypeDeclHandler(self):
143            self.out.append("End doctype")
144            return 1
145
146        def EntityDeclHandler(self, *args):
147            self.out.append(('Entity declaration', args))
148            return 1
149
150        def XmlDeclHandler(self, *args):
151            self.out.append(('XML declaration', args))
152            return 1
153
154        def ElementDeclHandler(self, *args):
155            self.out.append(('Element declaration', args))
156            return 1
157
158        def AttlistDeclHandler(self, *args):
159            self.out.append(('Attribute list declaration', args))
160            return 1
161
162        def SkippedEntityHandler(self, *args):
163            self.out.append(("Skipped entity", args))
164            return 1
165
166        def DefaultHandler(self, userData):
167            pass
168
169        def DefaultHandlerExpand(self, userData):
170            pass
171
172    handler_names = [
173        'StartElementHandler', 'EndElementHandler', 'CharacterDataHandler',
174        'ProcessingInstructionHandler', 'UnparsedEntityDeclHandler',
175        'NotationDeclHandler', 'StartNamespaceDeclHandler',
176        'EndNamespaceDeclHandler', 'CommentHandler',
177        'StartCdataSectionHandler', 'EndCdataSectionHandler', 'DefaultHandler',
178        'DefaultHandlerExpand', 'NotStandaloneHandler',
179        'ExternalEntityRefHandler', 'StartDoctypeDeclHandler',
180        'EndDoctypeDeclHandler', 'EntityDeclHandler', 'XmlDeclHandler',
181        'ElementDeclHandler', 'AttlistDeclHandler', 'SkippedEntityHandler',
182        ]
183
184    def _hookup_callbacks(self, parser, handler):
185        """
186        Set each of the callbacks defined on handler and named in
187        self.handler_names on the given parser.
188        """
189        for name in self.handler_names:
190            setattr(parser, name, getattr(handler, name))
191
192    def _verify_parse_output(self, operations):
193        expected_operations = [
194            ('XML declaration', ('1.0', 'iso-8859-1', 0)),
195            'PI: \'xml-stylesheet\' \'href="stylesheet.css"\'',
196            "Comment: ' comment data '",
197            "Not standalone",
198            ("Start doctype", ('quotations', 'quotations.dtd', None, 1)),
199            ('Element declaration', ('root', (2, 0, None, ()))),
200            ('Attribute list declaration', ('root', 'attr1', 'CDATA', None,
201                1)),
202            ('Attribute list declaration', ('root', 'attr2', 'CDATA', None,
203                0)),
204            "Notation declared: ('notation', None, 'notation.jpeg', None)",
205            ('Entity declaration', ('acirc', 0, '\xe2', None, None, None, None)),
206            ('Entity declaration', ('external_entity', 0, None, None,
207                'entity.file', None, None)),
208            "Unparsed entity decl: ('unparsed_entity', None, 'entity.file', None, 'notation')",
209            "Not standalone",
210            "End doctype",
211            "Start element: 'root' {'attr1': 'value1', 'attr2': 'value2\u1f40'}",
212            "NS decl: 'myns' 'http://www.python.org/namespace'",
213            "Start element: 'http://www.python.org/namespace!subelement' {}",
214            "Character data: 'Contents of subelements'",
215            "End element: 'http://www.python.org/namespace!subelement'",
216            "End of NS decl: 'myns'",
217            "Start element: 'sub2' {}",
218            'Start of CDATA section',
219            "Character data: 'contents of CDATA section'",
220            'End of CDATA section',
221            "End element: 'sub2'",
222            "External entity ref: (None, 'entity.file', None)",
223            ('Skipped entity', ('skipped_entity', 0)),
224            "Character data: '\xb5'",
225            "End element: 'root'",
226        ]
227        for operation, expected_operation in zip(operations, expected_operations):
228            self.assertEqual(operation, expected_operation)
229
230    def test_parse_bytes(self):
231        out = self.Outputter()
232        parser = expat.ParserCreate(namespace_separator='!')
233        self._hookup_callbacks(parser, out)
234
235        parser.Parse(data, True)
236
237        operations = out.out
238        self._verify_parse_output(operations)
239        # Issue #6697.
240        self.assertRaises(AttributeError, getattr, parser, '\uD800')
241
242    def test_parse_str(self):
243        out = self.Outputter()
244        parser = expat.ParserCreate(namespace_separator='!')
245        self._hookup_callbacks(parser, out)
246
247        parser.Parse(data.decode('iso-8859-1'), True)
248
249        operations = out.out
250        self._verify_parse_output(operations)
251
252    def test_parse_file(self):
253        # Try parsing a file
254        out = self.Outputter()
255        parser = expat.ParserCreate(namespace_separator='!')
256        self._hookup_callbacks(parser, out)
257        file = BytesIO(data)
258
259        parser.ParseFile(file)
260
261        operations = out.out
262        self._verify_parse_output(operations)
263
264    def test_parse_again(self):
265        parser = expat.ParserCreate()
266        file = BytesIO(data)
267        parser.ParseFile(file)
268        # Issue 6676: ensure a meaningful exception is raised when attempting
269        # to parse more than one XML document per xmlparser instance,
270        # a limitation of the Expat library.
271        with self.assertRaises(expat.error) as cm:
272            parser.ParseFile(file)
273        self.assertEqual(expat.ErrorString(cm.exception.code),
274                          expat.errors.XML_ERROR_FINISHED)
275
276class NamespaceSeparatorTest(unittest.TestCase):
277    def test_legal(self):
278        # Tests that make sure we get errors when the namespace_separator value
279        # is illegal, and that we don't for good values:
280        expat.ParserCreate()
281        expat.ParserCreate(namespace_separator=None)
282        expat.ParserCreate(namespace_separator=' ')
283
284    def test_illegal(self):
285        with self.assertRaisesRegex(TypeError,
286                r"ParserCreate\(\) argument (2|'namespace_separator') "
287                r"must be str or None, not int"):
288            expat.ParserCreate(namespace_separator=42)
289
290        try:
291            expat.ParserCreate(namespace_separator='too long')
292            self.fail()
293        except ValueError as e:
294            self.assertEqual(str(e),
295                'namespace_separator must be at most one character, omitted, or None')
296
297    def test_zero_length(self):
298        # ParserCreate() needs to accept a namespace_separator of zero length
299        # to satisfy the requirements of RDF applications that are required
300        # to simply glue together the namespace URI and the localname.  Though
301        # considered a wart of the RDF specifications, it needs to be supported.
302        #
303        # See XML-SIG mailing list thread starting with
304        # http://mail.python.org/pipermail/xml-sig/2001-April/005202.html
305        #
306        expat.ParserCreate(namespace_separator='') # too short
307
308
309class InterningTest(unittest.TestCase):
310    def test(self):
311        # Test the interning machinery.
312        p = expat.ParserCreate()
313        L = []
314        def collector(name, *args):
315            L.append(name)
316        p.StartElementHandler = collector
317        p.EndElementHandler = collector
318        p.Parse(b"<e> <e/> <e></e> </e>", True)
319        tag = L[0]
320        self.assertEqual(len(L), 6)
321        for entry in L:
322            # L should have the same string repeated over and over.
323            self.assertTrue(tag is entry)
324
325    def test_issue9402(self):
326        # create an ExternalEntityParserCreate with buffer text
327        class ExternalOutputter:
328            def __init__(self, parser):
329                self.parser = parser
330                self.parser_result = None
331
332            def ExternalEntityRefHandler(self, context, base, sysId, pubId):
333                external_parser = self.parser.ExternalEntityParserCreate("")
334                self.parser_result = external_parser.Parse(b"", True)
335                return 1
336
337        parser = expat.ParserCreate(namespace_separator='!')
338        parser.buffer_text = 1
339        out = ExternalOutputter(parser)
340        parser.ExternalEntityRefHandler = out.ExternalEntityRefHandler
341        parser.Parse(data, True)
342        self.assertEqual(out.parser_result, 1)
343
344
345class BufferTextTest(unittest.TestCase):
346    def setUp(self):
347        self.stuff = []
348        self.parser = expat.ParserCreate()
349        self.parser.buffer_text = 1
350        self.parser.CharacterDataHandler = self.CharacterDataHandler
351
352    def check(self, expected, label):
353        self.assertEqual(self.stuff, expected,
354                "%s\nstuff    = %r\nexpected = %r"
355                % (label, self.stuff, map(str, expected)))
356
357    def CharacterDataHandler(self, text):
358        self.stuff.append(text)
359
360    def StartElementHandler(self, name, attrs):
361        self.stuff.append("<%s>" % name)
362        bt = attrs.get("buffer-text")
363        if bt == "yes":
364            self.parser.buffer_text = 1
365        elif bt == "no":
366            self.parser.buffer_text = 0
367
368    def EndElementHandler(self, name):
369        self.stuff.append("</%s>" % name)
370
371    def CommentHandler(self, data):
372        self.stuff.append("<!--%s-->" % data)
373
374    def setHandlers(self, handlers=[]):
375        for name in handlers:
376            setattr(self.parser, name, getattr(self, name))
377
378    def test_default_to_disabled(self):
379        parser = expat.ParserCreate()
380        self.assertFalse(parser.buffer_text)
381
382    def test_buffering_enabled(self):
383        # Make sure buffering is turned on
384        self.assertTrue(self.parser.buffer_text)
385        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", True)
386        self.assertEqual(self.stuff, ['123'],
387                         "buffered text not properly collapsed")
388
389    def test1(self):
390        # XXX This test exposes more detail of Expat's text chunking than we
391        # XXX like, but it tests what we need to concisely.
392        self.setHandlers(["StartElementHandler"])
393        self.parser.Parse(b"<a>1<b buffer-text='no'/>2\n3<c buffer-text='yes'/>4\n5</a>", True)
394        self.assertEqual(self.stuff,
395                         ["<a>", "1", "<b>", "2", "\n", "3", "<c>", "4\n5"],
396                         "buffering control not reacting as expected")
397
398    def test2(self):
399        self.parser.Parse(b"<a>1<b/>&lt;2&gt;<c/>&#32;\n&#x20;3</a>", True)
400        self.assertEqual(self.stuff, ["1<2> \n 3"],
401                         "buffered text not properly collapsed")
402
403    def test3(self):
404        self.setHandlers(["StartElementHandler"])
405        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", True)
406        self.assertEqual(self.stuff, ["<a>", "1", "<b>", "2", "<c>", "3"],
407                         "buffered text not properly split")
408
409    def test4(self):
410        self.setHandlers(["StartElementHandler", "EndElementHandler"])
411        self.parser.CharacterDataHandler = None
412        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", True)
413        self.assertEqual(self.stuff,
414                         ["<a>", "<b>", "</b>", "<c>", "</c>", "</a>"])
415
416    def test5(self):
417        self.setHandlers(["StartElementHandler", "EndElementHandler"])
418        self.parser.Parse(b"<a>1<b></b>2<c/>3</a>", True)
419        self.assertEqual(self.stuff,
420            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3", "</a>"])
421
422    def test6(self):
423        self.setHandlers(["CommentHandler", "EndElementHandler",
424                    "StartElementHandler"])
425        self.parser.Parse(b"<a>1<b/>2<c></c>345</a> ", True)
426        self.assertEqual(self.stuff,
427            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "345", "</a>"],
428            "buffered text not properly split")
429
430    def test7(self):
431        self.setHandlers(["CommentHandler", "EndElementHandler",
432                    "StartElementHandler"])
433        self.parser.Parse(b"<a>1<b/>2<c></c>3<!--abc-->4<!--def-->5</a> ", True)
434        self.assertEqual(self.stuff,
435                         ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3",
436                          "<!--abc-->", "4", "<!--def-->", "5", "</a>"],
437                         "buffered text not properly split")
438
439
440# Test handling of exception from callback:
441class HandlerExceptionTest(unittest.TestCase):
442    def StartElementHandler(self, name, attrs):
443        raise RuntimeError(f'StartElementHandler: <{name}>')
444
445    def check_traceback_entry(self, entry, filename, funcname):
446        self.assertEqual(os.path.basename(entry.filename), filename)
447        self.assertEqual(entry.name, funcname)
448
449    @support.cpython_only
450    def test_exception(self):
451        # gh-66652: test _PyTraceback_Add() used by pyexpat.c to inject frames
452
453        # Change the current directory to the Python source code directory
454        # if it is available.
455        src_dir = sysconfig.get_config_var('abs_builddir')
456        if src_dir:
457            have_source = os.path.isdir(src_dir)
458        else:
459            have_source = False
460        if have_source:
461            with os_helper.change_cwd(src_dir):
462                self._test_exception(have_source)
463        else:
464            self._test_exception(have_source)
465
466    def _test_exception(self, have_source):
467        # Use path relative to the current directory which should be the Python
468        # source code directory (if it is available).
469        PYEXPAT_C = os.path.join('Modules', 'pyexpat.c')
470
471        parser = expat.ParserCreate()
472        parser.StartElementHandler = self.StartElementHandler
473        try:
474            parser.Parse(b"<a><b><c/></b></a>", True)
475
476            self.fail("the parser did not raise RuntimeError")
477        except RuntimeError as exc:
478            self.assertEqual(exc.args[0], 'StartElementHandler: <a>', exc)
479            entries = traceback.extract_tb(exc.__traceback__)
480
481        self.assertEqual(len(entries), 3, entries)
482        self.check_traceback_entry(entries[0],
483                                   "test_pyexpat.py", "_test_exception")
484        self.check_traceback_entry(entries[1],
485                                   os.path.basename(PYEXPAT_C),
486                                   "StartElement")
487        self.check_traceback_entry(entries[2],
488                                   "test_pyexpat.py", "StartElementHandler")
489
490        # Check that the traceback contains the relevant line in
491        # Modules/pyexpat.c. Skip the test if Modules/pyexpat.c is not
492        # available.
493        if have_source and os.path.exists(PYEXPAT_C):
494            self.assertIn('call_with_frame("StartElement"',
495                          entries[1].line)
496
497
498# Test Current* members:
499class PositionTest(unittest.TestCase):
500    def StartElementHandler(self, name, attrs):
501        self.check_pos('s')
502
503    def EndElementHandler(self, name):
504        self.check_pos('e')
505
506    def check_pos(self, event):
507        pos = (event,
508               self.parser.CurrentByteIndex,
509               self.parser.CurrentLineNumber,
510               self.parser.CurrentColumnNumber)
511        self.assertTrue(self.upto < len(self.expected_list),
512                        'too many parser events')
513        expected = self.expected_list[self.upto]
514        self.assertEqual(pos, expected,
515                'Expected position %s, got position %s' %(pos, expected))
516        self.upto += 1
517
518    def test(self):
519        self.parser = expat.ParserCreate()
520        self.parser.StartElementHandler = self.StartElementHandler
521        self.parser.EndElementHandler = self.EndElementHandler
522        self.upto = 0
523        self.expected_list = [('s', 0, 1, 0), ('s', 5, 2, 1), ('s', 11, 3, 2),
524                              ('e', 15, 3, 6), ('e', 17, 4, 1), ('e', 22, 5, 0)]
525
526        xml = b'<a>\n <b>\n  <c/>\n </b>\n</a>'
527        self.parser.Parse(xml, True)
528
529
530class sf1296433Test(unittest.TestCase):
531    def test_parse_only_xml_data(self):
532        # https://bugs.python.org/issue1296433
533        #
534        xml = "<?xml version='1.0' encoding='iso8859'?><s>%s</s>" % ('a' * 1025)
535        # this one doesn't crash
536        #xml = "<?xml version='1.0'?><s>%s</s>" % ('a' * 10000)
537
538        class SpecificException(Exception):
539            pass
540
541        def handler(text):
542            raise SpecificException
543
544        parser = expat.ParserCreate()
545        parser.CharacterDataHandler = handler
546
547        self.assertRaises(SpecificException, parser.Parse, xml.encode('iso8859'))
548
549class ChardataBufferTest(unittest.TestCase):
550    """
551    test setting of chardata buffer size
552    """
553
554    def test_1025_bytes(self):
555        self.assertEqual(self.small_buffer_test(1025), 2)
556
557    def test_1000_bytes(self):
558        self.assertEqual(self.small_buffer_test(1000), 1)
559
560    def test_wrong_size(self):
561        parser = expat.ParserCreate()
562        parser.buffer_text = 1
563        with self.assertRaises(ValueError):
564            parser.buffer_size = -1
565        with self.assertRaises(ValueError):
566            parser.buffer_size = 0
567        with self.assertRaises((ValueError, OverflowError)):
568            parser.buffer_size = sys.maxsize + 1
569        with self.assertRaises(TypeError):
570            parser.buffer_size = 512.0
571
572    def test_unchanged_size(self):
573        xml1 = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * 512
574        xml2 = b'a'*512 + b'</s>'
575        parser = expat.ParserCreate()
576        parser.CharacterDataHandler = self.counting_handler
577        parser.buffer_size = 512
578        parser.buffer_text = 1
579
580        # Feed 512 bytes of character data: the handler should be called
581        # once.
582        self.n = 0
583        parser.Parse(xml1)
584        self.assertEqual(self.n, 1)
585
586        # Reassign to buffer_size, but assign the same size.
587        parser.buffer_size = parser.buffer_size
588        self.assertEqual(self.n, 1)
589
590        # Try parsing rest of the document
591        parser.Parse(xml2)
592        self.assertEqual(self.n, 2)
593
594
595    def test_disabling_buffer(self):
596        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>" + b'a' * 512
597        xml2 = b'b' * 1024
598        xml3 = b'c' * 1024 + b'</a>';
599        parser = expat.ParserCreate()
600        parser.CharacterDataHandler = self.counting_handler
601        parser.buffer_text = 1
602        parser.buffer_size = 1024
603        self.assertEqual(parser.buffer_size, 1024)
604
605        # Parse one chunk of XML
606        self.n = 0
607        parser.Parse(xml1, False)
608        self.assertEqual(parser.buffer_size, 1024)
609        self.assertEqual(self.n, 1)
610
611        # Turn off buffering and parse the next chunk.
612        parser.buffer_text = 0
613        self.assertFalse(parser.buffer_text)
614        self.assertEqual(parser.buffer_size, 1024)
615        for i in range(10):
616            parser.Parse(xml2, False)
617        self.assertEqual(self.n, 11)
618
619        parser.buffer_text = 1
620        self.assertTrue(parser.buffer_text)
621        self.assertEqual(parser.buffer_size, 1024)
622        parser.Parse(xml3, True)
623        self.assertEqual(self.n, 12)
624
625    def counting_handler(self, text):
626        self.n += 1
627
628    def small_buffer_test(self, buffer_len):
629        xml = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * buffer_len + b'</s>'
630        parser = expat.ParserCreate()
631        parser.CharacterDataHandler = self.counting_handler
632        parser.buffer_size = 1024
633        parser.buffer_text = 1
634
635        self.n = 0
636        parser.Parse(xml)
637        return self.n
638
639    def test_change_size_1(self):
640        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a><s>" + b'a' * 1024
641        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
642        parser = expat.ParserCreate()
643        parser.CharacterDataHandler = self.counting_handler
644        parser.buffer_text = 1
645        parser.buffer_size = 1024
646        self.assertEqual(parser.buffer_size, 1024)
647
648        self.n = 0
649        parser.Parse(xml1, False)
650        parser.buffer_size *= 2
651        self.assertEqual(parser.buffer_size, 2048)
652        parser.Parse(xml2, True)
653        self.assertEqual(self.n, 2)
654
655    def test_change_size_2(self):
656        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>a<s>" + b'a' * 1023
657        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
658        parser = expat.ParserCreate()
659        parser.CharacterDataHandler = self.counting_handler
660        parser.buffer_text = 1
661        parser.buffer_size = 2048
662        self.assertEqual(parser.buffer_size, 2048)
663
664        self.n=0
665        parser.Parse(xml1, False)
666        parser.buffer_size = parser.buffer_size // 2
667        self.assertEqual(parser.buffer_size, 1024)
668        parser.Parse(xml2, True)
669        self.assertEqual(self.n, 4)
670
671class MalformedInputTest(unittest.TestCase):
672    def test1(self):
673        xml = b"\0\r\n"
674        parser = expat.ParserCreate()
675        try:
676            parser.Parse(xml, True)
677            self.fail()
678        except expat.ExpatError as e:
679            self.assertEqual(str(e), 'unclosed token: line 2, column 0')
680
681    def test2(self):
682        # \xc2\x85 is UTF-8 encoded U+0085 (NEXT LINE)
683        xml = b"<?xml version\xc2\x85='1.0'?>\r\n"
684        parser = expat.ParserCreate()
685        err_pattern = r'XML declaration not well-formed: line 1, column \d+'
686        with self.assertRaisesRegex(expat.ExpatError, err_pattern):
687            parser.Parse(xml, True)
688
689class ErrorMessageTest(unittest.TestCase):
690    def test_codes(self):
691        # verify mapping of errors.codes and errors.messages
692        self.assertEqual(errors.XML_ERROR_SYNTAX,
693                         errors.messages[errors.codes[errors.XML_ERROR_SYNTAX]])
694
695    def test_expaterror(self):
696        xml = b'<'
697        parser = expat.ParserCreate()
698        try:
699            parser.Parse(xml, True)
700            self.fail()
701        except expat.ExpatError as e:
702            self.assertEqual(e.code,
703                             errors.codes[errors.XML_ERROR_UNCLOSED_TOKEN])
704
705
706class ForeignDTDTests(unittest.TestCase):
707    """
708    Tests for the UseForeignDTD method of expat parser objects.
709    """
710    def test_use_foreign_dtd(self):
711        """
712        If UseForeignDTD is passed True and a document without an external
713        entity reference is parsed, ExternalEntityRefHandler is first called
714        with None for the public and system ids.
715        """
716        handler_call_args = []
717        def resolve_entity(context, base, system_id, public_id):
718            handler_call_args.append((public_id, system_id))
719            return 1
720
721        parser = expat.ParserCreate()
722        parser.UseForeignDTD(True)
723        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
724        parser.ExternalEntityRefHandler = resolve_entity
725        parser.Parse(b"<?xml version='1.0'?><element/>")
726        self.assertEqual(handler_call_args, [(None, None)])
727
728        # test UseForeignDTD() is equal to UseForeignDTD(True)
729        handler_call_args[:] = []
730
731        parser = expat.ParserCreate()
732        parser.UseForeignDTD()
733        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
734        parser.ExternalEntityRefHandler = resolve_entity
735        parser.Parse(b"<?xml version='1.0'?><element/>")
736        self.assertEqual(handler_call_args, [(None, None)])
737
738    def test_ignore_use_foreign_dtd(self):
739        """
740        If UseForeignDTD is passed True and a document with an external
741        entity reference is parsed, ExternalEntityRefHandler is called with
742        the public and system ids from the document.
743        """
744        handler_call_args = []
745        def resolve_entity(context, base, system_id, public_id):
746            handler_call_args.append((public_id, system_id))
747            return 1
748
749        parser = expat.ParserCreate()
750        parser.UseForeignDTD(True)
751        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
752        parser.ExternalEntityRefHandler = resolve_entity
753        parser.Parse(
754            b"<?xml version='1.0'?><!DOCTYPE foo PUBLIC 'bar' 'baz'><element/>")
755        self.assertEqual(handler_call_args, [("bar", "baz")])
756
757
758class ReparseDeferralTest(unittest.TestCase):
759    def test_getter_setter_round_trip(self):
760        parser = expat.ParserCreate()
761        enabled = (expat.version_info >= (2, 6, 0))
762
763        self.assertIs(parser.GetReparseDeferralEnabled(), enabled)
764        parser.SetReparseDeferralEnabled(False)
765        self.assertIs(parser.GetReparseDeferralEnabled(), False)
766        parser.SetReparseDeferralEnabled(True)
767        self.assertIs(parser.GetReparseDeferralEnabled(), enabled)
768
769    def test_reparse_deferral_enabled(self):
770        if expat.version_info < (2, 6, 0):
771            self.skipTest(f'Expat {expat.version_info} does not '
772                          'support reparse deferral')
773
774        started = []
775
776        def start_element(name, _):
777            started.append(name)
778
779        parser = expat.ParserCreate()
780        parser.StartElementHandler = start_element
781        self.assertTrue(parser.GetReparseDeferralEnabled())
782
783        for chunk in (b'<doc', b'/>'):
784            parser.Parse(chunk, False)
785
786        # The key test: Have handlers already fired?  Expecting: no.
787        self.assertEqual(started, [])
788
789        parser.Parse(b'', True)
790
791        self.assertEqual(started, ['doc'])
792
793    def test_reparse_deferral_disabled(self):
794        started = []
795
796        def start_element(name, _):
797            started.append(name)
798
799        parser = expat.ParserCreate()
800        parser.StartElementHandler = start_element
801        if expat.version_info >= (2, 6, 0):
802            parser.SetReparseDeferralEnabled(False)
803        self.assertFalse(parser.GetReparseDeferralEnabled())
804
805        for chunk in (b'<doc', b'/>'):
806            parser.Parse(chunk, False)
807
808        # The key test: Have handlers already fired?  Expecting: yes.
809        self.assertEqual(started, ['doc'])
810
811
812if __name__ == "__main__":
813    unittest.main()
814