• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Generate Rust unit tests for canonical test vectors.
16 
17 use quote::{format_ident, quote};
18 use serde::{Deserialize, Serialize};
19 use serde_json::Value;
20 
21 #[derive(Debug, Deserialize)]
22 struct Packet {
23     #[serde(rename = "packet")]
24     name: String,
25     tests: Vec<TestVector>,
26 }
27 
28 #[derive(Debug, Deserialize)]
29 struct TestVector {
30     packed: String,
31     unpacked: Value,
32     packet: Option<String>,
33 }
34 
35 /// Convert a string of hexadecimal characters into a Rust vector of
36 /// bytes.
37 ///
38 /// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream39 fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
40     assert!(hex.len() % 2 == 0, "Expects an even number of hex digits");
41     let bytes = hex.as_bytes().chunks_exact(2).map(|chunk| {
42         let number = format!("0x{}", std::str::from_utf8(chunk).unwrap());
43         syn::parse_str::<syn::LitInt>(&number).unwrap()
44     });
45 
46     quote! {
47         vec![#(#bytes),*]
48     }
49 }
50 
51 /// Convert `value` to a JSON string literal.
52 ///
53 /// The string literal is a raw literal to avoid escaping
54 /// double-quotes.
to_json<T: Serialize>(value: &T) -> syn::LitStr55 fn to_json<T: Serialize>(value: &T) -> syn::LitStr {
56     let json = serde_json::to_string(value).unwrap();
57     assert!(!json.contains("\"#"), "Please increase number of # for {json:?}");
58     syn::parse_str::<syn::LitStr>(&format!("r#\" {json} \"#")).unwrap()
59 }
60 
generate_unit_tests(input: &str, packet_names: &[&str]) -> Result<String, String>61 fn generate_unit_tests(input: &str, packet_names: &[&str]) -> Result<String, String> {
62     eprintln!("Reading test vectors from {input}, will use {} packets", packet_names.len());
63 
64     let data = std::fs::read_to_string(input)
65         .unwrap_or_else(|err| panic!("Could not read {input}: {err}"));
66     let packets: Vec<Packet> = serde_json::from_str(&data).expect("Could not parse JSON");
67 
68     let mut tests = Vec::new();
69     for packet in &packets {
70         for (i, test_vector) in packet.tests.iter().enumerate() {
71             let test_packet = test_vector.packet.as_deref().unwrap_or(packet.name.as_str());
72             if !packet_names.contains(&test_packet) {
73                 eprintln!("Skipping packet {}", test_packet);
74                 continue;
75             }
76             eprintln!("Generating tests for packet {}", test_packet);
77 
78             let parse_test_name = format_ident!(
79                 "test_parse_{}_vector_{}_0x{}",
80                 test_packet,
81                 i + 1,
82                 &test_vector.packed
83             );
84             let serialize_test_name = format_ident!(
85                 "test_serialize_{}_vector_{}_0x{}",
86                 test_packet,
87                 i + 1,
88                 &test_vector.packed
89             );
90             let packed = hexadecimal_to_vec(&test_vector.packed);
91             let packet_name = format_ident!("{}", test_packet);
92             let builder_name = format_ident!("{}Builder", test_packet);
93 
94             let object = test_vector.unpacked.as_object().unwrap_or_else(|| {
95                 panic!("Expected test vector object, found: {}", test_vector.unpacked)
96             });
97             let assertions = object.iter().map(|(key, value)| {
98                 let getter = format_ident!("get_{key}");
99                 let expected = format_ident!("expected_{key}");
100                 let json = to_json(&value);
101                 quote! {
102                     let #expected: serde_json::Value = serde_json::from_str(#json)
103                         .expect("Could not create expected value from canonical JSON data");
104                     assert_eq!(json!(actual.#getter()), #expected);
105                 }
106             });
107 
108             let json = to_json(&object);
109             tests.push(quote! {
110                 #[test]
111                 fn #parse_test_name() {
112                     let packed = #packed;
113                     let actual = #packet_name::parse(&packed).unwrap();
114                     #(#assertions)*
115                 }
116 
117                 #[test]
118                 fn #serialize_test_name() {
119                     let builder: #builder_name = serde_json::from_str(#json)
120                         .expect("Could not create builder from canonical JSON data");
121                     let packet = builder.build();
122                     let packed: Vec<u8> = #packed;
123                     assert_eq!(packet.encode_to_vec(), Ok(packed));
124                 }
125             });
126         }
127     }
128 
129     // TODO(mgeisler): make the generated code clean from warnings.
130     let code = quote! {
131         #[allow(warnings, missing_docs)]
132         #[cfg(test)]
133         mod test {
134             use pdl_runtime::Packet;
135             use serde_json::json;
136             use super::*;
137 
138             #(#tests)*
139         }
140     };
141     let syntax_tree = syn::parse2::<syn::File>(code).expect("Could not parse {code:#?}");
142     Ok(prettyplease::unparse(&syntax_tree))
143 }
144 
generate_tests(input_file: &str) -> Result<String, String>145 pub fn generate_tests(input_file: &str) -> Result<String, String> {
146     // TODO(mgeisler): remove the `packet_names` argument when we
147     // support all canonical packets.
148     generate_unit_tests(
149         input_file,
150         &[
151             "EnumChild_A",
152             "EnumChild_B",
153             "Packet_Array_Field_ByteElement_ConstantSize",
154             "Packet_Array_Field_ByteElement_UnknownSize",
155             "Packet_Array_Field_ByteElement_VariableCount",
156             "Packet_Array_Field_ByteElement_VariableSize",
157             "Packet_Array_Field_EnumElement",
158             "Packet_Array_Field_EnumElement_ConstantSize",
159             "Packet_Array_Field_EnumElement_UnknownSize",
160             "Packet_Array_Field_EnumElement_VariableCount",
161             "Packet_Array_Field_EnumElement_VariableCount",
162             "Packet_Array_Field_ScalarElement",
163             "Packet_Array_Field_ScalarElement_ConstantSize",
164             "Packet_Array_Field_ScalarElement_UnknownSize",
165             "Packet_Array_Field_ScalarElement_VariableCount",
166             "Packet_Array_Field_ScalarElement_VariableSize",
167             "Packet_Array_Field_SizedElement_ConstantSize",
168             "Packet_Array_Field_SizedElement_UnknownSize",
169             "Packet_Array_Field_SizedElement_VariableCount",
170             "Packet_Array_Field_SizedElement_VariableSize",
171             "Packet_Array_Field_UnsizedElement_ConstantSize",
172             "Packet_Array_Field_UnsizedElement_UnknownSize",
173             "Packet_Array_Field_UnsizedElement_VariableCount",
174             "Packet_Array_Field_UnsizedElement_VariableSize",
175             "Packet_Array_Field_SizedElement_VariableSize_Padded",
176             "Packet_Array_Field_UnsizedElement_VariableCount_Padded",
177             "Packet_Optional_Scalar_Field",
178             "Packet_Optional_Enum_Field",
179             "Packet_Optional_Struct_Field",
180             "Packet_Body_Field_UnknownSize",
181             "Packet_Body_Field_UnknownSize_Terminal",
182             "Packet_Body_Field_VariableSize",
183             "Packet_Count_Field",
184             "Packet_Enum8_Field",
185             "Packet_Enum_Field",
186             "Packet_FixedEnum_Field",
187             "Packet_FixedScalar_Field",
188             "Packet_Payload_Field_UnknownSize",
189             "Packet_Payload_Field_UnknownSize_Terminal",
190             "Packet_Payload_Field_VariableSize",
191             "Packet_Payload_Field_SizeModifier",
192             "Packet_Reserved_Field",
193             "Packet_Scalar_Field",
194             "Packet_Size_Field",
195             "Packet_Struct_Field",
196             "ScalarChild_A",
197             "ScalarChild_B",
198             "Struct_Count_Field",
199             "Struct_Array_Field_ByteElement_ConstantSize",
200             "Struct_Array_Field_ByteElement_UnknownSize",
201             "Struct_Array_Field_ByteElement_UnknownSize",
202             "Struct_Array_Field_ByteElement_VariableCount",
203             "Struct_Array_Field_ByteElement_VariableCount",
204             "Struct_Array_Field_ByteElement_VariableSize",
205             "Struct_Array_Field_ByteElement_VariableSize",
206             "Struct_Array_Field_EnumElement_ConstantSize",
207             "Struct_Array_Field_EnumElement_UnknownSize",
208             "Struct_Array_Field_EnumElement_UnknownSize",
209             "Struct_Array_Field_EnumElement_VariableCount",
210             "Struct_Array_Field_EnumElement_VariableCount",
211             "Struct_Array_Field_EnumElement_VariableSize",
212             "Struct_Array_Field_EnumElement_VariableSize",
213             "Struct_Array_Field_ScalarElement_ConstantSize",
214             "Struct_Array_Field_ScalarElement_UnknownSize",
215             "Struct_Array_Field_ScalarElement_UnknownSize",
216             "Struct_Array_Field_ScalarElement_VariableCount",
217             "Struct_Array_Field_ScalarElement_VariableCount",
218             "Struct_Array_Field_ScalarElement_VariableSize",
219             "Struct_Array_Field_ScalarElement_VariableSize",
220             "Struct_Array_Field_SizedElement_ConstantSize",
221             "Struct_Array_Field_SizedElement_UnknownSize",
222             "Struct_Array_Field_SizedElement_UnknownSize",
223             "Struct_Array_Field_SizedElement_VariableCount",
224             "Struct_Array_Field_SizedElement_VariableCount",
225             "Struct_Array_Field_SizedElement_VariableSize",
226             "Struct_Array_Field_SizedElement_VariableSize",
227             "Struct_Array_Field_UnsizedElement_ConstantSize",
228             "Struct_Array_Field_UnsizedElement_UnknownSize",
229             "Struct_Array_Field_UnsizedElement_UnknownSize",
230             "Struct_Array_Field_UnsizedElement_VariableCount",
231             "Struct_Array_Field_UnsizedElement_VariableCount",
232             "Struct_Array_Field_UnsizedElement_VariableSize",
233             "Struct_Array_Field_UnsizedElement_VariableSize",
234             "Struct_Array_Field_SizedElement_VariableSize_Padded",
235             "Struct_Array_Field_UnsizedElement_VariableCount_Padded",
236             "Struct_Optional_Scalar_Field",
237             "Struct_Optional_Enum_Field",
238             "Struct_Optional_Struct_Field",
239             "Struct_Enum_Field",
240             "Struct_FixedEnum_Field",
241             "Struct_FixedScalar_Field",
242             "Struct_Size_Field",
243             "Struct_Struct_Field",
244             "Enum_Incomplete_Truncated_Closed",
245             "Enum_Incomplete_Truncated_Open",
246             "Enum_Incomplete_Truncated_Closed_WithRange",
247             "Enum_Incomplete_Truncated_Open_WithRange",
248             "Enum_Complete_Truncated",
249             "Enum_Complete_Truncated_WithRange",
250             "Enum_Complete_WithRange",
251         ],
252     )
253 }
254