• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::load_protos;
2 use crate::{Flag, FlagSource};
3 use crate::{FlagPermission, FlagValue, ValuePickedFrom};
4 use aconfigd_protos::{
5     ProtoFlagOverrideMessage, ProtoFlagOverrideType, ProtoFlagQueryReturnMessage,
6     ProtoListStorageMessage, ProtoListStorageMessageMsg, ProtoRemoveLocalOverrideMessage,
7     ProtoRemoveOverrideType, ProtoStorageRequestMessage, ProtoStorageRequestMessageMsg,
8     ProtoStorageRequestMessages, ProtoStorageReturnMessage, ProtoStorageReturnMessageMsg,
9     ProtoStorageReturnMessages,
10 };
11 use anyhow::anyhow;
12 use anyhow::Result;
13 use protobuf::Message;
14 use protobuf::SpecialFields;
15 use std::collections::HashMap;
16 use std::io::{Read, Write};
17 use std::net::Shutdown;
18 use std::os::unix::net::UnixStream;
19 
20 pub struct AconfigStorageSource {}
21 
22 static ACONFIGD_SYSTEM_SOCKET_NAME: &str = "/dev/socket/aconfigd_system";
23 static ACONFIGD_MAINLINE_SOCKET_NAME: &str = "/dev/socket/aconfigd_mainline";
24 
25 enum AconfigdSocket {
26     System,
27     Mainline,
28 }
29 
30 impl AconfigdSocket {
name(&self) -> &str31     pub fn name(&self) -> &str {
32         match self {
33             AconfigdSocket::System => ACONFIGD_SYSTEM_SOCKET_NAME,
34             AconfigdSocket::Mainline => ACONFIGD_MAINLINE_SOCKET_NAME,
35         }
36     }
37 }
38 
convert(msg: ProtoFlagQueryReturnMessage, containers: &HashMap<String, String>) -> Result<Flag>39 fn convert(msg: ProtoFlagQueryReturnMessage, containers: &HashMap<String, String>) -> Result<Flag> {
40     let value = FlagValue::try_from(
41         msg.boot_flag_value
42             .clone()
43             .ok_or(anyhow!("no boot flag value for {:?}", msg.flag_name))?
44             .as_str(),
45     )?;
46 
47     let value_picked_from = if msg.has_boot_local_override.unwrap_or(false) {
48         ValuePickedFrom::Local
49     } else if msg.boot_flag_value == msg.default_flag_value {
50         ValuePickedFrom::Default
51     } else {
52         ValuePickedFrom::Server
53     };
54 
55     let staged_value = if msg.has_local_override.unwrap_or(false) {
56         // If a local override is staged, display it.
57         if msg.boot_flag_value == msg.local_flag_value {
58             None
59         } else {
60             Some(FlagValue::try_from(
61                 msg.local_flag_value.ok_or(anyhow!("no local flag value"))?.as_str(),
62             )?)
63         }
64     } else {
65         // Otherwise, display if we're flipping to the default, or a server value.
66         let boot_value = msg.boot_flag_value.unwrap_or("".to_string());
67         let server_value = msg.server_flag_value.unwrap_or("".to_string());
68         let default_value = msg.default_flag_value.unwrap_or("".to_string());
69 
70         if boot_value != server_value && server_value != *"" {
71             Some(FlagValue::try_from(server_value.as_str())?)
72         } else if msg.has_boot_local_override.unwrap_or(false) && boot_value != default_value {
73             Some(FlagValue::try_from(default_value.as_str())?)
74         } else {
75             None
76         }
77     };
78 
79     let permission = match msg.is_readwrite {
80         Some(is_readwrite) => {
81             if is_readwrite {
82                 FlagPermission::ReadWrite
83             } else {
84                 FlagPermission::ReadOnly
85             }
86         }
87         None => return Err(anyhow!("missing permission")),
88     };
89 
90     let name = msg.flag_name.ok_or(anyhow!("missing flag name"))?;
91     let package = msg.package_name.ok_or(anyhow!("missing package name"))?;
92     let qualified_name = format!("{package}.{name}");
93     Ok(Flag {
94         name,
95         package,
96         value,
97         permission,
98         value_picked_from,
99         staged_value,
100         container: containers
101             .get(&qualified_name)
102             .cloned()
103             .unwrap_or_else(|| "<no container>".to_string())
104             .to_string(),
105         // TODO: remove once DeviceConfig is not in the CLI.
106         namespace: "-".to_string(),
107     })
108 }
109 
write_socket_messages( socket: AconfigdSocket, messages: ProtoStorageRequestMessages, ) -> Result<ProtoStorageReturnMessages>110 fn write_socket_messages(
111     socket: AconfigdSocket,
112     messages: ProtoStorageRequestMessages,
113 ) -> Result<ProtoStorageReturnMessages> {
114     let mut socket = UnixStream::connect(socket.name())?;
115 
116     let message_buffer = messages.write_to_bytes()?;
117     let mut message_length_buffer: [u8; 4] = [0; 4];
118     let message_size = &message_buffer.len();
119     message_length_buffer[0] = (message_size >> 24) as u8;
120     message_length_buffer[1] = (message_size >> 16) as u8;
121     message_length_buffer[2] = (message_size >> 8) as u8;
122     message_length_buffer[3] = *message_size as u8;
123     socket.write_all(&message_length_buffer)?;
124     socket.write_all(&message_buffer)?;
125     socket.shutdown(Shutdown::Write)?;
126 
127     let mut response_length_buffer: [u8; 4] = [0; 4];
128     socket.read_exact(&mut response_length_buffer)?;
129     let response_length = u32::from_be_bytes(response_length_buffer) as usize;
130     let mut response_buffer = vec![0; response_length];
131     socket.read_exact(&mut response_buffer)?;
132 
133     let response: ProtoStorageReturnMessages =
134         protobuf::Message::parse_from_bytes(&response_buffer)?;
135 
136     Ok(response)
137 }
138 
send_list_flags_command(socket: AconfigdSocket) -> Result<Vec<ProtoFlagQueryReturnMessage>>139 fn send_list_flags_command(socket: AconfigdSocket) -> Result<Vec<ProtoFlagQueryReturnMessage>> {
140     let messages = ProtoStorageRequestMessages {
141         msgs: vec![ProtoStorageRequestMessage {
142             msg: Some(ProtoStorageRequestMessageMsg::ListStorageMessage(ProtoListStorageMessage {
143                 msg: Some(ProtoListStorageMessageMsg::All(true)),
144                 special_fields: SpecialFields::new(),
145             })),
146             special_fields: SpecialFields::new(),
147         }],
148         special_fields: SpecialFields::new(),
149     };
150 
151     let response = write_socket_messages(socket, messages)?;
152     match response.msgs.as_slice() {
153         [ProtoStorageReturnMessage {
154             msg: Some(ProtoStorageReturnMessageMsg::ListStorageMessage(list_storage_message)),
155             ..
156         }] => Ok(list_storage_message.flags.clone()),
157         _ => Err(anyhow!("unexpected response from aconfigd")),
158     }
159 }
160 
send_override_command( socket: AconfigdSocket, package_name: &str, flag_name: &str, value: &str, immediate: bool, ) -> Result<()>161 fn send_override_command(
162     socket: AconfigdSocket,
163     package_name: &str,
164     flag_name: &str,
165     value: &str,
166     immediate: bool,
167 ) -> Result<()> {
168     let override_type = if immediate {
169         ProtoFlagOverrideType::LOCAL_IMMEDIATE
170     } else {
171         ProtoFlagOverrideType::LOCAL_ON_REBOOT
172     };
173 
174     let messages = ProtoStorageRequestMessages {
175         msgs: vec![ProtoStorageRequestMessage {
176             msg: Some(ProtoStorageRequestMessageMsg::FlagOverrideMessage(
177                 ProtoFlagOverrideMessage {
178                     package_name: Some(package_name.to_string()),
179                     flag_name: Some(flag_name.to_string()),
180                     flag_value: Some(value.to_string()),
181                     override_type: Some(override_type.into()),
182                     special_fields: SpecialFields::new(),
183                 },
184             )),
185             special_fields: SpecialFields::new(),
186         }],
187         special_fields: SpecialFields::new(),
188     };
189 
190     write_socket_messages(socket, messages)?;
191     Ok(())
192 }
193 
194 impl FlagSource for AconfigStorageSource {
list_flags() -> Result<Vec<Flag>>195     fn list_flags() -> Result<Vec<Flag>> {
196         let flag_defaults = load_protos::load()?;
197         let system_messages = send_list_flags_command(AconfigdSocket::System);
198         let mainline_messages = send_list_flags_command(AconfigdSocket::Mainline);
199 
200         let mut all_messages = vec![];
201         if let Ok(system_messages) = system_messages {
202             all_messages.extend_from_slice(&system_messages);
203         }
204         if let Ok(mainline_messages) = mainline_messages {
205             all_messages.extend_from_slice(&mainline_messages);
206         }
207 
208         let container_map: HashMap<String, String> = flag_defaults
209             .clone()
210             .into_iter()
211             .map(|default| (default.qualified_name(), default.container))
212             .collect();
213         let socket_flags: Vec<Result<Flag>> = all_messages
214             .into_iter()
215             .map(|query_message| convert(query_message.clone(), &container_map))
216             .collect();
217         let socket_flags: Result<Vec<Flag>> = socket_flags.into_iter().collect();
218 
219         // Load the defaults from the on-device protos.
220         // If the sockets are unavailable, just display the proto defaults.
221         let mut flags = flag_defaults.clone();
222         let name_to_socket_flag: HashMap<String, Flag> =
223             socket_flags?.into_iter().map(|p| (p.qualified_name(), p)).collect();
224         flags.iter_mut().for_each(|flag| {
225             if let Some(socket_flag) = name_to_socket_flag.get(&flag.qualified_name()) {
226                 *flag = socket_flag.clone();
227             }
228         });
229 
230         Ok(flags)
231     }
232 
override_flag( _namespace: &str, qualified_name: &str, value: &str, immediate: bool, ) -> Result<()>233     fn override_flag(
234         _namespace: &str,
235         qualified_name: &str,
236         value: &str,
237         immediate: bool,
238     ) -> Result<()> {
239         let (package, flag_name) = if let Some(last_dot_index) = qualified_name.rfind('.') {
240             (&qualified_name[..last_dot_index], &qualified_name[last_dot_index + 1..])
241         } else {
242             return Err(anyhow!(format!("invalid flag name: {qualified_name}")));
243         };
244 
245         let _ = send_override_command(AconfigdSocket::System, package, flag_name, value, immediate);
246         let _ =
247             send_override_command(AconfigdSocket::Mainline, package, flag_name, value, immediate);
248         Ok(())
249     }
250 
unset_flag(_namespace: &str, qualified_name: &str, immediate: bool) -> Result<()>251     fn unset_flag(_namespace: &str, qualified_name: &str, immediate: bool) -> Result<()> {
252         let last_period_index = qualified_name.rfind('.').ok_or(anyhow!("No period found"))?;
253         let (package, flag_name) = qualified_name.split_at(last_period_index);
254 
255         let removal_type = if immediate {
256             ProtoRemoveOverrideType::REMOVE_LOCAL_IMMEDIATE
257         } else {
258             ProtoRemoveOverrideType::REMOVE_LOCAL_ON_REBOOT
259         };
260 
261         let socket_message = ProtoStorageRequestMessages {
262             msgs: vec![ProtoStorageRequestMessage {
263                 msg: Some(ProtoStorageRequestMessageMsg::RemoveLocalOverrideMessage(
264                     ProtoRemoveLocalOverrideMessage {
265                         package_name: Some(package.to_string()),
266                         flag_name: Some(flag_name[1..].to_string()),
267                         remove_all: Some(false),
268                         remove_override_type: Some(removal_type.into()),
269                         special_fields: SpecialFields::new(),
270                     },
271                 )),
272                 special_fields: SpecialFields::new(),
273             }],
274             special_fields: SpecialFields::new(),
275         };
276 
277         let _ = write_socket_messages(AconfigdSocket::Mainline, socket_message.clone());
278         let _ = write_socket_messages(AconfigdSocket::System, socket_message);
279 
280         Ok(())
281     }
282 }
283