• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/dbus/dbus_method_invoker.h>
6 
7 #include <string>
8 
9 #include <base/files/scoped_file.h>
10 #include <brillo/bind_lambda.h>
11 #include <dbus/mock_bus.h>
12 #include <dbus/mock_object_proxy.h>
13 #include <dbus/scoped_dbus_error.h>
14 #include <gmock/gmock.h>
15 #include <gtest/gtest.h>
16 
17 #include "brillo/dbus/test.pb.h"
18 
19 using testing::AnyNumber;
20 using testing::InSequence;
21 using testing::Invoke;
22 using testing::Return;
23 using testing::_;
24 
25 using dbus::MessageReader;
26 using dbus::MessageWriter;
27 using dbus::Response;
28 
29 namespace brillo {
30 namespace dbus_utils {
31 
32 const char kTestPath[] = "/test/path";
33 const char kTestServiceName[] = "org.test.Object";
34 const char kTestInterface[] = "org.test.Object.TestInterface";
35 const char kTestMethod1[] = "TestMethod1";
36 const char kTestMethod2[] = "TestMethod2";
37 const char kTestMethod3[] = "TestMethod3";
38 const char kTestMethod4[] = "TestMethod4";
39 
40 class DBusMethodInvokerTest : public testing::Test {
41  public:
SetUp()42   void SetUp() override {
43     dbus::Bus::Options options;
44     options.bus_type = dbus::Bus::SYSTEM;
45     bus_ = new dbus::MockBus(options);
46     // By default, don't worry about threading assertions.
47     EXPECT_CALL(*bus_, AssertOnOriginThread()).Times(AnyNumber());
48     EXPECT_CALL(*bus_, AssertOnDBusThread()).Times(AnyNumber());
49     // Use a mock exported object.
50     mock_object_proxy_ = new dbus::MockObjectProxy(
51         bus_.get(), kTestServiceName, dbus::ObjectPath(kTestPath));
52     EXPECT_CALL(*bus_,
53                 GetObjectProxy(kTestServiceName, dbus::ObjectPath(kTestPath)))
54         .WillRepeatedly(Return(mock_object_proxy_.get()));
55     int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT;
56     EXPECT_CALL(*mock_object_proxy_,
57                 MockCallMethodAndBlockWithErrorDetails(_, def_timeout_ms, _))
58         .WillRepeatedly(Invoke(this, &DBusMethodInvokerTest::CreateResponse));
59   }
60 
TearDown()61   void TearDown() override { bus_ = nullptr; }
62 
CreateResponse(dbus::MethodCall * method_call,int,dbus::ScopedDBusError * dbus_error)63   Response* CreateResponse(dbus::MethodCall* method_call,
64                            int /* timeout_ms */,
65                            dbus::ScopedDBusError* dbus_error) {
66     if (method_call->GetInterface() == kTestInterface) {
67       if (method_call->GetMember() == kTestMethod1) {
68         MessageReader reader(method_call);
69         int v1, v2;
70         // Input: two ints.
71         // Output: sum of the ints converted to string.
72         if (reader.PopInt32(&v1) && reader.PopInt32(&v2)) {
73           auto response = Response::CreateEmpty();
74           MessageWriter writer(response.get());
75           writer.AppendString(std::to_string(v1 + v2));
76           return response.release();
77         }
78       } else if (method_call->GetMember() == kTestMethod2) {
79         method_call->SetSerial(123);
80         dbus_set_error(dbus_error->get(), "org.MyError", "My error message");
81         return nullptr;
82       } else if (method_call->GetMember() == kTestMethod3) {
83         MessageReader reader(method_call);
84         dbus_utils_test::TestMessage msg;
85         if (PopValueFromReader(&reader, &msg)) {
86           auto response = Response::CreateEmpty();
87           MessageWriter writer(response.get());
88           AppendValueToWriter(&writer, msg);
89           return response.release();
90         }
91       } else if (method_call->GetMember() == kTestMethod4) {
92         method_call->SetSerial(123);
93         MessageReader reader(method_call);
94         base::ScopedFD fd;
95         if (reader.PopFileDescriptor(&fd)) {
96           auto response = Response::CreateEmpty();
97           MessageWriter writer(response.get());
98           fd.CheckValidity();
99           writer.AppendFileDescriptor(fd);
100           return response.release();
101         }
102       }
103     }
104 
105     LOG(ERROR) << "Unexpected method call: " << method_call->ToString();
106     return nullptr;
107   }
108 
CallTestMethod(int v1,int v2)109   std::string CallTestMethod(int v1, int v2) {
110     std::unique_ptr<dbus::Response> response =
111         brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
112                                                kTestInterface, kTestMethod1,
113                                                nullptr, v1, v2);
114     EXPECT_NE(nullptr, response.get());
115     std::string result;
116     using brillo::dbus_utils::ExtractMethodCallResults;
117     EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &result));
118     return result;
119   }
120 
CallProtobufTestMethod(const dbus_utils_test::TestMessage & message)121   dbus_utils_test::TestMessage CallProtobufTestMethod(
122       const dbus_utils_test::TestMessage& message) {
123     std::unique_ptr<dbus::Response> response =
124         brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
125                                                kTestInterface, kTestMethod3,
126                                                nullptr, message);
127     EXPECT_NE(nullptr, response.get());
128     dbus_utils_test::TestMessage result;
129     using brillo::dbus_utils::ExtractMethodCallResults;
130     EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &result));
131     return result;
132   }
133 
134   // Sends a file descriptor received over D-Bus back to the caller.
EchoFD(const base::ScopedFD & fd_in)135   base::ScopedFD EchoFD(const base::ScopedFD& fd_in) {
136     std::unique_ptr<dbus::Response> response =
137         brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
138                                                kTestInterface, kTestMethod4,
139                                                nullptr, fd_in);
140     EXPECT_NE(nullptr, response.get());
141     base::ScopedFD fd_out;
142     using brillo::dbus_utils::ExtractMethodCallResults;
143     EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &fd_out));
144     return fd_out;
145   }
146 
147   scoped_refptr<dbus::MockBus> bus_;
148   scoped_refptr<dbus::MockObjectProxy> mock_object_proxy_;
149 };
150 
TEST_F(DBusMethodInvokerTest,TestSuccess)151 TEST_F(DBusMethodInvokerTest, TestSuccess) {
152   EXPECT_EQ("4", CallTestMethod(2, 2));
153   EXPECT_EQ("10", CallTestMethod(3, 7));
154   EXPECT_EQ("-4", CallTestMethod(13, -17));
155 }
156 
TEST_F(DBusMethodInvokerTest,TestFailure)157 TEST_F(DBusMethodInvokerTest, TestFailure) {
158   brillo::ErrorPtr error;
159   std::unique_ptr<dbus::Response> response =
160       brillo::dbus_utils::CallMethodAndBlock(
161           mock_object_proxy_.get(), kTestInterface, kTestMethod2, &error);
162   EXPECT_EQ(nullptr, response.get());
163   EXPECT_EQ(brillo::errors::dbus::kDomain, error->GetDomain());
164   EXPECT_EQ("org.MyError", error->GetCode());
165   EXPECT_EQ("My error message", error->GetMessage());
166 }
167 
TEST_F(DBusMethodInvokerTest,TestProtobuf)168 TEST_F(DBusMethodInvokerTest, TestProtobuf) {
169   dbus_utils_test::TestMessage test_message;
170   test_message.set_foo(123);
171   test_message.set_bar("bar");
172 
173   dbus_utils_test::TestMessage resp = CallProtobufTestMethod(test_message);
174 
175   EXPECT_EQ(123, resp.foo());
176   EXPECT_EQ("bar", resp.bar());
177 }
178 
TEST_F(DBusMethodInvokerTest,TestFileDescriptors)179 TEST_F(DBusMethodInvokerTest, TestFileDescriptors) {
180   // Passing a file descriptor over D-Bus would effectively duplicate the fd.
181   // So the resulting file descriptor value would be different but it still
182   // should be valid.
183   base::ScopedFD fd_stdin(0);
184   fd_stdin.CheckValidity();
185   EXPECT_NE(fd_stdin.value(), EchoFD(fd_stdin).value());
186   base::ScopedFD fd_stdout(1);
187   fd_stdout.CheckValidity();
188   EXPECT_NE(fd_stdout.value(), EchoFD(fd_stdout).value());
189   base::ScopedFD fd_stderr(2);
190   fd_stderr.CheckValidity();
191   EXPECT_NE(fd_stderr.value(), EchoFD(fd_stderr).value());
192 }
193 
194 //////////////////////////////////////////////////////////////////////////////
195 // Asynchronous method invocation support
196 
197 class AsyncDBusMethodInvokerTest : public testing::Test {
198  public:
SetUp()199   void SetUp() override {
200     dbus::Bus::Options options;
201     options.bus_type = dbus::Bus::SYSTEM;
202     bus_ = new dbus::MockBus(options);
203     // By default, don't worry about threading assertions.
204     EXPECT_CALL(*bus_, AssertOnOriginThread()).Times(AnyNumber());
205     EXPECT_CALL(*bus_, AssertOnDBusThread()).Times(AnyNumber());
206     // Use a mock exported object.
207     mock_object_proxy_ = new dbus::MockObjectProxy(
208         bus_.get(), kTestServiceName, dbus::ObjectPath(kTestPath));
209     EXPECT_CALL(*bus_,
210                 GetObjectProxy(kTestServiceName, dbus::ObjectPath(kTestPath)))
211         .WillRepeatedly(Return(mock_object_proxy_.get()));
212     int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT;
213     EXPECT_CALL(*mock_object_proxy_,
214                 CallMethodWithErrorCallback(_, def_timeout_ms, _, _))
215         .WillRepeatedly(Invoke(this, &AsyncDBusMethodInvokerTest::HandleCall));
216   }
217 
TearDown()218   void TearDown() override { bus_ = nullptr; }
219 
HandleCall(dbus::MethodCall * method_call,int,dbus::ObjectProxy::ResponseCallback success_callback,dbus::ObjectProxy::ErrorCallback error_callback)220   void HandleCall(dbus::MethodCall* method_call,
221                   int /* timeout_ms */,
222                   dbus::ObjectProxy::ResponseCallback success_callback,
223                   dbus::ObjectProxy::ErrorCallback error_callback) {
224     if (method_call->GetInterface() == kTestInterface) {
225       if (method_call->GetMember() == kTestMethod1) {
226         MessageReader reader(method_call);
227         int v1, v2;
228         // Input: two ints.
229         // Output: sum of the ints converted to string.
230         if (reader.PopInt32(&v1) && reader.PopInt32(&v2)) {
231           auto response = Response::CreateEmpty();
232           MessageWriter writer(response.get());
233           writer.AppendString(std::to_string(v1 + v2));
234           success_callback.Run(response.get());
235         }
236         return;
237       } else if (method_call->GetMember() == kTestMethod2) {
238         method_call->SetSerial(123);
239         auto error_response = dbus::ErrorResponse::FromMethodCall(
240             method_call, "org.MyError", "My error message");
241         error_callback.Run(error_response.get());
242         return;
243       }
244     }
245 
246     LOG(FATAL) << "Unexpected method call: " << method_call->ToString();
247   }
248 
SuccessCallback(const std::string & in_result,int * in_counter)249   base::Callback<void(const std::string&)> SuccessCallback(
250       const std::string& in_result, int* in_counter) {
251     return base::Bind(
252         [](const std::string& result,
253            int* counter,
254            const std::string& actual_result) {
255           (*counter)++;
256           EXPECT_EQ(result, actual_result);
257         },
258         in_result,
259         base::Unretained(in_counter));
260   }
261 
SuccessCallback(int * in_counter)262   base::Callback<void(const std::string&)> SuccessCallback(int* in_counter) {
263     return base::Bind(
264         [](int* counter, const std::string& actual_result) {
265           (*counter)++;
266           EXPECT_EQ("", actual_result);
267         },
268         base::Unretained(in_counter));
269   }
270 
ErrorCallback(int * in_counter)271   AsyncErrorCallback ErrorCallback(int* in_counter) {
272     return base::Bind(
273         [](int* counter, brillo::Error* error) {
274           (*counter)++;
275           EXPECT_NE(nullptr, error);
276           EXPECT_EQ("", error->GetDomain());
277           EXPECT_EQ("", error->GetCode());
278           EXPECT_EQ("", error->GetMessage());
279         },
280         base::Unretained(in_counter));
281   }
282 
ErrorCallback(const std::string & domain,const std::string & code,const std::string & message,int * in_counter)283   AsyncErrorCallback ErrorCallback(const std::string& domain,
284                                    const std::string& code,
285                                    const std::string& message,
286                                    int* in_counter) {
287     return base::Bind(
288         [](const std::string& domain,
289            const std::string& code,
290            const std::string& message,
291            int* counter,
292            brillo::Error* error) {
293           (*counter)++;
294           EXPECT_NE(nullptr, error);
295           EXPECT_EQ(domain, error->GetDomain());
296           EXPECT_EQ(code, error->GetCode());
297           EXPECT_EQ(message, error->GetMessage());
298         },
299         domain,
300         code,
301         message,
302         base::Unretained(in_counter));
303   }
304 
305   scoped_refptr<dbus::MockBus> bus_;
306   scoped_refptr<dbus::MockObjectProxy> mock_object_proxy_;
307 };
308 
TEST_F(AsyncDBusMethodInvokerTest,TestSuccess)309 TEST_F(AsyncDBusMethodInvokerTest, TestSuccess) {
310   int error_count = 0;
311   int success_count = 0;
312   brillo::dbus_utils::CallMethod(
313       mock_object_proxy_.get(),
314       kTestInterface,
315       kTestMethod1,
316       base::Bind(SuccessCallback("4", &success_count)),
317       base::Bind(ErrorCallback(&error_count)),
318       2, 2);
319   brillo::dbus_utils::CallMethod(
320       mock_object_proxy_.get(),
321       kTestInterface,
322       kTestMethod1,
323       base::Bind(SuccessCallback("10", &success_count)),
324       base::Bind(ErrorCallback(&error_count)),
325       3, 7);
326   brillo::dbus_utils::CallMethod(
327       mock_object_proxy_.get(),
328       kTestInterface,
329       kTestMethod1,
330       base::Bind(SuccessCallback("-4", &success_count)),
331       base::Bind(ErrorCallback(&error_count)),
332       13, -17);
333   EXPECT_EQ(0, error_count);
334   EXPECT_EQ(3, success_count);
335 }
336 
TEST_F(AsyncDBusMethodInvokerTest,TestFailure)337 TEST_F(AsyncDBusMethodInvokerTest, TestFailure) {
338   int error_count = 0;
339   int success_count = 0;
340   brillo::dbus_utils::CallMethod(
341       mock_object_proxy_.get(),
342       kTestInterface,
343       kTestMethod2,
344       base::Bind(SuccessCallback(&success_count)),
345       base::Bind(ErrorCallback(brillo::errors::dbus::kDomain,
346                                "org.MyError",
347                                "My error message",
348                                &error_count)),
349       2, 2);
350   EXPECT_EQ(1, error_count);
351   EXPECT_EQ(0, success_count);
352 }
353 
354 }  // namespace dbus_utils
355 }  // namespace brillo
356