1 /*
2 * Copyright 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "test_channel_transport.h"
18
19 #include <errno.h> // for errno, EBADF
20 #include <stddef.h> // for size_t
21
22 #include <cstdint> // for uint8_t
23 #include <cstring> // for strerror
24 #include <type_traits> // for remove_extent_t
25
26 #include "log.h" // for LOG_INFO, ASSERT_LOG, LOG_WARN
27 #include "net/async_data_channel.h" // for AsyncDataChannel
28
29 using std::vector;
30
31 namespace rootcanal {
32
SetUp(std::shared_ptr<AsyncDataChannelServer> server,ConnectCallback connection_callback)33 bool TestChannelTransport::SetUp(std::shared_ptr<AsyncDataChannelServer> server,
34 ConnectCallback connection_callback) {
35 socket_server_ = server;
36 socket_server_->SetOnConnectCallback(connection_callback);
37 socket_server_->StartListening();
38 return socket_server_ != nullptr;
39 }
40
CleanUp()41 void TestChannelTransport::CleanUp() {
42 socket_server_->StopListening();
43 socket_server_->Close();
44 }
45
OnCommandReady(AsyncDataChannel * socket,std::function<void (void)> unwatch)46 void TestChannelTransport::OnCommandReady(AsyncDataChannel* socket,
47 std::function<void(void)> unwatch) {
48 uint8_t command_name_size = 0;
49 ssize_t bytes_read = socket->Recv(&command_name_size, 1);
50 if (bytes_read != 1) {
51 LOG_INFO("Unexpected (command_name_size) bytes_read: %zd != %d, %s",
52 bytes_read, 1, strerror(errno));
53 socket->Close();
54 }
55 vector<uint8_t> command_name_raw;
56 command_name_raw.resize(command_name_size);
57 bytes_read = socket->Recv(command_name_raw.data(), command_name_size);
58 if (bytes_read != command_name_size) {
59 LOG_INFO("Unexpected (command_name) bytes_read: %zd != %d, %s", bytes_read,
60 command_name_size, strerror(errno));
61 }
62 std::string command_name(command_name_raw.begin(), command_name_raw.end());
63
64 if (command_name == "CLOSE_TEST_CHANNEL" || command_name.empty()) {
65 LOG_INFO("Test channel closed");
66 unwatch();
67 socket->Close();
68 return;
69 }
70
71 uint8_t num_args = 0;
72 bytes_read = socket->Recv(&num_args, 1);
73 if (bytes_read != 1) {
74 LOG_INFO("Unexpected (num_args) bytes_read: %zd != %d, %s", bytes_read, 1,
75 strerror(errno));
76 }
77 vector<std::string> args;
78 for (uint8_t i = 0; i < num_args; ++i) {
79 uint8_t arg_size = 0;
80 bytes_read = socket->Recv(&arg_size, 1);
81 if (bytes_read != 1) {
82 LOG_INFO("Unexpected (arg_size) bytes_read: %zd != %d, %s", bytes_read, 1,
83 strerror(errno));
84 }
85 vector<uint8_t> arg;
86 arg.resize(arg_size);
87 bytes_read = socket->Recv(arg.data(), arg_size);
88 if (bytes_read != arg_size) {
89 LOG_INFO("Unexpected (arg) bytes_read: %zd != %d, %s", bytes_read,
90 arg_size, strerror(errno));
91 }
92 args.push_back(std::string(arg.begin(), arg.end()));
93 }
94
95 command_handler_(command_name, args);
96 }
97
SendResponse(std::shared_ptr<AsyncDataChannel> socket,const std::string & response)98 void TestChannelTransport::SendResponse(
99 std::shared_ptr<AsyncDataChannel> socket, const std::string& response) {
100 size_t size = response.size();
101 // Cap to 64K
102 if (size > 0xffff) {
103 size = 0xffff;
104 }
105 uint8_t size_buf[4] = {static_cast<uint8_t>(size & 0xff),
106 static_cast<uint8_t>((size >> 8) & 0xff),
107 static_cast<uint8_t>((size >> 16) & 0xff),
108 static_cast<uint8_t>((size >> 24) & 0xff)};
109 ssize_t written = socket->Send(size_buf, 4);
110 if (written == -1 && errno == EBADF) {
111 LOG_WARN("Unable to send a response. EBADF");
112 return;
113 }
114 ASSERT_LOG(written == 4, "What happened? written = %zd errno = %d", written,
115 errno);
116 written =
117 socket->Send(reinterpret_cast<const uint8_t*>(response.c_str()), size);
118 ASSERT_LOG(written == static_cast<int>(size),
119 "What happened? written = %zd errno = %d", written, errno);
120 }
121
RegisterCommandHandler(const std::function<void (const std::string &,const std::vector<std::string> &)> & callback)122 void TestChannelTransport::RegisterCommandHandler(
123 const std::function<void(const std::string&,
124 const std::vector<std::string>&)>& callback) {
125 command_handler_ = callback;
126 }
127
128 } // namespace rootcanal
129