• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests of grpc_reflection.v1alpha.reflection."""
15
16import unittest
17
18from google.protobuf.descriptor_pool import DescriptorPool
19import grpc
20from grpc_reflection.v1alpha import reflection
21from grpc_reflection.v1alpha.proto_reflection_descriptor_database import (
22    ProtoReflectionDescriptorDatabase,
23)
24
25from src.proto.grpc.testing import test_pb2
26
27# Needed to load the EmptyWithExtensions message
28from src.proto.grpc.testing.proto2 import empty2_extensions_pb2
29from tests.unit import test_common
30
31_PROTO_PACKAGE_NAME = "grpc.testing"
32_PROTO_FILE_NAME = "src/proto/grpc/testing/test.proto"
33_EMPTY_PROTO_FILE_NAME = "src/proto/grpc/testing/empty.proto"
34_INVALID_FILE_NAME = "i-do-not-exist.proto"
35_EMPTY_PROTO_SYMBOL_NAME = "grpc.testing.Empty"
36_INVALID_SYMBOL_NAME = "IDoNotExist"
37_EMPTY_EXTENSIONS_SYMBOL_NAME = "grpc.testing.proto2.EmptyWithExtensions"
38
39
40class ReflectionClientTest(unittest.TestCase):
41    def setUp(self):
42        self._server = test_common.test_server()
43        self._SERVICE_NAMES = (
44            test_pb2.DESCRIPTOR.services_by_name["TestService"].full_name,
45            reflection.SERVICE_NAME,
46        )
47        reflection.enable_server_reflection(self._SERVICE_NAMES, self._server)
48        port = self._server.add_insecure_port("[::]:0")
49        self._server.start()
50
51        self._channel = grpc.insecure_channel("localhost:%d" % port)
52
53        self._reflection_db = ProtoReflectionDescriptorDatabase(self._channel)
54        self.desc_pool = DescriptorPool(self._reflection_db)
55
56    def tearDown(self):
57        self._server.stop(None)
58        self._channel.close()
59
60    def testListServices(self):
61        services = self._reflection_db.get_services()
62        self.assertCountEqual(self._SERVICE_NAMES, services)
63
64    def testReflectionServiceName(self):
65        self.assertEqual(
66            reflection.SERVICE_NAME, "grpc.reflection.v1alpha.ServerReflection"
67        )
68
69    def testFindFile(self):
70        file_name = _PROTO_FILE_NAME
71        file_desc = self.desc_pool.FindFileByName(file_name)
72        self.assertEqual(file_name, file_desc.name)
73        self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package)
74        self.assertIn("TestService", file_desc.services_by_name)
75
76        file_name = _EMPTY_PROTO_FILE_NAME
77        file_desc = self.desc_pool.FindFileByName(file_name)
78        self.assertEqual(file_name, file_desc.name)
79        self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package)
80        self.assertIn("Empty", file_desc.message_types_by_name)
81
82    def testFindFileError(self):
83        with self.assertRaises(KeyError):
84            self.desc_pool.FindFileByName(_INVALID_FILE_NAME)
85
86    def testFindMessage(self):
87        message_name = _EMPTY_PROTO_SYMBOL_NAME
88        message_desc = self.desc_pool.FindMessageTypeByName(message_name)
89        self.assertEqual(message_name, message_desc.full_name)
90        self.assertTrue(message_name.endswith(message_desc.name))
91
92    def testFindMessageError(self):
93        with self.assertRaises(KeyError):
94            self.desc_pool.FindMessageTypeByName(_INVALID_SYMBOL_NAME)
95
96    def testFindServiceFindMethod(self):
97        service_name = self._SERVICE_NAMES[0]
98        service_desc = self.desc_pool.FindServiceByName(service_name)
99        self.assertEqual(service_name, service_desc.full_name)
100        self.assertTrue(service_name.endswith(service_desc.name))
101        file_name = _PROTO_FILE_NAME
102        file_desc = self.desc_pool.FindFileByName(file_name)
103        self.assertIs(file_desc, service_desc.file)
104
105        method_name = "EmptyCall"
106        self.assertIn(method_name, service_desc.methods_by_name)
107
108        method_desc = service_desc.FindMethodByName(method_name)
109        self.assertIs(method_desc, service_desc.methods_by_name[method_name])
110        self.assertIs(service_desc, method_desc.containing_service)
111        self.assertEqual(method_name, method_desc.name)
112        self.assertTrue(method_desc.full_name.endswith(method_name))
113
114        empty_message_desc = self.desc_pool.FindMessageTypeByName(
115            _EMPTY_PROTO_SYMBOL_NAME
116        )
117        self.assertEqual(empty_message_desc, method_desc.input_type)
118        self.assertEqual(empty_message_desc, method_desc.output_type)
119
120    def testFindServiceError(self):
121        with self.assertRaises(KeyError):
122            self.desc_pool.FindServiceByName(_INVALID_SYMBOL_NAME)
123
124    def testFindMethodError(self):
125        service_name = self._SERVICE_NAMES[0]
126        service_desc = self.desc_pool.FindServiceByName(service_name)
127
128        # FindMethodByName sometimes raises a KeyError, and sometimes returns None.
129        # See https://github.com/protocolbuffers/protobuf/issues/9592
130        with self.assertRaises(KeyError):
131            res = service_desc.FindMethodByName(_INVALID_SYMBOL_NAME)
132            if res is None:
133                raise KeyError()
134
135    def testFindExtensionNotImplemented(self):
136        """
137        Extensions aren't implemented in Protobuf for Python.
138        For now, simply assert that indeed they don't work.
139        """
140        message_name = _EMPTY_EXTENSIONS_SYMBOL_NAME
141        message_desc = self.desc_pool.FindMessageTypeByName(message_name)
142        self.assertEqual(message_name, message_desc.full_name)
143        self.assertTrue(message_name.endswith(message_desc.name))
144        extension_field_descs = self.desc_pool.FindAllExtensions(message_desc)
145
146        self.assertEqual(0, len(extension_field_descs))
147        with self.assertRaises(KeyError):
148            self.desc_pool.FindExtensionByName(message_name)
149
150
151if __name__ == "__main__":
152    unittest.main(verbosity=2)
153