• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Inspired by Clang's clang-format-diff:
2 //
3 // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
4 
5 #![deny(warnings)]
6 
7 #[macro_use]
8 extern crate log;
9 
10 use serde::{Deserialize, Serialize};
11 use serde_json as json;
12 use thiserror::Error;
13 
14 use std::collections::HashSet;
15 use std::env;
16 use std::ffi::OsStr;
17 use std::io::{self, BufRead};
18 use std::process;
19 
20 use regex::Regex;
21 
22 use clap::{CommandFactory, Parser};
23 
24 /// The default pattern of files to format.
25 ///
26 /// We only want to format rust files by default.
27 const DEFAULT_PATTERN: &str = r".*\.rs";
28 
29 #[derive(Error, Debug)]
30 enum FormatDiffError {
31     #[error("{0}")]
32     IncorrectOptions(#[from] getopts::Fail),
33     #[error("{0}")]
34     IncorrectFilter(#[from] regex::Error),
35     #[error("{0}")]
36     IoError(#[from] io::Error),
37 }
38 
39 #[derive(Parser, Debug)]
40 #[clap(
41     name = "rustfmt-format-diff",
42     disable_version_flag = true,
43     next_line_help = true
44 )]
45 pub struct Opts {
46     /// Skip the smallest prefix containing NUMBER slashes
47     #[clap(
48         short = 'p',
49         long = "skip-prefix",
50         value_name = "NUMBER",
51         default_value = "0"
52     )]
53     skip_prefix: u32,
54 
55     /// Custom pattern selecting file paths to reformat
56     #[clap(
57         short = 'f',
58         long = "filter",
59         value_name = "PATTERN",
60         default_value = DEFAULT_PATTERN
61     )]
62     filter: String,
63 }
64 
main()65 fn main() {
66     env_logger::Builder::from_env("RUSTFMT_LOG").init();
67     let opts = Opts::parse();
68     if let Err(e) = run(opts) {
69         println!("{}", e);
70         Opts::command()
71             .print_help()
72             .expect("cannot write to stdout");
73         process::exit(1);
74     }
75 }
76 
77 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
78 struct Range {
79     file: String,
80     range: [u32; 2],
81 }
82 
run(opts: Opts) -> Result<(), FormatDiffError>83 fn run(opts: Opts) -> Result<(), FormatDiffError> {
84     let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
85     run_rustfmt(&files, &ranges)
86 }
87 
run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError>88 fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
89     if files.is_empty() || ranges.is_empty() {
90         debug!("No files to format found");
91         return Ok(());
92     }
93 
94     let ranges_as_json = json::to_string(ranges).unwrap();
95 
96     debug!("Files: {:?}", files);
97     debug!("Ranges: {:?}", ranges);
98 
99     let rustfmt_var = env::var_os("RUSTFMT");
100     let rustfmt = match &rustfmt_var {
101         Some(rustfmt) => rustfmt,
102         None => OsStr::new("rustfmt"),
103     };
104     let exit_status = process::Command::new(rustfmt)
105         .args(files)
106         .arg("--file-lines")
107         .arg(ranges_as_json)
108         .status()?;
109 
110     if !exit_status.success() {
111         return Err(FormatDiffError::IoError(io::Error::new(
112             io::ErrorKind::Other,
113             format!("rustfmt failed with {}", exit_status),
114         )));
115     }
116     Ok(())
117 }
118 
119 /// Scans a diff from `from`, and returns the set of files found, and the ranges
120 /// in those files.
scan_diff<R>( from: R, skip_prefix: u32, file_filter: &str, ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError> where R: io::Read,121 fn scan_diff<R>(
122     from: R,
123     skip_prefix: u32,
124     file_filter: &str,
125 ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
126 where
127     R: io::Read,
128 {
129     let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{}}}(\S*)", skip_prefix);
130     let diff_pattern = Regex::new(&diff_pattern).unwrap();
131 
132     let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
133 
134     let file_filter = Regex::new(&format!("^{}$", file_filter))?;
135 
136     let mut current_file = None;
137 
138     let mut files = HashSet::new();
139     let mut ranges = vec![];
140     for line in io::BufReader::new(from).lines() {
141         let line = line.unwrap();
142 
143         if let Some(captures) = diff_pattern.captures(&line) {
144             current_file = Some(captures.get(1).unwrap().as_str().to_owned());
145         }
146 
147         let file = match current_file {
148             Some(ref f) => &**f,
149             None => continue,
150         };
151 
152         // FIXME(emilio): We could avoid this most of the time if needed, but
153         // it's not clear it's worth it.
154         if !file_filter.is_match(file) {
155             continue;
156         }
157 
158         let lines_captures = match lines_pattern.captures(&line) {
159             Some(captures) => captures,
160             None => continue,
161         };
162 
163         let start_line = lines_captures
164             .get(1)
165             .unwrap()
166             .as_str()
167             .parse::<u32>()
168             .unwrap();
169         let line_count = match lines_captures.get(3) {
170             Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
171             None => 1,
172         };
173 
174         if line_count == 0 {
175             continue;
176         }
177 
178         let end_line = start_line + line_count - 1;
179         files.insert(file.to_owned());
180         ranges.push(Range {
181             file: file.to_owned(),
182             range: [start_line, end_line],
183         });
184     }
185 
186     Ok((files, ranges))
187 }
188 
189 #[test]
scan_simple_git_diff()190 fn scan_simple_git_diff() {
191     const DIFF: &str = include_str!("test/bindgen.diff");
192     let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
193 
194     assert!(
195         files.contains("src/ir/traversal.rs"),
196         "Should've matched the filter"
197     );
198 
199     assert!(
200         !files.contains("tests/headers/anon_enum.hpp"),
201         "Shouldn't have matched the filter"
202     );
203 
204     assert_eq!(
205         &ranges,
206         &[
207             Range {
208                 file: "src/ir/item.rs".to_owned(),
209                 range: [148, 158],
210             },
211             Range {
212                 file: "src/ir/item.rs".to_owned(),
213                 range: [160, 170],
214             },
215             Range {
216                 file: "src/ir/traversal.rs".to_owned(),
217                 range: [9, 16],
218             },
219             Range {
220                 file: "src/ir/traversal.rs".to_owned(),
221                 range: [35, 43],
222             },
223         ]
224     );
225 }
226 
227 #[cfg(test)]
228 mod cmd_line_tests {
229     use super::*;
230 
231     #[test]
default_options()232     fn default_options() {
233         let empty: Vec<String> = vec![];
234         let o = Opts::parse_from(&empty);
235         assert_eq!(DEFAULT_PATTERN, o.filter);
236         assert_eq!(0, o.skip_prefix);
237     }
238 
239     #[test]
good_options()240     fn good_options() {
241         let o = Opts::parse_from(&["test", "-p", "10", "-f", r".*\.hs"]);
242         assert_eq!(r".*\.hs", o.filter);
243         assert_eq!(10, o.skip_prefix);
244     }
245 
246     #[test]
unexpected_option()247     fn unexpected_option() {
248         assert!(
249             Opts::command()
250                 .try_get_matches_from(&["test", "unexpected"])
251                 .is_err()
252         );
253     }
254 
255     #[test]
unexpected_flag()256     fn unexpected_flag() {
257         assert!(
258             Opts::command()
259                 .try_get_matches_from(&["test", "--flag"])
260                 .is_err()
261         );
262     }
263 
264     #[test]
overridden_option()265     fn overridden_option() {
266         assert!(
267             Opts::command()
268                 .try_get_matches_from(&["test", "-p", "10", "-p", "20"])
269                 .is_err()
270         );
271     }
272 
273     #[test]
negative_filter()274     fn negative_filter() {
275         assert!(
276             Opts::command()
277                 .try_get_matches_from(&["test", "-p", "-1"])
278                 .is_err()
279         );
280     }
281 }
282