1 //! Oneof-related codegen functions. 2 3 use std::collections::HashSet; 4 5 use code_writer::CodeWriter; 6 use field::FieldElem; 7 use field::FieldGen; 8 use message::MessageGen; 9 use protobuf::descriptor::FieldDescriptorProto_Type; 10 use protobuf_name::ProtobufAbsolutePath; 11 use rust_name::RustIdent; 12 use rust_types_values::RustType; 13 use scope::FieldWithContext; 14 use scope::OneofVariantWithContext; 15 use scope::OneofWithContext; 16 use scope::RootScope; 17 use scope::WithScope; 18 use serde; 19 use Customize; 20 21 // oneof one { ... } 22 #[derive(Clone)] 23 pub(crate) struct OneofField<'a> { 24 pub elem: FieldElem<'a>, 25 pub oneof_rust_field_name: RustIdent, 26 pub oneof_type_name: RustType, 27 pub boxed: bool, 28 } 29 30 impl<'a> OneofField<'a> { 31 // Detecting recursion: if oneof fields contains a self-reference 32 // or another message which has a reference to self, 33 // put oneof variant into a box. need_boxed(field: &FieldWithContext, root_scope: &RootScope, owner_name: &str) -> bool34 fn need_boxed(field: &FieldWithContext, root_scope: &RootScope, owner_name: &str) -> bool { 35 let mut visited_messages = HashSet::new(); 36 let mut fields = vec![field.clone()]; 37 while let Some(field) = fields.pop() { 38 if field.field.get_field_type() == FieldDescriptorProto_Type::TYPE_MESSAGE { 39 let message_name = ProtobufAbsolutePath::from(field.field.get_type_name()); 40 if !visited_messages.insert(message_name.clone()) { 41 continue; 42 } 43 if message_name.path == owner_name { 44 return true; 45 } 46 let message = root_scope.find_message(&message_name); 47 fields.extend(message.fields().into_iter().filter(|f| f.is_oneof())); 48 } 49 } 50 false 51 } 52 parse( oneof: &OneofWithContext<'a>, field: &FieldWithContext<'a>, elem: FieldElem<'a>, root_scope: &RootScope, ) -> OneofField<'a>53 pub fn parse( 54 oneof: &OneofWithContext<'a>, 55 field: &FieldWithContext<'a>, 56 elem: FieldElem<'a>, 57 root_scope: &RootScope, 58 ) -> OneofField<'a> { 59 let boxed = OneofField::need_boxed(field, root_scope, &oneof.message.name_absolute().path); 60 61 OneofField { 62 elem, 63 boxed, 64 oneof_rust_field_name: oneof.field_name().into(), 65 oneof_type_name: RustType::Oneof(oneof.rust_name().to_string()), 66 } 67 } 68 rust_type(&self) -> RustType69 pub fn rust_type(&self) -> RustType { 70 let t = self.elem.rust_storage_type(); 71 72 if self.boxed { 73 RustType::Uniq(Box::new(t)) 74 } else { 75 t 76 } 77 } 78 } 79 80 #[derive(Clone)] 81 pub(crate) struct OneofVariantGen<'a> { 82 _oneof: &'a OneofGen<'a>, 83 _variant: OneofVariantWithContext<'a>, 84 oneof_field: OneofField<'a>, 85 pub field: FieldGen<'a>, 86 path: String, 87 _customize: Customize, 88 } 89 90 impl<'a> OneofVariantGen<'a> { parse( oneof: &'a OneofGen<'a>, variant: OneofVariantWithContext<'a>, field: &'a FieldGen, _root_scope: &RootScope, customize: Customize, ) -> OneofVariantGen<'a>91 fn parse( 92 oneof: &'a OneofGen<'a>, 93 variant: OneofVariantWithContext<'a>, 94 field: &'a FieldGen, 95 _root_scope: &RootScope, 96 customize: Customize, 97 ) -> OneofVariantGen<'a> { 98 OneofVariantGen { 99 _oneof: oneof, 100 _variant: variant.clone(), 101 field: field.clone(), 102 path: format!( 103 "{}::{}", 104 oneof.type_name.to_code(&field.customize), 105 field.rust_name 106 ), 107 oneof_field: OneofField::parse( 108 variant.oneof, 109 &field.proto_field, 110 field.oneof().elem.clone(), 111 oneof.message.root_scope, 112 ), 113 _customize: customize, 114 } 115 } 116 rust_type(&self) -> RustType117 fn rust_type(&self) -> RustType { 118 self.oneof_field.rust_type() 119 } 120 path(&self) -> String121 pub fn path(&self) -> String { 122 self.path.clone() 123 } 124 } 125 126 #[derive(Clone)] 127 pub(crate) struct OneofGen<'a> { 128 // Message containing this oneof 129 message: &'a MessageGen<'a>, 130 pub oneof: OneofWithContext<'a>, 131 type_name: RustType, 132 customize: Customize, 133 } 134 135 impl<'a> OneofGen<'a> { parse( message: &'a MessageGen, oneof: OneofWithContext<'a>, customize: &Customize, ) -> OneofGen<'a>136 pub fn parse( 137 message: &'a MessageGen, 138 oneof: OneofWithContext<'a>, 139 customize: &Customize, 140 ) -> OneofGen<'a> { 141 let rust_name = oneof.rust_name(); 142 OneofGen { 143 message, 144 oneof, 145 type_name: RustType::Oneof(rust_name.to_string()), 146 customize: customize.clone(), 147 } 148 } 149 variants_except_group(&'a self) -> Vec<OneofVariantGen<'a>>150 pub fn variants_except_group(&'a self) -> Vec<OneofVariantGen<'a>> { 151 self.oneof 152 .variants() 153 .into_iter() 154 .filter_map(|v| { 155 let field = self 156 .message 157 .fields 158 .iter() 159 .filter(|f| f.proto_field.name() == v.field.get_name()) 160 .next() 161 .expect(&format!("field not found by name: {}", v.field.get_name())); 162 match field.proto_type { 163 FieldDescriptorProto_Type::TYPE_GROUP => None, 164 _ => Some(OneofVariantGen::parse( 165 self, 166 v, 167 field, 168 self.message.root_scope, 169 self.customize.clone(), 170 )), 171 } 172 }) 173 .collect() 174 } 175 full_storage_type(&self) -> RustType176 pub fn full_storage_type(&self) -> RustType { 177 RustType::Option(Box::new(self.type_name.clone())) 178 } 179 write_enum(&self, w: &mut CodeWriter)180 pub fn write_enum(&self, w: &mut CodeWriter) { 181 let derive = vec!["Clone", "PartialEq", "Debug"]; 182 w.derive(&derive); 183 serde::write_serde_attr( 184 w, 185 &self.customize, 186 "derive(::serde::Serialize, ::serde::Deserialize)", 187 ); 188 w.pub_enum(&self.type_name.to_code(&self.customize), |w| { 189 for variant in self.variants_except_group() { 190 w.write_line(&format!( 191 "{}({}),", 192 variant.field.rust_name, 193 &variant.rust_type().to_code(&self.customize) 194 )); 195 } 196 }); 197 } 198 } 199