1 use crate::load_protos;
2 use crate::{Flag, FlagSource};
3 use crate::{FlagPermission, FlagValue, ValuePickedFrom};
4 use aconfigd_protos::{
5 ProtoFlagOverrideMessage, ProtoFlagOverrideType, ProtoFlagQueryReturnMessage,
6 ProtoListStorageMessage, ProtoListStorageMessageMsg, ProtoRemoveLocalOverrideMessage,
7 ProtoRemoveOverrideType, ProtoStorageRequestMessage, ProtoStorageRequestMessageMsg,
8 ProtoStorageRequestMessages, ProtoStorageReturnMessage, ProtoStorageReturnMessageMsg,
9 ProtoStorageReturnMessages,
10 };
11 use anyhow::anyhow;
12 use anyhow::Result;
13 use protobuf::Message;
14 use protobuf::SpecialFields;
15 use std::collections::HashMap;
16 use std::io::{Read, Write};
17 use std::net::Shutdown;
18 use std::os::unix::net::UnixStream;
19
20 pub struct AconfigStorageSource {}
21
22 static ACONFIGD_SYSTEM_SOCKET_NAME: &str = "/dev/socket/aconfigd_system";
23 static ACONFIGD_MAINLINE_SOCKET_NAME: &str = "/dev/socket/aconfigd_mainline";
24
25 enum AconfigdSocket {
26 System,
27 Mainline,
28 }
29
30 impl AconfigdSocket {
name(&self) -> &str31 pub fn name(&self) -> &str {
32 match self {
33 AconfigdSocket::System => ACONFIGD_SYSTEM_SOCKET_NAME,
34 AconfigdSocket::Mainline => ACONFIGD_MAINLINE_SOCKET_NAME,
35 }
36 }
37 }
38
convert(msg: ProtoFlagQueryReturnMessage, containers: &HashMap<String, String>) -> Result<Flag>39 fn convert(msg: ProtoFlagQueryReturnMessage, containers: &HashMap<String, String>) -> Result<Flag> {
40 let value = FlagValue::try_from(
41 msg.boot_flag_value
42 .clone()
43 .ok_or(anyhow!("no boot flag value for {:?}", msg.flag_name))?
44 .as_str(),
45 )?;
46
47 let value_picked_from = if msg.has_boot_local_override.unwrap_or(false) {
48 ValuePickedFrom::Local
49 } else if msg.boot_flag_value == msg.default_flag_value {
50 ValuePickedFrom::Default
51 } else {
52 ValuePickedFrom::Server
53 };
54
55 let staged_value = if msg.has_local_override.unwrap_or(false) {
56 // If a local override is staged, display it.
57 if msg.boot_flag_value == msg.local_flag_value {
58 None
59 } else {
60 Some(FlagValue::try_from(
61 msg.local_flag_value.ok_or(anyhow!("no local flag value"))?.as_str(),
62 )?)
63 }
64 } else {
65 // Otherwise, display if we're flipping to the default, or a server value.
66 let boot_value = msg.boot_flag_value.unwrap_or("".to_string());
67 let server_value = msg.server_flag_value.unwrap_or("".to_string());
68 let default_value = msg.default_flag_value.unwrap_or("".to_string());
69
70 if boot_value != server_value && server_value != *"" {
71 Some(FlagValue::try_from(server_value.as_str())?)
72 } else if msg.has_boot_local_override.unwrap_or(false) && boot_value != default_value {
73 Some(FlagValue::try_from(default_value.as_str())?)
74 } else {
75 None
76 }
77 };
78
79 let permission = match msg.is_readwrite {
80 Some(is_readwrite) => {
81 if is_readwrite {
82 FlagPermission::ReadWrite
83 } else {
84 FlagPermission::ReadOnly
85 }
86 }
87 None => return Err(anyhow!("missing permission")),
88 };
89
90 let name = msg.flag_name.ok_or(anyhow!("missing flag name"))?;
91 let package = msg.package_name.ok_or(anyhow!("missing package name"))?;
92 let qualified_name = format!("{package}.{name}");
93 Ok(Flag {
94 name,
95 package,
96 value,
97 permission,
98 value_picked_from,
99 staged_value,
100 container: containers
101 .get(&qualified_name)
102 .cloned()
103 .unwrap_or_else(|| "<no container>".to_string())
104 .to_string(),
105 // TODO: remove once DeviceConfig is not in the CLI.
106 namespace: "-".to_string(),
107 })
108 }
109
write_socket_messages( socket: AconfigdSocket, messages: ProtoStorageRequestMessages, ) -> Result<ProtoStorageReturnMessages>110 fn write_socket_messages(
111 socket: AconfigdSocket,
112 messages: ProtoStorageRequestMessages,
113 ) -> Result<ProtoStorageReturnMessages> {
114 let mut socket = UnixStream::connect(socket.name())?;
115
116 let message_buffer = messages.write_to_bytes()?;
117 let mut message_length_buffer: [u8; 4] = [0; 4];
118 let message_size = &message_buffer.len();
119 message_length_buffer[0] = (message_size >> 24) as u8;
120 message_length_buffer[1] = (message_size >> 16) as u8;
121 message_length_buffer[2] = (message_size >> 8) as u8;
122 message_length_buffer[3] = *message_size as u8;
123 socket.write_all(&message_length_buffer)?;
124 socket.write_all(&message_buffer)?;
125 socket.shutdown(Shutdown::Write)?;
126
127 let mut response_length_buffer: [u8; 4] = [0; 4];
128 socket.read_exact(&mut response_length_buffer)?;
129 let response_length = u32::from_be_bytes(response_length_buffer) as usize;
130 let mut response_buffer = vec![0; response_length];
131 socket.read_exact(&mut response_buffer)?;
132
133 let response: ProtoStorageReturnMessages =
134 protobuf::Message::parse_from_bytes(&response_buffer)?;
135
136 Ok(response)
137 }
138
send_list_flags_command(socket: AconfigdSocket) -> Result<Vec<ProtoFlagQueryReturnMessage>>139 fn send_list_flags_command(socket: AconfigdSocket) -> Result<Vec<ProtoFlagQueryReturnMessage>> {
140 let messages = ProtoStorageRequestMessages {
141 msgs: vec![ProtoStorageRequestMessage {
142 msg: Some(ProtoStorageRequestMessageMsg::ListStorageMessage(ProtoListStorageMessage {
143 msg: Some(ProtoListStorageMessageMsg::All(true)),
144 special_fields: SpecialFields::new(),
145 })),
146 special_fields: SpecialFields::new(),
147 }],
148 special_fields: SpecialFields::new(),
149 };
150
151 let response = write_socket_messages(socket, messages)?;
152 match response.msgs.as_slice() {
153 [ProtoStorageReturnMessage {
154 msg: Some(ProtoStorageReturnMessageMsg::ListStorageMessage(list_storage_message)),
155 ..
156 }] => Ok(list_storage_message.flags.clone()),
157 _ => Err(anyhow!("unexpected response from aconfigd")),
158 }
159 }
160
send_override_command( socket: AconfigdSocket, package_name: &str, flag_name: &str, value: &str, immediate: bool, ) -> Result<()>161 fn send_override_command(
162 socket: AconfigdSocket,
163 package_name: &str,
164 flag_name: &str,
165 value: &str,
166 immediate: bool,
167 ) -> Result<()> {
168 let override_type = if immediate {
169 ProtoFlagOverrideType::LOCAL_IMMEDIATE
170 } else {
171 ProtoFlagOverrideType::LOCAL_ON_REBOOT
172 };
173
174 let messages = ProtoStorageRequestMessages {
175 msgs: vec![ProtoStorageRequestMessage {
176 msg: Some(ProtoStorageRequestMessageMsg::FlagOverrideMessage(
177 ProtoFlagOverrideMessage {
178 package_name: Some(package_name.to_string()),
179 flag_name: Some(flag_name.to_string()),
180 flag_value: Some(value.to_string()),
181 override_type: Some(override_type.into()),
182 special_fields: SpecialFields::new(),
183 },
184 )),
185 special_fields: SpecialFields::new(),
186 }],
187 special_fields: SpecialFields::new(),
188 };
189
190 write_socket_messages(socket, messages)?;
191 Ok(())
192 }
193
194 impl FlagSource for AconfigStorageSource {
list_flags() -> Result<Vec<Flag>>195 fn list_flags() -> Result<Vec<Flag>> {
196 let flag_defaults = load_protos::load()?;
197 let system_messages = send_list_flags_command(AconfigdSocket::System);
198 let mainline_messages = send_list_flags_command(AconfigdSocket::Mainline);
199
200 let mut all_messages = vec![];
201 if let Ok(system_messages) = system_messages {
202 all_messages.extend_from_slice(&system_messages);
203 }
204 if let Ok(mainline_messages) = mainline_messages {
205 all_messages.extend_from_slice(&mainline_messages);
206 }
207
208 let container_map: HashMap<String, String> = flag_defaults
209 .clone()
210 .into_iter()
211 .map(|default| (default.qualified_name(), default.container))
212 .collect();
213 let socket_flags: Vec<Result<Flag>> = all_messages
214 .into_iter()
215 .map(|query_message| convert(query_message.clone(), &container_map))
216 .collect();
217 let socket_flags: Result<Vec<Flag>> = socket_flags.into_iter().collect();
218
219 // Load the defaults from the on-device protos.
220 // If the sockets are unavailable, just display the proto defaults.
221 let mut flags = flag_defaults.clone();
222 let name_to_socket_flag: HashMap<String, Flag> =
223 socket_flags?.into_iter().map(|p| (p.qualified_name(), p)).collect();
224 flags.iter_mut().for_each(|flag| {
225 if let Some(socket_flag) = name_to_socket_flag.get(&flag.qualified_name()) {
226 *flag = socket_flag.clone();
227 }
228 });
229
230 Ok(flags)
231 }
232
override_flag( _namespace: &str, qualified_name: &str, value: &str, immediate: bool, ) -> Result<()>233 fn override_flag(
234 _namespace: &str,
235 qualified_name: &str,
236 value: &str,
237 immediate: bool,
238 ) -> Result<()> {
239 let (package, flag_name) = if let Some(last_dot_index) = qualified_name.rfind('.') {
240 (&qualified_name[..last_dot_index], &qualified_name[last_dot_index + 1..])
241 } else {
242 return Err(anyhow!(format!("invalid flag name: {qualified_name}")));
243 };
244
245 let _ = send_override_command(AconfigdSocket::System, package, flag_name, value, immediate);
246 let _ =
247 send_override_command(AconfigdSocket::Mainline, package, flag_name, value, immediate);
248 Ok(())
249 }
250
unset_flag(_namespace: &str, qualified_name: &str, immediate: bool) -> Result<()>251 fn unset_flag(_namespace: &str, qualified_name: &str, immediate: bool) -> Result<()> {
252 let last_period_index = qualified_name.rfind('.').ok_or(anyhow!("No period found"))?;
253 let (package, flag_name) = qualified_name.split_at(last_period_index);
254
255 let removal_type = if immediate {
256 ProtoRemoveOverrideType::REMOVE_LOCAL_IMMEDIATE
257 } else {
258 ProtoRemoveOverrideType::REMOVE_LOCAL_ON_REBOOT
259 };
260
261 let socket_message = ProtoStorageRequestMessages {
262 msgs: vec![ProtoStorageRequestMessage {
263 msg: Some(ProtoStorageRequestMessageMsg::RemoveLocalOverrideMessage(
264 ProtoRemoveLocalOverrideMessage {
265 package_name: Some(package.to_string()),
266 flag_name: Some(flag_name[1..].to_string()),
267 remove_all: Some(false),
268 remove_override_type: Some(removal_type.into()),
269 special_fields: SpecialFields::new(),
270 },
271 )),
272 special_fields: SpecialFields::new(),
273 }],
274 special_fields: SpecialFields::new(),
275 };
276
277 let _ = write_socket_messages(AconfigdSocket::Mainline, socket_message.clone());
278 let _ = write_socket_messages(AconfigdSocket::System, socket_message);
279
280 Ok(())
281 }
282 }
283