1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tests for the formatter core.""" 15 16from pathlib import Path 17from tempfile import TemporaryDirectory 18import unittest 19 20from pw_presubmit.format.core import ( 21 FileChecker, 22 FormattedDiff, 23 FormattedFileContents, 24) 25 26 27class FakeFileChecker(FileChecker): 28 FORMAT_MAP = { 29 'foo': 'bar', 30 'bar': 'bar', 31 'baz': '\nbaz\n', 32 'new\n': 'newer\n', 33 } 34 35 def format_file_in_memory( 36 self, file_path: Path, file_contents: bytes 37 ) -> FormattedFileContents: 38 error = '' 39 formatted = self.FORMAT_MAP.get(file_contents.decode(), None) 40 if formatted is None: 41 error = f'I do not know how to "{file_contents.decode()}".' 42 return FormattedFileContents( 43 ok=not error, 44 formatted_file_contents=formatted.encode() 45 if formatted is not None 46 else b'', 47 error_message=error, 48 ) 49 50 51def _check_files( 52 formatter: FileChecker, file_contents: dict[str, str], dry_run=False 53) -> list[FormattedDiff]: 54 with TemporaryDirectory() as tmp: 55 paths = [] 56 for f in file_contents.keys(): 57 file_path = Path(tmp) / f 58 file_path.write_bytes(file_contents[f].encode()) 59 paths.append(file_path) 60 61 return list(formatter.get_formatting_diffs(paths, dry_run)) 62 63 64class TestFormatCore(unittest.TestCase): 65 """Tests for the format core.""" 66 67 def setUp(self) -> None: 68 self.formatter = FakeFileChecker() 69 70 def test_check_files(self): 71 """Tests that check_files() produces diffs as intended.""" 72 file_contents = { 73 'foo.txt': 'foo', 74 'bar.txt': 'bar', 75 'baz.txt': 'baz', 76 'yep.txt': 'new\n', 77 } 78 expected_diffs = { 79 'foo.txt': '\n'.join( 80 ( 81 '-foo', 82 '+bar', 83 ' No newline at end of file', 84 ) 85 ), 86 'baz.txt': '\n'.join( 87 ( 88 '+', 89 ' baz', 90 '-No newline at end of file', 91 ) 92 ), 93 'yep.txt': '\n'.join( 94 ( 95 '-new', 96 '+newer', 97 ) 98 ), 99 } 100 101 for result in _check_files(self.formatter, file_contents): 102 filename = result.file_path.name 103 self.assertIn(filename, expected_diffs) 104 self.assertTrue(result.ok) 105 lines = result.diff.splitlines() 106 self.assertEqual( 107 lines.pop(0), f'--- {result.file_path} (original)' 108 ) 109 self.assertEqual( 110 lines.pop(0), f'+++ {result.file_path} (reformatted)' 111 ) 112 self.assertTrue(lines.pop(0).startswith('@@')) 113 114 self.assertMultiLineEqual( 115 '\n'.join(lines), expected_diffs[filename] 116 ) 117 expected_diffs.pop(filename) 118 119 self.assertFalse(expected_diffs) 120 121 def test_check_files_error(self): 122 """Tests that check_files() propagates error messages.""" 123 file_contents = { 124 'foo.txt': 'broken', 125 'bar.txt': 'bar', 126 } 127 expected_errors = { 128 'foo.txt': '\n'.join(('I do not know how to "broken".',)), 129 } 130 for result in _check_files(self.formatter, file_contents): 131 filename = result.file_path.name 132 self.assertIn(filename, expected_errors) 133 self.assertFalse(result.ok) 134 self.assertEqual(result.diff, '') 135 self.assertEqual(result.error_message, expected_errors[filename]) 136 expected_errors.pop(filename) 137 138 self.assertFalse(expected_errors) 139 140 def test_check_files_dry_run(self): 141 """Tests that check_files() dry run produces no delta.""" 142 file_contents = { 143 'foo.txt': 'foo', 144 'bar.txt': 'bar', 145 'baz.txt': 'baz', 146 'yep.txt': 'new\n', 147 } 148 result = _check_files(self.formatter, file_contents, dry_run=True) 149 self.assertFalse(result) 150 151 152if __name__ == '__main__': 153 unittest.main() 154