• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tests for pw_cli.plugins."""
15
16from pathlib import Path
17import sys
18import tempfile
19import types
20from typing import Dict, Iterator
21import unittest
22
23from pw_cli import plugins
24
25
26def _no_docstring() -> int:
27    return 123
28
29
30def _with_docstring() -> int:
31    """This docstring is brought to you courtesy of Pigweed."""
32    return 456
33
34
35def _create_files(directory: str, files: Dict[str, str]) -> Iterator[Path]:
36    for relative_path, contents in files.items():
37        path = Path(directory) / relative_path
38        path.parent.mkdir(exist_ok=True, parents=True)
39        path.write_text(contents)
40        yield path
41
42
43class TestPlugin(unittest.TestCase):
44    """Tests for plugins.Plugins."""
45
46    def test_target_name_attribute(self) -> None:
47        self.assertEqual(
48            plugins.Plugin('abc', _no_docstring).target_name,
49            f'{__name__}._no_docstring',
50        )
51
52    def test_target_name_no_name_attribute(self) -> None:
53        has_no_name = 'no __name__'
54        self.assertFalse(hasattr(has_no_name, '__name__'))
55
56        self.assertEqual(
57            plugins.Plugin('abc', has_no_name).target_name,
58            '<unknown>.no __name__',
59        )
60
61
62_TEST_PLUGINS = {
63    'TEST_PLUGINS': (
64        f'test_plugin {__name__} _with_docstring\n'
65        f'other_plugin {__name__} _no_docstring\n'
66    ),
67    'nested/in/dirs/TEST_PLUGINS': f'test_plugin {__name__} _no_docstring\n',
68}
69
70
71class TestPluginRegistry(unittest.TestCase):
72    """Tests for plugins.Registry."""
73
74    def setUp(self) -> None:
75        self._registry = plugins.Registry(
76            validator=plugins.callable_with_no_args
77        )
78
79    def test_register(self) -> None:
80        self.assertIsNotNone(self._registry.register('a_plugin', _no_docstring))
81        self.assertIs(self._registry['a_plugin'].target, _no_docstring)
82
83    def test_register_by_name(self) -> None:
84        self.assertIsNotNone(
85            self._registry.register_by_name(
86                'plugin_one', __name__, '_no_docstring'
87            )
88        )
89        self.assertIsNotNone(
90            self._registry.register('plugin_two', _no_docstring)
91        )
92
93        self.assertIs(self._registry['plugin_one'].target, _no_docstring)
94        self.assertIs(self._registry['plugin_two'].target, _no_docstring)
95
96    def test_register_by_name_undefined_module(self) -> None:
97        with self.assertRaisesRegex(plugins.Error, 'No module named'):
98            self._registry.register_by_name(
99                'plugin_two', 'not a module', 'something'
100            )
101
102    def test_register_by_name_undefined_function(self) -> None:
103        with self.assertRaisesRegex(plugins.Error, 'does not exist'):
104            self._registry.register_by_name('plugin_two', __name__, 'nothing')
105
106    def test_register_fails_validation(self) -> None:
107        with self.assertRaisesRegex(plugins.Error, 'must be callable'):
108            self._registry.register('plugin_two', 'not function')
109
110    def test_run_with_argv_sets_sys_argv(self) -> None:
111        def stash_argv() -> int:
112            self.assertEqual(['pw go', '1', '2'], sys.argv)
113            return 1
114
115        self.assertIsNotNone(self._registry.register('go', stash_argv))
116
117        original_argv = sys.argv
118        self.assertEqual(self._registry.run_with_argv('go', ['1', '2']), 1)
119        self.assertIs(sys.argv, original_argv)
120
121    def test_run_with_argv_registered_plugin(self) -> None:
122        with self.assertRaises(KeyError):
123            self._registry.run_with_argv('plugin_one', [])
124
125    def test_register_cannot_overwrite(self) -> None:
126        self.assertIsNotNone(self._registry.register('foo', lambda: None))
127        self.assertIsNotNone(
128            self._registry.register_by_name('bar', __name__, '_no_docstring')
129        )
130
131        with self.assertRaises(plugins.Error):
132            self._registry.register('foo', lambda: None)
133
134        with self.assertRaises(plugins.Error):
135            self._registry.register('bar', lambda: None)
136
137    def test_register_directory_innermost_takes_priority(self) -> None:
138        with tempfile.TemporaryDirectory() as tempdir:
139            paths = list(_create_files(tempdir, _TEST_PLUGINS))
140            self._registry.register_directory(paths[1].parent, 'TEST_PLUGINS')
141
142        self.assertEqual(self._registry.run_with_argv('test_plugin', []), 123)
143
144    def test_register_directory_only_searches_up(self) -> None:
145        with tempfile.TemporaryDirectory() as tempdir:
146            paths = list(_create_files(tempdir, _TEST_PLUGINS))
147            self._registry.register_directory(paths[0].parent, 'TEST_PLUGINS')
148
149        self.assertEqual(self._registry.run_with_argv('test_plugin', []), 456)
150
151    def test_register_directory_with_restriction(self) -> None:
152        with tempfile.TemporaryDirectory() as tempdir:
153            paths = list(_create_files(tempdir, _TEST_PLUGINS))
154            self._registry.register_directory(
155                paths[0].parent, 'TEST_PLUGINS', Path(tempdir, 'nested', 'in')
156            )
157
158        self.assertNotIn('other_plugin', self._registry)
159
160    def test_register_same_file_multiple_times_no_error(self) -> None:
161        with tempfile.TemporaryDirectory() as tempdir:
162            paths = list(_create_files(tempdir, _TEST_PLUGINS))
163            self._registry.register_file(paths[0])
164            self._registry.register_file(paths[0])
165            self._registry.register_file(paths[0])
166
167        self.assertIs(self._registry['test_plugin'].target, _with_docstring)
168
169    def test_help_uses_function_or_module_docstring(self) -> None:
170        self.assertIsNotNone(self._registry.register('a', _no_docstring))
171        self.assertIsNotNone(self._registry.register('b', _with_docstring))
172
173        self.assertIn(__doc__, '\n'.join(self._registry.detailed_help(['a'])))
174
175        self.assertNotIn(
176            __doc__, '\n'.join(self._registry.detailed_help(['b']))
177        )
178        self.assertIn(
179            _with_docstring.__doc__,
180            '\n'.join(self._registry.detailed_help(['b'])),
181        )
182
183    def test_empty_string_if_no_help(self) -> None:
184        fake_module_name = f'{__name__}.fake_module_for_test{id(self)}'
185        fake_module = types.ModuleType(fake_module_name)
186        self.assertIsNone(fake_module.__doc__)
187
188        sys.modules[fake_module_name] = fake_module
189
190        try:
191            function = lambda: None
192            function.__module__ = fake_module_name
193            self.assertIsNotNone(self._registry.register('a', function))
194
195            self.assertEqual(self._registry['a'].help(full=False), '')
196            self.assertEqual(self._registry['a'].help(full=True), '')
197        finally:
198            del sys.modules[fake_module_name]
199
200    def test_decorator_not_called(self) -> None:
201        @self._registry.plugin
202        def nifty() -> None:
203            pass
204
205        self.assertEqual(self._registry['nifty'].target, nifty)
206
207    def test_decorator_called_no_args(self) -> None:
208        @self._registry.plugin()
209        def nifty() -> None:
210            pass
211
212        self.assertEqual(self._registry['nifty'].target, nifty)
213
214    def test_decorator_called_with_args(self) -> None:
215        @self._registry.plugin(name='nifty')
216        def my_nifty_keen_plugin() -> None:
217            pass
218
219        self.assertEqual(self._registry['nifty'].target, my_nifty_keen_plugin)
220
221
222if __name__ == '__main__':
223    unittest.main()
224