• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tests for the formatter core."""
15
16from pathlib import Path
17from tempfile import TemporaryDirectory
18import unittest
19
20from pw_presubmit.format.core import (
21    FileChecker,
22    FormattedDiff,
23    FormattedFileContents,
24)
25
26
27class FakeFileChecker(FileChecker):
28    FORMAT_MAP = {
29        'foo': 'bar',
30        'bar': 'bar',
31        'baz': '\nbaz\n',
32        'new\n': 'newer\n',
33    }
34
35    def format_file_in_memory(
36        self, file_path: Path, file_contents: bytes
37    ) -> FormattedFileContents:
38        error = ''
39        formatted = self.FORMAT_MAP.get(file_contents.decode(), None)
40        if formatted is None:
41            error = f'I do not know how to "{file_contents.decode()}".'
42        return FormattedFileContents(
43            ok=not error,
44            formatted_file_contents=formatted.encode()
45            if formatted is not None
46            else b'',
47            error_message=error,
48        )
49
50
51def _check_files(
52    formatter: FileChecker, file_contents: dict[str, str], dry_run=False
53) -> list[FormattedDiff]:
54    with TemporaryDirectory() as tmp:
55        paths = []
56        for f in file_contents.keys():
57            file_path = Path(tmp) / f
58            file_path.write_bytes(file_contents[f].encode())
59            paths.append(file_path)
60
61        return list(formatter.get_formatting_diffs(paths, dry_run))
62
63
64class TestFormatCore(unittest.TestCase):
65    """Tests for the format core."""
66
67    def setUp(self) -> None:
68        self.formatter = FakeFileChecker()
69
70    def test_check_files(self):
71        """Tests that check_files() produces diffs as intended."""
72        file_contents = {
73            'foo.txt': 'foo',
74            'bar.txt': 'bar',
75            'baz.txt': 'baz',
76            'yep.txt': 'new\n',
77        }
78        expected_diffs = {
79            'foo.txt': '\n'.join(
80                (
81                    '-foo',
82                    '+bar',
83                    ' No newline at end of file',
84                )
85            ),
86            'baz.txt': '\n'.join(
87                (
88                    '+',
89                    ' baz',
90                    '-No newline at end of file',
91                )
92            ),
93            'yep.txt': '\n'.join(
94                (
95                    '-new',
96                    '+newer',
97                )
98            ),
99        }
100
101        for result in _check_files(self.formatter, file_contents):
102            filename = result.file_path.name
103            self.assertIn(filename, expected_diffs)
104            self.assertTrue(result.ok)
105            lines = result.diff.splitlines()
106            self.assertEqual(
107                lines.pop(0), f'--- {result.file_path}  (original)'
108            )
109            self.assertEqual(
110                lines.pop(0), f'+++ {result.file_path}  (reformatted)'
111            )
112            self.assertTrue(lines.pop(0).startswith('@@'))
113
114            self.assertMultiLineEqual(
115                '\n'.join(lines), expected_diffs[filename]
116            )
117            expected_diffs.pop(filename)
118
119        self.assertFalse(expected_diffs)
120
121    def test_check_files_error(self):
122        """Tests that check_files() propagates error messages."""
123        file_contents = {
124            'foo.txt': 'broken',
125            'bar.txt': 'bar',
126        }
127        expected_errors = {
128            'foo.txt': '\n'.join(('I do not know how to "broken".',)),
129        }
130        for result in _check_files(self.formatter, file_contents):
131            filename = result.file_path.name
132            self.assertIn(filename, expected_errors)
133            self.assertFalse(result.ok)
134            self.assertEqual(result.diff, '')
135            self.assertEqual(result.error_message, expected_errors[filename])
136            expected_errors.pop(filename)
137
138        self.assertFalse(expected_errors)
139
140    def test_check_files_dry_run(self):
141        """Tests that check_files() dry run produces no delta."""
142        file_contents = {
143            'foo.txt': 'foo',
144            'bar.txt': 'bar',
145            'baz.txt': 'baz',
146            'yep.txt': 'new\n',
147        }
148        result = _check_files(self.formatter, file_contents, dry_run=True)
149        self.assertFalse(result)
150
151
152if __name__ == '__main__':
153    unittest.main()
154