1 /* Copyright (c) 2019, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #include <string>
16 #include <vector>
17
18 #include <assert.h>
19 #include <errno.h>
20 #include <string.h>
21 #include <sys/uio.h>
22 #include <unistd.h>
23 #include <cstdarg>
24
25 #include <openssl/aes.h>
26 #include <openssl/sha.h>
27 #include <openssl/span.h>
28
29 static constexpr size_t kMaxArgs = 8;
30 static constexpr size_t kMaxArgLength = (1 << 20);
31 static constexpr size_t kMaxNameLength = 30;
32
33 static_assert((kMaxArgs - 1 * kMaxArgLength) + kMaxNameLength > (1 << 30),
34 "Argument limits permit excessive messages");
35
36 using namespace bssl;
37
ReadAll(int fd,void * in_data,size_t data_len)38 static bool ReadAll(int fd, void *in_data, size_t data_len) {
39 uint8_t *data = reinterpret_cast<uint8_t *>(in_data);
40 size_t done = 0;
41
42 while (done < data_len) {
43 ssize_t r;
44 do {
45 r = read(fd, &data[done], data_len - done);
46 } while (r == -1 && errno == EINTR);
47
48 if (r <= 0) {
49 return false;
50 }
51
52 done += r;
53 }
54
55 return true;
56 }
57
58 template <typename... Args>
WriteReply(int fd,Args...args)59 static bool WriteReply(int fd, Args... args) {
60 std::vector<Span<const uint8_t>> spans = {args...};
61 if (spans.empty() || spans.size() > kMaxArgs) {
62 abort();
63 }
64
65 uint32_t nums[1 + kMaxArgs];
66 iovec iovs[kMaxArgs + 1];
67 nums[0] = spans.size();
68 iovs[0].iov_base = nums;
69 iovs[0].iov_len = sizeof(uint32_t) * (1 + spans.size());
70
71 for (size_t i = 0; i < spans.size(); i++) {
72 const auto &span = spans[i];
73 nums[i + 1] = span.size();
74 iovs[i + 1].iov_base = const_cast<uint8_t *>(span.data());
75 iovs[i + 1].iov_len = span.size();
76 }
77
78 const size_t num_iov = spans.size() + 1;
79 size_t iov_done = 0;
80 while (iov_done < num_iov) {
81 ssize_t r;
82 do {
83 r = writev(fd, &iovs[iov_done], num_iov - iov_done);
84 } while (r == -1 && errno == EINTR);
85
86 if (r <= 0) {
87 return false;
88 }
89
90 size_t written = r;
91 for (size_t i = iov_done; written > 0 && i < num_iov; i++) {
92 iovec &iov = iovs[i];
93
94 size_t done = written;
95 if (done > iov.iov_len) {
96 done = iov.iov_len;
97 }
98
99 iov.iov_base = reinterpret_cast<uint8_t *>(iov.iov_base) + done;
100 iov.iov_len -= done;
101 written -= done;
102
103 if (iov.iov_len == 0) {
104 iov_done++;
105 }
106 }
107
108 assert(written == 0);
109 }
110
111 return true;
112 }
113
GetConfig(const Span<const uint8_t> args[])114 static bool GetConfig(const Span<const uint8_t> args[]) {
115 static constexpr char kConfig[] =
116 "["
117 "{"
118 " \"algorithm\": \"SHA2-224\","
119 " \"revision\": \"1.0\","
120 " \"messageLength\": [{"
121 " \"min\": 0, \"max\": 65528, \"increment\": 8"
122 " }]"
123 "},"
124 "{"
125 " \"algorithm\": \"SHA2-256\","
126 " \"revision\": \"1.0\","
127 " \"messageLength\": [{"
128 " \"min\": 0, \"max\": 65528, \"increment\": 8"
129 " }]"
130 "},"
131 "{"
132 " \"algorithm\": \"SHA2-384\","
133 " \"revision\": \"1.0\","
134 " \"messageLength\": [{"
135 " \"min\": 0, \"max\": 65528, \"increment\": 8"
136 " }]"
137 "},"
138 "{"
139 " \"algorithm\": \"SHA2-512\","
140 " \"revision\": \"1.0\","
141 " \"messageLength\": [{"
142 " \"min\": 0, \"max\": 65528, \"increment\": 8"
143 " }]"
144 "},"
145 "{"
146 " \"algorithm\": \"SHA-1\","
147 " \"revision\": \"1.0\","
148 " \"messageLength\": [{"
149 " \"min\": 0, \"max\": 65528, \"increment\": 8"
150 " }]"
151 "},"
152 "{"
153 " \"algorithm\": \"ACVP-AES-ECB\","
154 " \"revision\": \"1.0\","
155 " \"direction\": [\"encrypt\", \"decrypt\"],"
156 " \"keyLen\": [128, 192, 256]"
157 "},"
158 "{"
159 " \"algorithm\": \"ACVP-AES-CBC\","
160 " \"revision\": \"1.0\","
161 " \"direction\": [\"encrypt\", \"decrypt\"],"
162 " \"keyLen\": [128, 192, 256]"
163 "}"
164 "]";
165 return WriteReply(
166 STDOUT_FILENO,
167 Span<const uint8_t>(reinterpret_cast<const uint8_t *>(kConfig),
168 sizeof(kConfig) - 1));
169 }
170
171 template <uint8_t *(*OneShotHash)(const uint8_t *, size_t, uint8_t *),
172 size_t DigestLength>
Hash(const Span<const uint8_t> args[])173 static bool Hash(const Span<const uint8_t> args[]) {
174 uint8_t digest[DigestLength];
175 OneShotHash(args[0].data(), args[0].size(), digest);
176 return WriteReply(STDOUT_FILENO, Span<const uint8_t>(digest));
177 }
178
179 template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
180 void (*Block)(const uint8_t *in, uint8_t *out, const AES_KEY *key)>
AES(const Span<const uint8_t> args[])181 static bool AES(const Span<const uint8_t> args[]) {
182 AES_KEY key;
183 if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
184 return false;
185 }
186 if (args[1].size() % AES_BLOCK_SIZE != 0) {
187 return false;
188 }
189
190 std::vector<uint8_t> out;
191 out.resize(args[1].size());
192 for (size_t i = 0; i < args[1].size(); i += AES_BLOCK_SIZE) {
193 Block(args[1].data() + i, &out[i], &key);
194 }
195 return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
196 }
197
198 template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
199 int Direction>
AES_CBC(const Span<const uint8_t> args[])200 static bool AES_CBC(const Span<const uint8_t> args[]) {
201 AES_KEY key;
202 if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
203 return false;
204 }
205 if (args[1].size() % AES_BLOCK_SIZE != 0 ||
206 args[2].size() != AES_BLOCK_SIZE) {
207 return false;
208 }
209 uint8_t iv[AES_BLOCK_SIZE];
210 memcpy(iv, args[2].data(), AES_BLOCK_SIZE);
211
212 std::vector<uint8_t> out;
213 out.resize(args[1].size());
214 AES_cbc_encrypt(args[1].data(), out.data(), args[1].size(), &key, iv,
215 Direction);
216 return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
217 }
218
219 static constexpr struct {
220 const char name[kMaxNameLength + 1];
221 uint8_t expected_args;
222 bool (*handler)(const Span<const uint8_t>[]);
223 } kFunctions[] = {
224 {"getConfig", 0, GetConfig},
225 {"SHA-1", 1, Hash<SHA1, SHA_DIGEST_LENGTH>},
226 {"SHA2-224", 1, Hash<SHA224, SHA224_DIGEST_LENGTH>},
227 {"SHA2-256", 1, Hash<SHA256, SHA256_DIGEST_LENGTH>},
228 {"SHA2-384", 1, Hash<SHA384, SHA256_DIGEST_LENGTH>},
229 {"SHA2-512", 1, Hash<SHA512, SHA512_DIGEST_LENGTH>},
230 {"AES/encrypt", 2, AES<AES_set_encrypt_key, AES_encrypt>},
231 {"AES/decrypt", 2, AES<AES_set_decrypt_key, AES_decrypt>},
232 {"AES-CBC/encrypt", 3, AES_CBC<AES_set_encrypt_key, AES_ENCRYPT>},
233 {"AES-CBC/decrypt", 3, AES_CBC<AES_set_decrypt_key, AES_DECRYPT>},
234 };
235
main()236 int main() {
237 uint32_t nums[1 + kMaxArgs];
238 uint8_t *buf = nullptr;
239 size_t buf_len = 0;
240 Span<const uint8_t> args[kMaxArgs];
241
242 for (;;) {
243 if (!ReadAll(STDIN_FILENO, nums, sizeof(uint32_t) * 2)) {
244 return 1;
245 }
246
247 const size_t num_args = nums[0];
248 if (num_args == 0) {
249 fprintf(stderr, "Invalid, zero-argument operation requested.\n");
250 return 2;
251 } else if (num_args > kMaxArgs) {
252 fprintf(stderr,
253 "Operation requested with %zu args, but %zu is the limit.\n",
254 num_args, kMaxArgs);
255 return 2;
256 }
257
258 if (num_args > 1 &&
259 !ReadAll(STDIN_FILENO, &nums[2], sizeof(uint32_t) * (num_args - 1))) {
260 return 1;
261 }
262
263 size_t need = 0;
264 for (size_t i = 0; i < num_args; i++) {
265 const size_t arg_length = nums[i + 1];
266 if (i == 0 && arg_length > kMaxNameLength) {
267 fprintf(stderr,
268 "Operation with name of length %zu exceeded limit of %zu.\n",
269 arg_length, kMaxNameLength);
270 return 2;
271 } else if (arg_length > kMaxArgLength) {
272 fprintf(
273 stderr,
274 "Operation with argument of length %zu exceeded limit of %zu.\n",
275 arg_length, kMaxArgLength);
276 return 2;
277 }
278
279 // static_assert around kMaxArgs etc enforces that this doesn't overflow.
280 need += arg_length;
281 }
282
283 if (need > buf_len) {
284 free(buf);
285 size_t alloced = need + (need >> 1);
286 if (alloced < need) {
287 abort();
288 }
289 buf = reinterpret_cast<uint8_t *>(malloc(alloced));
290 if (buf == nullptr) {
291 abort();
292 }
293 buf_len = alloced;
294 }
295
296 if (!ReadAll(STDIN_FILENO, buf, need)) {
297 return 1;
298 }
299
300 size_t offset = 0;
301 for (size_t i = 0; i < num_args; i++) {
302 args[i] = Span<const uint8_t>(&buf[offset], nums[i + 1]);
303 offset += nums[i + 1];
304 }
305
306 bool found = true;
307 for (const auto &func : kFunctions) {
308 if (args[0].size() == strlen(func.name) &&
309 memcmp(args[0].data(), func.name, args[0].size()) == 0) {
310 if (num_args - 1 != func.expected_args) {
311 fprintf(stderr,
312 "\'%s\' operation received %zu arguments but expected %u.\n",
313 func.name, num_args - 1, func.expected_args);
314 return 2;
315 }
316
317 if (!func.handler(&args[1])) {
318 return 4;
319 }
320
321 found = true;
322 break;
323 }
324 }
325
326 if (!found) {
327 const std::string name(reinterpret_cast<const char *>(args[0].data()),
328 args[0].size());
329 fprintf(stderr, "Unknown operation: %s\n", name.c_str());
330 return 3;
331 }
332 }
333 }
334