• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::{
2     fmt,
3     io::{self, BufRead, Write},
4 };
5 
6 use serde::{de::DeserializeOwned, Deserialize, Serialize};
7 
8 use crate::error::ExtractError;
9 
10 #[derive(Serialize, Deserialize, Debug, Clone)]
11 #[serde(untagged)]
12 pub enum Message {
13     Request(Request),
14     Response(Response),
15     Notification(Notification),
16 }
17 
18 impl From<Request> for Message {
from(request: Request) -> Message19     fn from(request: Request) -> Message {
20         Message::Request(request)
21     }
22 }
23 
24 impl From<Response> for Message {
from(response: Response) -> Message25     fn from(response: Response) -> Message {
26         Message::Response(response)
27     }
28 }
29 
30 impl From<Notification> for Message {
from(notification: Notification) -> Message31     fn from(notification: Notification) -> Message {
32         Message::Notification(notification)
33     }
34 }
35 
36 #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
37 #[serde(transparent)]
38 pub struct RequestId(IdRepr);
39 
40 #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
41 #[serde(untagged)]
42 enum IdRepr {
43     I32(i32),
44     String(String),
45 }
46 
47 impl From<i32> for RequestId {
from(id: i32) -> RequestId48     fn from(id: i32) -> RequestId {
49         RequestId(IdRepr::I32(id))
50     }
51 }
52 
53 impl From<String> for RequestId {
from(id: String) -> RequestId54     fn from(id: String) -> RequestId {
55         RequestId(IdRepr::String(id))
56     }
57 }
58 
59 impl fmt::Display for RequestId {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result60     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61         match &self.0 {
62             IdRepr::I32(it) => fmt::Display::fmt(it, f),
63             // Use debug here, to make it clear that `92` and `"92"` are
64             // different, and to reduce WTF factor if the sever uses `" "` as an
65             // ID.
66             IdRepr::String(it) => fmt::Debug::fmt(it, f),
67         }
68     }
69 }
70 
71 #[derive(Debug, Serialize, Deserialize, Clone)]
72 pub struct Request {
73     pub id: RequestId,
74     pub method: String,
75     #[serde(default = "serde_json::Value::default")]
76     #[serde(skip_serializing_if = "serde_json::Value::is_null")]
77     pub params: serde_json::Value,
78 }
79 
80 #[derive(Debug, Serialize, Deserialize, Clone)]
81 pub struct Response {
82     // JSON RPC allows this to be null if it was impossible
83     // to decode the request's id. Ignore this special case
84     // and just die horribly.
85     pub id: RequestId,
86     #[serde(skip_serializing_if = "Option::is_none")]
87     pub result: Option<serde_json::Value>,
88     #[serde(skip_serializing_if = "Option::is_none")]
89     pub error: Option<ResponseError>,
90 }
91 
92 #[derive(Debug, Serialize, Deserialize, Clone)]
93 pub struct ResponseError {
94     pub code: i32,
95     pub message: String,
96     #[serde(skip_serializing_if = "Option::is_none")]
97     pub data: Option<serde_json::Value>,
98 }
99 
100 #[derive(Clone, Copy, Debug)]
101 #[non_exhaustive]
102 pub enum ErrorCode {
103     // Defined by JSON RPC:
104     ParseError = -32700,
105     InvalidRequest = -32600,
106     MethodNotFound = -32601,
107     InvalidParams = -32602,
108     InternalError = -32603,
109     ServerErrorStart = -32099,
110     ServerErrorEnd = -32000,
111 
112     /// Error code indicating that a server received a notification or
113     /// request before the server has received the `initialize` request.
114     ServerNotInitialized = -32002,
115     UnknownErrorCode = -32001,
116 
117     // Defined by the protocol:
118     /// The client has canceled a request and a server has detected
119     /// the cancel.
120     RequestCanceled = -32800,
121 
122     /// The server detected that the content of a document got
123     /// modified outside normal conditions. A server should
124     /// NOT send this error code if it detects a content change
125     /// in it unprocessed messages. The result even computed
126     /// on an older state might still be useful for the client.
127     ///
128     /// If a client decides that a result is not of any use anymore
129     /// the client should cancel the request.
130     ContentModified = -32801,
131 
132     /// The server cancelled the request. This error code should
133     /// only be used for requests that explicitly support being
134     /// server cancellable.
135     ///
136     /// @since 3.17.0
137     ServerCancelled = -32802,
138 
139     /// A request failed but it was syntactically correct, e.g the
140     /// method name was known and the parameters were valid. The error
141     /// message should contain human readable information about why
142     /// the request failed.
143     ///
144     /// @since 3.17.0
145     RequestFailed = -32803,
146 }
147 
148 #[derive(Debug, Serialize, Deserialize, Clone)]
149 pub struct Notification {
150     pub method: String,
151     #[serde(default = "serde_json::Value::default")]
152     #[serde(skip_serializing_if = "serde_json::Value::is_null")]
153     pub params: serde_json::Value,
154 }
155 
156 impl Message {
read(r: &mut impl BufRead) -> io::Result<Option<Message>>157     pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> {
158         Message::_read(r)
159     }
_read(r: &mut dyn BufRead) -> io::Result<Option<Message>>160     fn _read(r: &mut dyn BufRead) -> io::Result<Option<Message>> {
161         let text = match read_msg_text(r)? {
162             None => return Ok(None),
163             Some(text) => text,
164         };
165         let msg = serde_json::from_str(&text)?;
166         Ok(Some(msg))
167     }
write(self, w: &mut impl Write) -> io::Result<()>168     pub fn write(self, w: &mut impl Write) -> io::Result<()> {
169         self._write(w)
170     }
_write(self, w: &mut dyn Write) -> io::Result<()>171     fn _write(self, w: &mut dyn Write) -> io::Result<()> {
172         #[derive(Serialize)]
173         struct JsonRpc {
174             jsonrpc: &'static str,
175             #[serde(flatten)]
176             msg: Message,
177         }
178         let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?;
179         write_msg_text(w, &text)
180     }
181 }
182 
183 impl Response {
new_ok<R: Serialize>(id: RequestId, result: R) -> Response184     pub fn new_ok<R: Serialize>(id: RequestId, result: R) -> Response {
185         Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None }
186     }
new_err(id: RequestId, code: i32, message: String) -> Response187     pub fn new_err(id: RequestId, code: i32, message: String) -> Response {
188         let error = ResponseError { code, message, data: None };
189         Response { id, result: None, error: Some(error) }
190     }
191 }
192 
193 impl Request {
new<P: Serialize>(id: RequestId, method: String, params: P) -> Request194     pub fn new<P: Serialize>(id: RequestId, method: String, params: P) -> Request {
195         Request { id, method, params: serde_json::to_value(params).unwrap() }
196     }
extract<P: DeserializeOwned>( self, method: &str, ) -> Result<(RequestId, P), ExtractError<Request>>197     pub fn extract<P: DeserializeOwned>(
198         self,
199         method: &str,
200     ) -> Result<(RequestId, P), ExtractError<Request>> {
201         if self.method != method {
202             return Err(ExtractError::MethodMismatch(self));
203         }
204         match serde_json::from_value(self.params) {
205             Ok(params) => Ok((self.id, params)),
206             Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
207         }
208     }
209 
is_shutdown(&self) -> bool210     pub(crate) fn is_shutdown(&self) -> bool {
211         self.method == "shutdown"
212     }
is_initialize(&self) -> bool213     pub(crate) fn is_initialize(&self) -> bool {
214         self.method == "initialize"
215     }
216 }
217 
218 impl Notification {
new(method: String, params: impl Serialize) -> Notification219     pub fn new(method: String, params: impl Serialize) -> Notification {
220         Notification { method, params: serde_json::to_value(params).unwrap() }
221     }
extract<P: DeserializeOwned>( self, method: &str, ) -> Result<P, ExtractError<Notification>>222     pub fn extract<P: DeserializeOwned>(
223         self,
224         method: &str,
225     ) -> Result<P, ExtractError<Notification>> {
226         if self.method != method {
227             return Err(ExtractError::MethodMismatch(self));
228         }
229         match serde_json::from_value(self.params) {
230             Ok(params) => Ok(params),
231             Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
232         }
233     }
is_exit(&self) -> bool234     pub(crate) fn is_exit(&self) -> bool {
235         self.method == "exit"
236     }
is_initialized(&self) -> bool237     pub(crate) fn is_initialized(&self) -> bool {
238         self.method == "initialized"
239     }
240 }
241 
read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>>242 fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> {
243     fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
244         io::Error::new(io::ErrorKind::InvalidData, error)
245     }
246     macro_rules! invalid_data {
247         ($($tt:tt)*) => (invalid_data(format!($($tt)*)))
248     }
249 
250     let mut size = None;
251     let mut buf = String::new();
252     loop {
253         buf.clear();
254         if inp.read_line(&mut buf)? == 0 {
255             return Ok(None);
256         }
257         if !buf.ends_with("\r\n") {
258             return Err(invalid_data!("malformed header: {:?}", buf));
259         }
260         let buf = &buf[..buf.len() - 2];
261         if buf.is_empty() {
262             break;
263         }
264         let mut parts = buf.splitn(2, ": ");
265         let header_name = parts.next().unwrap();
266         let header_value =
267             parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
268         if header_name == "Content-Length" {
269             size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
270         }
271     }
272     let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?;
273     let mut buf = buf.into_bytes();
274     buf.resize(size, 0);
275     inp.read_exact(&mut buf)?;
276     let buf = String::from_utf8(buf).map_err(invalid_data)?;
277     log::debug!("< {}", buf);
278     Ok(Some(buf))
279 }
280 
write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()>281 fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> {
282     log::debug!("> {}", msg);
283     write!(out, "Content-Length: {}\r\n\r\n", msg.len())?;
284     out.write_all(msg.as_bytes())?;
285     out.flush()?;
286     Ok(())
287 }
288 
289 #[cfg(test)]
290 mod tests {
291     use super::{Message, Notification, Request, RequestId};
292 
293     #[test]
shutdown_with_explicit_null()294     fn shutdown_with_explicit_null() {
295         let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\", \"params\": null }";
296         let msg: Message = serde_json::from_str(text).unwrap();
297 
298         assert!(
299             matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
300         );
301     }
302 
303     #[test]
shutdown_with_no_params()304     fn shutdown_with_no_params() {
305         let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\"}";
306         let msg: Message = serde_json::from_str(text).unwrap();
307 
308         assert!(
309             matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
310         );
311     }
312 
313     #[test]
notification_with_explicit_null()314     fn notification_with_explicit_null() {
315         let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\", \"params\": null }";
316         let msg: Message = serde_json::from_str(text).unwrap();
317 
318         assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
319     }
320 
321     #[test]
notification_with_no_params()322     fn notification_with_no_params() {
323         let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\"}";
324         let msg: Message = serde_json::from_str(text).unwrap();
325 
326         assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
327     }
328 
329     #[test]
serialize_request_with_null_params()330     fn serialize_request_with_null_params() {
331         let msg = Message::Request(Request {
332             id: RequestId::from(3),
333             method: "shutdown".into(),
334             params: serde_json::Value::Null,
335         });
336         let serialized = serde_json::to_string(&msg).unwrap();
337 
338         assert_eq!("{\"id\":3,\"method\":\"shutdown\"}", serialized);
339     }
340 
341     #[test]
serialize_notification_with_null_params()342     fn serialize_notification_with_null_params() {
343         let msg = Message::Notification(Notification {
344             method: "exit".into(),
345             params: serde_json::Value::Null,
346         });
347         let serialized = serde_json::to_string(&msg).unwrap();
348 
349         assert_eq!("{\"method\":\"exit\"}", serialized);
350     }
351 }
352