• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright (c) 2021, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15// testmodulewrapper is a modulewrapper binary that works with acvptool and
16// implements the primitives that BoringSSL's modulewrapper doesn't, so that
17// we have something that can exercise all the code in avcptool.
18
19package main
20
21import (
22	"bytes"
23	"crypto/aes"
24	"crypto/cipher"
25	"crypto/hmac"
26	"crypto/rand"
27	"crypto/sha256"
28	"encoding/binary"
29	"errors"
30	"fmt"
31	"io"
32	"os"
33
34	"golang.org/x/crypto/hkdf"
35	"golang.org/x/crypto/xts"
36)
37
38var handlers = map[string]func([][]byte) error{
39	"getConfig":                getConfig,
40	"KDF-counter":              kdfCounter,
41	"AES-XTS/encrypt":          xtsEncrypt,
42	"AES-XTS/decrypt":          xtsDecrypt,
43	"HKDF/SHA2-256":            hkdfMAC,
44	"hmacDRBG-reseed/SHA2-256": hmacDRBGReseed,
45	"hmacDRBG-pr/SHA2-256":     hmacDRBGPredictionResistance,
46	"AES-CBC-CS3/encrypt":      ctsEncrypt,
47	"AES-CBC-CS3/decrypt":      ctsDecrypt,
48}
49
50func getConfig(args [][]byte) error {
51	if len(args) != 0 {
52		return fmt.Errorf("getConfig received %d args", len(args))
53	}
54
55	return reply([]byte(`[
56	{
57		"algorithm": "KDF",
58		"revision": "1.0",
59		"capabilities": [{
60			"kdfMode": "counter",
61			"macMode": [
62				"HMAC-SHA2-256"
63			],
64			"supportedLengths": [{
65				"min": 8,
66				"max": 4096,
67				"increment": 8
68			}],
69			"fixedDataOrder": [
70				"before fixed data"
71			],
72			"counterLength": [
73				32
74			]
75		}]
76	}, {
77		"algorithm": "ACVP-AES-XTS",
78		"revision": "1.0",
79		"direction": [
80		  "encrypt",
81		  "decrypt"
82		],
83		"keyLen": [
84		  128,
85		  256
86		],
87		"payloadLen": [
88		  1024
89		],
90		"tweakMode": [
91		  "number"
92		]
93	}, {
94		"algorithm": "KDA",
95		"mode": "HKDF",
96		"revision": "Sp800-56Cr1",
97		"fixedInfoPattern": "uPartyInfo||vPartyInfo",
98		"encoding": [
99			"concatenation"
100		],
101		"hmacAlg": [
102			"SHA2-256"
103		],
104		"macSaltMethods": [
105			"default",
106			"random"
107		],
108		"l": 256,
109		"z": [256, 384]
110	}, {
111		"algorithm": "hmacDRBG",
112		"revision": "1.0",
113		"predResistanceEnabled": [false, true],
114		"reseedImplemented": true,
115		"capabilities": [{
116			"mode": "SHA2-256",
117			"derFuncEnabled": false,
118			"entropyInputLen": [
119				256
120			],
121			"nonceLen": [
122				128
123			],
124			"persoStringLen": [
125				256
126			],
127			"additionalInputLen": [
128				256
129			],
130			"returnedBitsLen": 256
131		}]
132	}, {
133		"algorithm": "ACVP-AES-CBC-CS3",
134		"revision": "1.0",
135		"payloadLen": [{
136			"min": 128,
137			"max": 2048,
138			"increment": 8
139		}],
140		"direction": [
141		  "encrypt",
142		  "decrypt"
143		],
144		"keyLen": [
145		  128,
146		  256
147		]
148	}
149]`))
150}
151
152func kdfCounter(args [][]byte) error {
153	if len(args) != 5 {
154		return fmt.Errorf("KDF received %d args", len(args))
155	}
156
157	outputBytes32, prf, counterLocation, key, counterBits32 := args[0], args[1], args[2], args[3], args[4]
158	outputBytes := binary.LittleEndian.Uint32(outputBytes32)
159	counterBits := binary.LittleEndian.Uint32(counterBits32)
160
161	if !bytes.Equal(prf, []byte("HMAC-SHA2-256")) {
162		return fmt.Errorf("KDF received unsupported PRF %q", string(prf))
163	}
164	if !bytes.Equal(counterLocation, []byte("before fixed data")) {
165		return fmt.Errorf("KDF received unsupported counter location %q", counterLocation)
166	}
167	if counterBits != 32 {
168		return fmt.Errorf("KDF received unsupported counter length %d", counterBits)
169	}
170
171	if len(key) == 0 {
172		key = make([]byte, 32)
173		rand.Reader.Read(key)
174	}
175
176	// See https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-108.pdf section 5.1
177	if outputBytes+31 < outputBytes {
178		return fmt.Errorf("KDF received excessive output length %d", outputBytes)
179	}
180
181	n := (outputBytes + 31) / 32
182	result := make([]byte, 0, 32*n)
183	mac := hmac.New(sha256.New, key)
184	var input [4 + 8]byte
185	var digest []byte
186	rand.Reader.Read(input[4:])
187	for i := uint32(1); i <= n; i++ {
188		mac.Reset()
189		binary.BigEndian.PutUint32(input[:4], i)
190		mac.Write(input[:])
191		digest = mac.Sum(digest[:0])
192		result = append(result, digest...)
193	}
194
195	return reply(key, input[4:], result[:outputBytes])
196}
197
198func reply(responses ...[]byte) error {
199	if len(responses) > maxArgs {
200		return fmt.Errorf("%d responses is too many", len(responses))
201	}
202
203	var lengths [4 * (1 + maxArgs)]byte
204	binary.LittleEndian.PutUint32(lengths[:4], uint32(len(responses)))
205	for i, response := range responses {
206		binary.LittleEndian.PutUint32(lengths[4*(i+1):4*(i+2)], uint32(len(response)))
207	}
208
209	lengthsLength := (1 + len(responses)) * 4
210	if n, err := os.Stdout.Write(lengths[:lengthsLength]); n != lengthsLength || err != nil {
211		return fmt.Errorf("write failed: %s", err)
212	}
213
214	for _, response := range responses {
215		if n, err := os.Stdout.Write(response); n != len(response) || err != nil {
216			return fmt.Errorf("write failed: %s", err)
217		}
218	}
219
220	return nil
221}
222
223func xtsEncrypt(args [][]byte) error {
224	return doXTS(args, false)
225}
226
227func xtsDecrypt(args [][]byte) error {
228	return doXTS(args, true)
229}
230
231func doXTS(args [][]byte, decrypt bool) error {
232	if len(args) != 3 {
233		return fmt.Errorf("XTS received %d args, wanted 3", len(args))
234	}
235	key := args[0]
236	msg := args[1]
237	tweak := args[2]
238
239	if len(msg)%16 != 0 {
240		return fmt.Errorf("XTS received %d-byte msg, need multiple of 16", len(msg))
241	}
242	if len(tweak) != 16 {
243		return fmt.Errorf("XTS received %d-byte tweak, wanted 16", len(tweak))
244	}
245
246	var zeros [8]byte
247	if !bytes.Equal(tweak[8:], zeros[:]) {
248		return errors.New("XTS received tweak with invalid structure. Ensure that configuration specifies a 'number' tweak")
249	}
250
251	sectorNum := binary.LittleEndian.Uint64(tweak[:8])
252
253	c, err := xts.NewCipher(aes.NewCipher, key)
254	if err != nil {
255		return err
256	}
257
258	if decrypt {
259		c.Decrypt(msg, msg, sectorNum)
260	} else {
261		c.Encrypt(msg, msg, sectorNum)
262	}
263
264	return reply(msg)
265}
266
267func hkdfMAC(args [][]byte) error {
268	if len(args) != 4 {
269		return fmt.Errorf("HKDF received %d args, wanted 4", len(args))
270	}
271
272	key := args[0]
273	salt := args[1]
274	info := args[2]
275	lengthBytes := args[3]
276
277	if len(lengthBytes) != 4 {
278		return fmt.Errorf("uint32 length was %d bytes long", len(lengthBytes))
279	}
280
281	length := binary.LittleEndian.Uint32(lengthBytes)
282
283	mac := hkdf.New(sha256.New, key, salt, info)
284	ret := make([]byte, length)
285	mac.Read(ret)
286
287	return reply(ret)
288}
289
290func hmacDRBGReseed(args [][]byte) error {
291	if len(args) != 8 {
292		return fmt.Errorf("hmacDRBG received %d args, wanted 8", len(args))
293	}
294
295	outLenBytes, entropy, personalisation, reseedAdditionalData, reseedEntropy, additionalData1, additionalData2, nonce := args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
296
297	if len(outLenBytes) != 4 {
298		return fmt.Errorf("uint32 length was %d bytes long", len(outLenBytes))
299	}
300	outLen := binary.LittleEndian.Uint32(outLenBytes)
301	out := make([]byte, outLen)
302
303	drbg := NewHMACDRBG(entropy, nonce, personalisation)
304	drbg.Reseed(reseedEntropy, reseedAdditionalData)
305	drbg.Generate(out, additionalData1)
306	drbg.Generate(out, additionalData2)
307
308	return reply(out)
309}
310
311func hmacDRBGPredictionResistance(args [][]byte) error {
312	if len(args) != 8 {
313		return fmt.Errorf("hmacDRBG received %d args, wanted 8", len(args))
314	}
315
316	outLenBytes, entropy, personalisation, additionalData1, entropy1, additionalData2, entropy2, nonce := args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
317
318	if len(outLenBytes) != 4 {
319		return fmt.Errorf("uint32 length was %d bytes long", len(outLenBytes))
320	}
321	outLen := binary.LittleEndian.Uint32(outLenBytes)
322	out := make([]byte, outLen)
323
324	drbg := NewHMACDRBG(entropy, nonce, personalisation)
325	drbg.Reseed(entropy1, additionalData1)
326	drbg.Generate(out, nil)
327	drbg.Reseed(entropy2, additionalData2)
328	drbg.Generate(out, nil)
329
330	return reply(out)
331}
332
333func swapFinalTwoAESBlocks(d []byte) {
334	var blockNMinus1 [aes.BlockSize]byte
335	copy(blockNMinus1[:], d[len(d)-2*aes.BlockSize:])
336	copy(d[len(d)-2*aes.BlockSize:], d[len(d)-aes.BlockSize:])
337	copy(d[len(d)-aes.BlockSize:], blockNMinus1[:])
338}
339
340func roundUp(n, m int) int {
341	return n + (m-(n%m))%m
342}
343
344func doCTSEncrypt(key, origPlaintext, iv []byte) []byte {
345	// https://nvlpubs.nist.gov/nistpubs/legacy/sp/nistspecialpublication800-38a-add.pdf
346	if len(origPlaintext) < aes.BlockSize {
347		panic("input too small")
348	}
349
350	plaintext := make([]byte, roundUp(len(origPlaintext), aes.BlockSize))
351	copy(plaintext, origPlaintext)
352
353	block, err := aes.NewCipher(key)
354	if err != nil {
355		panic(err)
356	}
357	cbcEncryptor := cipher.NewCBCEncrypter(block, iv)
358	cbcEncryptor.CryptBlocks(plaintext, plaintext)
359	ciphertext := plaintext
360
361	if len(origPlaintext) > aes.BlockSize {
362		swapFinalTwoAESBlocks(ciphertext)
363
364		if len(origPlaintext)%16 != 0 {
365			// Truncate the ciphertext
366			ciphertext = ciphertext[:len(ciphertext)-aes.BlockSize+(len(origPlaintext)%aes.BlockSize)]
367		}
368	}
369
370	if len(ciphertext) != len(origPlaintext) {
371		panic("internal error")
372	}
373
374	return ciphertext
375}
376
377func doCTSDecrypt(key, origCiphertext, iv []byte) []byte {
378	if len(origCiphertext) < aes.BlockSize {
379		panic("input too small")
380	}
381
382	ciphertext := make([]byte, roundUp(len(origCiphertext), aes.BlockSize))
383	copy(ciphertext, origCiphertext)
384
385	if len(ciphertext) > aes.BlockSize {
386		swapFinalTwoAESBlocks(ciphertext)
387	}
388
389	block, err := aes.NewCipher(key)
390	if err != nil {
391		panic(err)
392	}
393	cbcDecrypter := cipher.NewCBCDecrypter(block, iv)
394
395	var plaintext []byte
396	if overhang := len(origCiphertext) % aes.BlockSize; overhang == 0 {
397		cbcDecrypter.CryptBlocks(ciphertext, ciphertext)
398		plaintext = ciphertext
399	} else {
400		ciphertext, finalBlock := ciphertext[:len(ciphertext)-aes.BlockSize], ciphertext[len(ciphertext)-aes.BlockSize:]
401		var plaintextFinalBlock [aes.BlockSize]byte
402		block.Decrypt(plaintextFinalBlock[:], finalBlock)
403		copy(ciphertext[len(ciphertext)-aes.BlockSize+overhang:], plaintextFinalBlock[overhang:])
404		plaintext = make([]byte, len(origCiphertext))
405		cbcDecrypter.CryptBlocks(plaintext, ciphertext)
406		for i := 0; i < overhang; i++ {
407			plaintextFinalBlock[i] ^= ciphertext[len(ciphertext)-aes.BlockSize+i]
408		}
409		copy(plaintext[len(ciphertext):], plaintextFinalBlock[:overhang])
410	}
411
412	return plaintext
413}
414
415func ctsEncrypt(args [][]byte) error {
416	if len(args) != 4 {
417		return fmt.Errorf("ctsEncrypt received %d args, wanted 4", len(args))
418	}
419
420	key, plaintext, iv, numIterations32 := args[0], args[1], args[2], args[3]
421	if len(numIterations32) != 4 || binary.LittleEndian.Uint32(numIterations32) != 1 {
422		return errors.New("only a single iteration supported for ctsEncrypt")
423	}
424
425	if len(plaintext) < aes.BlockSize {
426		return fmt.Errorf("ctsEncrypt plaintext too short: %d bytes", len(plaintext))
427	}
428
429	return reply(doCTSEncrypt(key, plaintext, iv))
430}
431
432func ctsDecrypt(args [][]byte) error {
433	if len(args) != 4 {
434		return fmt.Errorf("ctsDecrypt received %d args, wanted 4", len(args))
435	}
436
437	key, ciphertext, iv, numIterations32 := args[0], args[1], args[2], args[3]
438	if len(numIterations32) != 4 || binary.LittleEndian.Uint32(numIterations32) != 1 {
439		return errors.New("only a single iteration supported for ctsDecrypt")
440	}
441
442	if len(ciphertext) < aes.BlockSize {
443		return errors.New("ctsDecrypt ciphertext too short")
444	}
445
446	return reply(doCTSDecrypt(key, ciphertext, iv))
447}
448
449const (
450	maxArgs       = 9
451	maxArgLength  = 1 << 20
452	maxNameLength = 30
453)
454
455func main() {
456	if err := do(); err != nil {
457		fmt.Fprintf(os.Stderr, "%s.\n", err)
458		os.Exit(1)
459	}
460}
461
462func do() error {
463	var nums [4 * (1 + maxArgs)]byte
464	var argLengths [maxArgs]uint32
465	var args [maxArgs][]byte
466	var argsData []byte
467
468	for {
469		if _, err := io.ReadFull(os.Stdin, nums[:8]); err != nil {
470			return err
471		}
472
473		numArgs := binary.LittleEndian.Uint32(nums[:4])
474		if numArgs == 0 {
475			return errors.New("Invalid, zero-argument operation requested")
476		} else if numArgs > maxArgs {
477			return fmt.Errorf("Operation requested with %d args, but %d is the limit", numArgs, maxArgs)
478		}
479
480		if numArgs > 1 {
481			if _, err := io.ReadFull(os.Stdin, nums[8:4+4*numArgs]); err != nil {
482				return err
483			}
484		}
485
486		input := nums[4:]
487		var need uint64
488		for i := uint32(0); i < numArgs; i++ {
489			argLength := binary.LittleEndian.Uint32(input[:4])
490			if i == 0 && argLength > maxNameLength {
491				return fmt.Errorf("Operation with name of length %d exceeded limit of %d", argLength, maxNameLength)
492			} else if argLength > maxArgLength {
493				return fmt.Errorf("Operation with argument of length %d exceeded limit of %d", argLength, maxArgLength)
494			}
495			need += uint64(argLength)
496			argLengths[i] = argLength
497			input = input[4:]
498		}
499
500		if need > uint64(cap(argsData)) {
501			argsData = make([]byte, need)
502		} else {
503			argsData = argsData[:need]
504		}
505
506		if _, err := io.ReadFull(os.Stdin, argsData); err != nil {
507			return err
508		}
509
510		input = argsData
511		for i := uint32(0); i < numArgs; i++ {
512			args[i] = input[:argLengths[i]]
513			input = input[argLengths[i]:]
514		}
515
516		name := string(args[0])
517		if handler, ok := handlers[name]; !ok {
518			return fmt.Errorf("unknown operation %q", name)
519		} else {
520			if err := handler(args[1:numArgs]); err != nil {
521				return err
522			}
523		}
524	}
525}
526