• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <uds/client_channel.h>
2 
3 #include <sys/socket.h>
4 
5 #include <algorithm>
6 #include <limits>
7 #include <random>
8 #include <thread>
9 
10 #include <gmock/gmock.h>
11 #include <gtest/gtest.h>
12 
13 #include <pdx/client.h>
14 #include <pdx/rpc/remote_method.h>
15 #include <pdx/service.h>
16 
17 #include <uds/client_channel_factory.h>
18 #include <uds/service_endpoint.h>
19 
20 using testing::Return;
21 using testing::_;
22 
23 using android::pdx::ClientBase;
24 using android::pdx::LocalChannelHandle;
25 using android::pdx::LocalHandle;
26 using android::pdx::Message;
27 using android::pdx::ServiceBase;
28 using android::pdx::ServiceDispatcher;
29 using android::pdx::Status;
30 using android::pdx::rpc::DispatchRemoteMethod;
31 using android::pdx::uds::ClientChannel;
32 using android::pdx::uds::ClientChannelFactory;
33 using android::pdx::uds::Endpoint;
34 
35 namespace {
36 
37 struct TestProtocol {
38   using DataType = int8_t;
39   enum {
40     kOpSum = 0,
41   };
42   PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
43 };
44 
45 class TestService : public ServiceBase<TestService> {
46  public:
TestService(std::unique_ptr<Endpoint> endpoint)47   TestService(std::unique_ptr<Endpoint> endpoint)
48       : ServiceBase{"TestService", std::move(endpoint)} {}
49 
HandleMessage(Message & message)50   Status<void> HandleMessage(Message& message) override {
51     switch (message.GetOp()) {
52       case TestProtocol::kOpSum:
53         DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
54                                                 message);
55         return {};
56 
57       default:
58         return Service::HandleMessage(message);
59     }
60   }
61 
OnSum(Message &,const std::vector<TestProtocol::DataType> & data)62   int64_t OnSum(Message& /*message*/,
63                 const std::vector<TestProtocol::DataType>& data) {
64     return std::accumulate(data.begin(), data.end(), int64_t{0});
65   }
66 };
67 
68 class TestClient : public ClientBase<TestClient> {
69  public:
70   using ClientBase::ClientBase;
71 
Sum(const std::vector<TestProtocol::DataType> & data)72   int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
73     auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
74     return status ? status.get() : -1;
75   }
76 };
77 
78 class TestServiceRunner {
79  public:
TestServiceRunner(LocalHandle channel_socket)80   TestServiceRunner(LocalHandle channel_socket) {
81     auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
82     endpoint->RegisterNewChannelForTests(std::move(channel_socket));
83     service_ = TestService::Create(std::move(endpoint));
84     dispatcher_ = android::pdx::uds::ServiceDispatcher::Create();
85     dispatcher_->AddService(service_);
86     dispatch_thread_ = std::thread(
87         std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
88   }
89 
~TestServiceRunner()90   ~TestServiceRunner() {
91     dispatcher_->SetCanceled(true);
92     dispatch_thread_.join();
93     dispatcher_->RemoveService(service_);
94   }
95 
96  private:
97   std::shared_ptr<TestService> service_;
98   std::unique_ptr<ServiceDispatcher> dispatcher_;
99   std::thread dispatch_thread_;
100 };
101 
102 class ClientChannelTest : public testing::Test {
103  public:
SetUp()104   void SetUp() override {
105     int channel_sockets[2] = {};
106     ASSERT_EQ(
107         0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
108     LocalHandle service_channel{channel_sockets[0]};
109     LocalHandle client_channel{channel_sockets[1]};
110 
111     service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
112     auto factory = ClientChannelFactory::Create(std::move(client_channel));
113     auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
114     ASSERT_TRUE(status);
115     client_ = TestClient::Create(status.take());
116   }
117 
TearDown()118   void TearDown() override {
119     service_runner_.reset();
120     client_.reset();
121   }
122 
123  protected:
124   std::unique_ptr<TestServiceRunner> service_runner_;
125   std::shared_ptr<TestClient> client_;
126 };
127 
TEST_F(ClientChannelTest,MultithreadedClient)128 TEST_F(ClientChannelTest, MultithreadedClient) {
129   constexpr int kNumTestThreads = 8;
130   constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.
131 
132   std::random_device rd;
133   std::mt19937 gen{rd()};
134   std::uniform_int_distribution<TestProtocol::DataType> dist{
135       std::numeric_limits<TestProtocol::DataType>::min(),
136       std::numeric_limits<TestProtocol::DataType>::max()};
137 
138   auto worker = [](std::shared_ptr<TestClient> client,
139                    std::vector<TestProtocol::DataType> data) {
140     constexpr int kMaxIterations = 500;
141     int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
142     for (int i = 0; i < kMaxIterations; i++) {
143       ASSERT_EQ(expected, client->Sum(data));
144     }
145   };
146 
147   // Start client threads.
148   std::vector<TestProtocol::DataType> data;
149   data.resize(kDataSize);
150   std::vector<std::thread> threads;
151   for (int i = 0; i < kNumTestThreads; i++) {
152     std::generate(data.begin(), data.end(),
153                   [&dist, &gen]() { return dist(gen); });
154     threads.emplace_back(worker, client_, data);
155   }
156 
157   // Wait for threads to finish.
158   for (auto& thread : threads)
159     thread.join();
160 }
161 
162 }  // namespace
163