• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.protobuf.services;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 import static com.google.common.base.Preconditions.checkState;
21 
22 import com.google.protobuf.Descriptors.Descriptor;
23 import com.google.protobuf.Descriptors.FieldDescriptor;
24 import com.google.protobuf.Descriptors.FileDescriptor;
25 import com.google.protobuf.Descriptors.MethodDescriptor;
26 import com.google.protobuf.Descriptors.ServiceDescriptor;
27 import io.grpc.BindableService;
28 import io.grpc.ExperimentalApi;
29 import io.grpc.InternalServer;
30 import io.grpc.Server;
31 import io.grpc.ServerServiceDefinition;
32 import io.grpc.Status;
33 import io.grpc.protobuf.ProtoFileDescriptorSupplier;
34 import io.grpc.reflection.v1alpha.ErrorResponse;
35 import io.grpc.reflection.v1alpha.ExtensionNumberResponse;
36 import io.grpc.reflection.v1alpha.ExtensionRequest;
37 import io.grpc.reflection.v1alpha.FileDescriptorResponse;
38 import io.grpc.reflection.v1alpha.ListServiceResponse;
39 import io.grpc.reflection.v1alpha.ServerReflectionGrpc;
40 import io.grpc.reflection.v1alpha.ServerReflectionRequest;
41 import io.grpc.reflection.v1alpha.ServerReflectionResponse;
42 import io.grpc.reflection.v1alpha.ServiceResponse;
43 import io.grpc.stub.ServerCallStreamObserver;
44 import io.grpc.stub.StreamObserver;
45 import java.util.ArrayDeque;
46 import java.util.Collections;
47 import java.util.HashMap;
48 import java.util.HashSet;
49 import java.util.List;
50 import java.util.Map;
51 import java.util.Queue;
52 import java.util.Set;
53 import java.util.WeakHashMap;
54 import javax.annotation.Nullable;
55 import javax.annotation.concurrent.GuardedBy;
56 
57 /**
58  * Provides a reflection service for Protobuf services (including the reflection service itself).
59  *
60  * <p>Separately tracks mutable and immutable services. Throws an exception if either group of
61  * services contains multiple Protobuf files with declarations of the same service, method, type, or
62  * extension.
63  */
64 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222")
65 public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase {
66 
67   private final Object lock = new Object();
68 
69   @GuardedBy("lock")
70   private final Map<Server, ServerReflectionIndex> serverReflectionIndexes = new WeakHashMap<>();
71 
ProtoReflectionService()72   private ProtoReflectionService() {}
73 
74   /**
75    * Creates a instance of {@link ProtoReflectionService}.
76    */
newInstance()77   public static BindableService newInstance() {
78     return new ProtoReflectionService();
79   }
80 
81   /**
82    * Retrieves the index for services of the server that dispatches the current call. Computes
83    * one if not exist. The index is updated if any changes to the server's mutable services are
84    * detected. A change is any addition or removal in the set of file descriptors attached to the
85    * mutable services or a change in the service names.
86    */
getRefreshedIndex()87   private ServerReflectionIndex getRefreshedIndex() {
88     synchronized (lock) {
89       Server server = InternalServer.SERVER_CONTEXT_KEY.get();
90       ServerReflectionIndex index = serverReflectionIndexes.get(server);
91       if (index == null) {
92         index =
93             new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices());
94         serverReflectionIndexes.put(server, index);
95         return index;
96       }
97 
98       Set<FileDescriptor> serverFileDescriptors = new HashSet<>();
99       Set<String> serverServiceNames = new HashSet<>();
100       List<ServerServiceDefinition> serverMutableServices = server.getMutableServices();
101       for (ServerServiceDefinition mutableService : serverMutableServices) {
102         io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor();
103         if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) {
104           String serviceName = serviceDescriptor.getName();
105           FileDescriptor fileDescriptor =
106               ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor())
107                   .getFileDescriptor();
108           serverFileDescriptors.add(fileDescriptor);
109           serverServiceNames.add(serviceName);
110         }
111       }
112 
113       // Replace the index if the underlying mutable services have changed. Check both the file
114       // descriptors and the service names, because one file descriptor can define multiple
115       // services.
116       FileDescriptorIndex mutableServicesIndex = index.getMutableServicesIndex();
117       if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors)
118           || !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) {
119         index =
120             new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices);
121         serverReflectionIndexes.put(server, index);
122       }
123 
124       return index;
125     }
126   }
127 
128   @Override
serverReflectionInfo( final StreamObserver<ServerReflectionResponse> responseObserver)129   public StreamObserver<ServerReflectionRequest> serverReflectionInfo(
130       final StreamObserver<ServerReflectionResponse> responseObserver) {
131     final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver =
132         (ServerCallStreamObserver<ServerReflectionResponse>) responseObserver;
133     ProtoReflectionStreamObserver requestObserver =
134         new ProtoReflectionStreamObserver(getRefreshedIndex(), serverCallStreamObserver);
135     serverCallStreamObserver.setOnReadyHandler(requestObserver);
136     serverCallStreamObserver.disableAutoRequest();
137     serverCallStreamObserver.request(1);
138     return requestObserver;
139   }
140 
141   private static class ProtoReflectionStreamObserver
142       implements Runnable, StreamObserver<ServerReflectionRequest> {
143     private final ServerReflectionIndex serverReflectionIndex;
144     private final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver;
145 
146     private boolean closeAfterSend = false;
147     private ServerReflectionRequest request;
148 
ProtoReflectionStreamObserver( ServerReflectionIndex serverReflectionIndex, ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver)149     ProtoReflectionStreamObserver(
150         ServerReflectionIndex serverReflectionIndex,
151         ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver) {
152       this.serverReflectionIndex = serverReflectionIndex;
153       this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer");
154     }
155 
156     @Override
run()157     public void run() {
158       if (request != null) {
159         handleReflectionRequest();
160       }
161     }
162 
163     @Override
onNext(ServerReflectionRequest request)164     public void onNext(ServerReflectionRequest request) {
165       checkState(this.request == null);
166       this.request = checkNotNull(request);
167       handleReflectionRequest();
168     }
169 
handleReflectionRequest()170     private void handleReflectionRequest() {
171       if (serverCallStreamObserver.isReady()) {
172         switch (request.getMessageRequestCase()) {
173           case FILE_BY_FILENAME:
174             getFileByName(request);
175             break;
176           case FILE_CONTAINING_SYMBOL:
177             getFileContainingSymbol(request);
178             break;
179           case FILE_CONTAINING_EXTENSION:
180             getFileByExtension(request);
181             break;
182           case ALL_EXTENSION_NUMBERS_OF_TYPE:
183             getAllExtensions(request);
184             break;
185           case LIST_SERVICES:
186             listServices(request);
187             break;
188           default:
189             sendErrorResponse(
190                 request,
191                 Status.Code.UNIMPLEMENTED,
192                 "not implemented " + request.getMessageRequestCase());
193         }
194         request = null;
195         if (closeAfterSend) {
196           serverCallStreamObserver.onCompleted();
197         } else {
198           serverCallStreamObserver.request(1);
199         }
200       }
201     }
202 
203     @Override
onCompleted()204     public void onCompleted() {
205       if (request != null) {
206         closeAfterSend = true;
207       } else {
208         serverCallStreamObserver.onCompleted();
209       }
210     }
211 
212     @Override
onError(Throwable cause)213     public void onError(Throwable cause) {
214       serverCallStreamObserver.onError(cause);
215     }
216 
getFileByName(ServerReflectionRequest request)217     private void getFileByName(ServerReflectionRequest request) {
218       String name = request.getFileByFilename();
219       FileDescriptor fd = serverReflectionIndex.getFileDescriptorByName(name);
220       if (fd != null) {
221         serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
222       } else {
223         sendErrorResponse(request, Status.Code.NOT_FOUND, "File not found.");
224       }
225     }
226 
getFileContainingSymbol(ServerReflectionRequest request)227     private void getFileContainingSymbol(ServerReflectionRequest request) {
228       String symbol = request.getFileContainingSymbol();
229       FileDescriptor fd = serverReflectionIndex.getFileDescriptorBySymbol(symbol);
230       if (fd != null) {
231         serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
232       } else {
233         sendErrorResponse(request, Status.Code.NOT_FOUND, "Symbol not found.");
234       }
235     }
236 
getFileByExtension(ServerReflectionRequest request)237     private void getFileByExtension(ServerReflectionRequest request) {
238       ExtensionRequest extensionRequest = request.getFileContainingExtension();
239       String type = extensionRequest.getContainingType();
240       int extension = extensionRequest.getExtensionNumber();
241       FileDescriptor fd =
242           serverReflectionIndex.getFileDescriptorByExtensionAndNumber(type, extension);
243       if (fd != null) {
244         serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
245       } else {
246         sendErrorResponse(request, Status.Code.NOT_FOUND, "Extension not found.");
247       }
248     }
249 
getAllExtensions(ServerReflectionRequest request)250     private void getAllExtensions(ServerReflectionRequest request) {
251       String type = request.getAllExtensionNumbersOfType();
252       Set<Integer> extensions = serverReflectionIndex.getExtensionNumbersOfType(type);
253       if (extensions != null) {
254         ExtensionNumberResponse.Builder builder =
255             ExtensionNumberResponse.newBuilder()
256                 .setBaseTypeName(type)
257                 .addAllExtensionNumber(extensions);
258         serverCallStreamObserver.onNext(
259             ServerReflectionResponse.newBuilder()
260                 .setValidHost(request.getHost())
261                 .setOriginalRequest(request)
262                 .setAllExtensionNumbersResponse(builder)
263                 .build());
264       } else {
265         sendErrorResponse(request, Status.Code.NOT_FOUND, "Type not found.");
266       }
267     }
268 
listServices(ServerReflectionRequest request)269     private void listServices(ServerReflectionRequest request) {
270       ListServiceResponse.Builder builder = ListServiceResponse.newBuilder();
271       for (String serviceName : serverReflectionIndex.getServiceNames()) {
272         builder.addService(ServiceResponse.newBuilder().setName(serviceName));
273       }
274       serverCallStreamObserver.onNext(
275           ServerReflectionResponse.newBuilder()
276               .setValidHost(request.getHost())
277               .setOriginalRequest(request)
278               .setListServicesResponse(builder)
279               .build());
280     }
281 
sendErrorResponse( ServerReflectionRequest request, Status.Code code, String message)282     private void sendErrorResponse(
283         ServerReflectionRequest request, Status.Code code, String message) {
284       ServerReflectionResponse response =
285           ServerReflectionResponse.newBuilder()
286               .setValidHost(request.getHost())
287               .setOriginalRequest(request)
288               .setErrorResponse(
289                   ErrorResponse.newBuilder()
290                       .setErrorCode(code.value())
291                       .setErrorMessage(message))
292               .build();
293       serverCallStreamObserver.onNext(response);
294     }
295 
createServerReflectionResponse( ServerReflectionRequest request, FileDescriptor fd)296     private ServerReflectionResponse createServerReflectionResponse(
297         ServerReflectionRequest request, FileDescriptor fd) {
298       FileDescriptorResponse.Builder fdRBuilder = FileDescriptorResponse.newBuilder();
299 
300       Set<String> seenFiles = new HashSet<>();
301       Queue<FileDescriptor> frontier = new ArrayDeque<>();
302       seenFiles.add(fd.getName());
303       frontier.add(fd);
304       while (!frontier.isEmpty()) {
305         FileDescriptor nextFd = frontier.remove();
306         fdRBuilder.addFileDescriptorProto(nextFd.toProto().toByteString());
307         for (FileDescriptor dependencyFd : nextFd.getDependencies()) {
308           if (!seenFiles.contains(dependencyFd.getName())) {
309             seenFiles.add(dependencyFd.getName());
310             frontier.add(dependencyFd);
311           }
312         }
313       }
314       return ServerReflectionResponse.newBuilder()
315           .setValidHost(request.getHost())
316           .setOriginalRequest(request)
317           .setFileDescriptorResponse(fdRBuilder)
318           .build();
319     }
320   }
321 
322   /**
323    * Indexes the server's services and allows lookups of file descriptors by filename, symbol, type,
324    * and extension number.
325    *
326    * <p>Internally, this stores separate indices for the immutable and mutable services. When
327    * queried, the immutable service index is checked for a matching value. Only if there is no match
328    * in the immutable service index are the mutable services checked.
329    */
330   private static final class ServerReflectionIndex {
331     private final FileDescriptorIndex immutableServicesIndex;
332     private final FileDescriptorIndex mutableServicesIndex;
333 
ServerReflectionIndex( List<ServerServiceDefinition> immutableServices, List<ServerServiceDefinition> mutableServices)334     public ServerReflectionIndex(
335         List<ServerServiceDefinition> immutableServices,
336         List<ServerServiceDefinition> mutableServices) {
337       immutableServicesIndex = new FileDescriptorIndex(immutableServices);
338       mutableServicesIndex = new FileDescriptorIndex(mutableServices);
339     }
340 
getMutableServicesIndex()341     private FileDescriptorIndex getMutableServicesIndex() {
342       return mutableServicesIndex;
343     }
344 
getServiceNames()345     private Set<String> getServiceNames() {
346       Set<String> immutableServiceNames = immutableServicesIndex.getServiceNames();
347       Set<String> mutableServiceNames = mutableServicesIndex.getServiceNames();
348       Set<String> serviceNames =
349           new HashSet<>(immutableServiceNames.size() + mutableServiceNames.size());
350       serviceNames.addAll(immutableServiceNames);
351       serviceNames.addAll(mutableServiceNames);
352       return serviceNames;
353     }
354 
355     @Nullable
getFileDescriptorByName(String name)356     private FileDescriptor getFileDescriptorByName(String name) {
357       FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name);
358       if (fd == null) {
359         fd = mutableServicesIndex.getFileDescriptorByName(name);
360       }
361       return fd;
362     }
363 
364     @Nullable
getFileDescriptorBySymbol(String symbol)365     private FileDescriptor getFileDescriptorBySymbol(String symbol) {
366       FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol);
367       if (fd == null) {
368         fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol);
369       }
370       return fd;
371     }
372 
373     @Nullable
getFileDescriptorByExtensionAndNumber(String type, int extension)374     private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) {
375       FileDescriptor fd =
376           immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension);
377       if (fd == null) {
378         fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension);
379       }
380       return fd;
381     }
382 
383     @Nullable
getExtensionNumbersOfType(String type)384     private Set<Integer> getExtensionNumbersOfType(String type) {
385       Set<Integer> extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type);
386       if (extensionNumbers == null) {
387         extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type);
388       }
389       return extensionNumbers;
390     }
391   }
392 
393   /**
394    * Provides a set of methods for answering reflection queries for the file descriptors underlying
395    * a set of services. Used by {@link ServerReflectionIndex} to separately index immutable and
396    * mutable services.
397    */
398   private static final class FileDescriptorIndex {
399     private final Set<String> serviceNames = new HashSet<>();
400     private final Set<FileDescriptor> serviceFileDescriptors = new HashSet<>();
401     private final Map<String, FileDescriptor> fileDescriptorsByName =
402         new HashMap<>();
403     private final Map<String, FileDescriptor> fileDescriptorsBySymbol =
404         new HashMap<>();
405     private final Map<String, Map<Integer, FileDescriptor>> fileDescriptorsByExtensionAndNumber =
406         new HashMap<>();
407 
FileDescriptorIndex(List<ServerServiceDefinition> services)408     FileDescriptorIndex(List<ServerServiceDefinition> services) {
409       Queue<FileDescriptor> fileDescriptorsToProcess = new ArrayDeque<>();
410       Set<String> seenFiles = new HashSet<>();
411       for (ServerServiceDefinition service : services) {
412         io.grpc.ServiceDescriptor serviceDescriptor = service.getServiceDescriptor();
413         if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) {
414           FileDescriptor fileDescriptor =
415               ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor())
416                   .getFileDescriptor();
417           String serviceName = serviceDescriptor.getName();
418           checkState(
419               !serviceNames.contains(serviceName), "Service already defined: %s", serviceName);
420           serviceFileDescriptors.add(fileDescriptor);
421           serviceNames.add(serviceName);
422           if (!seenFiles.contains(fileDescriptor.getName())) {
423             seenFiles.add(fileDescriptor.getName());
424             fileDescriptorsToProcess.add(fileDescriptor);
425           }
426         }
427       }
428 
429       while (!fileDescriptorsToProcess.isEmpty()) {
430         FileDescriptor currentFd = fileDescriptorsToProcess.remove();
431         processFileDescriptor(currentFd);
432         for (FileDescriptor dependencyFd : currentFd.getDependencies()) {
433           if (!seenFiles.contains(dependencyFd.getName())) {
434             seenFiles.add(dependencyFd.getName());
435             fileDescriptorsToProcess.add(dependencyFd);
436           }
437         }
438       }
439     }
440 
441     /**
442      * Returns the file descriptors for the indexed services, but not their dependencies. This is
443      * used to check if the server's mutable services have changed.
444      */
getServiceFileDescriptors()445     private Set<FileDescriptor> getServiceFileDescriptors() {
446       return Collections.unmodifiableSet(serviceFileDescriptors);
447     }
448 
getServiceNames()449     private Set<String> getServiceNames() {
450       return Collections.unmodifiableSet(serviceNames);
451     }
452 
453     @Nullable
getFileDescriptorByName(String name)454     private FileDescriptor getFileDescriptorByName(String name) {
455       return fileDescriptorsByName.get(name);
456     }
457 
458     @Nullable
getFileDescriptorBySymbol(String symbol)459     private FileDescriptor getFileDescriptorBySymbol(String symbol) {
460       return fileDescriptorsBySymbol.get(symbol);
461     }
462 
463     @Nullable
getFileDescriptorByExtensionAndNumber(String type, int number)464     private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int number) {
465       if (fileDescriptorsByExtensionAndNumber.containsKey(type)) {
466         return fileDescriptorsByExtensionAndNumber.get(type).get(number);
467       }
468       return null;
469     }
470 
471     @Nullable
getExtensionNumbersOfType(String type)472     private Set<Integer> getExtensionNumbersOfType(String type) {
473       if (fileDescriptorsByExtensionAndNumber.containsKey(type)) {
474         return Collections.unmodifiableSet(fileDescriptorsByExtensionAndNumber.get(type).keySet());
475       }
476       return null;
477     }
478 
processFileDescriptor(FileDescriptor fd)479     private void processFileDescriptor(FileDescriptor fd) {
480       String fdName = fd.getName();
481       checkState(!fileDescriptorsByName.containsKey(fdName), "File name already used: %s", fdName);
482       fileDescriptorsByName.put(fdName, fd);
483       for (ServiceDescriptor service : fd.getServices()) {
484         processService(service, fd);
485       }
486       for (Descriptor type : fd.getMessageTypes()) {
487         processType(type, fd);
488       }
489       for (FieldDescriptor extension : fd.getExtensions()) {
490         processExtension(extension, fd);
491       }
492     }
493 
processService(ServiceDescriptor service, FileDescriptor fd)494     private void processService(ServiceDescriptor service, FileDescriptor fd) {
495       String serviceName = service.getFullName();
496       checkState(
497           !fileDescriptorsBySymbol.containsKey(serviceName),
498           "Service already defined: %s",
499           serviceName);
500       fileDescriptorsBySymbol.put(serviceName, fd);
501       for (MethodDescriptor method : service.getMethods()) {
502         String methodName = method.getFullName();
503         checkState(
504             !fileDescriptorsBySymbol.containsKey(methodName),
505             "Method already defined: %s",
506             methodName);
507         fileDescriptorsBySymbol.put(methodName, fd);
508       }
509     }
510 
processType(Descriptor type, FileDescriptor fd)511     private void processType(Descriptor type, FileDescriptor fd) {
512       String typeName = type.getFullName();
513       checkState(
514           !fileDescriptorsBySymbol.containsKey(typeName), "Type already defined: %s", typeName);
515       fileDescriptorsBySymbol.put(typeName, fd);
516       for (FieldDescriptor extension : type.getExtensions()) {
517         processExtension(extension, fd);
518       }
519       for (Descriptor nestedType : type.getNestedTypes()) {
520         processType(nestedType, fd);
521       }
522     }
523 
processExtension(FieldDescriptor extension, FileDescriptor fd)524     private void processExtension(FieldDescriptor extension, FileDescriptor fd) {
525       String extensionName = extension.getContainingType().getFullName();
526       int extensionNumber = extension.getNumber();
527       if (!fileDescriptorsByExtensionAndNumber.containsKey(extensionName)) {
528         fileDescriptorsByExtensionAndNumber.put(
529             extensionName, new HashMap<Integer, FileDescriptor>());
530       }
531       checkState(
532           !fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber),
533           "Extension name and number already defined: %s, %s",
534           extensionName,
535           extensionNumber);
536       fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd);
537     }
538   }
539 }
540