• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Tests for google.protobuf.internal.service_reflection."""
9
10__author__ = 'petar@google.com (Petar Petrov)'
11
12
13import unittest
14
15from google.protobuf import service_reflection
16from google.protobuf import service
17from google.protobuf import unittest_pb2
18
19
20class FooUnitTest(unittest.TestCase):
21
22  def testService(self):
23    class MockRpcChannel(service.RpcChannel):
24      def CallMethod(self, method, controller, request, response, callback):
25        self.method = method
26        self.controller = controller
27        self.request = request
28        callback(response)
29
30    class MockRpcController(service.RpcController):
31      def SetFailed(self, msg):
32        self.failure_message = msg
33
34    self.callback_response = None
35
36    class MyService(unittest_pb2.TestService):
37      pass
38
39    self.callback_response = None
40
41    def MyCallback(response):
42      self.callback_response = response
43
44    rpc_controller = MockRpcController()
45    channel = MockRpcChannel()
46    srvc = MyService()
47    srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
48    self.assertEqual('Method Foo not implemented.',
49                     rpc_controller.failure_message)
50    self.assertEqual(None, self.callback_response)
51
52    rpc_controller.failure_message = None
53
54    service_descriptor = unittest_pb2.TestService.GetDescriptor()
55    srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
56                    unittest_pb2.BarRequest(), MyCallback)
57    self.assertTrue(srvc.GetRequestClass(service_descriptor.methods[1]) is
58                    unittest_pb2.BarRequest)
59    self.assertTrue(srvc.GetResponseClass(service_descriptor.methods[1]) is
60                    unittest_pb2.BarResponse)
61    self.assertEqual('Method Bar not implemented.',
62                     rpc_controller.failure_message)
63    self.assertEqual(None, self.callback_response)
64
65    class MyServiceImpl(unittest_pb2.TestService):
66      def Foo(self, rpc_controller, request, done):
67        self.foo_called = True
68      def Bar(self, rpc_controller, request, done):
69        self.bar_called = True
70
71    srvc = MyServiceImpl()
72    rpc_controller.failure_message = None
73    srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
74    self.assertEqual(None, rpc_controller.failure_message)
75    self.assertEqual(True, srvc.foo_called)
76
77    rpc_controller.failure_message = None
78    srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
79                    unittest_pb2.BarRequest(), MyCallback)
80    self.assertEqual(None, rpc_controller.failure_message)
81    self.assertEqual(True, srvc.bar_called)
82
83  def testServiceStub(self):
84    class MockRpcChannel(service.RpcChannel):
85      def CallMethod(self, method, controller, request,
86                     response_class, callback):
87        self.method = method
88        self.controller = controller
89        self.request = request
90        callback(response_class())
91
92    self.callback_response = None
93
94    def MyCallback(response):
95      self.callback_response = response
96
97    channel = MockRpcChannel()
98    stub = unittest_pb2.TestService_Stub(channel)
99    rpc_controller = 'controller'
100    request = 'request'
101
102    # GetDescriptor now static, still works as instance method for compatibility
103    self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(),
104                     stub.GetDescriptor())
105
106    # Invoke method.
107    stub.Foo(rpc_controller, request, MyCallback)
108
109    self.assertIsInstance(self.callback_response, unittest_pb2.FooResponse)
110    self.assertEqual(request, channel.request)
111    self.assertEqual(rpc_controller, channel.controller)
112    self.assertEqual(stub.GetDescriptor().methods[0], channel.method)
113
114
115if __name__ == '__main__':
116  unittest.main()
117