1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use quote::quote;
16 use syn::{parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Expr, Ident};
17
18 struct AccumulatePartsState {
19 var_num: usize,
20 error_message_ident: Ident,
21 statements: Vec<proc_macro2::TokenStream>,
22 }
23
expr_to_string(expr: &Expr) -> String24 fn expr_to_string(expr: &Expr) -> String {
25 quote!(#expr).to_string()
26 }
27
28 impl AccumulatePartsState {
new() -> Self29 fn new() -> Self {
30 Self {
31 var_num: 0,
32 error_message_ident: Ident::new(
33 "__googletest__verify_pred__error_message",
34 ::proc_macro2::Span::call_site(),
35 ),
36 statements: vec![],
37 }
38 }
39
40 /// Takes an expression with chained field accesses and method calls and
41 /// accumulates intermediate expressions used for computing `verify_pred!`'s
42 /// expression, including intermediate variable assignments to evaluate
43 /// parts of the expression exactly once, and the format string used to
44 /// output intermediate values on condition failure. It returns the new form
45 /// of the input expression with parts of it potentially replaced by the
46 /// intermediate variables.
accumulate_parts(&mut self, expr: Expr) -> Expr47 fn accumulate_parts(&mut self, expr: Expr) -> Expr {
48 // Literals don't need to be printed or stored in intermediate variables.
49 if is_literal(&expr) {
50 return expr;
51 }
52 let expr_string = expr_to_string(&expr);
53 let new_expr = match expr {
54 Expr::Group(mut group) => {
55 // This is an invisible group added for correct precedence in the AST. Just pass
56 // through without having a separate printing result.
57 *group.expr = self.accumulate_parts(*group.expr);
58 return Expr::Group(group);
59 }
60 Expr::Field(mut field) => {
61 // Don't assign field access to an intermediate variable to avoid moving out of
62 // non-`Copy` fields.
63 *field.base = self.accumulate_parts(*field.base);
64 Expr::Field(field)
65 }
66 Expr::Call(mut call) => {
67 // Cache args into intermediate variables.
68 call.args = self.define_variables_for_args(call.args);
69 // Cache function value into an intermediate variable.
70 self.define_variable(&Expr::Call(call))
71 }
72 Expr::MethodCall(mut method_call) => {
73 *method_call.receiver = self.accumulate_parts(*method_call.receiver);
74 // Cache args into intermediate variables.
75 method_call.args = self.define_variables_for_args(method_call.args);
76 // Cache method value into an intermediate variable.
77 self.define_variable(&Expr::MethodCall(method_call))
78 }
79 Expr::Binary(mut binary) => {
80 *binary.left = self.accumulate_parts(*binary.left);
81 *binary.right = self.accumulate_parts(*binary.right);
82 Expr::Binary(binary)
83 }
84 Expr::Unary(mut unary) => {
85 *unary.expr = self.accumulate_parts(*unary.expr);
86 Expr::Unary(unary)
87 }
88 // A path expression doesn't need to be stored in an intermediate variable.
89 // This avoids moving out of an existing variable.
90 Expr::Path(_) => expr,
91 // By default, assume it's some expression that needs to be cached to avoid
92 // double-evaluation.
93 _ => self.define_variable(&expr),
94 };
95 let error_message_ident = &self.error_message_ident;
96 self.statements.push(quote! {
97 ::googletest::fmt::internal::__googletest__write_expr_value!(
98 &mut #error_message_ident,
99 #expr_string,
100 #new_expr,
101 );
102 });
103 new_expr
104 }
105
106 // Defines a variable for each argument expression so that it's evaluated
107 // exactly once.
define_variables_for_args( &mut self, args: Punctuated<Expr, Comma>, ) -> Punctuated<Expr, Comma>108 fn define_variables_for_args(
109 &mut self,
110 args: Punctuated<Expr, Comma>,
111 ) -> Punctuated<Expr, Comma> {
112 args.into_pairs()
113 .map(|mut pair| {
114 // Don't need to assign literals to intermediate variables.
115 if is_literal(pair.value()) {
116 return pair;
117 }
118
119 let var_expr = self.define_variable(pair.value());
120 let error_message_ident = &self.error_message_ident;
121 let expr_string = expr_to_string(pair.value());
122 self.statements.push(quote! {
123 ::googletest::fmt::internal::__googletest__write_expr_value!(
124 &mut #error_message_ident,
125 #expr_string,
126 #var_expr,
127 );
128 });
129
130 *pair.value_mut() = var_expr;
131 pair
132 })
133 .collect()
134 }
135
136 /// Defines a new variable assigned to the expression and returns the
137 /// variable as an expression to be used in place of the passed-in
138 /// expression.
define_variable(&mut self, value: &Expr) -> Expr139 fn define_variable(&mut self, value: &Expr) -> Expr {
140 let var_name =
141 Ident::new(&format!("__googletest__verify_pred__var{}", self.var_num), value.span());
142 self.var_num += 1;
143 self.statements.push(quote! {
144 #[allow(non_snake_case)]
145 let mut #var_name = #value;
146 });
147 syn::parse::<Expr>(quote!(#var_name).into()).unwrap()
148 }
149 }
150
151 // Whether it's a literal or unary operator applied to a literal (1, -1).
is_literal(expr: &Expr) -> bool152 fn is_literal(expr: &Expr) -> bool {
153 match expr {
154 Expr::Lit(_) => true,
155 Expr::Unary(unary) => matches!(&*unary.expr, Expr::Lit(_)),
156 _ => false,
157 }
158 }
159
verify_pred_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream160 pub fn verify_pred_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
161 let parsed = parse_macro_input!(input as Expr);
162 let error_message = quote!(#parsed).to_string() + " was false with";
163
164 let mut state = AccumulatePartsState::new();
165 let pred_value = state.accumulate_parts(parsed);
166 let AccumulatePartsState { error_message_ident, mut statements, .. } = state;
167
168 let _ = statements.pop(); // The last statement prints the full expression itself.
169 quote! {
170 {
171 let mut #error_message_ident = #error_message.to_string();
172 #(#statements)*
173 if (#pred_value) {
174 Ok(())
175 } else {
176 ::core::result::Result::Err(
177 ::googletest::internal::test_outcome::TestAssertionFailure::create(
178 #error_message_ident))
179 }
180 }
181 }
182 .into()
183 }
184