• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2018 The Go Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec.
16// https://www.jsonrpc.org/specification
17// It is intended to be compatible with other implementations at the wire level.
18package jsonrpc2
19
20import (
21	"context"
22	"encoding/json"
23	"fmt"
24	"sync"
25	"sync/atomic"
26)
27
28// Conn is a JSON RPC 2 client server connection.
29// Conn is bidirectional; it does not have a designated server or client end.
30type Conn struct {
31	seq        int64 // must only be accessed using atomic operations
32	handlers   []Handler
33	stream     Stream
34	err        error
35	pendingMu  sync.Mutex // protects the pending map
36	pending    map[ID]chan *WireResponse
37	handlingMu sync.Mutex // protects the handling map
38	handling   map[ID]*Request
39}
40
41type requestState int
42
43const (
44	requestWaiting = requestState(iota)
45	requestSerial
46	requestParallel
47	requestReplied
48	requestDone
49)
50
51// Request is sent to a server to represent a Call or Notify operaton.
52type Request struct {
53	conn        *Conn
54	cancel      context.CancelFunc
55	state       requestState
56	nextRequest chan struct{}
57
58	// The Wire values of the request.
59	WireRequest
60}
61
62// NewErrorf builds a Error struct for the supplied message and code.
63// If args is not empty, message and args will be passed to Sprintf.
64func NewErrorf(code int64, format string, args ...interface{}) *Error {
65	return &Error{
66		Code:    code,
67		Message: fmt.Sprintf(format, args...),
68	}
69}
70
71// NewConn creates a new connection object around the supplied stream.
72// You must call Run for the connection to be active.
73func NewConn(s Stream) *Conn {
74	conn := &Conn{
75		handlers: []Handler{defaultHandler{}},
76		stream:   s,
77		pending:  make(map[ID]chan *WireResponse),
78		handling: make(map[ID]*Request),
79	}
80	return conn
81}
82
83// AddHandler adds a new handler to the set the connection will invoke.
84// Handlers are invoked in the reverse order of how they were added, this
85// allows the most recent addition to be the first one to attempt to handle a
86// message.
87func (c *Conn) AddHandler(handler Handler) {
88	// prepend the new handlers so we use them first
89	c.handlers = append([]Handler{handler}, c.handlers...)
90}
91
92// Cancel cancels a pending Call on the server side.
93// The call is identified by its id.
94// JSON RPC 2 does not specify a cancel message, so cancellation support is not
95// directly wired in. This method allows a higher level protocol to choose how
96// to propagate the cancel.
97func (c *Conn) Cancel(id ID) {
98	c.handlingMu.Lock()
99	handling, found := c.handling[id]
100	c.handlingMu.Unlock()
101	if found {
102		handling.cancel()
103	}
104}
105
106// Notify is called to send a notification request over the connection.
107// It will return as soon as the notification has been sent, as no response is
108// possible.
109func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (err error) {
110	jsonParams, err := marshalToRaw(params)
111	if err != nil {
112		return fmt.Errorf("marshalling notify parameters: %v", err)
113	}
114	request := &WireRequest{
115		Method: method,
116		Params: jsonParams,
117	}
118	data, err := json.Marshal(request)
119	if err != nil {
120		return fmt.Errorf("marshalling notify request: %v", err)
121	}
122	for _, h := range c.handlers {
123		ctx = h.Request(ctx, c, Send, request)
124	}
125	defer func() {
126		for _, h := range c.handlers {
127			h.Done(ctx, err)
128		}
129	}()
130	n, err := c.stream.Write(ctx, data)
131	for _, h := range c.handlers {
132		ctx = h.Wrote(ctx, n)
133	}
134	return err
135}
136
137// Call sends a request over the connection and then waits for a response.
138// If the response is not an error, it will be decoded into result.
139// result must be of a type you an pass to json.Unmarshal.
140func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) (err error) {
141	// generate a new request identifier
142	id := ID{Number: atomic.AddInt64(&c.seq, 1)}
143	jsonParams, err := marshalToRaw(params)
144	if err != nil {
145		return fmt.Errorf("marshalling call parameters: %v", err)
146	}
147	request := &WireRequest{
148		ID:     &id,
149		Method: method,
150		Params: jsonParams,
151	}
152	// marshal the request now it is complete
153	data, err := json.Marshal(request)
154	if err != nil {
155		return fmt.Errorf("marshalling call request: %v", err)
156	}
157	for _, h := range c.handlers {
158		ctx = h.Request(ctx, c, Send, request)
159	}
160	// we have to add ourselves to the pending map before we send, otherwise we
161	// are racing the response
162	rchan := make(chan *WireResponse)
163	c.pendingMu.Lock()
164	c.pending[id] = rchan
165	c.pendingMu.Unlock()
166	defer func() {
167		// clean up the pending response handler on the way out
168		c.pendingMu.Lock()
169		delete(c.pending, id)
170		c.pendingMu.Unlock()
171		for _, h := range c.handlers {
172			h.Done(ctx, err)
173		}
174	}()
175	// now we are ready to send
176	n, err := c.stream.Write(ctx, data)
177	for _, h := range c.handlers {
178		ctx = h.Wrote(ctx, n)
179	}
180	if err != nil {
181		// sending failed, we will never get a response, so don't leave it pending
182		return err
183	}
184	// now wait for the response
185	select {
186	case response := <-rchan:
187		for _, h := range c.handlers {
188			ctx = h.Response(ctx, c, Receive, response)
189		}
190		// is it an error response?
191		if response.Error != nil {
192			return response.Error
193		}
194		if result == nil || response.Result == nil {
195			return nil
196		}
197		if err := json.Unmarshal(*response.Result, result); err != nil {
198			return fmt.Errorf("unmarshalling result: %v", err)
199		}
200		return nil
201	case <-ctx.Done():
202		// allow the handler to propagate the cancel
203		cancelled := false
204		for _, h := range c.handlers {
205			if h.Cancel(ctx, c, id, cancelled) {
206				cancelled = true
207			}
208		}
209		return ctx.Err()
210	}
211}
212
213// Conn returns the connection that created this request.
214func (r *Request) Conn() *Conn { return r.conn }
215
216// IsNotify returns true if this request is a notification.
217func (r *Request) IsNotify() bool {
218	return r.ID == nil
219}
220
221// Parallel indicates that the system is now allowed to process other requests
222// in parallel with this one.
223// It is safe to call any number of times, but must only be called from the
224// request handling go routine.
225// It is implied by both reply and by the handler returning.
226func (r *Request) Parallel() {
227	if r.state >= requestParallel {
228		return
229	}
230	r.state = requestParallel
231	close(r.nextRequest)
232}
233
234// Reply sends a reply to the given request.
235// It is an error to call this if request was not a call.
236// You must call this exactly once for any given request.
237// It should only be called from the handler go routine.
238// If err is set then result will be ignored.
239// If the request has not yet dropped into parallel mode
240// it will be before this function returns.
241func (r *Request) Reply(ctx context.Context, result interface{}, err error) error {
242	if r.state >= requestReplied {
243		return fmt.Errorf("reply invoked more than once")
244	}
245	if r.IsNotify() {
246		return fmt.Errorf("reply not invoked with a valid call")
247	}
248	// reply ends the handling phase of a call, so if we are not yet
249	// parallel we should be now. The go routine is allowed to continue
250	// to do work after replying, which is why it is important to unlock
251	// the rpc system at this point.
252	r.Parallel()
253	r.state = requestReplied
254
255	var raw *json.RawMessage
256	if err == nil {
257		raw, err = marshalToRaw(result)
258	}
259	response := &WireResponse{
260		Result: raw,
261		ID:     r.ID,
262	}
263	if err != nil {
264		if callErr, ok := err.(*Error); ok {
265			response.Error = callErr
266		} else {
267			response.Error = NewErrorf(0, "%s", err)
268		}
269	}
270	data, err := json.Marshal(response)
271	if err != nil {
272		return err
273	}
274	for _, h := range r.conn.handlers {
275		ctx = h.Response(ctx, r.conn, Send, response)
276	}
277	n, err := r.conn.stream.Write(ctx, data)
278	for _, h := range r.conn.handlers {
279		ctx = h.Wrote(ctx, n)
280	}
281
282	if err != nil {
283		// TODO(iancottrell): if a stream write fails, we really need to shut down
284		// the whole stream
285		return err
286	}
287	return nil
288}
289
290func (c *Conn) setHandling(r *Request, active bool) {
291	if r.ID == nil {
292		return
293	}
294	r.conn.handlingMu.Lock()
295	defer r.conn.handlingMu.Unlock()
296	if active {
297		r.conn.handling[*r.ID] = r
298	} else {
299		delete(r.conn.handling, *r.ID)
300	}
301}
302
303// combined has all the fields of both Request and Response.
304// We can decode this and then work out which it is.
305type combined struct {
306	VersionTag VersionTag       `json:"jsonrpc"`
307	ID         *ID              `json:"id,omitempty"`
308	Method     string           `json:"method"`
309	Params     *json.RawMessage `json:"params,omitempty"`
310	Result     *json.RawMessage `json:"result,omitempty"`
311	Error      *Error           `json:"error,omitempty"`
312}
313
314// Run blocks until the connection is terminated, and returns any error that
315// caused the termination.
316// It must be called exactly once for each Conn.
317// It returns only when the reader is closed or there is an error in the stream.
318func (c *Conn) Run(runCtx context.Context) error {
319	// we need to make the next request "lock" in an unlocked state to allow
320	// the first incoming request to proceed. All later requests are unlocked
321	// by the preceding request going to parallel mode.
322	nextRequest := make(chan struct{})
323	close(nextRequest)
324	for {
325		// get the data for a message
326		data, n, err := c.stream.Read(runCtx)
327		if err != nil {
328			// the stream failed, we cannot continue
329			return err
330		}
331		// read a combined message
332		msg := &combined{}
333		if err := json.Unmarshal(data, msg); err != nil {
334			// a badly formed message arrived, log it and continue
335			// we trust the stream to have isolated the error to just this message
336			for _, h := range c.handlers {
337				h.Error(runCtx, fmt.Errorf("unmarshal failed: %v", err))
338			}
339			continue
340		}
341		// work out which kind of message we have
342		switch {
343		case msg.Method != "":
344			// if method is set it must be a request
345			reqCtx, cancelReq := context.WithCancel(runCtx)
346			thisRequest := nextRequest
347			nextRequest = make(chan struct{})
348			req := &Request{
349				conn:        c,
350				cancel:      cancelReq,
351				nextRequest: nextRequest,
352				WireRequest: WireRequest{
353					VersionTag: msg.VersionTag,
354					Method:     msg.Method,
355					Params:     msg.Params,
356					ID:         msg.ID,
357				},
358			}
359			for _, h := range c.handlers {
360				reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest)
361				reqCtx = h.Read(reqCtx, n)
362			}
363			c.setHandling(req, true)
364			go func() {
365				<-thisRequest
366				req.state = requestSerial
367				defer func() {
368					c.setHandling(req, false)
369					if !req.IsNotify() && req.state < requestReplied {
370						req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method))
371					}
372					req.Parallel()
373					for _, h := range c.handlers {
374						h.Done(reqCtx, err)
375					}
376					cancelReq()
377				}()
378				delivered := false
379				for _, h := range c.handlers {
380					if h.Deliver(reqCtx, req, delivered) {
381						delivered = true
382					}
383				}
384			}()
385		case msg.ID != nil:
386			// we have a response, get the pending entry from the map
387			c.pendingMu.Lock()
388			rchan := c.pending[*msg.ID]
389			if rchan != nil {
390				delete(c.pending, *msg.ID)
391			}
392			c.pendingMu.Unlock()
393			// and send the reply to the channel
394			response := &WireResponse{
395				Result: msg.Result,
396				Error:  msg.Error,
397				ID:     msg.ID,
398			}
399			rchan <- response
400			close(rchan)
401		default:
402			for _, h := range c.handlers {
403				h.Error(runCtx, fmt.Errorf("message not a call, notify or response, ignoring"))
404			}
405		}
406	}
407}
408
409func marshalToRaw(obj interface{}) (*json.RawMessage, error) {
410	data, err := json.Marshal(obj)
411	if err != nil {
412		return nil, err
413	}
414	raw := json.RawMessage(data)
415	return &raw, nil
416}
417