• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright (c) 2021, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15// testmodulewrapper is a modulewrapper binary that works with acvptool and
16// implements the primitives that BoringSSL's modulewrapper doesn't, so that
17// we have something that can exercise all the code in avcptool.
18
19package main
20
21import (
22	"bytes"
23	"crypto/aes"
24	"crypto/cipher"
25	"crypto/hmac"
26	"crypto/rand"
27	"crypto/sha256"
28	"encoding/binary"
29	"errors"
30	"fmt"
31	"io"
32	"os"
33
34	"golang.org/x/crypto/hkdf"
35	"golang.org/x/crypto/xts"
36)
37
38var (
39	output       io.Writer
40	outputBuffer *bytes.Buffer
41)
42
43var handlers = map[string]func([][]byte) error{
44	"flush":                    flush,
45	"getConfig":                getConfig,
46	"KDF-counter":              kdfCounter,
47	"AES-XTS/encrypt":          xtsEncrypt,
48	"AES-XTS/decrypt":          xtsDecrypt,
49	"HKDF/SHA2-256":            hkdfMAC,
50	"hmacDRBG-reseed/SHA2-256": hmacDRBGReseed,
51	"hmacDRBG-pr/SHA2-256":     hmacDRBGPredictionResistance,
52	"AES-CBC-CS3/encrypt":      ctsEncrypt,
53	"AES-CBC-CS3/decrypt":      ctsDecrypt,
54}
55
56func flush(args [][]byte) error {
57	if outputBuffer == nil {
58		return nil
59	}
60
61	if _, err := os.Stdout.Write(outputBuffer.Bytes()); err != nil {
62		return err
63	}
64	outputBuffer = new(bytes.Buffer)
65	output = outputBuffer
66	return nil
67}
68
69func getConfig(args [][]byte) error {
70	if len(args) != 0 {
71		return fmt.Errorf("getConfig received %d args", len(args))
72	}
73
74	if err := reply([]byte(`[
75	{
76		"algorithm": "acvptool",
77		"features": ["batch"]
78	}, {
79		"algorithm": "KDF",
80		"revision": "1.0",
81		"capabilities": [{
82			"kdfMode": "counter",
83			"macMode": [
84				"HMAC-SHA2-256"
85			],
86			"supportedLengths": [{
87				"min": 8,
88				"max": 4096,
89				"increment": 8
90			}],
91			"fixedDataOrder": [
92				"before fixed data"
93			],
94			"counterLength": [
95				32
96			]
97		}]
98	}, {
99		"algorithm": "ACVP-AES-XTS",
100		"revision": "1.0",
101		"direction": [
102		  "encrypt",
103		  "decrypt"
104		],
105		"keyLen": [
106		  128,
107		  256
108		],
109		"payloadLen": [
110		  1024
111		],
112		"tweakMode": [
113		  "number"
114		]
115	}, {
116		"algorithm": "KDA",
117		"mode": "HKDF",
118		"revision": "Sp800-56Cr1",
119		"fixedInfoPattern": "uPartyInfo||vPartyInfo",
120		"encoding": [
121			"concatenation"
122		],
123		"hmacAlg": [
124			"SHA2-256"
125		],
126		"macSaltMethods": [
127			"default",
128			"random"
129		],
130		"l": 256,
131		"z": [256, 384]
132	}, {
133		"algorithm": "hmacDRBG",
134		"revision": "1.0",
135		"predResistanceEnabled": [false, true],
136		"reseedImplemented": true,
137		"capabilities": [{
138			"mode": "SHA2-256",
139			"derFuncEnabled": false,
140			"entropyInputLen": [
141				256
142			],
143			"nonceLen": [
144				128
145			],
146			"persoStringLen": [
147				256
148			],
149			"additionalInputLen": [
150				256
151			],
152			"returnedBitsLen": 256
153		}]
154	}, {
155		"algorithm": "ACVP-AES-CBC-CS3",
156		"revision": "1.0",
157		"payloadLen": [{
158			"min": 128,
159			"max": 2048,
160			"increment": 8
161		}],
162		"direction": [
163		  "encrypt",
164		  "decrypt"
165		],
166		"keyLen": [
167		  128,
168		  256
169		]
170	}
171]`)); err != nil {
172		return err
173	}
174
175	return flush(nil)
176}
177
178func kdfCounter(args [][]byte) error {
179	if len(args) != 5 {
180		return fmt.Errorf("KDF received %d args", len(args))
181	}
182
183	outputBytes32, prf, counterLocation, key, counterBits32 := args[0], args[1], args[2], args[3], args[4]
184	outputBytes := binary.LittleEndian.Uint32(outputBytes32)
185	counterBits := binary.LittleEndian.Uint32(counterBits32)
186
187	if !bytes.Equal(prf, []byte("HMAC-SHA2-256")) {
188		return fmt.Errorf("KDF received unsupported PRF %q", string(prf))
189	}
190	if !bytes.Equal(counterLocation, []byte("before fixed data")) {
191		return fmt.Errorf("KDF received unsupported counter location %q", counterLocation)
192	}
193	if counterBits != 32 {
194		return fmt.Errorf("KDF received unsupported counter length %d", counterBits)
195	}
196
197	if len(key) == 0 {
198		key = make([]byte, 32)
199		rand.Reader.Read(key)
200	}
201
202	// See https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-108.pdf section 5.1
203	if outputBytes+31 < outputBytes {
204		return fmt.Errorf("KDF received excessive output length %d", outputBytes)
205	}
206
207	n := (outputBytes + 31) / 32
208	result := make([]byte, 0, 32*n)
209	mac := hmac.New(sha256.New, key)
210	var input [4 + 8]byte
211	var digest []byte
212	rand.Reader.Read(input[4:])
213	for i := uint32(1); i <= n; i++ {
214		mac.Reset()
215		binary.BigEndian.PutUint32(input[:4], i)
216		mac.Write(input[:])
217		digest = mac.Sum(digest[:0])
218		result = append(result, digest...)
219	}
220
221	return reply(key, input[4:], result[:outputBytes])
222}
223
224func reply(responses ...[]byte) error {
225	if len(responses) > maxArgs {
226		return fmt.Errorf("%d responses is too many", len(responses))
227	}
228
229	var lengths [4 * (1 + maxArgs)]byte
230	binary.LittleEndian.PutUint32(lengths[:4], uint32(len(responses)))
231	for i, response := range responses {
232		binary.LittleEndian.PutUint32(lengths[4*(i+1):4*(i+2)], uint32(len(response)))
233	}
234
235	lengthsLength := (1 + len(responses)) * 4
236	if n, err := output.Write(lengths[:lengthsLength]); n != lengthsLength || err != nil {
237		return fmt.Errorf("write failed: %s", err)
238	}
239
240	for _, response := range responses {
241		if n, err := output.Write(response); n != len(response) || err != nil {
242			return fmt.Errorf("write failed: %s", err)
243		}
244	}
245
246	return nil
247}
248
249func xtsEncrypt(args [][]byte) error {
250	return doXTS(args, false)
251}
252
253func xtsDecrypt(args [][]byte) error {
254	return doXTS(args, true)
255}
256
257func doXTS(args [][]byte, decrypt bool) error {
258	if len(args) != 3 {
259		return fmt.Errorf("XTS received %d args, wanted 3", len(args))
260	}
261	key := args[0]
262	msg := args[1]
263	tweak := args[2]
264
265	if len(msg)%16 != 0 {
266		return fmt.Errorf("XTS received %d-byte msg, need multiple of 16", len(msg))
267	}
268	if len(tweak) != 16 {
269		return fmt.Errorf("XTS received %d-byte tweak, wanted 16", len(tweak))
270	}
271
272	var zeros [8]byte
273	if !bytes.Equal(tweak[8:], zeros[:]) {
274		return errors.New("XTS received tweak with invalid structure. Ensure that configuration specifies a 'number' tweak")
275	}
276
277	sectorNum := binary.LittleEndian.Uint64(tweak[:8])
278
279	c, err := xts.NewCipher(aes.NewCipher, key)
280	if err != nil {
281		return err
282	}
283
284	if decrypt {
285		c.Decrypt(msg, msg, sectorNum)
286	} else {
287		c.Encrypt(msg, msg, sectorNum)
288	}
289
290	return reply(msg)
291}
292
293func hkdfMAC(args [][]byte) error {
294	if len(args) != 4 {
295		return fmt.Errorf("HKDF received %d args, wanted 4", len(args))
296	}
297
298	key := args[0]
299	salt := args[1]
300	info := args[2]
301	lengthBytes := args[3]
302
303	if len(lengthBytes) != 4 {
304		return fmt.Errorf("uint32 length was %d bytes long", len(lengthBytes))
305	}
306
307	length := binary.LittleEndian.Uint32(lengthBytes)
308
309	mac := hkdf.New(sha256.New, key, salt, info)
310	ret := make([]byte, length)
311	mac.Read(ret)
312
313	return reply(ret)
314}
315
316func hmacDRBGReseed(args [][]byte) error {
317	if len(args) != 8 {
318		return fmt.Errorf("hmacDRBG received %d args, wanted 8", len(args))
319	}
320
321	outLenBytes, entropy, personalisation, reseedAdditionalData, reseedEntropy, additionalData1, additionalData2, nonce := args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
322
323	if len(outLenBytes) != 4 {
324		return fmt.Errorf("uint32 length was %d bytes long", len(outLenBytes))
325	}
326	outLen := binary.LittleEndian.Uint32(outLenBytes)
327	out := make([]byte, outLen)
328
329	drbg := NewHMACDRBG(entropy, nonce, personalisation)
330	drbg.Reseed(reseedEntropy, reseedAdditionalData)
331	drbg.Generate(out, additionalData1)
332	drbg.Generate(out, additionalData2)
333
334	return reply(out)
335}
336
337func hmacDRBGPredictionResistance(args [][]byte) error {
338	if len(args) != 8 {
339		return fmt.Errorf("hmacDRBG received %d args, wanted 8", len(args))
340	}
341
342	outLenBytes, entropy, personalisation, additionalData1, entropy1, additionalData2, entropy2, nonce := args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
343
344	if len(outLenBytes) != 4 {
345		return fmt.Errorf("uint32 length was %d bytes long", len(outLenBytes))
346	}
347	outLen := binary.LittleEndian.Uint32(outLenBytes)
348	out := make([]byte, outLen)
349
350	drbg := NewHMACDRBG(entropy, nonce, personalisation)
351	drbg.Reseed(entropy1, additionalData1)
352	drbg.Generate(out, nil)
353	drbg.Reseed(entropy2, additionalData2)
354	drbg.Generate(out, nil)
355
356	return reply(out)
357}
358
359func swapFinalTwoAESBlocks(d []byte) {
360	var blockNMinus1 [aes.BlockSize]byte
361	copy(blockNMinus1[:], d[len(d)-2*aes.BlockSize:])
362	copy(d[len(d)-2*aes.BlockSize:], d[len(d)-aes.BlockSize:])
363	copy(d[len(d)-aes.BlockSize:], blockNMinus1[:])
364}
365
366func roundUp(n, m int) int {
367	return n + (m-(n%m))%m
368}
369
370func doCTSEncrypt(key, origPlaintext, iv []byte) []byte {
371	// https://nvlpubs.nist.gov/nistpubs/legacy/sp/nistspecialpublication800-38a-add.pdf
372	if len(origPlaintext) < aes.BlockSize {
373		panic("input too small")
374	}
375
376	plaintext := make([]byte, roundUp(len(origPlaintext), aes.BlockSize))
377	copy(plaintext, origPlaintext)
378
379	block, err := aes.NewCipher(key)
380	if err != nil {
381		panic(err)
382	}
383	cbcEncryptor := cipher.NewCBCEncrypter(block, iv)
384	cbcEncryptor.CryptBlocks(plaintext, plaintext)
385	ciphertext := plaintext
386
387	if len(origPlaintext) > aes.BlockSize {
388		swapFinalTwoAESBlocks(ciphertext)
389
390		if len(origPlaintext)%16 != 0 {
391			// Truncate the ciphertext
392			ciphertext = ciphertext[:len(ciphertext)-aes.BlockSize+(len(origPlaintext)%aes.BlockSize)]
393		}
394	}
395
396	if len(ciphertext) != len(origPlaintext) {
397		panic("internal error")
398	}
399
400	return ciphertext
401}
402
403func doCTSDecrypt(key, origCiphertext, iv []byte) []byte {
404	if len(origCiphertext) < aes.BlockSize {
405		panic("input too small")
406	}
407
408	ciphertext := make([]byte, roundUp(len(origCiphertext), aes.BlockSize))
409	copy(ciphertext, origCiphertext)
410
411	if len(ciphertext) > aes.BlockSize {
412		swapFinalTwoAESBlocks(ciphertext)
413	}
414
415	block, err := aes.NewCipher(key)
416	if err != nil {
417		panic(err)
418	}
419	cbcDecrypter := cipher.NewCBCDecrypter(block, iv)
420
421	var plaintext []byte
422	if overhang := len(origCiphertext) % aes.BlockSize; overhang == 0 {
423		cbcDecrypter.CryptBlocks(ciphertext, ciphertext)
424		plaintext = ciphertext
425	} else {
426		ciphertext, finalBlock := ciphertext[:len(ciphertext)-aes.BlockSize], ciphertext[len(ciphertext)-aes.BlockSize:]
427		var plaintextFinalBlock [aes.BlockSize]byte
428		block.Decrypt(plaintextFinalBlock[:], finalBlock)
429		copy(ciphertext[len(ciphertext)-aes.BlockSize+overhang:], plaintextFinalBlock[overhang:])
430		plaintext = make([]byte, len(origCiphertext))
431		cbcDecrypter.CryptBlocks(plaintext, ciphertext)
432		for i := 0; i < overhang; i++ {
433			plaintextFinalBlock[i] ^= ciphertext[len(ciphertext)-aes.BlockSize+i]
434		}
435		copy(plaintext[len(ciphertext):], plaintextFinalBlock[:overhang])
436	}
437
438	return plaintext
439}
440
441func ctsEncrypt(args [][]byte) error {
442	if len(args) != 4 {
443		return fmt.Errorf("ctsEncrypt received %d args, wanted 4", len(args))
444	}
445
446	key, plaintext, iv, numIterations32 := args[0], args[1], args[2], args[3]
447	if len(numIterations32) != 4 || binary.LittleEndian.Uint32(numIterations32) != 1 {
448		return errors.New("only a single iteration supported for ctsEncrypt")
449	}
450
451	if len(plaintext) < aes.BlockSize {
452		return fmt.Errorf("ctsEncrypt plaintext too short: %d bytes", len(plaintext))
453	}
454
455	return reply(doCTSEncrypt(key, plaintext, iv))
456}
457
458func ctsDecrypt(args [][]byte) error {
459	if len(args) != 4 {
460		return fmt.Errorf("ctsDecrypt received %d args, wanted 4", len(args))
461	}
462
463	key, ciphertext, iv, numIterations32 := args[0], args[1], args[2], args[3]
464	if len(numIterations32) != 4 || binary.LittleEndian.Uint32(numIterations32) != 1 {
465		return errors.New("only a single iteration supported for ctsDecrypt")
466	}
467
468	if len(ciphertext) < aes.BlockSize {
469		return errors.New("ctsDecrypt ciphertext too short")
470	}
471
472	return reply(doCTSDecrypt(key, ciphertext, iv))
473}
474
475const (
476	maxArgs       = 9
477	maxArgLength  = 1 << 20
478	maxNameLength = 30
479)
480
481func main() {
482	if err := do(); err != nil {
483		fmt.Fprintf(os.Stderr, "%s.\n", err)
484		os.Exit(1)
485	}
486}
487
488func do() error {
489	// In order to exercise pipelining, all output is buffered until a "flush".
490	outputBuffer = new(bytes.Buffer)
491	output = outputBuffer
492
493	var nums [4 * (1 + maxArgs)]byte
494	var argLengths [maxArgs]uint32
495	var args [maxArgs][]byte
496	var argsData []byte
497
498	for {
499		if _, err := io.ReadFull(os.Stdin, nums[:8]); err != nil {
500			return err
501		}
502
503		numArgs := binary.LittleEndian.Uint32(nums[:4])
504		if numArgs == 0 {
505			return errors.New("Invalid, zero-argument operation requested")
506		} else if numArgs > maxArgs {
507			return fmt.Errorf("Operation requested with %d args, but %d is the limit", numArgs, maxArgs)
508		}
509
510		if numArgs > 1 {
511			if _, err := io.ReadFull(os.Stdin, nums[8:4+4*numArgs]); err != nil {
512				return err
513			}
514		}
515
516		input := nums[4:]
517		var need uint64
518		for i := uint32(0); i < numArgs; i++ {
519			argLength := binary.LittleEndian.Uint32(input[:4])
520			if i == 0 && argLength > maxNameLength {
521				return fmt.Errorf("Operation with name of length %d exceeded limit of %d", argLength, maxNameLength)
522			} else if argLength > maxArgLength {
523				return fmt.Errorf("Operation with argument of length %d exceeded limit of %d", argLength, maxArgLength)
524			}
525			need += uint64(argLength)
526			argLengths[i] = argLength
527			input = input[4:]
528		}
529
530		if need > uint64(cap(argsData)) {
531			argsData = make([]byte, need)
532		} else {
533			argsData = argsData[:need]
534		}
535
536		if _, err := io.ReadFull(os.Stdin, argsData); err != nil {
537			return err
538		}
539
540		input = argsData
541		for i := uint32(0); i < numArgs; i++ {
542			args[i] = input[:argLengths[i]]
543			input = input[argLengths[i]:]
544		}
545
546		name := string(args[0])
547		if handler, ok := handlers[name]; !ok {
548			return fmt.Errorf("unknown operation %q", name)
549		} else {
550			if err := handler(args[1:numArgs]); err != nil {
551				return err
552			}
553		}
554	}
555}
556