• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use quote::quote;
16 use syn::{parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Expr, Ident};
17 
18 struct AccumulatePartsState {
19     var_num: usize,
20     error_message_ident: Ident,
21     statements: Vec<proc_macro2::TokenStream>,
22 }
23 
expr_to_string(expr: &Expr) -> String24 fn expr_to_string(expr: &Expr) -> String {
25     quote!(#expr).to_string()
26 }
27 
28 impl AccumulatePartsState {
new() -> Self29     fn new() -> Self {
30         Self {
31             var_num: 0,
32             error_message_ident: Ident::new(
33                 "__googletest__verify_pred__error_message",
34                 ::proc_macro2::Span::call_site(),
35             ),
36             statements: vec![],
37         }
38     }
39 
40     /// Takes an expression with chained field accesses and method calls and
41     /// accumulates intermediate expressions used for computing `verify_pred!`'s
42     /// expression, including intermediate variable assignments to evaluate
43     /// parts of the expression exactly once, and the format string used to
44     /// output intermediate values on condition failure. It returns the new form
45     /// of the input expression with parts of it potentially replaced by the
46     /// intermediate variables.
accumulate_parts(&mut self, expr: Expr) -> Expr47     fn accumulate_parts(&mut self, expr: Expr) -> Expr {
48         // Literals don't need to be printed or stored in intermediate variables.
49         if is_literal(&expr) {
50             return expr;
51         }
52         let expr_string = expr_to_string(&expr);
53         let new_expr = match expr {
54             Expr::Group(mut group) => {
55                 // This is an invisible group added for correct precedence in the AST. Just pass
56                 // through without having a separate printing result.
57                 *group.expr = self.accumulate_parts(*group.expr);
58                 return Expr::Group(group);
59             }
60             Expr::Field(mut field) => {
61                 // Don't assign field access to an intermediate variable to avoid moving out of
62                 // non-`Copy` fields.
63                 *field.base = self.accumulate_parts(*field.base);
64                 Expr::Field(field)
65             }
66             Expr::Call(mut call) => {
67                 // Cache args into intermediate variables.
68                 call.args = self.define_variables_for_args(call.args);
69                 // Cache function value into an intermediate variable.
70                 self.define_variable(&Expr::Call(call))
71             }
72             Expr::MethodCall(mut method_call) => {
73                 *method_call.receiver = self.accumulate_parts(*method_call.receiver);
74                 // Cache args into intermediate variables.
75                 method_call.args = self.define_variables_for_args(method_call.args);
76                 // Cache method value into an intermediate variable.
77                 self.define_variable(&Expr::MethodCall(method_call))
78             }
79             Expr::Binary(mut binary) => {
80                 *binary.left = self.accumulate_parts(*binary.left);
81                 *binary.right = self.accumulate_parts(*binary.right);
82                 Expr::Binary(binary)
83             }
84             Expr::Unary(mut unary) => {
85                 *unary.expr = self.accumulate_parts(*unary.expr);
86                 Expr::Unary(unary)
87             }
88             // A path expression doesn't need to be stored in an intermediate variable.
89             // This avoids moving out of an existing variable.
90             Expr::Path(_) => expr,
91             // By default, assume it's some expression that needs to be cached to avoid
92             // double-evaluation.
93             _ => self.define_variable(&expr),
94         };
95         let error_message_ident = &self.error_message_ident;
96         self.statements.push(quote! {
97             ::googletest::fmt::internal::__googletest__write_expr_value!(
98                 &mut #error_message_ident,
99                 #expr_string,
100                 #new_expr,
101             );
102         });
103         new_expr
104     }
105 
106     // Defines a variable for each argument expression so that it's evaluated
107     // exactly once.
define_variables_for_args( &mut self, args: Punctuated<Expr, Comma>, ) -> Punctuated<Expr, Comma>108     fn define_variables_for_args(
109         &mut self,
110         args: Punctuated<Expr, Comma>,
111     ) -> Punctuated<Expr, Comma> {
112         args.into_pairs()
113             .map(|mut pair| {
114                 // Don't need to assign literals to intermediate variables.
115                 if is_literal(pair.value()) {
116                     return pair;
117                 }
118 
119                 let var_expr = self.define_variable(pair.value());
120                 let error_message_ident = &self.error_message_ident;
121                 let expr_string = expr_to_string(pair.value());
122                 self.statements.push(quote! {
123                     ::googletest::fmt::internal::__googletest__write_expr_value!(
124                         &mut #error_message_ident,
125                         #expr_string,
126                         #var_expr,
127                     );
128                 });
129 
130                 *pair.value_mut() = var_expr;
131                 pair
132             })
133             .collect()
134     }
135 
136     /// Defines a new variable assigned to the expression and returns the
137     /// variable as an expression to be used in place of the passed-in
138     /// expression.
define_variable(&mut self, value: &Expr) -> Expr139     fn define_variable(&mut self, value: &Expr) -> Expr {
140         let var_name =
141             Ident::new(&format!("__googletest__verify_pred__var{}", self.var_num), value.span());
142         self.var_num += 1;
143         self.statements.push(quote! {
144             #[allow(non_snake_case)]
145             let mut #var_name = #value;
146         });
147         syn::parse::<Expr>(quote!(#var_name).into()).unwrap()
148     }
149 }
150 
151 // Whether it's a literal or unary operator applied to a literal (1, -1).
is_literal(expr: &Expr) -> bool152 fn is_literal(expr: &Expr) -> bool {
153     match expr {
154         Expr::Lit(_) => true,
155         Expr::Unary(unary) => matches!(&*unary.expr, Expr::Lit(_)),
156         _ => false,
157     }
158 }
159 
verify_pred_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream160 pub fn verify_pred_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
161     let parsed = parse_macro_input!(input as Expr);
162     let error_message = quote!(#parsed).to_string() + " was false with";
163 
164     let mut state = AccumulatePartsState::new();
165     let pred_value = state.accumulate_parts(parsed);
166     let AccumulatePartsState { error_message_ident, mut statements, .. } = state;
167 
168     let _ = statements.pop(); // The last statement prints the full expression itself.
169     quote! {
170         {
171             let mut #error_message_ident = #error_message.to_string();
172             #(#statements)*
173             if (#pred_value) {
174                 Ok(())
175             } else {
176                 ::core::result::Result::Err(
177                     ::googletest::internal::test_outcome::TestAssertionFailure::create(
178                         #error_message_ident))
179             }
180         }
181     }
182     .into()
183 }
184