• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2018 The Go Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package jsonrpc2
16
17import (
18	"bufio"
19	"context"
20	"encoding/json"
21	"fmt"
22	"io"
23	"strconv"
24	"strings"
25	"sync"
26)
27
28// Stream abstracts the transport mechanics from the JSON RPC protocol.
29// A Conn reads and writes messages using the stream it was provided on
30// construction, and assumes that each call to Read or Write fully transfers
31// a single message, or returns an error.
32type Stream interface {
33	// Read gets the next message from the stream.
34	// It is never called concurrently.
35	Read(context.Context) ([]byte, int64, error)
36	// Write sends a message to the stream.
37	// It must be safe for concurrent use.
38	Write(context.Context, []byte) (int64, error)
39}
40
41// NewStream returns a Stream built on top of an io.Reader and io.Writer
42// The messages are sent with no wrapping, and rely on json decode consistency
43// to determine message boundaries.
44func NewStream(in io.Reader, out io.Writer) Stream {
45	return &plainStream{
46		in:  json.NewDecoder(in),
47		out: out,
48	}
49}
50
51type plainStream struct {
52	in    *json.Decoder
53	outMu sync.Mutex
54	out   io.Writer
55}
56
57func (s *plainStream) Read(ctx context.Context) ([]byte, int64, error) {
58	select {
59	case <-ctx.Done():
60		return nil, 0, ctx.Err()
61	default:
62	}
63	var raw json.RawMessage
64	if err := s.in.Decode(&raw); err != nil {
65		return nil, 0, err
66	}
67	return raw, int64(len(raw)), nil
68}
69
70func (s *plainStream) Write(ctx context.Context, data []byte) (int64, error) {
71	select {
72	case <-ctx.Done():
73		return 0, ctx.Err()
74	default:
75	}
76	s.outMu.Lock()
77	n, err := s.out.Write(data)
78	s.outMu.Unlock()
79	return int64(n), err
80}
81
82// NewHeaderStream returns a Stream built on top of an io.Reader and io.Writer
83// The messages are sent with HTTP content length and MIME type headers.
84// This is the format used by LSP and others.
85func NewHeaderStream(in io.Reader, out io.Writer) Stream {
86	return &headerStream{
87		in:  bufio.NewReader(in),
88		out: out,
89	}
90}
91
92type headerStream struct {
93	in    *bufio.Reader
94	outMu sync.Mutex
95	out   io.Writer
96}
97
98func (s *headerStream) Read(ctx context.Context) ([]byte, int64, error) {
99	select {
100	case <-ctx.Done():
101		return nil, 0, ctx.Err()
102	default:
103	}
104	var total, length int64
105	// read the header, stop on the first empty line
106	for {
107		line, err := s.in.ReadString('\n')
108		total += int64(len(line))
109		if err != nil {
110			return nil, total, fmt.Errorf("failed reading header line %q", err)
111		}
112		line = strings.TrimSpace(line)
113		// check we have a header line
114		if line == "" {
115			break
116		}
117		colon := strings.IndexRune(line, ':')
118		if colon < 0 {
119			return nil, total, fmt.Errorf("invalid header line %q", line)
120		}
121		name, value := line[:colon], strings.TrimSpace(line[colon+1:])
122		switch name {
123		case "Content-Length":
124			if length, err = strconv.ParseInt(value, 10, 32); err != nil {
125				return nil, total, fmt.Errorf("failed parsing Content-Length: %v", value)
126			}
127			if length <= 0 {
128				return nil, total, fmt.Errorf("invalid Content-Length: %v", length)
129			}
130		default:
131			// ignoring unknown headers
132		}
133	}
134	if length == 0 {
135		return nil, total, fmt.Errorf("missing Content-Length header")
136	}
137	data := make([]byte, length)
138	if _, err := io.ReadFull(s.in, data); err != nil {
139		return nil, total, err
140	}
141	total += length
142	return data, total, nil
143}
144
145func (s *headerStream) Write(ctx context.Context, data []byte) (int64, error) {
146	select {
147	case <-ctx.Done():
148		return 0, ctx.Err()
149	default:
150	}
151	s.outMu.Lock()
152	defer s.outMu.Unlock()
153	n, err := fmt.Fprintf(s.out, "Content-Length: %v\r\n\r\n", len(data))
154	total := int64(n)
155	if err == nil {
156		n, err = s.out.Write(data)
157		total += int64(n)
158	}
159	return total, err
160}
161