• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! See [RequestDispatcher].
2 use std::{fmt, panic, thread};
3 
4 use ide::Cancelled;
5 use lsp_server::ExtractError;
6 use serde::{de::DeserializeOwned, Serialize};
7 use stdx::thread::ThreadIntent;
8 
9 use crate::{
10     global_state::{GlobalState, GlobalStateSnapshot},
11     main_loop::Task,
12     version::version,
13     LspError, Result,
14 };
15 
16 /// A visitor for routing a raw JSON request to an appropriate handler function.
17 ///
18 /// Most requests are read-only and async and are handled on the threadpool
19 /// (`on` method).
20 ///
21 /// Some read-only requests are latency sensitive, and are immediately handled
22 /// on the main loop thread (`on_sync`). These are typically typing-related
23 /// requests.
24 ///
25 /// Some requests modify the state, and are run on the main thread to get
26 /// `&mut` (`on_sync_mut`).
27 ///
28 /// Read-only requests are wrapped into `catch_unwind` -- they don't modify the
29 /// state, so it's OK to recover from their failures.
30 pub(crate) struct RequestDispatcher<'a> {
31     pub(crate) req: Option<lsp_server::Request>,
32     pub(crate) global_state: &'a mut GlobalState,
33 }
34 
35 impl<'a> RequestDispatcher<'a> {
36     /// Dispatches the request onto the current thread, given full access to
37     /// mutable global state. Unlike all other methods here, this one isn't
38     /// guarded by `catch_unwind`, so, please, don't make bugs :-)
on_sync_mut<R>( &mut self, f: fn(&mut GlobalState, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request, R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug, R::Result: Serialize,39     pub(crate) fn on_sync_mut<R>(
40         &mut self,
41         f: fn(&mut GlobalState, R::Params) -> Result<R::Result>,
42     ) -> &mut Self
43     where
44         R: lsp_types::request::Request,
45         R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug,
46         R::Result: Serialize,
47     {
48         let (req, params, panic_context) = match self.parse::<R>() {
49             Some(it) => it,
50             None => return self,
51         };
52         let result = {
53             let _pctx = stdx::panic_context::enter(panic_context);
54             f(self.global_state, params)
55         };
56         if let Ok(response) = result_to_response::<R>(req.id, result) {
57             self.global_state.respond(response);
58         }
59 
60         self
61     }
62 
63     /// Dispatches the request onto the current thread.
on_sync<R>( &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request, R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug, R::Result: Serialize,64     pub(crate) fn on_sync<R>(
65         &mut self,
66         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
67     ) -> &mut Self
68     where
69         R: lsp_types::request::Request,
70         R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug,
71         R::Result: Serialize,
72     {
73         let (req, params, panic_context) = match self.parse::<R>() {
74             Some(it) => it,
75             None => return self,
76         };
77         let global_state_snapshot = self.global_state.snapshot();
78 
79         let result = panic::catch_unwind(move || {
80             let _pctx = stdx::panic_context::enter(panic_context);
81             f(global_state_snapshot, params)
82         });
83 
84         if let Ok(response) = thread_result_to_response::<R>(req.id, result) {
85             self.global_state.respond(response);
86         }
87 
88         self
89     }
90 
91     /// Dispatches a non-latency-sensitive request onto the thread pool
92     /// without retrying it if it panics.
on_no_retry<R>( &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, R::Result: Serialize,93     pub(crate) fn on_no_retry<R>(
94         &mut self,
95         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
96     ) -> &mut Self
97     where
98         R: lsp_types::request::Request + 'static,
99         R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
100         R::Result: Serialize,
101     {
102         let (req, params, panic_context) = match self.parse::<R>() {
103             Some(it) => it,
104             None => return self,
105         };
106 
107         self.global_state.task_pool.handle.spawn(ThreadIntent::Worker, {
108             let world = self.global_state.snapshot();
109             move || {
110                 let result = panic::catch_unwind(move || {
111                     let _pctx = stdx::panic_context::enter(panic_context);
112                     f(world, params)
113                 });
114                 match thread_result_to_response::<R>(req.id.clone(), result) {
115                     Ok(response) => Task::Response(response),
116                     Err(_) => Task::Response(lsp_server::Response::new_err(
117                         req.id,
118                         lsp_server::ErrorCode::ContentModified as i32,
119                         "content modified".to_string(),
120                     )),
121                 }
122             }
123         });
124 
125         self
126     }
127 
128     /// Dispatches a non-latency-sensitive request onto the thread pool.
on<R>( &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, R::Result: Serialize,129     pub(crate) fn on<R>(
130         &mut self,
131         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
132     ) -> &mut Self
133     where
134         R: lsp_types::request::Request + 'static,
135         R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
136         R::Result: Serialize,
137     {
138         self.on_with_thread_intent::<true, R>(ThreadIntent::Worker, f)
139     }
140 
141     /// Dispatches a latency-sensitive request onto the thread pool.
on_latency_sensitive<R>( &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, R::Result: Serialize,142     pub(crate) fn on_latency_sensitive<R>(
143         &mut self,
144         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
145     ) -> &mut Self
146     where
147         R: lsp_types::request::Request + 'static,
148         R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
149         R::Result: Serialize,
150     {
151         self.on_with_thread_intent::<true, R>(ThreadIntent::LatencySensitive, f)
152     }
153 
154     /// Formatting requests should never block on waiting a for task thread to open up, editors will wait
155     /// on the response and a late formatting update might mess with the document and user.
156     /// We can't run this on the main thread though as we invoke rustfmt which may take arbitrary time to complete!
on_fmt_thread<R>( &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, R::Result: Serialize,157     pub(crate) fn on_fmt_thread<R>(
158         &mut self,
159         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
160     ) -> &mut Self
161     where
162         R: lsp_types::request::Request + 'static,
163         R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
164         R::Result: Serialize,
165     {
166         self.on_with_thread_intent::<false, R>(ThreadIntent::LatencySensitive, f)
167     }
168 
finish(&mut self)169     pub(crate) fn finish(&mut self) {
170         if let Some(req) = self.req.take() {
171             tracing::error!("unknown request: {:?}", req);
172             let response = lsp_server::Response::new_err(
173                 req.id,
174                 lsp_server::ErrorCode::MethodNotFound as i32,
175                 "unknown request".to_string(),
176             );
177             self.global_state.respond(response);
178         }
179     }
180 
on_with_thread_intent<const MAIN_POOL: bool, R>( &mut self, intent: ThreadIntent, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, R::Result: Serialize,181     fn on_with_thread_intent<const MAIN_POOL: bool, R>(
182         &mut self,
183         intent: ThreadIntent,
184         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
185     ) -> &mut Self
186     where
187         R: lsp_types::request::Request + 'static,
188         R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
189         R::Result: Serialize,
190     {
191         let (req, params, panic_context) = match self.parse::<R>() {
192             Some(it) => it,
193             None => return self,
194         };
195 
196         let world = self.global_state.snapshot();
197         if MAIN_POOL {
198             &mut self.global_state.task_pool.handle
199         } else {
200             &mut self.global_state.fmt_pool.handle
201         }
202         .spawn(intent, move || {
203             let result = panic::catch_unwind(move || {
204                 let _pctx = stdx::panic_context::enter(panic_context);
205                 f(world, params)
206             });
207             match thread_result_to_response::<R>(req.id.clone(), result) {
208                 Ok(response) => Task::Response(response),
209                 Err(_) => Task::Retry(req),
210             }
211         });
212 
213         self
214     }
215 
parse<R>(&mut self) -> Option<(lsp_server::Request, R::Params, String)> where R: lsp_types::request::Request, R::Params: DeserializeOwned + fmt::Debug,216     fn parse<R>(&mut self) -> Option<(lsp_server::Request, R::Params, String)>
217     where
218         R: lsp_types::request::Request,
219         R::Params: DeserializeOwned + fmt::Debug,
220     {
221         let req = match &self.req {
222             Some(req) if req.method == R::METHOD => self.req.take()?,
223             _ => return None,
224         };
225 
226         let res = crate::from_json(R::METHOD, &req.params);
227         match res {
228             Ok(params) => {
229                 let panic_context =
230                     format!("\nversion: {}\nrequest: {} {params:#?}", version(), R::METHOD);
231                 Some((req, params, panic_context))
232             }
233             Err(err) => {
234                 let response = lsp_server::Response::new_err(
235                     req.id,
236                     lsp_server::ErrorCode::InvalidParams as i32,
237                     err.to_string(),
238                 );
239                 self.global_state.respond(response);
240                 None
241             }
242         }
243     }
244 }
245 
thread_result_to_response<R>( id: lsp_server::RequestId, result: thread::Result<Result<R::Result>>, ) -> Result<lsp_server::Response, Cancelled> where R: lsp_types::request::Request, R::Params: DeserializeOwned, R::Result: Serialize,246 fn thread_result_to_response<R>(
247     id: lsp_server::RequestId,
248     result: thread::Result<Result<R::Result>>,
249 ) -> Result<lsp_server::Response, Cancelled>
250 where
251     R: lsp_types::request::Request,
252     R::Params: DeserializeOwned,
253     R::Result: Serialize,
254 {
255     match result {
256         Ok(result) => result_to_response::<R>(id, result),
257         Err(panic) => {
258             let panic_message = panic
259                 .downcast_ref::<String>()
260                 .map(String::as_str)
261                 .or_else(|| panic.downcast_ref::<&str>().copied());
262 
263             let mut message = "request handler panicked".to_string();
264             if let Some(panic_message) = panic_message {
265                 message.push_str(": ");
266                 message.push_str(panic_message)
267             };
268 
269             Ok(lsp_server::Response::new_err(
270                 id,
271                 lsp_server::ErrorCode::InternalError as i32,
272                 message,
273             ))
274         }
275     }
276 }
277 
result_to_response<R>( id: lsp_server::RequestId, result: Result<R::Result>, ) -> Result<lsp_server::Response, Cancelled> where R: lsp_types::request::Request, R::Params: DeserializeOwned, R::Result: Serialize,278 fn result_to_response<R>(
279     id: lsp_server::RequestId,
280     result: Result<R::Result>,
281 ) -> Result<lsp_server::Response, Cancelled>
282 where
283     R: lsp_types::request::Request,
284     R::Params: DeserializeOwned,
285     R::Result: Serialize,
286 {
287     let res = match result {
288         Ok(resp) => lsp_server::Response::new_ok(id, &resp),
289         Err(e) => match e.downcast::<LspError>() {
290             Ok(lsp_error) => lsp_server::Response::new_err(id, lsp_error.code, lsp_error.message),
291             Err(e) => match e.downcast::<Cancelled>() {
292                 Ok(cancelled) => return Err(*cancelled),
293                 Err(e) => lsp_server::Response::new_err(
294                     id,
295                     lsp_server::ErrorCode::InternalError as i32,
296                     e.to_string(),
297                 ),
298             },
299         },
300     };
301     Ok(res)
302 }
303 
304 pub(crate) struct NotificationDispatcher<'a> {
305     pub(crate) not: Option<lsp_server::Notification>,
306     pub(crate) global_state: &'a mut GlobalState,
307 }
308 
309 impl<'a> NotificationDispatcher<'a> {
on_sync_mut<N>( &mut self, f: fn(&mut GlobalState, N::Params) -> Result<()>, ) -> Result<&mut Self> where N: lsp_types::notification::Notification, N::Params: DeserializeOwned + Send,310     pub(crate) fn on_sync_mut<N>(
311         &mut self,
312         f: fn(&mut GlobalState, N::Params) -> Result<()>,
313     ) -> Result<&mut Self>
314     where
315         N: lsp_types::notification::Notification,
316         N::Params: DeserializeOwned + Send,
317     {
318         let not = match self.not.take() {
319             Some(it) => it,
320             None => return Ok(self),
321         };
322         let params = match not.extract::<N::Params>(N::METHOD) {
323             Ok(it) => it,
324             Err(ExtractError::JsonError { method, error }) => {
325                 panic!("Invalid request\nMethod: {method}\n error: {error}",)
326             }
327             Err(ExtractError::MethodMismatch(not)) => {
328                 self.not = Some(not);
329                 return Ok(self);
330             }
331         };
332         let _pctx = stdx::panic_context::enter(format!(
333             "\nversion: {}\nnotification: {}",
334             version(),
335             N::METHOD
336         ));
337         f(self.global_state, params)?;
338         Ok(self)
339     }
340 
finish(&mut self)341     pub(crate) fn finish(&mut self) {
342         if let Some(not) = &self.not {
343             if !not.method.starts_with("$/") {
344                 tracing::error!("unhandled notification: {:?}", not);
345             }
346         }
347     }
348 }
349