1 // Inspired by Clang's clang-format-diff:
2 //
3 // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
4
5 #![deny(warnings)]
6
7 #[macro_use]
8 extern crate log;
9
10 use serde::{Deserialize, Serialize};
11 use serde_json as json;
12 use thiserror::Error;
13
14 use std::collections::HashSet;
15 use std::env;
16 use std::ffi::OsStr;
17 use std::io::{self, BufRead};
18 use std::process;
19
20 use regex::Regex;
21
22 use clap::{CommandFactory, Parser};
23
24 /// The default pattern of files to format.
25 ///
26 /// We only want to format rust files by default.
27 const DEFAULT_PATTERN: &str = r".*\.rs";
28
29 #[derive(Error, Debug)]
30 enum FormatDiffError {
31 #[error("{0}")]
32 IncorrectOptions(#[from] getopts::Fail),
33 #[error("{0}")]
34 IncorrectFilter(#[from] regex::Error),
35 #[error("{0}")]
36 IoError(#[from] io::Error),
37 }
38
39 #[derive(Parser, Debug)]
40 #[clap(
41 name = "rustfmt-format-diff",
42 disable_version_flag = true,
43 next_line_help = true
44 )]
45 pub struct Opts {
46 /// Skip the smallest prefix containing NUMBER slashes
47 #[clap(
48 short = 'p',
49 long = "skip-prefix",
50 value_name = "NUMBER",
51 default_value = "0"
52 )]
53 skip_prefix: u32,
54
55 /// Custom pattern selecting file paths to reformat
56 #[clap(
57 short = 'f',
58 long = "filter",
59 value_name = "PATTERN",
60 default_value = DEFAULT_PATTERN
61 )]
62 filter: String,
63 }
64
main()65 fn main() {
66 env_logger::Builder::from_env("RUSTFMT_LOG").init();
67 let opts = Opts::parse();
68 if let Err(e) = run(opts) {
69 println!("{}", e);
70 Opts::command()
71 .print_help()
72 .expect("cannot write to stdout");
73 process::exit(1);
74 }
75 }
76
77 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
78 struct Range {
79 file: String,
80 range: [u32; 2],
81 }
82
run(opts: Opts) -> Result<(), FormatDiffError>83 fn run(opts: Opts) -> Result<(), FormatDiffError> {
84 let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
85 run_rustfmt(&files, &ranges)
86 }
87
run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError>88 fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
89 if files.is_empty() || ranges.is_empty() {
90 debug!("No files to format found");
91 return Ok(());
92 }
93
94 let ranges_as_json = json::to_string(ranges).unwrap();
95
96 debug!("Files: {:?}", files);
97 debug!("Ranges: {:?}", ranges);
98
99 let rustfmt_var = env::var_os("RUSTFMT");
100 let rustfmt = match &rustfmt_var {
101 Some(rustfmt) => rustfmt,
102 None => OsStr::new("rustfmt"),
103 };
104 let exit_status = process::Command::new(rustfmt)
105 .args(files)
106 .arg("--file-lines")
107 .arg(ranges_as_json)
108 .status()?;
109
110 if !exit_status.success() {
111 return Err(FormatDiffError::IoError(io::Error::new(
112 io::ErrorKind::Other,
113 format!("rustfmt failed with {}", exit_status),
114 )));
115 }
116 Ok(())
117 }
118
119 /// Scans a diff from `from`, and returns the set of files found, and the ranges
120 /// in those files.
scan_diff<R>( from: R, skip_prefix: u32, file_filter: &str, ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError> where R: io::Read,121 fn scan_diff<R>(
122 from: R,
123 skip_prefix: u32,
124 file_filter: &str,
125 ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
126 where
127 R: io::Read,
128 {
129 let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{}}}(\S*)", skip_prefix);
130 let diff_pattern = Regex::new(&diff_pattern).unwrap();
131
132 let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
133
134 let file_filter = Regex::new(&format!("^{}$", file_filter))?;
135
136 let mut current_file = None;
137
138 let mut files = HashSet::new();
139 let mut ranges = vec![];
140 for line in io::BufReader::new(from).lines() {
141 let line = line.unwrap();
142
143 if let Some(captures) = diff_pattern.captures(&line) {
144 current_file = Some(captures.get(1).unwrap().as_str().to_owned());
145 }
146
147 let file = match current_file {
148 Some(ref f) => &**f,
149 None => continue,
150 };
151
152 // FIXME(emilio): We could avoid this most of the time if needed, but
153 // it's not clear it's worth it.
154 if !file_filter.is_match(file) {
155 continue;
156 }
157
158 let lines_captures = match lines_pattern.captures(&line) {
159 Some(captures) => captures,
160 None => continue,
161 };
162
163 let start_line = lines_captures
164 .get(1)
165 .unwrap()
166 .as_str()
167 .parse::<u32>()
168 .unwrap();
169 let line_count = match lines_captures.get(3) {
170 Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
171 None => 1,
172 };
173
174 if line_count == 0 {
175 continue;
176 }
177
178 let end_line = start_line + line_count - 1;
179 files.insert(file.to_owned());
180 ranges.push(Range {
181 file: file.to_owned(),
182 range: [start_line, end_line],
183 });
184 }
185
186 Ok((files, ranges))
187 }
188
189 #[test]
scan_simple_git_diff()190 fn scan_simple_git_diff() {
191 const DIFF: &str = include_str!("test/bindgen.diff");
192 let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
193
194 assert!(
195 files.contains("src/ir/traversal.rs"),
196 "Should've matched the filter"
197 );
198
199 assert!(
200 !files.contains("tests/headers/anon_enum.hpp"),
201 "Shouldn't have matched the filter"
202 );
203
204 assert_eq!(
205 &ranges,
206 &[
207 Range {
208 file: "src/ir/item.rs".to_owned(),
209 range: [148, 158],
210 },
211 Range {
212 file: "src/ir/item.rs".to_owned(),
213 range: [160, 170],
214 },
215 Range {
216 file: "src/ir/traversal.rs".to_owned(),
217 range: [9, 16],
218 },
219 Range {
220 file: "src/ir/traversal.rs".to_owned(),
221 range: [35, 43],
222 },
223 ]
224 );
225 }
226
227 #[cfg(test)]
228 mod cmd_line_tests {
229 use super::*;
230
231 #[test]
default_options()232 fn default_options() {
233 let empty: Vec<String> = vec![];
234 let o = Opts::parse_from(&empty);
235 assert_eq!(DEFAULT_PATTERN, o.filter);
236 assert_eq!(0, o.skip_prefix);
237 }
238
239 #[test]
good_options()240 fn good_options() {
241 let o = Opts::parse_from(&["test", "-p", "10", "-f", r".*\.hs"]);
242 assert_eq!(r".*\.hs", o.filter);
243 assert_eq!(10, o.skip_prefix);
244 }
245
246 #[test]
unexpected_option()247 fn unexpected_option() {
248 assert!(
249 Opts::command()
250 .try_get_matches_from(&["test", "unexpected"])
251 .is_err()
252 );
253 }
254
255 #[test]
unexpected_flag()256 fn unexpected_flag() {
257 assert!(
258 Opts::command()
259 .try_get_matches_from(&["test", "--flag"])
260 .is_err()
261 );
262 }
263
264 #[test]
overridden_option()265 fn overridden_option() {
266 assert!(
267 Opts::command()
268 .try_get_matches_from(&["test", "-p", "10", "-p", "20"])
269 .is_err()
270 );
271 }
272
273 #[test]
negative_filter()274 fn negative_filter() {
275 assert!(
276 Opts::command()
277 .try_get_matches_from(&["test", "-p", "-1"])
278 .is_err()
279 );
280 }
281 }
282