• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Generate Rust unit tests for canonical test vectors.
16 
17 use quote::{format_ident, quote};
18 use serde::{Deserialize, Serialize};
19 use serde_json::Value;
20 
21 #[derive(Debug, Deserialize)]
22 struct Packet {
23     #[serde(rename = "packet")]
24     name: String,
25     tests: Vec<TestVector>,
26 }
27 
28 #[derive(Debug, Deserialize)]
29 struct TestVector {
30     packed: String,
31     unpacked: Value,
32     packet: Option<String>,
33 }
34 
35 /// Convert a string of hexadecimal characters into a Rust vector of
36 /// bytes.
37 ///
38 /// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream39 fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
40     assert!(hex.len() % 2 == 0, "Expects an even number of hex digits");
41     let bytes = hex.as_bytes().chunks_exact(2).map(|chunk| {
42         let number = format!("0x{}", std::str::from_utf8(chunk).unwrap());
43         syn::parse_str::<syn::LitInt>(&number).unwrap()
44     });
45 
46     quote! {
47         vec![#(#bytes),*]
48     }
49 }
50 
51 /// Convert `value` to a JSON string literal.
52 ///
53 /// The string literal is a raw literal to avoid escaping
54 /// double-quotes.
to_json<T: Serialize>(value: &T) -> syn::LitStr55 fn to_json<T: Serialize>(value: &T) -> syn::LitStr {
56     let json = serde_json::to_string(value).unwrap();
57     assert!(!json.contains("\"#"), "Please increase number of # for {json:?}");
58     syn::parse_str::<syn::LitStr>(&format!("r#\" {json} \"#")).unwrap()
59 }
60 
generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str)61 fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {
62     eprintln!("Reading test vectors from {input}, will use {} packets", packet_names.len());
63 
64     let data = std::fs::read_to_string(input)
65         .unwrap_or_else(|err| panic!("Could not read {input}: {err}"));
66     let packets: Vec<Packet> = serde_json::from_str(&data).expect("Could not parse JSON");
67 
68     let module = format_ident!("{}", module_name);
69     let mut tests = Vec::new();
70     for packet in &packets {
71         for (i, test_vector) in packet.tests.iter().enumerate() {
72             let test_packet = test_vector.packet.as_deref().unwrap_or(packet.name.as_str());
73             if !packet_names.contains(&test_packet) {
74                 eprintln!("Skipping packet {}", test_packet);
75                 continue;
76             }
77             eprintln!("Generating tests for packet {}", test_packet);
78 
79             let parse_test_name = format_ident!(
80                 "test_parse_{}_vector_{}_0x{}",
81                 test_packet,
82                 i + 1,
83                 &test_vector.packed
84             );
85             let serialize_test_name = format_ident!(
86                 "test_serialize_{}_vector_{}_0x{}",
87                 test_packet,
88                 i + 1,
89                 &test_vector.packed
90             );
91             let packed = hexadecimal_to_vec(&test_vector.packed);
92             let packet_name = format_ident!("{}", test_packet);
93             let builder_name = format_ident!("{}Builder", test_packet);
94 
95             let object = test_vector.unpacked.as_object().unwrap_or_else(|| {
96                 panic!("Expected test vector object, found: {}", test_vector.unpacked)
97             });
98             let assertions = object.iter().map(|(key, value)| {
99                 let getter = format_ident!("get_{key}");
100                 let expected = format_ident!("expected_{key}");
101                 let json = to_json(&value);
102                 quote! {
103                     let #expected: serde_json::Value = serde_json::from_str(#json)
104                         .expect("Could not create expected value from canonical JSON data");
105                     assert_eq!(json!(actual.#getter()), #expected);
106                 }
107             });
108 
109             let json = to_json(&object);
110             tests.push(quote! {
111                 #[test]
112                 fn #parse_test_name() {
113                     let packed = #packed;
114                     let actual = #module::#packet_name::parse(&packed).unwrap();
115                     #(#assertions)*
116                 }
117 
118                 #[test]
119                 fn #serialize_test_name() {
120                     let builder: #module::#builder_name = serde_json::from_str(#json)
121                         .expect("Could not create builder from canonical JSON data");
122                     let packet = builder.build();
123                     let packed: Vec<u8> = #packed;
124                     assert_eq!(packet.to_vec(), packed);
125                 }
126             });
127         }
128     }
129 
130     // TODO(mgeisler): make the generated code clean from warnings.
131     println!("#![allow(warnings, missing_docs)]");
132     println!();
133     println!(
134         "{}",
135         &quote! {
136             use #module::Packet;
137             use serde_json::json;
138 
139             #(#tests)*
140         }
141     );
142 }
143 
main()144 fn main() {
145     let input_path = std::env::args().nth(1).expect("Need path to JSON file with test vectors");
146     let module_name = std::env::args().nth(2).expect("Need name for the generated module");
147     // TODO(mgeisler): remove the `packet_names` argument when we
148     // support all canonical packets.
149     generate_unit_tests(
150         &input_path,
151         &[
152             "EnumChild_A",
153             "EnumChild_B",
154             "Packet_Array_Field_ByteElement_ConstantSize",
155             "Packet_Array_Field_ByteElement_UnknownSize",
156             "Packet_Array_Field_ByteElement_VariableCount",
157             "Packet_Array_Field_ByteElement_VariableSize",
158             "Packet_Array_Field_EnumElement",
159             "Packet_Array_Field_EnumElement_ConstantSize",
160             "Packet_Array_Field_EnumElement_UnknownSize",
161             "Packet_Array_Field_EnumElement_VariableCount",
162             "Packet_Array_Field_EnumElement_VariableCount",
163             "Packet_Array_Field_ScalarElement",
164             "Packet_Array_Field_ScalarElement_ConstantSize",
165             "Packet_Array_Field_ScalarElement_UnknownSize",
166             "Packet_Array_Field_ScalarElement_VariableCount",
167             "Packet_Array_Field_ScalarElement_VariableSize",
168             "Packet_Array_Field_SizedElement_ConstantSize",
169             "Packet_Array_Field_SizedElement_UnknownSize",
170             "Packet_Array_Field_SizedElement_VariableCount",
171             "Packet_Array_Field_SizedElement_VariableSize",
172             "Packet_Array_Field_UnsizedElement_ConstantSize",
173             "Packet_Array_Field_UnsizedElement_UnknownSize",
174             "Packet_Array_Field_UnsizedElement_VariableCount",
175             "Packet_Array_Field_UnsizedElement_VariableSize",
176             "Packet_Body_Field_UnknownSize",
177             "Packet_Body_Field_UnknownSize_Terminal",
178             "Packet_Body_Field_VariableSize",
179             "Packet_Count_Field",
180             "Packet_Enum8_Field",
181             "Packet_Enum_Field",
182             "Packet_FixedEnum_Field",
183             "Packet_FixedScalar_Field",
184             "Packet_Payload_Field_UnknownSize",
185             "Packet_Payload_Field_UnknownSize_Terminal",
186             "Packet_Payload_Field_VariableSize",
187             "Packet_Reserved_Field",
188             "Packet_Scalar_Field",
189             "Packet_Size_Field",
190             "Packet_Struct_Field",
191             "ScalarChild_A",
192             "ScalarChild_B",
193             "Struct_Count_Field",
194             "Struct_Array_Field_ByteElement_ConstantSize",
195             "Struct_Array_Field_ByteElement_UnknownSize",
196             "Struct_Array_Field_ByteElement_UnknownSize",
197             "Struct_Array_Field_ByteElement_VariableCount",
198             "Struct_Array_Field_ByteElement_VariableCount",
199             "Struct_Array_Field_ByteElement_VariableSize",
200             "Struct_Array_Field_ByteElement_VariableSize",
201             "Struct_Array_Field_EnumElement_ConstantSize",
202             "Struct_Array_Field_EnumElement_UnknownSize",
203             "Struct_Array_Field_EnumElement_UnknownSize",
204             "Struct_Array_Field_EnumElement_VariableCount",
205             "Struct_Array_Field_EnumElement_VariableCount",
206             "Struct_Array_Field_EnumElement_VariableSize",
207             "Struct_Array_Field_EnumElement_VariableSize",
208             "Struct_Array_Field_ScalarElement_ConstantSize",
209             "Struct_Array_Field_ScalarElement_UnknownSize",
210             "Struct_Array_Field_ScalarElement_UnknownSize",
211             "Struct_Array_Field_ScalarElement_VariableCount",
212             "Struct_Array_Field_ScalarElement_VariableCount",
213             "Struct_Array_Field_ScalarElement_VariableSize",
214             "Struct_Array_Field_ScalarElement_VariableSize",
215             "Struct_Array_Field_SizedElement_ConstantSize",
216             "Struct_Array_Field_SizedElement_UnknownSize",
217             "Struct_Array_Field_SizedElement_UnknownSize",
218             "Struct_Array_Field_SizedElement_VariableCount",
219             "Struct_Array_Field_SizedElement_VariableCount",
220             "Struct_Array_Field_SizedElement_VariableSize",
221             "Struct_Array_Field_SizedElement_VariableSize",
222             "Struct_Array_Field_UnsizedElement_ConstantSize",
223             "Struct_Array_Field_UnsizedElement_UnknownSize",
224             "Struct_Array_Field_UnsizedElement_UnknownSize",
225             "Struct_Array_Field_UnsizedElement_VariableCount",
226             "Struct_Array_Field_UnsizedElement_VariableCount",
227             "Struct_Array_Field_UnsizedElement_VariableSize",
228             "Struct_Array_Field_UnsizedElement_VariableSize",
229             "Struct_Enum_Field",
230             "Struct_FixedEnum_Field",
231             "Struct_FixedScalar_Field",
232             "Struct_Size_Field",
233             "Struct_Struct_Field",
234         ],
235         &module_name,
236     );
237 }
238