1// Copyright 2018 The Go Authors. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec. 16// https://www.jsonrpc.org/specification 17// It is intended to be compatible with other implementations at the wire level. 18package jsonrpc2 19 20import ( 21 "context" 22 "encoding/json" 23 "fmt" 24 "sync" 25 "sync/atomic" 26) 27 28// Conn is a JSON RPC 2 client server connection. 29// Conn is bidirectional; it does not have a designated server or client end. 30type Conn struct { 31 seq int64 // must only be accessed using atomic operations 32 handlers []Handler 33 stream Stream 34 err error 35 pendingMu sync.Mutex // protects the pending map 36 pending map[ID]chan *WireResponse 37 handlingMu sync.Mutex // protects the handling map 38 handling map[ID]*Request 39} 40 41type requestState int 42 43const ( 44 requestWaiting = requestState(iota) 45 requestSerial 46 requestParallel 47 requestReplied 48 requestDone 49) 50 51// Request is sent to a server to represent a Call or Notify operaton. 52type Request struct { 53 conn *Conn 54 cancel context.CancelFunc 55 state requestState 56 nextRequest chan struct{} 57 58 // The Wire values of the request. 59 WireRequest 60} 61 62// NewErrorf builds a Error struct for the supplied message and code. 63// If args is not empty, message and args will be passed to Sprintf. 64func NewErrorf(code int64, format string, args ...interface{}) *Error { 65 return &Error{ 66 Code: code, 67 Message: fmt.Sprintf(format, args...), 68 } 69} 70 71// NewConn creates a new connection object around the supplied stream. 72// You must call Run for the connection to be active. 73func NewConn(s Stream) *Conn { 74 conn := &Conn{ 75 handlers: []Handler{defaultHandler{}}, 76 stream: s, 77 pending: make(map[ID]chan *WireResponse), 78 handling: make(map[ID]*Request), 79 } 80 return conn 81} 82 83// AddHandler adds a new handler to the set the connection will invoke. 84// Handlers are invoked in the reverse order of how they were added, this 85// allows the most recent addition to be the first one to attempt to handle a 86// message. 87func (c *Conn) AddHandler(handler Handler) { 88 // prepend the new handlers so we use them first 89 c.handlers = append([]Handler{handler}, c.handlers...) 90} 91 92// Cancel cancels a pending Call on the server side. 93// The call is identified by its id. 94// JSON RPC 2 does not specify a cancel message, so cancellation support is not 95// directly wired in. This method allows a higher level protocol to choose how 96// to propagate the cancel. 97func (c *Conn) Cancel(id ID) { 98 c.handlingMu.Lock() 99 handling, found := c.handling[id] 100 c.handlingMu.Unlock() 101 if found { 102 handling.cancel() 103 } 104} 105 106// Notify is called to send a notification request over the connection. 107// It will return as soon as the notification has been sent, as no response is 108// possible. 109func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (err error) { 110 jsonParams, err := marshalToRaw(params) 111 if err != nil { 112 return fmt.Errorf("marshalling notify parameters: %v", err) 113 } 114 request := &WireRequest{ 115 Method: method, 116 Params: jsonParams, 117 } 118 data, err := json.Marshal(request) 119 if err != nil { 120 return fmt.Errorf("marshalling notify request: %v", err) 121 } 122 for _, h := range c.handlers { 123 ctx = h.Request(ctx, c, Send, request) 124 } 125 defer func() { 126 for _, h := range c.handlers { 127 h.Done(ctx, err) 128 } 129 }() 130 n, err := c.stream.Write(ctx, data) 131 for _, h := range c.handlers { 132 ctx = h.Wrote(ctx, n) 133 } 134 return err 135} 136 137// Call sends a request over the connection and then waits for a response. 138// If the response is not an error, it will be decoded into result. 139// result must be of a type you an pass to json.Unmarshal. 140func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) (err error) { 141 // generate a new request identifier 142 id := ID{Number: atomic.AddInt64(&c.seq, 1)} 143 jsonParams, err := marshalToRaw(params) 144 if err != nil { 145 return fmt.Errorf("marshalling call parameters: %v", err) 146 } 147 request := &WireRequest{ 148 ID: &id, 149 Method: method, 150 Params: jsonParams, 151 } 152 // marshal the request now it is complete 153 data, err := json.Marshal(request) 154 if err != nil { 155 return fmt.Errorf("marshalling call request: %v", err) 156 } 157 for _, h := range c.handlers { 158 ctx = h.Request(ctx, c, Send, request) 159 } 160 // we have to add ourselves to the pending map before we send, otherwise we 161 // are racing the response 162 rchan := make(chan *WireResponse) 163 c.pendingMu.Lock() 164 c.pending[id] = rchan 165 c.pendingMu.Unlock() 166 defer func() { 167 // clean up the pending response handler on the way out 168 c.pendingMu.Lock() 169 delete(c.pending, id) 170 c.pendingMu.Unlock() 171 for _, h := range c.handlers { 172 h.Done(ctx, err) 173 } 174 }() 175 // now we are ready to send 176 n, err := c.stream.Write(ctx, data) 177 for _, h := range c.handlers { 178 ctx = h.Wrote(ctx, n) 179 } 180 if err != nil { 181 // sending failed, we will never get a response, so don't leave it pending 182 return err 183 } 184 // now wait for the response 185 select { 186 case response := <-rchan: 187 for _, h := range c.handlers { 188 ctx = h.Response(ctx, c, Receive, response) 189 } 190 // is it an error response? 191 if response.Error != nil { 192 return response.Error 193 } 194 if result == nil || response.Result == nil { 195 return nil 196 } 197 if err := json.Unmarshal(*response.Result, result); err != nil { 198 return fmt.Errorf("unmarshalling result: %v", err) 199 } 200 return nil 201 case <-ctx.Done(): 202 // allow the handler to propagate the cancel 203 cancelled := false 204 for _, h := range c.handlers { 205 if h.Cancel(ctx, c, id, cancelled) { 206 cancelled = true 207 } 208 } 209 return ctx.Err() 210 } 211} 212 213// Conn returns the connection that created this request. 214func (r *Request) Conn() *Conn { return r.conn } 215 216// IsNotify returns true if this request is a notification. 217func (r *Request) IsNotify() bool { 218 return r.ID == nil 219} 220 221// Parallel indicates that the system is now allowed to process other requests 222// in parallel with this one. 223// It is safe to call any number of times, but must only be called from the 224// request handling go routine. 225// It is implied by both reply and by the handler returning. 226func (r *Request) Parallel() { 227 if r.state >= requestParallel { 228 return 229 } 230 r.state = requestParallel 231 close(r.nextRequest) 232} 233 234// Reply sends a reply to the given request. 235// It is an error to call this if request was not a call. 236// You must call this exactly once for any given request. 237// It should only be called from the handler go routine. 238// If err is set then result will be ignored. 239// If the request has not yet dropped into parallel mode 240// it will be before this function returns. 241func (r *Request) Reply(ctx context.Context, result interface{}, err error) error { 242 if r.state >= requestReplied { 243 return fmt.Errorf("reply invoked more than once") 244 } 245 if r.IsNotify() { 246 return fmt.Errorf("reply not invoked with a valid call") 247 } 248 // reply ends the handling phase of a call, so if we are not yet 249 // parallel we should be now. The go routine is allowed to continue 250 // to do work after replying, which is why it is important to unlock 251 // the rpc system at this point. 252 r.Parallel() 253 r.state = requestReplied 254 255 var raw *json.RawMessage 256 if err == nil { 257 raw, err = marshalToRaw(result) 258 } 259 response := &WireResponse{ 260 Result: raw, 261 ID: r.ID, 262 } 263 if err != nil { 264 if callErr, ok := err.(*Error); ok { 265 response.Error = callErr 266 } else { 267 response.Error = NewErrorf(0, "%s", err) 268 } 269 } 270 data, err := json.Marshal(response) 271 if err != nil { 272 return err 273 } 274 for _, h := range r.conn.handlers { 275 ctx = h.Response(ctx, r.conn, Send, response) 276 } 277 n, err := r.conn.stream.Write(ctx, data) 278 for _, h := range r.conn.handlers { 279 ctx = h.Wrote(ctx, n) 280 } 281 282 if err != nil { 283 // TODO(iancottrell): if a stream write fails, we really need to shut down 284 // the whole stream 285 return err 286 } 287 return nil 288} 289 290func (c *Conn) setHandling(r *Request, active bool) { 291 if r.ID == nil { 292 return 293 } 294 r.conn.handlingMu.Lock() 295 defer r.conn.handlingMu.Unlock() 296 if active { 297 r.conn.handling[*r.ID] = r 298 } else { 299 delete(r.conn.handling, *r.ID) 300 } 301} 302 303// combined has all the fields of both Request and Response. 304// We can decode this and then work out which it is. 305type combined struct { 306 VersionTag VersionTag `json:"jsonrpc"` 307 ID *ID `json:"id,omitempty"` 308 Method string `json:"method"` 309 Params *json.RawMessage `json:"params,omitempty"` 310 Result *json.RawMessage `json:"result,omitempty"` 311 Error *Error `json:"error,omitempty"` 312} 313 314// Run blocks until the connection is terminated, and returns any error that 315// caused the termination. 316// It must be called exactly once for each Conn. 317// It returns only when the reader is closed or there is an error in the stream. 318func (c *Conn) Run(runCtx context.Context) error { 319 // we need to make the next request "lock" in an unlocked state to allow 320 // the first incoming request to proceed. All later requests are unlocked 321 // by the preceding request going to parallel mode. 322 nextRequest := make(chan struct{}) 323 close(nextRequest) 324 for { 325 // get the data for a message 326 data, n, err := c.stream.Read(runCtx) 327 if err != nil { 328 // the stream failed, we cannot continue 329 return err 330 } 331 // read a combined message 332 msg := &combined{} 333 if err := json.Unmarshal(data, msg); err != nil { 334 // a badly formed message arrived, log it and continue 335 // we trust the stream to have isolated the error to just this message 336 for _, h := range c.handlers { 337 h.Error(runCtx, fmt.Errorf("unmarshal failed: %v", err)) 338 } 339 continue 340 } 341 // work out which kind of message we have 342 switch { 343 case msg.Method != "": 344 // if method is set it must be a request 345 reqCtx, cancelReq := context.WithCancel(runCtx) 346 thisRequest := nextRequest 347 nextRequest = make(chan struct{}) 348 req := &Request{ 349 conn: c, 350 cancel: cancelReq, 351 nextRequest: nextRequest, 352 WireRequest: WireRequest{ 353 VersionTag: msg.VersionTag, 354 Method: msg.Method, 355 Params: msg.Params, 356 ID: msg.ID, 357 }, 358 } 359 for _, h := range c.handlers { 360 reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest) 361 reqCtx = h.Read(reqCtx, n) 362 } 363 c.setHandling(req, true) 364 go func() { 365 <-thisRequest 366 req.state = requestSerial 367 defer func() { 368 c.setHandling(req, false) 369 if !req.IsNotify() && req.state < requestReplied { 370 req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method)) 371 } 372 req.Parallel() 373 for _, h := range c.handlers { 374 h.Done(reqCtx, err) 375 } 376 cancelReq() 377 }() 378 delivered := false 379 for _, h := range c.handlers { 380 if h.Deliver(reqCtx, req, delivered) { 381 delivered = true 382 } 383 } 384 }() 385 case msg.ID != nil: 386 // we have a response, get the pending entry from the map 387 c.pendingMu.Lock() 388 rchan := c.pending[*msg.ID] 389 if rchan != nil { 390 delete(c.pending, *msg.ID) 391 } 392 c.pendingMu.Unlock() 393 // and send the reply to the channel 394 response := &WireResponse{ 395 Result: msg.Result, 396 Error: msg.Error, 397 ID: msg.ID, 398 } 399 rchan <- response 400 close(rchan) 401 default: 402 for _, h := range c.handlers { 403 h.Error(runCtx, fmt.Errorf("message not a call, notify or response, ignoring")) 404 } 405 } 406 } 407} 408 409func marshalToRaw(obj interface{}) (*json.RawMessage, error) { 410 data, err := json.Marshal(obj) 411 if err != nil { 412 return nil, err 413 } 414 raw := json.RawMessage(data) 415 return &raw, nil 416} 417