1 //
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // https://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 // Unittest for the sandbox2::Comms class.
15
16 #include "sandboxed_api/sandbox2/comms.h"
17
18 #include <fcntl.h>
19 #include <sys/socket.h>
20 #include <sys/types.h>
21 #include <unistd.h>
22
23 #include <cstdint>
24 #include <cstring>
25 #include <functional>
26 #include <memory>
27 #include <string>
28 #include <vector>
29
30 #include "gmock/gmock.h"
31 #include "gtest/gtest.h"
32 #include "absl/container/fixed_array.h"
33 #include "absl/log/check.h"
34 #include "absl/log/log.h"
35 #include "absl/status/status.h"
36 #include "absl/strings/string_view.h"
37 #include "sandboxed_api/sandbox2/comms_test.pb.h"
38 #include "sandboxed_api/util/status_matchers.h"
39 #include "sandboxed_api/util/thread.h"
40
41 namespace sandbox2 {
42 namespace {
43
44 using ::sapi::IsOk;
45 using ::sapi::StatusIs;
46 using ::testing::Eq;
47 using ::testing::IsEmpty;
48 using ::testing::IsFalse;
49 using ::testing::IsTrue;
50
51 using CommunicationHandler = std::function<void(Comms* comms)>;
52
53 constexpr char kProtoStr[] = "ABCD";
NullTestString()54 static absl::string_view NullTestString() {
55 static constexpr char kHelperStr[] = "test\0\n\r\t\x01\x02";
56 return absl::string_view(kHelperStr, sizeof(kHelperStr) - 1);
57 }
58
59 // Helper function that handles the communication between the two handler
60 // functions.
HandleCommunication(const CommunicationHandler & a,const CommunicationHandler & b)61 void HandleCommunication(const CommunicationHandler& a,
62 const CommunicationHandler& b) {
63 int sv[2];
64 CHECK_NE(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), -1);
65 Comms comms(sv[0]);
66
67 // Start handler a.
68 sapi::Thread remote([sv, &a]() {
69 Comms my_comms(sv[1]);
70 a(&my_comms);
71 });
72
73 // Accept connection and run handler b.
74 b(&comms);
75 remote.Join();
76 }
77
TEST(CommsTest,TestSendRecv8)78 TEST(CommsTest, TestSendRecv8) {
79 auto a = [](Comms* comms) {
80 // Send Uint8.
81 ASSERT_THAT(comms->SendUint8(192), IsTrue());
82
83 // Recv Int8.
84 int8_t tmp8;
85 ASSERT_THAT(comms->RecvInt8(&tmp8), IsTrue());
86 EXPECT_THAT(tmp8, Eq(-7));
87 };
88 auto b = [](Comms* comms) {
89 // Recv Uint8.
90 uint8_t tmpu8;
91 ASSERT_THAT(comms->RecvUint8(&tmpu8), IsTrue());
92 EXPECT_THAT(tmpu8, Eq(192));
93
94 // Send Int8.
95 ASSERT_THAT(comms->SendInt8(-7), IsTrue());
96 };
97 HandleCommunication(a, b);
98 }
99
TEST(CommsTest,TestSendRecv16)100 TEST(CommsTest, TestSendRecv16) {
101 auto a = [](Comms* comms) {
102 // Send Uint16.
103 ASSERT_THAT(comms->SendUint16(40001), IsTrue());
104
105 // Recv Int16.
106 int16_t tmp16;
107 ASSERT_THAT(comms->RecvInt16(&tmp16), IsTrue());
108 EXPECT_THAT(tmp16, Eq(-22050));
109 };
110 auto b = [](Comms* comms) {
111 // Recv Uint16.
112 uint16_t tmpu16;
113 ASSERT_THAT(comms->RecvUint16(&tmpu16), IsTrue());
114 EXPECT_THAT(tmpu16, Eq(40001));
115
116 // Send Int16.
117 ASSERT_THAT(comms->SendInt16(-22050), IsTrue());
118 };
119 HandleCommunication(a, b);
120 }
121
TEST(CommsTest,TestSendRecv32)122 TEST(CommsTest, TestSendRecv32) {
123 auto a = [](Comms* comms) {
124 // SendUint32.
125 ASSERT_THAT(comms->SendUint32(3221225472UL), IsTrue());
126
127 // Recv Int32.
128 int32_t tmp32;
129 ASSERT_THAT(comms->RecvInt32(&tmp32), IsTrue());
130 EXPECT_THAT(tmp32, Eq(-1073741824));
131 };
132 auto b = [](Comms* comms) {
133 // Recv Uint32.
134 uint32_t tmpu32;
135 ASSERT_THAT(comms->RecvUint32(&tmpu32), IsTrue());
136 EXPECT_THAT(tmpu32, Eq(3221225472UL));
137
138 // Send Int32.
139 ASSERT_THAT(comms->SendInt32(-1073741824), IsTrue());
140 };
141 HandleCommunication(a, b);
142 }
143
TEST(CommsTest,TestSendRecv64)144 TEST(CommsTest, TestSendRecv64) {
145 auto a = [](Comms* comms) {
146 // SendUint64.
147 ASSERT_THAT(comms->SendUint64(1099511627776ULL), IsTrue());
148
149 // Recv Int64.
150 int64_t tmp64;
151 ASSERT_THAT(comms->RecvInt64(&tmp64), IsTrue());
152 EXPECT_THAT(tmp64, Eq(-1099511627776LL));
153 };
154 auto b = [](Comms* comms) {
155 // Recv Uint64.
156 uint64_t tmpu64;
157 ASSERT_THAT(comms->RecvUint64(&tmpu64), IsTrue());
158 EXPECT_THAT(tmpu64, Eq(1099511627776ULL));
159
160 // Send Int64.
161 ASSERT_THAT(comms->SendInt64(-1099511627776LL), IsTrue());
162 };
163 HandleCommunication(a, b);
164 }
165
TEST(CommsTest,TestTypeMismatch)166 TEST(CommsTest, TestTypeMismatch) {
167 auto a = [](Comms* comms) {
168 uint8_t tmpu8;
169 // Receive Int8 (but Uint8 expected).
170 EXPECT_THAT(comms->RecvUint8(&tmpu8), IsFalse());
171 };
172 auto b = [](Comms* comms) {
173 // Send Int8 (but Uint8 expected).
174 ASSERT_THAT(comms->SendInt8(-93), IsTrue());
175 };
176 HandleCommunication(a, b);
177 }
178
TEST(CommsTest,TestSendRecvString)179 TEST(CommsTest, TestSendRecvString) {
180 auto a = [](Comms* comms) {
181 std::string tmps;
182 ASSERT_THAT(comms->RecvString(&tmps), IsTrue());
183 EXPECT_TRUE(tmps == NullTestString());
184 EXPECT_THAT(tmps.size(), Eq(NullTestString().size()));
185 };
186 auto b = [](Comms* comms) {
187 ASSERT_THAT(comms->SendString(std::string(NullTestString())), IsTrue());
188 };
189 HandleCommunication(a, b);
190 }
191
TEST(CommsTest,TestSendRecvArray)192 TEST(CommsTest, TestSendRecvArray) {
193 auto a = [](Comms* comms) {
194 // Receive 1M bytes.
195 std::vector<uint8_t> buffer;
196 ASSERT_THAT(comms->RecvBytes(&buffer), IsTrue());
197 EXPECT_THAT(buffer.size(), Eq(1024 * 1024));
198 };
199 auto b = [](Comms* comms) {
200 // Send 1M bytes.
201 std::vector<uint8_t> buffer(1024 * 1024, 0);
202 ASSERT_THAT(comms->SendBytes(buffer), IsTrue());
203 };
204 HandleCommunication(a, b);
205 }
206
TEST(CommsTest,TestSendRecvEmptyArray)207 TEST(CommsTest, TestSendRecvEmptyArray) {
208 auto a = [](Comms* comms) {
209 std::vector<uint8_t> buffer;
210 ASSERT_THAT(comms->RecvBytes(&buffer), IsTrue());
211 EXPECT_THAT(buffer, IsEmpty());
212 };
213 auto b = [](Comms* comms) { ASSERT_THAT(comms->SendBytes({}), IsTrue()); };
214 HandleCommunication(a, b);
215 }
216
TEST(CommsTest,TestSendRecvFD)217 TEST(CommsTest, TestSendRecvFD) {
218 auto a = [](Comms* comms) {
219 // Receive FD and test it.
220 int fd = -1;
221 ASSERT_THAT(comms->RecvFD(&fd), IsTrue());
222 EXPECT_GE(fd, 0);
223 EXPECT_NE(fcntl(fd, F_GETFD), -1);
224 close(fd);
225 };
226 auto b = [](Comms* comms) {
227 // Send our STDERR to the thread.
228 ASSERT_THAT(comms->SendFD(STDERR_FILENO), IsTrue());
229 };
230 HandleCommunication(a, b);
231 }
232
TEST(CommsTest,TestSendRecvEmptyTLV)233 TEST(CommsTest, TestSendRecvEmptyTLV) {
234 auto a = [](Comms* comms) {
235 // Receive TLV without a value.
236 uint32_t tag;
237 std::vector<uint8_t> value;
238 ASSERT_THAT(comms->RecvTLV(&tag, &value), IsTrue()); // NOLINT
239 EXPECT_THAT(tag, Eq(0x00DEADBE));
240 EXPECT_THAT(value.size(), Eq(0));
241 };
242 auto b = [](Comms* comms) {
243 // Send TLV without a value.
244 ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue());
245 };
246 HandleCommunication(a, b);
247 }
248
TEST(CommsTest,TestSendRecvEmptyTLV2)249 TEST(CommsTest, TestSendRecvEmptyTLV2) {
250 auto a = [](Comms* comms) {
251 // Receive TLV without a value.
252 uint32_t tag;
253 std::vector<uint8_t> data;
254 ASSERT_THAT(comms->RecvTLV(&tag, &data), IsTrue());
255 EXPECT_THAT(tag, Eq(0x00DEADBE));
256 EXPECT_THAT(data.size(), Eq(0));
257 };
258 auto b = [](Comms* comms) {
259 // Send TLV without a value.
260 ASSERT_THAT(comms->SendTLV(0x00DEADBE, 0, nullptr), IsTrue());
261 };
262 HandleCommunication(a, b);
263 }
264
TEST(CommsTest,TestSendRecvProto)265 TEST(CommsTest, TestSendRecvProto) {
266 auto a = [](Comms* comms) {
267 // Receive a ProtoBuf.
268 std::unique_ptr<CommsTestMsg> comms_msg(new CommsTestMsg());
269 ASSERT_THAT(comms->RecvProtoBuf(comms_msg.get()), IsTrue());
270 ASSERT_THAT(comms_msg->value_size(), Eq(1));
271 EXPECT_THAT(comms_msg->value(0), Eq(kProtoStr));
272 };
273 auto b = [](Comms* comms) {
274 // Send a ProtoBuf.
275 std::unique_ptr<CommsTestMsg> comms_msg(new CommsTestMsg());
276 comms_msg->add_value(kProtoStr);
277 ASSERT_THAT(comms_msg->value_size(), Eq(1));
278 ASSERT_THAT(comms->SendProtoBuf(*comms_msg), IsTrue());
279 };
280 HandleCommunication(a, b);
281 }
282
TEST(CommsTest,TestSendRecvStatusOK)283 TEST(CommsTest, TestSendRecvStatusOK) {
284 auto a = [](Comms* comms) {
285 // Receive a good status.
286 absl::Status status;
287 ASSERT_THAT(comms->RecvStatus(&status), IsTrue());
288 EXPECT_THAT(status, IsOk());
289 };
290 auto b = [](Comms* comms) {
291 // Send a good status.
292 ASSERT_THAT(comms->SendStatus(absl::OkStatus()), IsTrue());
293 };
294 HandleCommunication(a, b);
295 }
296
TEST(CommsTest,TestSendRecvStatusFailing)297 TEST(CommsTest, TestSendRecvStatusFailing) {
298 auto a = [](Comms* comms) {
299 // Receive a failing status.
300 absl::Status status;
301 ASSERT_THAT(comms->RecvStatus(&status), IsTrue());
302 EXPECT_THAT(status, Not(IsOk()));
303 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "something odd"));
304 };
305 auto b = [](Comms* comms) {
306 // Send a failing status.
307 ASSERT_THAT(comms->SendStatus(
308 absl::Status{absl::StatusCode::kInternal, "something odd"}),
309 IsTrue());
310 };
311 HandleCommunication(a, b);
312 }
313
TEST(CommsTest,TestUsesDistinctBuffers)314 TEST(CommsTest, TestUsesDistinctBuffers) {
315 auto a = [](Comms* comms) {
316 // Receive 1M bytes.
317 std::vector<uint8_t> buffer1, buffer2;
318 ASSERT_THAT(comms->RecvBytes(&buffer1), IsTrue()); // NOLINT
319 EXPECT_THAT(buffer1.size(), Eq(1024 * 1024));
320
321 ASSERT_THAT(comms->RecvBytes(&buffer2), IsTrue()); // NOLINT
322 EXPECT_THAT(buffer2.size(), Eq(1024 * 1024));
323
324 // Make sure we can access the buffer (memory was not free'd).
325 // Probably only useful when running with ASAN/MSAN.
326 EXPECT_THAT(buffer1[1024 * 1024 - 1], Eq(buffer1[1024 * 1024 - 1]));
327 EXPECT_THAT(buffer2[1024 * 1024 - 1], Eq(buffer2[1024 * 1024 - 1]));
328 EXPECT_NE(buffer1.data(), buffer2.data());
329 };
330 auto b = [](Comms* comms) {
331 // Send 1M bytes.
332 absl::FixedArray<uint8_t> buf(1024 * 1024);
333 memset(buf.data(), 0, buf.size());
334 ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue());
335 ASSERT_THAT(comms->SendBytes(buf.data(), buf.size()), IsTrue());
336 };
337 HandleCommunication(a, b);
338 }
339
TEST(CommsTest,TestSendRecvCredentials)340 TEST(CommsTest, TestSendRecvCredentials) {
341 auto a = [](Comms* comms) {
342 // Check credentials.
343 pid_t pid;
344 uid_t uid;
345 gid_t gid;
346 ASSERT_THAT(comms->RecvCreds(&pid, &uid, &gid), IsTrue());
347 EXPECT_THAT(pid, Eq(getpid()));
348 EXPECT_THAT(uid, Eq(getuid()));
349 EXPECT_THAT(gid, Eq(getgid()));
350 };
351 auto b = [](Comms* comms) {
352 // Nothing to do here.
353 };
354 HandleCommunication(a, b);
355 }
356
TEST(CommsTest,TestSendTooMuchData)357 TEST(CommsTest, TestSendTooMuchData) {
358 auto a = [](Comms* comms) {
359 // Nothing to do here.
360 };
361 auto b = [](Comms* comms) {
362 // Send too much data.
363 ASSERT_THAT(comms->SendBytes(nullptr, comms->GetMaxMsgSize() + 1),
364 IsFalse());
365 };
366 HandleCommunication(a, b);
367 }
368
TEST(CommsTest,TestSendRecvBytes)369 TEST(CommsTest, TestSendRecvBytes) {
370 auto a = [](Comms* comms) {
371 std::vector<uint8_t> buffer;
372 ASSERT_THAT(comms->RecvBytes(&buffer), IsTrue());
373 ASSERT_THAT(comms->SendBytes(buffer), IsTrue());
374 };
375 auto b = [](Comms* comms) {
376 const std::vector<uint8_t> request = {0, 1, 2, 3, 7};
377 ASSERT_THAT(comms->SendBytes(request), IsTrue());
378
379 std::vector<uint8_t> response;
380 ASSERT_THAT(comms->RecvBytes(&response), IsTrue());
381 EXPECT_THAT(request, Eq(response));
382 };
383 HandleCommunication(a, b);
384 }
385
TEST(CommsTest,SendRecvFailsAfterTerminate)386 TEST(CommsTest, SendRecvFailsAfterTerminate) {
387 auto a = [](Comms* comms) {
388 comms->Terminate();
389 ASSERT_THAT(comms->IsTerminated(), IsTrue());
390 EXPECT_THAT(comms->SendInt8(0), IsFalse());
391 EXPECT_THAT(comms->SendFD(STDERR_FILENO), IsFalse());
392 int8_t tmp;
393 EXPECT_THAT(comms->RecvInt8(&tmp), IsFalse());
394 std::string s;
395 EXPECT_THAT(comms->RecvString(&s), IsFalse());
396 std::vector<uint8_t> b;
397 EXPECT_THAT(comms->RecvBytes(&b), IsFalse());
398 int fd;
399 EXPECT_THAT(comms->RecvFD(&fd), IsFalse());
400 CommsTestMsg msg;
401 EXPECT_THAT(comms->RecvProtoBuf(&msg), IsFalse());
402 };
403 auto b = [](Comms* comms) {};
404 HandleCommunication(a, b);
405 }
406
TEST(CommsTest,RecvIntFailsOnTagMismatch)407 TEST(CommsTest, RecvIntFailsOnTagMismatch) {
408 auto a = [](Comms* comms) {
409 int8_t tmp;
410 EXPECT_THAT(comms->RecvInt8(&tmp), IsFalse());
411 };
412 auto b = [](Comms* comms) { ASSERT_THAT(comms->SendUint8(0), IsTrue()); };
413 HandleCommunication(a, b);
414 }
415
TEST(CommsTest,RecvStringBytesFailsOnTagMismatch)416 TEST(CommsTest, RecvStringBytesFailsOnTagMismatch) {
417 auto a = [](Comms* comms) {
418 std::string s;
419 EXPECT_THAT(comms->RecvString(&s), IsFalse());
420 EXPECT_THAT(s, IsEmpty());
421 ASSERT_THAT(comms->SendString("hello"), IsTrue());
422 };
423 auto b = [](Comms* comms) {
424 ASSERT_THAT(comms->SendBytes({1, 0}), IsTrue());
425 std::vector<uint8_t> b;
426 EXPECT_THAT(comms->RecvBytes(&b), IsFalse());
427 EXPECT_THAT(b, IsEmpty());
428 };
429 HandleCommunication(a, b);
430 }
431
TEST(CommsTest,RecvFDFailsOnTagMismatch)432 TEST(CommsTest, RecvFDFailsOnTagMismatch) {
433 auto a = [](Comms* comms) {
434 int fd;
435 EXPECT_THAT(comms->RecvFD(&fd), IsFalse());
436 };
437 auto b = [](Comms* comms) { ASSERT_THAT(comms->SendBytes({}), IsTrue()); };
438 HandleCommunication(a, b);
439 }
440
TEST(CommsTest,RecvProtoBufFailsOnTagMismatch)441 TEST(CommsTest, RecvProtoBufFailsOnTagMismatch) {
442 auto a = [](Comms* comms) {
443 CommsTestMsg msg;
444 EXPECT_THAT(comms->RecvProtoBuf(&msg), IsFalse());
445 };
446 auto b = [](Comms* comms) {
447 ASSERT_THAT(comms->SendString("hello"), IsTrue());
448 };
449 HandleCommunication(a, b);
450 }
451
TEST(ListeningCommsTest,AbstractSocket)452 TEST(ListeningCommsTest, AbstractSocket) {
453 static constexpr absl::string_view kSocketName = "s2_test_comms";
454 SAPI_ASSERT_OK_AND_ASSIGN(
455 ListeningComms listening_comms,
456 ListeningComms::Create(kSocketName, /*abstract_uds=*/true));
457 sapi::Thread remote([]() {
458 SAPI_ASSERT_OK_AND_ASSIGN(
459 Comms comms,
460 Comms::Connect(std::string(kSocketName), /*abstract_uds=*/true));
461 comms.SendBool(true);
462 });
463 SAPI_ASSERT_OK_AND_ASSIGN(Comms comms, listening_comms.Accept());
464 bool b;
465 ASSERT_THAT(comms.RecvBool(&b), IsTrue());
466 EXPECT_THAT(b, Eq(true));
467 remote.Join();
468 }
469
470 } // namespace
471 } // namespace sandbox2
472