• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright (c) 2019, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <string>
16 #include <vector>
17 
18 #include <assert.h>
19 #include <errno.h>
20 #include <string.h>
21 #include <sys/uio.h>
22 #include <unistd.h>
23 #include <cstdarg>
24 
25 #include <openssl/aes.h>
26 #include <openssl/sha.h>
27 #include <openssl/span.h>
28 
29 static constexpr size_t kMaxArgs = 8;
30 static constexpr size_t kMaxArgLength = (1 << 20);
31 static constexpr size_t kMaxNameLength = 30;
32 
33 static_assert((kMaxArgs - 1 * kMaxArgLength) + kMaxNameLength > (1 << 30),
34               "Argument limits permit excessive messages");
35 
36 using namespace bssl;
37 
ReadAll(int fd,void * in_data,size_t data_len)38 static bool ReadAll(int fd, void *in_data, size_t data_len) {
39   uint8_t *data = reinterpret_cast<uint8_t *>(in_data);
40   size_t done = 0;
41 
42   while (done < data_len) {
43     ssize_t r;
44     do {
45       r = read(fd, &data[done], data_len - done);
46     } while (r == -1 && errno == EINTR);
47 
48     if (r <= 0) {
49       return false;
50     }
51 
52     done += r;
53   }
54 
55   return true;
56 }
57 
58 template <typename... Args>
WriteReply(int fd,Args...args)59 static bool WriteReply(int fd, Args... args) {
60   std::vector<Span<const uint8_t>> spans = {args...};
61   if (spans.empty() || spans.size() > kMaxArgs) {
62     abort();
63   }
64 
65   uint32_t nums[1 + kMaxArgs];
66   iovec iovs[kMaxArgs + 1];
67   nums[0] = spans.size();
68   iovs[0].iov_base = nums;
69   iovs[0].iov_len = sizeof(uint32_t) * (1 + spans.size());
70 
71   for (size_t i = 0; i < spans.size(); i++) {
72     const auto &span = spans[i];
73     nums[i + 1] = span.size();
74     iovs[i + 1].iov_base = const_cast<uint8_t *>(span.data());
75     iovs[i + 1].iov_len = span.size();
76   }
77 
78   const size_t num_iov = spans.size() + 1;
79   size_t iov_done = 0;
80   while (iov_done < num_iov) {
81     ssize_t r;
82     do {
83       r = writev(fd, &iovs[iov_done], num_iov - iov_done);
84     } while (r == -1 && errno == EINTR);
85 
86     if (r <= 0) {
87       return false;
88     }
89 
90     size_t written = r;
91     for (size_t i = iov_done; written > 0 && i < num_iov; i++) {
92       iovec &iov = iovs[i];
93 
94       size_t done = written;
95       if (done > iov.iov_len) {
96         done = iov.iov_len;
97       }
98 
99       iov.iov_base = reinterpret_cast<uint8_t *>(iov.iov_base) + done;
100       iov.iov_len -= done;
101       written -= done;
102 
103       if (iov.iov_len == 0) {
104         iov_done++;
105       }
106     }
107 
108     assert(written == 0);
109   }
110 
111   return true;
112 }
113 
GetConfig(const Span<const uint8_t> args[])114 static bool GetConfig(const Span<const uint8_t> args[]) {
115   static constexpr char kConfig[] =
116       "["
117       "{"
118       "  \"algorithm\": \"SHA2-224\","
119       "  \"revision\": \"1.0\","
120       "  \"messageLength\": [{"
121       "    \"min\": 0, \"max\": 65528, \"increment\": 8"
122       "  }]"
123       "},"
124       "{"
125       "  \"algorithm\": \"SHA2-256\","
126       "  \"revision\": \"1.0\","
127       "  \"messageLength\": [{"
128       "    \"min\": 0, \"max\": 65528, \"increment\": 8"
129       "  }]"
130       "},"
131       "{"
132       "  \"algorithm\": \"SHA2-384\","
133       "  \"revision\": \"1.0\","
134       "  \"messageLength\": [{"
135       "    \"min\": 0, \"max\": 65528, \"increment\": 8"
136       "  }]"
137       "},"
138       "{"
139       "  \"algorithm\": \"SHA2-512\","
140       "  \"revision\": \"1.0\","
141       "  \"messageLength\": [{"
142       "    \"min\": 0, \"max\": 65528, \"increment\": 8"
143       "  }]"
144       "},"
145       "{"
146       "  \"algorithm\": \"SHA-1\","
147       "  \"revision\": \"1.0\","
148       "  \"messageLength\": [{"
149       "    \"min\": 0, \"max\": 65528, \"increment\": 8"
150       "  }]"
151       "},"
152       "{"
153       "  \"algorithm\": \"ACVP-AES-ECB\","
154       "  \"revision\": \"1.0\","
155       "  \"direction\": [\"encrypt\", \"decrypt\"],"
156       "  \"keyLen\": [128, 192, 256]"
157       "},"
158       "{"
159       "  \"algorithm\": \"ACVP-AES-CBC\","
160       "  \"revision\": \"1.0\","
161       "  \"direction\": [\"encrypt\", \"decrypt\"],"
162       "  \"keyLen\": [128, 192, 256]"
163       "}"
164       "]";
165   return WriteReply(
166       STDOUT_FILENO,
167       Span<const uint8_t>(reinterpret_cast<const uint8_t *>(kConfig),
168                           sizeof(kConfig) - 1));
169 }
170 
171 template <uint8_t *(*OneShotHash)(const uint8_t *, size_t, uint8_t *),
172           size_t DigestLength>
Hash(const Span<const uint8_t> args[])173 static bool Hash(const Span<const uint8_t> args[]) {
174   uint8_t digest[DigestLength];
175   OneShotHash(args[0].data(), args[0].size(), digest);
176   return WriteReply(STDOUT_FILENO, Span<const uint8_t>(digest));
177 }
178 
179 template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
180           void (*Block)(const uint8_t *in, uint8_t *out, const AES_KEY *key)>
AES(const Span<const uint8_t> args[])181 static bool AES(const Span<const uint8_t> args[]) {
182   AES_KEY key;
183   if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
184     return false;
185   }
186   if (args[1].size() % AES_BLOCK_SIZE != 0) {
187     return false;
188   }
189 
190   std::vector<uint8_t> out;
191   out.resize(args[1].size());
192   for (size_t i = 0; i < args[1].size(); i += AES_BLOCK_SIZE) {
193     Block(args[1].data() + i, &out[i], &key);
194   }
195   return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
196 }
197 
198 template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
199           int Direction>
AES_CBC(const Span<const uint8_t> args[])200 static bool AES_CBC(const Span<const uint8_t> args[]) {
201   AES_KEY key;
202   if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
203     return false;
204   }
205   if (args[1].size() % AES_BLOCK_SIZE != 0 ||
206       args[2].size() != AES_BLOCK_SIZE) {
207     return false;
208   }
209   uint8_t iv[AES_BLOCK_SIZE];
210   memcpy(iv, args[2].data(), AES_BLOCK_SIZE);
211 
212   std::vector<uint8_t> out;
213   out.resize(args[1].size());
214   AES_cbc_encrypt(args[1].data(), out.data(), args[1].size(), &key, iv,
215                   Direction);
216   return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
217 }
218 
219 static constexpr struct {
220   const char name[kMaxNameLength + 1];
221   uint8_t expected_args;
222   bool (*handler)(const Span<const uint8_t>[]);
223 } kFunctions[] = {
224     {"getConfig", 0, GetConfig},
225     {"SHA-1", 1, Hash<SHA1, SHA_DIGEST_LENGTH>},
226     {"SHA2-224", 1, Hash<SHA224, SHA224_DIGEST_LENGTH>},
227     {"SHA2-256", 1, Hash<SHA256, SHA256_DIGEST_LENGTH>},
228     {"SHA2-384", 1, Hash<SHA384, SHA256_DIGEST_LENGTH>},
229     {"SHA2-512", 1, Hash<SHA512, SHA512_DIGEST_LENGTH>},
230     {"AES/encrypt", 2, AES<AES_set_encrypt_key, AES_encrypt>},
231     {"AES/decrypt", 2, AES<AES_set_decrypt_key, AES_decrypt>},
232     {"AES-CBC/encrypt", 3, AES_CBC<AES_set_encrypt_key, AES_ENCRYPT>},
233     {"AES-CBC/decrypt", 3, AES_CBC<AES_set_decrypt_key, AES_DECRYPT>},
234 };
235 
main()236 int main() {
237   uint32_t nums[1 + kMaxArgs];
238   uint8_t *buf = nullptr;
239   size_t buf_len = 0;
240   Span<const uint8_t> args[kMaxArgs];
241 
242   for (;;) {
243     if (!ReadAll(STDIN_FILENO, nums, sizeof(uint32_t) * 2)) {
244       return 1;
245     }
246 
247     const size_t num_args = nums[0];
248     if (num_args == 0) {
249       fprintf(stderr, "Invalid, zero-argument operation requested.\n");
250       return 2;
251     } else if (num_args > kMaxArgs) {
252       fprintf(stderr,
253               "Operation requested with %zu args, but %zu is the limit.\n",
254               num_args, kMaxArgs);
255       return 2;
256     }
257 
258     if (num_args > 1 &&
259         !ReadAll(STDIN_FILENO, &nums[2], sizeof(uint32_t) * (num_args - 1))) {
260       return 1;
261     }
262 
263     size_t need = 0;
264     for (size_t i = 0; i < num_args; i++) {
265       const size_t arg_length = nums[i + 1];
266       if (i == 0 && arg_length > kMaxNameLength) {
267         fprintf(stderr,
268                 "Operation with name of length %zu exceeded limit of %zu.\n",
269                 arg_length, kMaxNameLength);
270         return 2;
271       } else if (arg_length > kMaxArgLength) {
272         fprintf(
273             stderr,
274             "Operation with argument of length %zu exceeded limit of %zu.\n",
275             arg_length, kMaxArgLength);
276         return 2;
277       }
278 
279       // static_assert around kMaxArgs etc enforces that this doesn't overflow.
280       need += arg_length;
281     }
282 
283     if (need > buf_len) {
284       free(buf);
285       size_t alloced = need + (need >> 1);
286       if (alloced < need) {
287         abort();
288       }
289       buf = reinterpret_cast<uint8_t *>(malloc(alloced));
290       if (buf == nullptr) {
291         abort();
292       }
293       buf_len = alloced;
294     }
295 
296     if (!ReadAll(STDIN_FILENO, buf, need)) {
297       return 1;
298     }
299 
300     size_t offset = 0;
301     for (size_t i = 0; i < num_args; i++) {
302       args[i] = Span<const uint8_t>(&buf[offset], nums[i + 1]);
303       offset += nums[i + 1];
304     }
305 
306     bool found = true;
307     for (const auto &func : kFunctions) {
308       if (args[0].size() == strlen(func.name) &&
309           memcmp(args[0].data(), func.name, args[0].size()) == 0) {
310         if (num_args - 1 != func.expected_args) {
311           fprintf(stderr,
312                   "\'%s\' operation received %zu arguments but expected %u.\n",
313                   func.name, num_args - 1, func.expected_args);
314           return 2;
315         }
316 
317         if (!func.handler(&args[1])) {
318           return 4;
319         }
320 
321         found = true;
322         break;
323       }
324     }
325 
326     if (!found) {
327       const std::string name(reinterpret_cast<const char *>(args[0].data()),
328                              args[0].size());
329       fprintf(stderr, "Unknown operation: %s\n", name.c_str());
330       return 3;
331     }
332   }
333 }
334