1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //! Generate Rust unit tests for canonical test vectors.
16
17 use quote::{format_ident, quote};
18 use serde::{Deserialize, Serialize};
19 use serde_json::Value;
20
21 #[derive(Debug, Deserialize)]
22 struct Packet {
23 #[serde(rename = "packet")]
24 name: String,
25 tests: Vec<TestVector>,
26 }
27
28 #[derive(Debug, Deserialize)]
29 struct TestVector {
30 packed: String,
31 unpacked: Value,
32 packet: Option<String>,
33 }
34
35 /// Convert a string of hexadecimal characters into a Rust vector of
36 /// bytes.
37 ///
38 /// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream39 fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
40 assert!(hex.len() % 2 == 0, "Expects an even number of hex digits");
41 let bytes = hex.as_bytes().chunks_exact(2).map(|chunk| {
42 let number = format!("0x{}", std::str::from_utf8(chunk).unwrap());
43 syn::parse_str::<syn::LitInt>(&number).unwrap()
44 });
45
46 quote! {
47 vec![#(#bytes),*]
48 }
49 }
50
51 /// Convert `value` to a JSON string literal.
52 ///
53 /// The string literal is a raw literal to avoid escaping
54 /// double-quotes.
to_json<T: Serialize>(value: &T) -> syn::LitStr55 fn to_json<T: Serialize>(value: &T) -> syn::LitStr {
56 let json = serde_json::to_string(value).unwrap();
57 assert!(!json.contains("\"#"), "Please increase number of # for {json:?}");
58 syn::parse_str::<syn::LitStr>(&format!("r#\" {json} \"#")).unwrap()
59 }
60
generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str)61 fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {
62 eprintln!("Reading test vectors from {input}, will use {} packets", packet_names.len());
63
64 let data = std::fs::read_to_string(input)
65 .unwrap_or_else(|err| panic!("Could not read {input}: {err}"));
66 let packets: Vec<Packet> = serde_json::from_str(&data).expect("Could not parse JSON");
67
68 let module = format_ident!("{}", module_name);
69 let mut tests = Vec::new();
70 for packet in &packets {
71 for (i, test_vector) in packet.tests.iter().enumerate() {
72 let test_packet = test_vector.packet.as_deref().unwrap_or(packet.name.as_str());
73 if !packet_names.contains(&test_packet) {
74 eprintln!("Skipping packet {}", test_packet);
75 continue;
76 }
77 eprintln!("Generating tests for packet {}", test_packet);
78
79 let parse_test_name = format_ident!(
80 "test_parse_{}_vector_{}_0x{}",
81 test_packet,
82 i + 1,
83 &test_vector.packed
84 );
85 let serialize_test_name = format_ident!(
86 "test_serialize_{}_vector_{}_0x{}",
87 test_packet,
88 i + 1,
89 &test_vector.packed
90 );
91 let packed = hexadecimal_to_vec(&test_vector.packed);
92 let packet_name = format_ident!("{}", test_packet);
93 let builder_name = format_ident!("{}Builder", test_packet);
94
95 let object = test_vector.unpacked.as_object().unwrap_or_else(|| {
96 panic!("Expected test vector object, found: {}", test_vector.unpacked)
97 });
98 let assertions = object.iter().map(|(key, value)| {
99 let getter = format_ident!("get_{key}");
100 let expected = format_ident!("expected_{key}");
101 let json = to_json(&value);
102 quote! {
103 let #expected: serde_json::Value = serde_json::from_str(#json)
104 .expect("Could not create expected value from canonical JSON data");
105 assert_eq!(json!(actual.#getter()), #expected);
106 }
107 });
108
109 let json = to_json(&object);
110 tests.push(quote! {
111 #[test]
112 fn #parse_test_name() {
113 let packed = #packed;
114 let actual = #module::#packet_name::parse(&packed).unwrap();
115 #(#assertions)*
116 }
117
118 #[test]
119 fn #serialize_test_name() {
120 let builder: #module::#builder_name = serde_json::from_str(#json)
121 .expect("Could not create builder from canonical JSON data");
122 let packet = builder.build();
123 let packed: Vec<u8> = #packed;
124 assert_eq!(packet.to_vec(), packed);
125 }
126 });
127 }
128 }
129
130 // TODO(mgeisler): make the generated code clean from warnings.
131 println!("#![allow(warnings, missing_docs)]");
132 println!();
133 println!(
134 "{}",
135 "e! {
136 use #module::Packet;
137 use serde_json::json;
138
139 #(#tests)*
140 }
141 );
142 }
143
main()144 fn main() {
145 let input_path = std::env::args().nth(1).expect("Need path to JSON file with test vectors");
146 let module_name = std::env::args().nth(2).expect("Need name for the generated module");
147 // TODO(mgeisler): remove the `packet_names` argument when we
148 // support all canonical packets.
149 generate_unit_tests(
150 &input_path,
151 &[
152 "EnumChild_A",
153 "EnumChild_B",
154 "Packet_Array_Field_ByteElement_ConstantSize",
155 "Packet_Array_Field_ByteElement_UnknownSize",
156 "Packet_Array_Field_ByteElement_VariableCount",
157 "Packet_Array_Field_ByteElement_VariableSize",
158 "Packet_Array_Field_EnumElement",
159 "Packet_Array_Field_EnumElement_ConstantSize",
160 "Packet_Array_Field_EnumElement_UnknownSize",
161 "Packet_Array_Field_EnumElement_VariableCount",
162 "Packet_Array_Field_EnumElement_VariableCount",
163 "Packet_Array_Field_ScalarElement",
164 "Packet_Array_Field_ScalarElement_ConstantSize",
165 "Packet_Array_Field_ScalarElement_UnknownSize",
166 "Packet_Array_Field_ScalarElement_VariableCount",
167 "Packet_Array_Field_ScalarElement_VariableSize",
168 "Packet_Array_Field_SizedElement_ConstantSize",
169 "Packet_Array_Field_SizedElement_UnknownSize",
170 "Packet_Array_Field_SizedElement_VariableCount",
171 "Packet_Array_Field_SizedElement_VariableSize",
172 "Packet_Array_Field_UnsizedElement_ConstantSize",
173 "Packet_Array_Field_UnsizedElement_UnknownSize",
174 "Packet_Array_Field_UnsizedElement_VariableCount",
175 "Packet_Array_Field_UnsizedElement_VariableSize",
176 "Packet_Body_Field_UnknownSize",
177 "Packet_Body_Field_UnknownSize_Terminal",
178 "Packet_Body_Field_VariableSize",
179 "Packet_Count_Field",
180 "Packet_Enum8_Field",
181 "Packet_Enum_Field",
182 "Packet_FixedEnum_Field",
183 "Packet_FixedScalar_Field",
184 "Packet_Payload_Field_UnknownSize",
185 "Packet_Payload_Field_UnknownSize_Terminal",
186 "Packet_Payload_Field_VariableSize",
187 "Packet_Reserved_Field",
188 "Packet_Scalar_Field",
189 "Packet_Size_Field",
190 "Packet_Struct_Field",
191 "ScalarChild_A",
192 "ScalarChild_B",
193 "Struct_Count_Field",
194 "Struct_Array_Field_ByteElement_ConstantSize",
195 "Struct_Array_Field_ByteElement_UnknownSize",
196 "Struct_Array_Field_ByteElement_UnknownSize",
197 "Struct_Array_Field_ByteElement_VariableCount",
198 "Struct_Array_Field_ByteElement_VariableCount",
199 "Struct_Array_Field_ByteElement_VariableSize",
200 "Struct_Array_Field_ByteElement_VariableSize",
201 "Struct_Array_Field_EnumElement_ConstantSize",
202 "Struct_Array_Field_EnumElement_UnknownSize",
203 "Struct_Array_Field_EnumElement_UnknownSize",
204 "Struct_Array_Field_EnumElement_VariableCount",
205 "Struct_Array_Field_EnumElement_VariableCount",
206 "Struct_Array_Field_EnumElement_VariableSize",
207 "Struct_Array_Field_EnumElement_VariableSize",
208 "Struct_Array_Field_ScalarElement_ConstantSize",
209 "Struct_Array_Field_ScalarElement_UnknownSize",
210 "Struct_Array_Field_ScalarElement_UnknownSize",
211 "Struct_Array_Field_ScalarElement_VariableCount",
212 "Struct_Array_Field_ScalarElement_VariableCount",
213 "Struct_Array_Field_ScalarElement_VariableSize",
214 "Struct_Array_Field_ScalarElement_VariableSize",
215 "Struct_Array_Field_SizedElement_ConstantSize",
216 "Struct_Array_Field_SizedElement_UnknownSize",
217 "Struct_Array_Field_SizedElement_UnknownSize",
218 "Struct_Array_Field_SizedElement_VariableCount",
219 "Struct_Array_Field_SizedElement_VariableCount",
220 "Struct_Array_Field_SizedElement_VariableSize",
221 "Struct_Array_Field_SizedElement_VariableSize",
222 "Struct_Array_Field_UnsizedElement_ConstantSize",
223 "Struct_Array_Field_UnsizedElement_UnknownSize",
224 "Struct_Array_Field_UnsizedElement_UnknownSize",
225 "Struct_Array_Field_UnsizedElement_VariableCount",
226 "Struct_Array_Field_UnsizedElement_VariableCount",
227 "Struct_Array_Field_UnsizedElement_VariableSize",
228 "Struct_Array_Field_UnsizedElement_VariableSize",
229 "Struct_Enum_Field",
230 "Struct_FixedEnum_Field",
231 "Struct_FixedScalar_Field",
232 "Struct_Size_Field",
233 "Struct_Struct_Field",
234 ],
235 &module_name,
236 );
237 }
238