• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import io
2import marshal
3import os
4import sys
5from test import support
6from test.support import import_helper
7import types
8import unittest
9from unittest import mock
10import warnings
11
12from . import util as test_util
13
14init = test_util.import_importlib('importlib')
15abc = test_util.import_importlib('importlib.abc')
16machinery = test_util.import_importlib('importlib.machinery')
17util = test_util.import_importlib('importlib.util')
18
19
20##### Inheritance ##############################################################
21class InheritanceTests:
22
23    """Test that the specified class is a subclass/superclass of the expected
24    classes."""
25
26    subclasses = []
27    superclasses = []
28
29    def setUp(self):
30        self.superclasses = [getattr(self.abc, class_name)
31                             for class_name in self.superclass_names]
32        if hasattr(self, 'subclass_names'):
33            # Because test.support.import_fresh_module() creates a new
34            # importlib._bootstrap per module, inheritance checks fail when
35            # checking across module boundaries (i.e. the _bootstrap in abc is
36            # not the same as the one in machinery). That means stealing one of
37            # the modules from the other to make sure the same instance is used.
38            machinery = self.abc.machinery
39            self.subclasses = [getattr(machinery, class_name)
40                               for class_name in self.subclass_names]
41        assert self.subclasses or self.superclasses, self.__class__
42        self.__test = getattr(self.abc, self._NAME)
43
44    def test_subclasses(self):
45        # Test that the expected subclasses inherit.
46        for subclass in self.subclasses:
47            self.assertTrue(issubclass(subclass, self.__test),
48                "{0} is not a subclass of {1}".format(subclass, self.__test))
49
50    def test_superclasses(self):
51        # Test that the class inherits from the expected superclasses.
52        for superclass in self.superclasses:
53            self.assertTrue(issubclass(self.__test, superclass),
54               "{0} is not a superclass of {1}".format(superclass, self.__test))
55
56
57class MetaPathFinder(InheritanceTests):
58    superclass_names = []
59    subclass_names = ['BuiltinImporter', 'FrozenImporter', 'PathFinder',
60                      'WindowsRegistryFinder']
61
62
63(Frozen_MetaPathFinderInheritanceTests,
64 Source_MetaPathFinderInheritanceTests
65 ) = test_util.test_both(MetaPathFinder, abc=abc)
66
67
68class PathEntryFinder(InheritanceTests):
69    superclass_names = []
70    subclass_names = ['FileFinder']
71
72
73(Frozen_PathEntryFinderInheritanceTests,
74 Source_PathEntryFinderInheritanceTests
75 ) = test_util.test_both(PathEntryFinder, abc=abc)
76
77
78class ResourceLoader(InheritanceTests):
79    superclass_names = ['Loader']
80
81
82(Frozen_ResourceLoaderInheritanceTests,
83 Source_ResourceLoaderInheritanceTests
84 ) = test_util.test_both(ResourceLoader, abc=abc)
85
86
87class InspectLoader(InheritanceTests):
88    superclass_names = ['Loader']
89    subclass_names = ['BuiltinImporter', 'FrozenImporter', 'ExtensionFileLoader']
90
91
92(Frozen_InspectLoaderInheritanceTests,
93 Source_InspectLoaderInheritanceTests
94 ) = test_util.test_both(InspectLoader, abc=abc)
95
96
97class ExecutionLoader(InheritanceTests):
98    superclass_names = ['InspectLoader']
99    subclass_names = ['ExtensionFileLoader']
100
101
102(Frozen_ExecutionLoaderInheritanceTests,
103 Source_ExecutionLoaderInheritanceTests
104 ) = test_util.test_both(ExecutionLoader, abc=abc)
105
106
107class FileLoader(InheritanceTests):
108    superclass_names = ['ResourceLoader', 'ExecutionLoader']
109    subclass_names = ['SourceFileLoader', 'SourcelessFileLoader']
110
111
112(Frozen_FileLoaderInheritanceTests,
113 Source_FileLoaderInheritanceTests
114 ) = test_util.test_both(FileLoader, abc=abc)
115
116
117class SourceLoader(InheritanceTests):
118    superclass_names = ['ResourceLoader', 'ExecutionLoader']
119    subclass_names = ['SourceFileLoader']
120
121
122(Frozen_SourceLoaderInheritanceTests,
123 Source_SourceLoaderInheritanceTests
124 ) = test_util.test_both(SourceLoader, abc=abc)
125
126
127##### Default return values ####################################################
128
129def make_abc_subclasses(base_class, name=None, inst=False, **kwargs):
130    if name is None:
131        name = base_class.__name__
132    base = {kind: getattr(splitabc, name)
133            for kind, splitabc in abc.items()}
134    return {cls._KIND: cls() if inst else cls
135            for cls in test_util.split_frozen(base_class, base, **kwargs)}
136
137
138class ABCTestHarness:
139
140    @property
141    def ins(self):
142        # Lazily set ins on the class.
143        cls = self.SPLIT[self._KIND]
144        ins = cls()
145        self.__class__.ins = ins
146        return ins
147
148
149class MetaPathFinder:
150
151    def find_module(self, fullname, path):
152        return super().find_module(fullname, path)
153
154
155class MetaPathFinderDefaultsTests(ABCTestHarness):
156
157    SPLIT = make_abc_subclasses(MetaPathFinder)
158
159    def test_find_module(self):
160        # Default should return None.
161        with self.assertWarns(DeprecationWarning):
162            found = self.ins.find_module('something', None)
163        self.assertIsNone(found)
164
165    def test_invalidate_caches(self):
166        # Calling the method is a no-op.
167        self.ins.invalidate_caches()
168
169
170(Frozen_MPFDefaultTests,
171 Source_MPFDefaultTests
172 ) = test_util.test_both(MetaPathFinderDefaultsTests)
173
174
175class PathEntryFinder:
176
177    def find_loader(self, fullname):
178        return super().find_loader(fullname)
179
180
181class PathEntryFinderDefaultsTests(ABCTestHarness):
182
183    SPLIT = make_abc_subclasses(PathEntryFinder)
184
185    def test_find_loader(self):
186        with self.assertWarns(DeprecationWarning):
187            found = self.ins.find_loader('something')
188        self.assertEqual(found, (None, []))
189
190    def find_module(self):
191        self.assertEqual(None, self.ins.find_module('something'))
192
193    def test_invalidate_caches(self):
194        # Should be a no-op.
195        self.ins.invalidate_caches()
196
197
198(Frozen_PEFDefaultTests,
199 Source_PEFDefaultTests
200 ) = test_util.test_both(PathEntryFinderDefaultsTests)
201
202
203class Loader:
204
205    def load_module(self, fullname):
206        return super().load_module(fullname)
207
208
209class LoaderDefaultsTests(ABCTestHarness):
210
211    SPLIT = make_abc_subclasses(Loader)
212
213    def test_create_module(self):
214        spec = 'a spec'
215        self.assertIsNone(self.ins.create_module(spec))
216
217    def test_load_module(self):
218        with self.assertRaises(ImportError):
219            self.ins.load_module('something')
220
221    def test_module_repr(self):
222        mod = types.ModuleType('blah')
223        with warnings.catch_warnings():
224            warnings.simplefilter("ignore", DeprecationWarning)
225            with self.assertRaises(NotImplementedError):
226                self.ins.module_repr(mod)
227            original_repr = repr(mod)
228            mod.__loader__ = self.ins
229            # Should still return a proper repr.
230            self.assertTrue(repr(mod))
231
232
233(Frozen_LDefaultTests,
234 SourceLDefaultTests
235 ) = test_util.test_both(LoaderDefaultsTests)
236
237
238class ResourceLoader(Loader):
239
240    def get_data(self, path):
241        return super().get_data(path)
242
243
244class ResourceLoaderDefaultsTests(ABCTestHarness):
245
246    SPLIT = make_abc_subclasses(ResourceLoader)
247
248    def test_get_data(self):
249        with self.assertRaises(IOError):
250            self.ins.get_data('/some/path')
251
252
253(Frozen_RLDefaultTests,
254 Source_RLDefaultTests
255 ) = test_util.test_both(ResourceLoaderDefaultsTests)
256
257
258class InspectLoader(Loader):
259
260    def is_package(self, fullname):
261        return super().is_package(fullname)
262
263    def get_source(self, fullname):
264        return super().get_source(fullname)
265
266
267SPLIT_IL = make_abc_subclasses(InspectLoader)
268
269
270class InspectLoaderDefaultsTests(ABCTestHarness):
271
272    SPLIT = SPLIT_IL
273
274    def test_is_package(self):
275        with self.assertRaises(ImportError):
276            self.ins.is_package('blah')
277
278    def test_get_source(self):
279        with self.assertRaises(ImportError):
280            self.ins.get_source('blah')
281
282
283(Frozen_ILDefaultTests,
284 Source_ILDefaultTests
285 ) = test_util.test_both(InspectLoaderDefaultsTests)
286
287
288class ExecutionLoader(InspectLoader):
289
290    def get_filename(self, fullname):
291        return super().get_filename(fullname)
292
293
294SPLIT_EL = make_abc_subclasses(ExecutionLoader)
295
296
297class ExecutionLoaderDefaultsTests(ABCTestHarness):
298
299    SPLIT = SPLIT_EL
300
301    def test_get_filename(self):
302        with self.assertRaises(ImportError):
303            self.ins.get_filename('blah')
304
305
306(Frozen_ELDefaultTests,
307 Source_ELDefaultsTests
308 ) = test_util.test_both(InspectLoaderDefaultsTests)
309
310
311class ResourceReader:
312
313    def open_resource(self, *args, **kwargs):
314        return super().open_resource(*args, **kwargs)
315
316    def resource_path(self, *args, **kwargs):
317        return super().resource_path(*args, **kwargs)
318
319    def is_resource(self, *args, **kwargs):
320        return super().is_resource(*args, **kwargs)
321
322    def contents(self, *args, **kwargs):
323        return super().contents(*args, **kwargs)
324
325
326class ResourceReaderDefaultsTests(ABCTestHarness):
327
328    SPLIT = make_abc_subclasses(ResourceReader)
329
330    def test_open_resource(self):
331        with self.assertRaises(FileNotFoundError):
332            self.ins.open_resource('dummy_file')
333
334    def test_resource_path(self):
335        with self.assertRaises(FileNotFoundError):
336            self.ins.resource_path('dummy_file')
337
338    def test_is_resource(self):
339        with self.assertRaises(FileNotFoundError):
340            self.ins.is_resource('dummy_file')
341
342    def test_contents(self):
343        with self.assertRaises(FileNotFoundError):
344            self.ins.contents()
345
346
347(Frozen_RRDefaultTests,
348 Source_RRDefaultsTests
349 ) = test_util.test_both(ResourceReaderDefaultsTests)
350
351
352##### MetaPathFinder concrete methods ##########################################
353class MetaPathFinderFindModuleTests:
354
355    @classmethod
356    def finder(cls, spec):
357        class MetaPathSpecFinder(cls.abc.MetaPathFinder):
358
359            def find_spec(self, fullname, path, target=None):
360                self.called_for = fullname, path
361                return spec
362
363        return MetaPathSpecFinder()
364
365    def test_find_module(self):
366        finder = self.finder(None)
367        path = ['a', 'b', 'c']
368        name = 'blah'
369        with self.assertWarns(DeprecationWarning):
370            found = finder.find_module(name, path)
371        self.assertIsNone(found)
372
373    def test_find_spec_with_explicit_target(self):
374        loader = object()
375        spec = self.util.spec_from_loader('blah', loader)
376        finder = self.finder(spec)
377        found = finder.find_spec('blah', 'blah', None)
378        self.assertEqual(found, spec)
379
380    def test_no_spec(self):
381        finder = self.finder(None)
382        path = ['a', 'b', 'c']
383        name = 'blah'
384        found = finder.find_spec(name, path, None)
385        self.assertIsNone(found)
386        self.assertEqual(name, finder.called_for[0])
387        self.assertEqual(path, finder.called_for[1])
388
389    def test_spec(self):
390        loader = object()
391        spec = self.util.spec_from_loader('blah', loader)
392        finder = self.finder(spec)
393        found = finder.find_spec('blah', None)
394        self.assertIs(found, spec)
395
396
397(Frozen_MPFFindModuleTests,
398 Source_MPFFindModuleTests
399 ) = test_util.test_both(MetaPathFinderFindModuleTests, abc=abc, util=util)
400
401
402##### PathEntryFinder concrete methods #########################################
403class PathEntryFinderFindLoaderTests:
404
405    @classmethod
406    def finder(cls, spec):
407        class PathEntrySpecFinder(cls.abc.PathEntryFinder):
408
409            def find_spec(self, fullname, target=None):
410                self.called_for = fullname
411                return spec
412
413        return PathEntrySpecFinder()
414
415    def test_no_spec(self):
416        finder = self.finder(None)
417        name = 'blah'
418        with self.assertWarns(DeprecationWarning):
419            found = finder.find_loader(name)
420        self.assertIsNone(found[0])
421        self.assertEqual([], found[1])
422        self.assertEqual(name, finder.called_for)
423
424    def test_spec_with_loader(self):
425        loader = object()
426        spec = self.util.spec_from_loader('blah', loader)
427        finder = self.finder(spec)
428        with self.assertWarns(DeprecationWarning):
429            found = finder.find_loader('blah')
430        self.assertIs(found[0], spec.loader)
431
432    def test_spec_with_portions(self):
433        spec = self.machinery.ModuleSpec('blah', None)
434        paths = ['a', 'b', 'c']
435        spec.submodule_search_locations = paths
436        finder = self.finder(spec)
437        with self.assertWarns(DeprecationWarning):
438            found = finder.find_loader('blah')
439        self.assertIsNone(found[0])
440        self.assertEqual(paths, found[1])
441
442
443(Frozen_PEFFindLoaderTests,
444 Source_PEFFindLoaderTests
445 ) = test_util.test_both(PathEntryFinderFindLoaderTests, abc=abc, util=util,
446                         machinery=machinery)
447
448
449##### Loader concrete methods ##################################################
450class LoaderLoadModuleTests:
451
452    def loader(self):
453        class SpecLoader(self.abc.Loader):
454            found = None
455            def exec_module(self, module):
456                self.found = module
457
458            def is_package(self, fullname):
459                """Force some non-default module state to be set."""
460                return True
461
462        return SpecLoader()
463
464    def test_fresh(self):
465        with warnings.catch_warnings():
466            warnings.simplefilter("ignore", DeprecationWarning)
467            loader = self.loader()
468            name = 'blah'
469            with test_util.uncache(name):
470                loader.load_module(name)
471                module = loader.found
472                self.assertIs(sys.modules[name], module)
473            self.assertEqual(loader, module.__loader__)
474            self.assertEqual(loader, module.__spec__.loader)
475            self.assertEqual(name, module.__name__)
476            self.assertEqual(name, module.__spec__.name)
477            self.assertIsNotNone(module.__path__)
478            self.assertIsNotNone(module.__path__,
479                                module.__spec__.submodule_search_locations)
480
481    def test_reload(self):
482        with warnings.catch_warnings():
483            warnings.simplefilter("ignore", DeprecationWarning)
484            name = 'blah'
485            loader = self.loader()
486            module = types.ModuleType(name)
487            module.__spec__ = self.util.spec_from_loader(name, loader)
488            module.__loader__ = loader
489            with test_util.uncache(name):
490                sys.modules[name] = module
491                loader.load_module(name)
492                found = loader.found
493                self.assertIs(found, sys.modules[name])
494                self.assertIs(module, sys.modules[name])
495
496
497(Frozen_LoaderLoadModuleTests,
498 Source_LoaderLoadModuleTests
499 ) = test_util.test_both(LoaderLoadModuleTests, abc=abc, util=util)
500
501
502##### InspectLoader concrete methods ###########################################
503class InspectLoaderSourceToCodeTests:
504
505    def source_to_module(self, data, path=None):
506        """Help with source_to_code() tests."""
507        module = types.ModuleType('blah')
508        loader = self.InspectLoaderSubclass()
509        if path is None:
510            code = loader.source_to_code(data)
511        else:
512            code = loader.source_to_code(data, path)
513        exec(code, module.__dict__)
514        return module
515
516    def test_source_to_code_source(self):
517        # Since compile() can handle strings, so should source_to_code().
518        source = 'attr = 42'
519        module = self.source_to_module(source)
520        self.assertTrue(hasattr(module, 'attr'))
521        self.assertEqual(module.attr, 42)
522
523    def test_source_to_code_bytes(self):
524        # Since compile() can handle bytes, so should source_to_code().
525        source = b'attr = 42'
526        module = self.source_to_module(source)
527        self.assertTrue(hasattr(module, 'attr'))
528        self.assertEqual(module.attr, 42)
529
530    def test_source_to_code_path(self):
531        # Specifying a path should set it for the code object.
532        path = 'path/to/somewhere'
533        loader = self.InspectLoaderSubclass()
534        code = loader.source_to_code('', path)
535        self.assertEqual(code.co_filename, path)
536
537    def test_source_to_code_no_path(self):
538        # Not setting a path should still work and be set to <string> since that
539        # is a pre-existing practice as a default to compile().
540        loader = self.InspectLoaderSubclass()
541        code = loader.source_to_code('')
542        self.assertEqual(code.co_filename, '<string>')
543
544
545(Frozen_ILSourceToCodeTests,
546 Source_ILSourceToCodeTests
547 ) = test_util.test_both(InspectLoaderSourceToCodeTests,
548                         InspectLoaderSubclass=SPLIT_IL)
549
550
551class InspectLoaderGetCodeTests:
552
553    def test_get_code(self):
554        # Test success.
555        module = types.ModuleType('blah')
556        with mock.patch.object(self.InspectLoaderSubclass, 'get_source') as mocked:
557            mocked.return_value = 'attr = 42'
558            loader = self.InspectLoaderSubclass()
559            code = loader.get_code('blah')
560        exec(code, module.__dict__)
561        self.assertEqual(module.attr, 42)
562
563    def test_get_code_source_is_None(self):
564        # If get_source() is None then this should be None.
565        with mock.patch.object(self.InspectLoaderSubclass, 'get_source') as mocked:
566            mocked.return_value = None
567            loader = self.InspectLoaderSubclass()
568            code = loader.get_code('blah')
569        self.assertIsNone(code)
570
571    def test_get_code_source_not_found(self):
572        # If there is no source then there is no code object.
573        loader = self.InspectLoaderSubclass()
574        with self.assertRaises(ImportError):
575            loader.get_code('blah')
576
577
578(Frozen_ILGetCodeTests,
579 Source_ILGetCodeTests
580 ) = test_util.test_both(InspectLoaderGetCodeTests,
581                         InspectLoaderSubclass=SPLIT_IL)
582
583
584class InspectLoaderLoadModuleTests:
585
586    """Test InspectLoader.load_module()."""
587
588    module_name = 'blah'
589
590    def setUp(self):
591        import_helper.unload(self.module_name)
592        self.addCleanup(import_helper.unload, self.module_name)
593
594    def load(self, loader):
595        spec = self.util.spec_from_loader(self.module_name, loader)
596        with warnings.catch_warnings():
597            warnings.simplefilter('ignore', DeprecationWarning)
598            return self.init._bootstrap._load_unlocked(spec)
599
600    def mock_get_code(self):
601        return mock.patch.object(self.InspectLoaderSubclass, 'get_code')
602
603    def test_get_code_ImportError(self):
604        # If get_code() raises ImportError, it should propagate.
605        with self.mock_get_code() as mocked_get_code:
606            mocked_get_code.side_effect = ImportError
607            with self.assertRaises(ImportError):
608                loader = self.InspectLoaderSubclass()
609                self.load(loader)
610
611    def test_get_code_None(self):
612        # If get_code() returns None, raise ImportError.
613        with self.mock_get_code() as mocked_get_code:
614            mocked_get_code.return_value = None
615            with self.assertRaises(ImportError):
616                loader = self.InspectLoaderSubclass()
617                self.load(loader)
618
619    def test_module_returned(self):
620        # The loaded module should be returned.
621        code = compile('attr = 42', '<string>', 'exec')
622        with self.mock_get_code() as mocked_get_code:
623            mocked_get_code.return_value = code
624            loader = self.InspectLoaderSubclass()
625            module = self.load(loader)
626            self.assertEqual(module, sys.modules[self.module_name])
627
628
629(Frozen_ILLoadModuleTests,
630 Source_ILLoadModuleTests
631 ) = test_util.test_both(InspectLoaderLoadModuleTests,
632                         InspectLoaderSubclass=SPLIT_IL,
633                         init=init,
634                         util=util)
635
636
637##### ExecutionLoader concrete methods #########################################
638class ExecutionLoaderGetCodeTests:
639
640    def mock_methods(self, *, get_source=False, get_filename=False):
641        source_mock_context, filename_mock_context = None, None
642        if get_source:
643            source_mock_context = mock.patch.object(self.ExecutionLoaderSubclass,
644                                                    'get_source')
645        if get_filename:
646            filename_mock_context = mock.patch.object(self.ExecutionLoaderSubclass,
647                                                      'get_filename')
648        return source_mock_context, filename_mock_context
649
650    def test_get_code(self):
651        path = 'blah.py'
652        source_mock_context, filename_mock_context = self.mock_methods(
653                get_source=True, get_filename=True)
654        with source_mock_context as source_mock, filename_mock_context as name_mock:
655            source_mock.return_value = 'attr = 42'
656            name_mock.return_value = path
657            loader = self.ExecutionLoaderSubclass()
658            code = loader.get_code('blah')
659        self.assertEqual(code.co_filename, path)
660        module = types.ModuleType('blah')
661        exec(code, module.__dict__)
662        self.assertEqual(module.attr, 42)
663
664    def test_get_code_source_is_None(self):
665        # If get_source() is None then this should be None.
666        source_mock_context, _ = self.mock_methods(get_source=True)
667        with source_mock_context as mocked:
668            mocked.return_value = None
669            loader = self.ExecutionLoaderSubclass()
670            code = loader.get_code('blah')
671        self.assertIsNone(code)
672
673    def test_get_code_source_not_found(self):
674        # If there is no source then there is no code object.
675        loader = self.ExecutionLoaderSubclass()
676        with self.assertRaises(ImportError):
677            loader.get_code('blah')
678
679    def test_get_code_no_path(self):
680        # If get_filename() raises ImportError then simply skip setting the path
681        # on the code object.
682        source_mock_context, filename_mock_context = self.mock_methods(
683                get_source=True, get_filename=True)
684        with source_mock_context as source_mock, filename_mock_context as name_mock:
685            source_mock.return_value = 'attr = 42'
686            name_mock.side_effect = ImportError
687            loader = self.ExecutionLoaderSubclass()
688            code = loader.get_code('blah')
689        self.assertEqual(code.co_filename, '<string>')
690        module = types.ModuleType('blah')
691        exec(code, module.__dict__)
692        self.assertEqual(module.attr, 42)
693
694
695(Frozen_ELGetCodeTests,
696 Source_ELGetCodeTests
697 ) = test_util.test_both(ExecutionLoaderGetCodeTests,
698                         ExecutionLoaderSubclass=SPLIT_EL)
699
700
701##### SourceLoader concrete methods ############################################
702class SourceOnlyLoader:
703
704    # Globals that should be defined for all modules.
705    source = (b"_ = '::'.join([__name__, __file__, __cached__, __package__, "
706              b"repr(__loader__)])")
707
708    def __init__(self, path):
709        self.path = path
710
711    def get_data(self, path):
712        if path != self.path:
713            raise IOError
714        return self.source
715
716    def get_filename(self, fullname):
717        return self.path
718
719    def module_repr(self, module):
720        return '<module>'
721
722
723SPLIT_SOL = make_abc_subclasses(SourceOnlyLoader, 'SourceLoader')
724
725
726class SourceLoader(SourceOnlyLoader):
727
728    source_mtime = 1
729
730    def __init__(self, path, magic=None):
731        super().__init__(path)
732        self.bytecode_path = self.util.cache_from_source(self.path)
733        self.source_size = len(self.source)
734        if magic is None:
735            magic = self.util.MAGIC_NUMBER
736        data = bytearray(magic)
737        data.extend(self.init._pack_uint32(0))
738        data.extend(self.init._pack_uint32(self.source_mtime))
739        data.extend(self.init._pack_uint32(self.source_size))
740        code_object = compile(self.source, self.path, 'exec',
741                                dont_inherit=True)
742        data.extend(marshal.dumps(code_object))
743        self.bytecode = bytes(data)
744        self.written = {}
745
746    def get_data(self, path):
747        if path == self.path:
748            return super().get_data(path)
749        elif path == self.bytecode_path:
750            return self.bytecode
751        else:
752            raise OSError
753
754    def path_stats(self, path):
755        if path != self.path:
756            raise IOError
757        return {'mtime': self.source_mtime, 'size': self.source_size}
758
759    def set_data(self, path, data):
760        self.written[path] = bytes(data)
761        return path == self.bytecode_path
762
763
764SPLIT_SL = make_abc_subclasses(SourceLoader, util=util, init=init)
765
766
767class SourceLoaderTestHarness:
768
769    def setUp(self, *, is_package=True, **kwargs):
770        self.package = 'pkg'
771        if is_package:
772            self.path = os.path.join(self.package, '__init__.py')
773            self.name = self.package
774        else:
775            module_name = 'mod'
776            self.path = os.path.join(self.package, '.'.join(['mod', 'py']))
777            self.name = '.'.join([self.package, module_name])
778        self.cached = self.util.cache_from_source(self.path)
779        self.loader = self.loader_mock(self.path, **kwargs)
780
781    def verify_module(self, module):
782        self.assertEqual(module.__name__, self.name)
783        self.assertEqual(module.__file__, self.path)
784        self.assertEqual(module.__cached__, self.cached)
785        self.assertEqual(module.__package__, self.package)
786        self.assertEqual(module.__loader__, self.loader)
787        values = module._.split('::')
788        self.assertEqual(values[0], self.name)
789        self.assertEqual(values[1], self.path)
790        self.assertEqual(values[2], self.cached)
791        self.assertEqual(values[3], self.package)
792        self.assertEqual(values[4], repr(self.loader))
793
794    def verify_code(self, code_object):
795        module = types.ModuleType(self.name)
796        module.__file__ = self.path
797        module.__cached__ = self.cached
798        module.__package__ = self.package
799        module.__loader__ = self.loader
800        module.__path__ = []
801        exec(code_object, module.__dict__)
802        self.verify_module(module)
803
804
805class SourceOnlyLoaderTests(SourceLoaderTestHarness):
806
807    """Test importlib.abc.SourceLoader for source-only loading.
808
809    Reload testing is subsumed by the tests for
810    importlib.util.module_for_loader.
811
812    """
813
814    def test_get_source(self):
815        # Verify the source code is returned as a string.
816        # If an OSError is raised by get_data then raise ImportError.
817        expected_source = self.loader.source.decode('utf-8')
818        self.assertEqual(self.loader.get_source(self.name), expected_source)
819        def raise_OSError(path):
820            raise OSError
821        self.loader.get_data = raise_OSError
822        with self.assertRaises(ImportError) as cm:
823            self.loader.get_source(self.name)
824        self.assertEqual(cm.exception.name, self.name)
825
826    def test_is_package(self):
827        # Properly detect when loading a package.
828        self.setUp(is_package=False)
829        self.assertFalse(self.loader.is_package(self.name))
830        self.setUp(is_package=True)
831        self.assertTrue(self.loader.is_package(self.name))
832        self.assertFalse(self.loader.is_package(self.name + '.__init__'))
833
834    def test_get_code(self):
835        # Verify the code object is created.
836        code_object = self.loader.get_code(self.name)
837        self.verify_code(code_object)
838
839    def test_source_to_code(self):
840        # Verify the compiled code object.
841        code = self.loader.source_to_code(self.loader.source, self.path)
842        self.verify_code(code)
843
844    def test_load_module(self):
845        # Loading a module should set __name__, __loader__, __package__,
846        # __path__ (for packages), __file__, and __cached__.
847        # The module should also be put into sys.modules.
848        with warnings.catch_warnings():
849            warnings.simplefilter("ignore", ImportWarning)
850            with test_util.uncache(self.name):
851                with warnings.catch_warnings():
852                    warnings.simplefilter('ignore', DeprecationWarning)
853                    module = self.loader.load_module(self.name)
854                self.verify_module(module)
855                self.assertEqual(module.__path__, [os.path.dirname(self.path)])
856                self.assertIn(self.name, sys.modules)
857
858    def test_package_settings(self):
859        # __package__ needs to be set, while __path__ is set on if the module
860        # is a package.
861        # Testing the values for a package are covered by test_load_module.
862        with warnings.catch_warnings():
863            warnings.simplefilter("ignore", ImportWarning)
864            self.setUp(is_package=False)
865            with test_util.uncache(self.name):
866                with warnings.catch_warnings():
867                    warnings.simplefilter('ignore', DeprecationWarning)
868                    module = self.loader.load_module(self.name)
869                self.verify_module(module)
870                self.assertFalse(hasattr(module, '__path__'))
871
872    def test_get_source_encoding(self):
873        # Source is considered encoded in UTF-8 by default unless otherwise
874        # specified by an encoding line.
875        source = "_ = 'ü'"
876        self.loader.source = source.encode('utf-8')
877        returned_source = self.loader.get_source(self.name)
878        self.assertEqual(returned_source, source)
879        source = "# coding: latin-1\n_ = ü"
880        self.loader.source = source.encode('latin-1')
881        returned_source = self.loader.get_source(self.name)
882        self.assertEqual(returned_source, source)
883
884
885(Frozen_SourceOnlyLoaderTests,
886 Source_SourceOnlyLoaderTests
887 ) = test_util.test_both(SourceOnlyLoaderTests, util=util,
888                         loader_mock=SPLIT_SOL)
889
890
891@unittest.skipIf(sys.dont_write_bytecode, "sys.dont_write_bytecode is true")
892class SourceLoaderBytecodeTests(SourceLoaderTestHarness):
893
894    """Test importlib.abc.SourceLoader's use of bytecode.
895
896    Source-only testing handled by SourceOnlyLoaderTests.
897
898    """
899
900    def verify_code(self, code_object, *, bytecode_written=False):
901        super().verify_code(code_object)
902        if bytecode_written:
903            self.assertIn(self.cached, self.loader.written)
904            data = bytearray(self.util.MAGIC_NUMBER)
905            data.extend(self.init._pack_uint32(0))
906            data.extend(self.init._pack_uint32(self.loader.source_mtime))
907            data.extend(self.init._pack_uint32(self.loader.source_size))
908            data.extend(marshal.dumps(code_object))
909            self.assertEqual(self.loader.written[self.cached], bytes(data))
910
911    def test_code_with_everything(self):
912        # When everything should work.
913        code_object = self.loader.get_code(self.name)
914        self.verify_code(code_object)
915
916    def test_no_bytecode(self):
917        # If no bytecode exists then move on to the source.
918        self.loader.bytecode_path = "<does not exist>"
919        # Sanity check
920        with self.assertRaises(OSError):
921            bytecode_path = self.util.cache_from_source(self.path)
922            self.loader.get_data(bytecode_path)
923        code_object = self.loader.get_code(self.name)
924        self.verify_code(code_object, bytecode_written=True)
925
926    def test_code_bad_timestamp(self):
927        # Bytecode is only used when the timestamp matches the source EXACTLY.
928        for source_mtime in (0, 2):
929            assert source_mtime != self.loader.source_mtime
930            original = self.loader.source_mtime
931            self.loader.source_mtime = source_mtime
932            # If bytecode is used then EOFError would be raised by marshal.
933            self.loader.bytecode = self.loader.bytecode[8:]
934            code_object = self.loader.get_code(self.name)
935            self.verify_code(code_object, bytecode_written=True)
936            self.loader.source_mtime = original
937
938    def test_code_bad_magic(self):
939        # Skip over bytecode with a bad magic number.
940        self.setUp(magic=b'0000')
941        # If bytecode is used then EOFError would be raised by marshal.
942        self.loader.bytecode = self.loader.bytecode[8:]
943        code_object = self.loader.get_code(self.name)
944        self.verify_code(code_object, bytecode_written=True)
945
946    def test_dont_write_bytecode(self):
947        # Bytecode is not written if sys.dont_write_bytecode is true.
948        # Can assume it is false already thanks to the skipIf class decorator.
949        try:
950            sys.dont_write_bytecode = True
951            self.loader.bytecode_path = "<does not exist>"
952            code_object = self.loader.get_code(self.name)
953            self.assertNotIn(self.cached, self.loader.written)
954        finally:
955            sys.dont_write_bytecode = False
956
957    def test_no_set_data(self):
958        # If set_data is not defined, one can still read bytecode.
959        self.setUp(magic=b'0000')
960        original_set_data = self.loader.__class__.mro()[1].set_data
961        try:
962            del self.loader.__class__.mro()[1].set_data
963            code_object = self.loader.get_code(self.name)
964            self.verify_code(code_object)
965        finally:
966            self.loader.__class__.mro()[1].set_data = original_set_data
967
968    def test_set_data_raises_exceptions(self):
969        # Raising NotImplementedError or OSError is okay for set_data.
970        def raise_exception(exc):
971            def closure(*args, **kwargs):
972                raise exc
973            return closure
974
975        self.setUp(magic=b'0000')
976        self.loader.set_data = raise_exception(NotImplementedError)
977        code_object = self.loader.get_code(self.name)
978        self.verify_code(code_object)
979
980
981(Frozen_SLBytecodeTests,
982 SourceSLBytecodeTests
983 ) = test_util.test_both(SourceLoaderBytecodeTests, init=init, util=util,
984                         loader_mock=SPLIT_SL)
985
986
987class SourceLoaderGetSourceTests:
988
989    """Tests for importlib.abc.SourceLoader.get_source()."""
990
991    def test_default_encoding(self):
992        # Should have no problems with UTF-8 text.
993        name = 'mod'
994        mock = self.SourceOnlyLoaderMock('mod.file')
995        source = 'x = "ü"'
996        mock.source = source.encode('utf-8')
997        returned_source = mock.get_source(name)
998        self.assertEqual(returned_source, source)
999
1000    def test_decoded_source(self):
1001        # Decoding should work.
1002        name = 'mod'
1003        mock = self.SourceOnlyLoaderMock("mod.file")
1004        source = "# coding: Latin-1\nx='ü'"
1005        assert source.encode('latin-1') != source.encode('utf-8')
1006        mock.source = source.encode('latin-1')
1007        returned_source = mock.get_source(name)
1008        self.assertEqual(returned_source, source)
1009
1010    def test_universal_newlines(self):
1011        # PEP 302 says universal newlines should be used.
1012        name = 'mod'
1013        mock = self.SourceOnlyLoaderMock('mod.file')
1014        source = "x = 42\r\ny = -13\r\n"
1015        mock.source = source.encode('utf-8')
1016        expect = io.IncrementalNewlineDecoder(None, True).decode(source)
1017        self.assertEqual(mock.get_source(name), expect)
1018
1019
1020(Frozen_SourceOnlyLoaderGetSourceTests,
1021 Source_SourceOnlyLoaderGetSourceTests
1022 ) = test_util.test_both(SourceLoaderGetSourceTests,
1023                         SourceOnlyLoaderMock=SPLIT_SOL)
1024
1025
1026if __name__ == '__main__':
1027    unittest.main()
1028