• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runner
6
7import (
8	"crypto"
9	"crypto/ecdsa"
10	"crypto/ed25519"
11	"crypto/elliptic"
12	"crypto/rsa"
13	"crypto/subtle"
14	"crypto/x509"
15	"errors"
16	"fmt"
17	"io"
18	"math/big"
19
20	"golang.org/x/crypto/curve25519"
21)
22
23type keyType int
24
25const (
26	keyTypeRSA keyType = iota + 1
27	keyTypeECDSA
28)
29
30var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
31var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
32
33// rsaKeyAgreement implements the standard TLS key agreement where the client
34// encrypts the pre-master secret to the server's public key.
35type rsaKeyAgreement struct {
36	version       uint16
37	clientVersion uint16
38	exportKey     *rsa.PrivateKey
39}
40
41func (ka *rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
42	// Save the client version for comparison later.
43	ka.clientVersion = clientHello.vers
44
45	if !config.Bugs.RSAEphemeralKey {
46		return nil, nil
47	}
48
49	// Generate an ephemeral RSA key to use instead of the real
50	// one, as in RSA_EXPORT.
51	key, err := rsa.GenerateKey(config.rand(), 512)
52	if err != nil {
53		return nil, err
54	}
55	ka.exportKey = key
56
57	modulus := key.N.Bytes()
58	exponent := big.NewInt(int64(key.E)).Bytes()
59	serverRSAParams := make([]byte, 0, 2+len(modulus)+2+len(exponent))
60	serverRSAParams = append(serverRSAParams, byte(len(modulus)>>8), byte(len(modulus)))
61	serverRSAParams = append(serverRSAParams, modulus...)
62	serverRSAParams = append(serverRSAParams, byte(len(exponent)>>8), byte(len(exponent)))
63	serverRSAParams = append(serverRSAParams, exponent...)
64
65	var sigAlg signatureAlgorithm
66	if ka.version >= VersionTLS12 {
67		sigAlg, err = selectSignatureAlgorithm(ka.version, cert.PrivateKey, config, clientHello.signatureAlgorithms)
68		if err != nil {
69			return nil, err
70		}
71	}
72
73	sig, err := signMessage(ka.version, cert.PrivateKey, config, sigAlg, serverRSAParams)
74	if err != nil {
75		return nil, errors.New("failed to sign RSA parameters: " + err.Error())
76	}
77
78	skx := new(serverKeyExchangeMsg)
79	sigAlgsLen := 0
80	if ka.version >= VersionTLS12 {
81		sigAlgsLen = 2
82	}
83	skx.key = make([]byte, len(serverRSAParams)+sigAlgsLen+2+len(sig))
84	copy(skx.key, serverRSAParams)
85	k := skx.key[len(serverRSAParams):]
86	if ka.version >= VersionTLS12 {
87		k[0] = byte(sigAlg >> 8)
88		k[1] = byte(sigAlg)
89		k = k[2:]
90	}
91	k[0] = byte(len(sig) >> 8)
92	k[1] = byte(len(sig))
93	copy(k[2:], sig)
94
95	return skx, nil
96}
97
98func (ka *rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
99	preMasterSecret := make([]byte, 48)
100	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
101	if err != nil {
102		return nil, err
103	}
104
105	if len(ckx.ciphertext) < 2 {
106		return nil, errClientKeyExchange
107	}
108
109	ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
110	if ciphertextLen != len(ckx.ciphertext)-2 {
111		return nil, errClientKeyExchange
112	}
113	ciphertext := ckx.ciphertext[2:]
114
115	key := cert.PrivateKey.(*rsa.PrivateKey)
116	if ka.exportKey != nil {
117		key = ka.exportKey
118	}
119	err = rsa.DecryptPKCS1v15SessionKey(config.rand(), key, ciphertext, preMasterSecret)
120	if err != nil {
121		return nil, err
122	}
123	// This check should be done in constant-time, but this is a testing
124	// implementation. See the discussion at the end of section 7.4.7.1 of
125	// RFC 4346.
126	vers := uint16(preMasterSecret[0])<<8 | uint16(preMasterSecret[1])
127	if ka.clientVersion != vers {
128		return nil, fmt.Errorf("tls: invalid version in RSA premaster (got %04x, wanted %04x)", vers, ka.clientVersion)
129	}
130	return preMasterSecret, nil
131}
132
133func (ka *rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
134	return errors.New("tls: unexpected ServerKeyExchange")
135}
136
137func rsaSize(pub *rsa.PublicKey) int {
138	return (pub.N.BitLen() + 7) / 8
139}
140
141func rsaRawEncrypt(pub *rsa.PublicKey, msg []byte) ([]byte, error) {
142	k := rsaSize(pub)
143	if len(msg) != k {
144		return nil, errors.New("tls: bad padded RSA input")
145	}
146	m := new(big.Int).SetBytes(msg)
147	e := big.NewInt(int64(pub.E))
148	m.Exp(m, e, pub.N)
149	unpadded := m.Bytes()
150	ret := make([]byte, k)
151	copy(ret[len(ret)-len(unpadded):], unpadded)
152	return ret, nil
153}
154
155// nonZeroRandomBytes fills the given slice with non-zero random octets.
156func nonZeroRandomBytes(s []byte, rand io.Reader) {
157	if _, err := io.ReadFull(rand, s); err != nil {
158		panic(err)
159	}
160
161	for i := range s {
162		for s[i] == 0 {
163			if _, err := io.ReadFull(rand, s[i:i+1]); err != nil {
164				panic(err)
165			}
166		}
167	}
168}
169
170func (ka *rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
171	bad := config.Bugs.BadRSAClientKeyExchange
172	preMasterSecret := make([]byte, 48)
173	vers := clientHello.vers
174	if bad == RSABadValueWrongVersion1 {
175		vers ^= 1
176	} else if bad == RSABadValueWrongVersion2 {
177		vers ^= 0x100
178	}
179	preMasterSecret[0] = byte(vers >> 8)
180	preMasterSecret[1] = byte(vers)
181	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
182	if err != nil {
183		return nil, nil, err
184	}
185
186	sentPreMasterSecret := preMasterSecret
187	if bad == RSABadValueTooLong {
188		sentPreMasterSecret = make([]byte, 1, len(sentPreMasterSecret)+1)
189		sentPreMasterSecret = append(sentPreMasterSecret, preMasterSecret...)
190	} else if bad == RSABadValueTooShort {
191		sentPreMasterSecret = sentPreMasterSecret[:len(sentPreMasterSecret)-1]
192	}
193
194	// Pad for PKCS#1 v1.5.
195	padded := make([]byte, rsaSize(cert.PublicKey.(*rsa.PublicKey)))
196	padded[1] = 2
197	nonZeroRandomBytes(padded[2:len(padded)-len(sentPreMasterSecret)-1], config.rand())
198	copy(padded[len(padded)-len(sentPreMasterSecret):], sentPreMasterSecret)
199
200	if bad == RSABadValueWrongBlockType {
201		padded[1] = 3
202	} else if bad == RSABadValueWrongLeadingByte {
203		padded[0] = 1
204	} else if bad == RSABadValueNoZero {
205		for i := 2; i < len(padded); i++ {
206			if padded[i] == 0 {
207				padded[i]++
208			}
209		}
210	}
211
212	encrypted, err := rsaRawEncrypt(cert.PublicKey.(*rsa.PublicKey), padded)
213	if err != nil {
214		return nil, nil, err
215	}
216	if bad == RSABadValueCorrupt {
217		encrypted[len(encrypted)-1] ^= 1
218		// Clear the high byte to ensure |encrypted| is still below the RSA modulus.
219		encrypted[0] = 0
220	}
221	ckx := new(clientKeyExchangeMsg)
222	ckx.ciphertext = make([]byte, len(encrypted)+2)
223	ckx.ciphertext[0] = byte(len(encrypted) >> 8)
224	ckx.ciphertext[1] = byte(len(encrypted))
225	copy(ckx.ciphertext[2:], encrypted)
226	return preMasterSecret, ckx, nil
227}
228
229func (ka *rsaKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
230	return 0
231}
232
233// A kemImplementation is an instance of KEM-style construction for TLS.
234type kemImplementation interface {
235	// generate generates a keypair using rand. It returns the encoded public key.
236	generate(rand io.Reader) (publicKey []byte, err error)
237
238	// encap generates a symmetric, shared secret, encapsulates it with |peerKey|.
239	// It returns the encapsulated shared secret and the secret itself.
240	encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error)
241
242	// decap decapsulates |ciphertext| and returns the resulting shared secret.
243	decap(ciphertext []byte) (secret []byte, err error)
244}
245
246// ecdhKEM implements kemImplementation with an elliptic.Curve.
247//
248// TODO(davidben): Move this to Go's crypto/ecdh.
249type ecdhKEM struct {
250	curve          elliptic.Curve
251	privateKey     []byte
252	sendCompressed bool
253}
254
255func (e *ecdhKEM) generate(rand io.Reader) (publicKey []byte, err error) {
256	var x, y *big.Int
257	e.privateKey, x, y, err = elliptic.GenerateKey(e.curve, rand)
258	if err != nil {
259		return nil, err
260	}
261	ret := elliptic.Marshal(e.curve, x, y)
262	if e.sendCompressed {
263		l := (len(ret) - 1) / 2
264		tmp := make([]byte, 1+l)
265		tmp[0] = byte(2 | y.Bit(0))
266		copy(tmp[1:], ret[1:1+l])
267		ret = tmp
268	}
269	return ret, nil
270}
271
272func (e *ecdhKEM) encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error) {
273	ciphertext, err = e.generate(rand)
274	if err != nil {
275		return nil, nil, err
276	}
277	secret, err = e.decap(peerKey)
278	if err != nil {
279		return nil, nil, err
280	}
281	return
282}
283
284func (e *ecdhKEM) decap(ciphertext []byte) (secret []byte, err error) {
285	x, y := elliptic.Unmarshal(e.curve, ciphertext)
286	if x == nil {
287		return nil, errors.New("tls: invalid peer key")
288	}
289	x, _ = e.curve.ScalarMult(x, y, e.privateKey)
290	secret = make([]byte, (e.curve.Params().BitSize+7)>>3)
291	xBytes := x.Bytes()
292	copy(secret[len(secret)-len(xBytes):], xBytes)
293	return secret, nil
294}
295
296// x25519KEM implements kemImplementation with X25519.
297type x25519KEM struct {
298	privateKey [32]byte
299	setHighBit bool
300}
301
302func (e *x25519KEM) generate(rand io.Reader) (publicKey []byte, err error) {
303	_, err = io.ReadFull(rand, e.privateKey[:])
304	if err != nil {
305		return
306	}
307	var out [32]byte
308	curve25519.ScalarBaseMult(&out, &e.privateKey)
309	if e.setHighBit {
310		out[31] |= 0x80
311	}
312	return out[:], nil
313}
314
315func (e *x25519KEM) encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error) {
316	ciphertext, err = e.generate(rand)
317	if err != nil {
318		return nil, nil, err
319	}
320	secret, err = e.decap(peerKey)
321	if err != nil {
322		return nil, nil, err
323	}
324	return
325}
326
327func (e *x25519KEM) decap(ciphertext []byte) (secret []byte, err error) {
328	if len(ciphertext) != 32 {
329		return nil, errors.New("tls: invalid peer key")
330	}
331	var out [32]byte
332	curve25519.ScalarMult(&out, &e.privateKey, (*[32]byte)(ciphertext))
333
334	// Per RFC 7748, reject the all-zero value in constant time.
335	var zeros [32]byte
336	if subtle.ConstantTimeCompare(zeros[:], out[:]) == 1 {
337		return nil, errors.New("tls: X25519 value with wrong order")
338	}
339
340	return out[:], nil
341}
342
343func kemForCurveID(id CurveID, config *Config) (kemImplementation, bool) {
344	switch id {
345	case CurveP224:
346		return &ecdhKEM{curve: elliptic.P224(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
347	case CurveP256:
348		return &ecdhKEM{curve: elliptic.P256(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
349	case CurveP384:
350		return &ecdhKEM{curve: elliptic.P384(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
351	case CurveP521:
352		return &ecdhKEM{curve: elliptic.P521(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
353	case CurveX25519:
354		return &x25519KEM{setHighBit: config.Bugs.SetX25519HighBit}, true
355	default:
356		return nil, false
357	}
358
359}
360
361// keyAgreementAuthentication is a helper interface that specifies how
362// to authenticate the ServerKeyExchange parameters.
363type keyAgreementAuthentication interface {
364	signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error)
365	verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, params []byte, sig []byte) error
366}
367
368// nilKeyAgreementAuthentication does not authenticate the key
369// agreement parameters.
370type nilKeyAgreementAuthentication struct{}
371
372func (ka *nilKeyAgreementAuthentication) signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) {
373	skx := new(serverKeyExchangeMsg)
374	skx.key = params
375	return skx, nil
376}
377
378func (ka *nilKeyAgreementAuthentication) verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, params []byte, sig []byte) error {
379	return nil
380}
381
382// signedKeyAgreement signs the ServerKeyExchange parameters with the
383// server's private key.
384type signedKeyAgreement struct {
385	keyType                keyType
386	version                uint16
387	peerSignatureAlgorithm signatureAlgorithm
388}
389
390func (ka *signedKeyAgreement) signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) {
391	// The message to be signed is prepended by the randoms.
392	var msg []byte
393	msg = append(msg, clientHello.random...)
394	msg = append(msg, hello.random...)
395	msg = append(msg, params...)
396
397	var sigAlg signatureAlgorithm
398	var err error
399	if ka.version >= VersionTLS12 {
400		sigAlg, err = selectSignatureAlgorithm(ka.version, cert.PrivateKey, config, clientHello.signatureAlgorithms)
401		if err != nil {
402			return nil, err
403		}
404	}
405
406	sig, err := signMessage(ka.version, cert.PrivateKey, config, sigAlg, msg)
407	if err != nil {
408		return nil, err
409	}
410	if config.Bugs.SendSignatureAlgorithm != 0 {
411		sigAlg = config.Bugs.SendSignatureAlgorithm
412	}
413
414	skx := new(serverKeyExchangeMsg)
415	if config.Bugs.UnauthenticatedECDH {
416		skx.key = params
417	} else {
418		sigAlgsLen := 0
419		if ka.version >= VersionTLS12 {
420			sigAlgsLen = 2
421		}
422		skx.key = make([]byte, len(params)+sigAlgsLen+2+len(sig))
423		copy(skx.key, params)
424		k := skx.key[len(params):]
425		if ka.version >= VersionTLS12 {
426			k[0] = byte(sigAlg >> 8)
427			k[1] = byte(sigAlg)
428			k = k[2:]
429		}
430		k[0] = byte(len(sig) >> 8)
431		k[1] = byte(len(sig))
432		copy(k[2:], sig)
433	}
434
435	return skx, nil
436}
437
438func (ka *signedKeyAgreement) verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, publicKey crypto.PublicKey, params []byte, sig []byte) error {
439	// The peer's key must match the cipher type.
440	switch ka.keyType {
441	case keyTypeECDSA:
442		_, edsaOk := publicKey.(*ecdsa.PublicKey)
443		_, ed25519Ok := publicKey.(ed25519.PublicKey)
444		if !edsaOk && !ed25519Ok {
445			return errors.New("tls: ECDHE ECDSA requires a ECDSA or Ed25519 server public key")
446		}
447	case keyTypeRSA:
448		_, ok := publicKey.(*rsa.PublicKey)
449		if !ok {
450			return errors.New("tls: ECDHE RSA requires a RSA server public key")
451		}
452	default:
453		return errors.New("tls: unknown key type")
454	}
455
456	// The message to be signed is prepended by the randoms.
457	var msg []byte
458	msg = append(msg, clientHello.random...)
459	msg = append(msg, serverHello.random...)
460	msg = append(msg, params...)
461
462	var sigAlg signatureAlgorithm
463	if ka.version >= VersionTLS12 {
464		if len(sig) < 2 {
465			return errServerKeyExchange
466		}
467		sigAlg = signatureAlgorithm(sig[0])<<8 | signatureAlgorithm(sig[1])
468		sig = sig[2:]
469		// Stash the signature algorithm to be extracted by the handshake.
470		ka.peerSignatureAlgorithm = sigAlg
471	}
472
473	if len(sig) < 2 {
474		return errServerKeyExchange
475	}
476	sigLen := int(sig[0])<<8 | int(sig[1])
477	if sigLen+2 != len(sig) {
478		return errServerKeyExchange
479	}
480	sig = sig[2:]
481
482	return verifyMessage(ka.version, publicKey, config, sigAlg, msg, sig)
483}
484
485// ecdheKeyAgreement implements a TLS key agreement where the server
486// generates a ephemeral EC public/private key pair and signs it. The
487// pre-master secret is then calculated using ECDH. The signature may
488// either be ECDSA or RSA.
489type ecdheKeyAgreement struct {
490	auth    keyAgreementAuthentication
491	kem     kemImplementation
492	curveID CurveID
493	peerKey []byte
494}
495
496func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
497	var curveid CurveID
498	preferredCurves := config.curvePreferences()
499
500NextCandidate:
501	for _, candidate := range preferredCurves {
502		if isPqGroup(candidate) && version < VersionTLS13 {
503			// Post-quantum "groups" require TLS 1.3.
504			continue
505		}
506
507		for _, c := range clientHello.supportedCurves {
508			if candidate == c {
509				curveid = c
510				break NextCandidate
511			}
512		}
513	}
514
515	if curveid == 0 {
516		return nil, errors.New("tls: no supported elliptic curves offered")
517	}
518
519	var ok bool
520	if ka.kem, ok = kemForCurveID(curveid, config); !ok {
521		return nil, errors.New("tls: preferredCurves includes unsupported curve")
522	}
523	ka.curveID = curveid
524
525	publicKey, err := ka.kem.generate(config.rand())
526	if err != nil {
527		return nil, err
528	}
529
530	// http://tools.ietf.org/html/rfc4492#section-5.4
531	serverECDHParams := make([]byte, 1+2+1+len(publicKey))
532	serverECDHParams[0] = 3 // named curve
533	if config.Bugs.SendCurve != 0 {
534		curveid = config.Bugs.SendCurve
535	}
536	serverECDHParams[1] = byte(curveid >> 8)
537	serverECDHParams[2] = byte(curveid)
538	serverECDHParams[3] = byte(len(publicKey))
539	copy(serverECDHParams[4:], publicKey)
540	if config.Bugs.InvalidECDHPoint {
541		serverECDHParams[4] ^= 0xff
542	}
543
544	return ka.auth.signParameters(config, cert, clientHello, hello, serverECDHParams)
545}
546
547func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
548	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
549		return nil, errClientKeyExchange
550	}
551	return ka.kem.decap(ckx.ciphertext[1:])
552}
553
554func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
555	if len(skx.key) < 4 {
556		return errServerKeyExchange
557	}
558	if skx.key[0] != 3 { // named curve
559		return errors.New("tls: server selected unsupported curve")
560	}
561	curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
562	ka.curveID = curveID
563
564	var ok bool
565	if ka.kem, ok = kemForCurveID(curveID, config); !ok {
566		return errors.New("tls: server selected unsupported curve")
567	}
568
569	publicLen := int(skx.key[3])
570	if publicLen+4 > len(skx.key) {
571		return errServerKeyExchange
572	}
573	// Save the peer key for later.
574	ka.peerKey = skx.key[4 : 4+publicLen]
575
576	// Check the signature.
577	serverECDHParams := skx.key[:4+publicLen]
578	sig := skx.key[4+publicLen:]
579	return ka.auth.verifyParameters(config, clientHello, serverHello, key, serverECDHParams, sig)
580}
581
582func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
583	if ka.kem == nil {
584		return nil, nil, errors.New("missing ServerKeyExchange message")
585	}
586
587	ciphertext, secret, err := ka.kem.encap(config.rand(), ka.peerKey)
588	if err != nil {
589		return nil, nil, err
590	}
591
592	ckx := new(clientKeyExchangeMsg)
593	ckx.ciphertext = make([]byte, 1+len(ciphertext))
594	ckx.ciphertext[0] = byte(len(ciphertext))
595	copy(ckx.ciphertext[1:], ciphertext)
596	if config.Bugs.InvalidECDHPoint {
597		ckx.ciphertext[1] ^= 0xff
598	}
599
600	return secret, ckx, nil
601}
602
603func (ka *ecdheKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
604	if auth, ok := ka.auth.(*signedKeyAgreement); ok {
605		return auth.peerSignatureAlgorithm
606	}
607	return 0
608}
609
610// nilKeyAgreement is a fake key agreement used to implement the plain PSK key
611// exchange.
612type nilKeyAgreement struct{}
613
614func (ka *nilKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
615	return nil, nil
616}
617
618func (ka *nilKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
619	if len(ckx.ciphertext) != 0 {
620		return nil, errClientKeyExchange
621	}
622
623	// Although in plain PSK, otherSecret is all zeros, the base key
624	// agreement does not access to the length of the pre-shared
625	// key. pskKeyAgreement instead interprets nil to mean to use all zeros
626	// of the appropriate length.
627	return nil, nil
628}
629
630func (ka *nilKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
631	if len(skx.key) != 0 {
632		return errServerKeyExchange
633	}
634	return nil
635}
636
637func (ka *nilKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
638	// Although in plain PSK, otherSecret is all zeros, the base key
639	// agreement does not access to the length of the pre-shared
640	// key. pskKeyAgreement instead interprets nil to mean to use all zeros
641	// of the appropriate length.
642	return nil, &clientKeyExchangeMsg{}, nil
643}
644
645func (ka *nilKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
646	return 0
647}
648
649// makePSKPremaster formats a PSK pre-master secret based on otherSecret from
650// the base key exchange and psk.
651func makePSKPremaster(otherSecret, psk []byte) []byte {
652	out := make([]byte, 0, 2+len(otherSecret)+2+len(psk))
653	out = append(out, byte(len(otherSecret)>>8), byte(len(otherSecret)))
654	out = append(out, otherSecret...)
655	out = append(out, byte(len(psk)>>8), byte(len(psk)))
656	out = append(out, psk...)
657	return out
658}
659
660// pskKeyAgreement implements the PSK key agreement.
661type pskKeyAgreement struct {
662	base         keyAgreement
663	identityHint string
664}
665
666func (ka *pskKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
667	// Assemble the identity hint.
668	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
669	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
670	bytes[1] = byte(len(config.PreSharedKeyIdentity))
671	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
672
673	// If there is one, append the base key agreement's
674	// ServerKeyExchange.
675	baseSkx, err := ka.base.generateServerKeyExchange(config, cert, clientHello, hello, version)
676	if err != nil {
677		return nil, err
678	}
679
680	if baseSkx != nil {
681		bytes = append(bytes, baseSkx.key...)
682	} else if config.PreSharedKeyIdentity == "" && !config.Bugs.AlwaysSendPreSharedKeyIdentityHint {
683		// ServerKeyExchange is optional if the identity hint is empty
684		// and there would otherwise be no ServerKeyExchange.
685		return nil, nil
686	}
687
688	skx := new(serverKeyExchangeMsg)
689	skx.key = bytes
690	return skx, nil
691}
692
693func (ka *pskKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
694	// First, process the PSK identity.
695	if len(ckx.ciphertext) < 2 {
696		return nil, errClientKeyExchange
697	}
698	identityLen := (int(ckx.ciphertext[0]) << 8) | int(ckx.ciphertext[1])
699	if 2+identityLen > len(ckx.ciphertext) {
700		return nil, errClientKeyExchange
701	}
702	identity := string(ckx.ciphertext[2 : 2+identityLen])
703
704	if identity != config.PreSharedKeyIdentity {
705		return nil, errors.New("tls: unexpected identity")
706	}
707
708	if config.PreSharedKey == nil {
709		return nil, errors.New("tls: pre-shared key not configured")
710	}
711
712	// Process the remainder of the ClientKeyExchange to compute the base
713	// pre-master secret.
714	newCkx := new(clientKeyExchangeMsg)
715	newCkx.ciphertext = ckx.ciphertext[2+identityLen:]
716	otherSecret, err := ka.base.processClientKeyExchange(config, cert, newCkx, version)
717	if err != nil {
718		return nil, err
719	}
720
721	if otherSecret == nil {
722		// Special-case for the plain PSK key exchanges.
723		otherSecret = make([]byte, len(config.PreSharedKey))
724	}
725	return makePSKPremaster(otherSecret, config.PreSharedKey), nil
726}
727
728func (ka *pskKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
729	if len(skx.key) < 2 {
730		return errServerKeyExchange
731	}
732	identityLen := (int(skx.key[0]) << 8) | int(skx.key[1])
733	if 2+identityLen > len(skx.key) {
734		return errServerKeyExchange
735	}
736	ka.identityHint = string(skx.key[2 : 2+identityLen])
737
738	// Process the remainder of the ServerKeyExchange.
739	newSkx := new(serverKeyExchangeMsg)
740	newSkx.key = skx.key[2+identityLen:]
741	return ka.base.processServerKeyExchange(config, clientHello, serverHello, key, newSkx)
742}
743
744func (ka *pskKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
745	// The server only sends an identity hint but, for purposes of
746	// test code, the server always sends the hint and it is
747	// required to match.
748	if ka.identityHint != config.PreSharedKeyIdentity {
749		return nil, nil, errors.New("tls: unexpected identity")
750	}
751
752	// Serialize the identity.
753	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
754	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
755	bytes[1] = byte(len(config.PreSharedKeyIdentity))
756	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
757
758	// Append the base key exchange's ClientKeyExchange.
759	otherSecret, baseCkx, err := ka.base.generateClientKeyExchange(config, clientHello, cert)
760	if err != nil {
761		return nil, nil, err
762	}
763	ckx := new(clientKeyExchangeMsg)
764	ckx.ciphertext = append(bytes, baseCkx.ciphertext...)
765
766	if config.PreSharedKey == nil {
767		return nil, nil, errors.New("tls: pre-shared key not configured")
768	}
769	if otherSecret == nil {
770		otherSecret = make([]byte, len(config.PreSharedKey))
771	}
772	return makePSKPremaster(otherSecret, config.PreSharedKey), ckx, nil
773}
774
775func (ka *pskKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
776	return 0
777}
778