1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.protobuf.services; 18 19 import static com.google.common.base.Preconditions.checkNotNull; 20 import static com.google.common.base.Preconditions.checkState; 21 22 import com.google.protobuf.Descriptors.Descriptor; 23 import com.google.protobuf.Descriptors.FieldDescriptor; 24 import com.google.protobuf.Descriptors.FileDescriptor; 25 import com.google.protobuf.Descriptors.MethodDescriptor; 26 import com.google.protobuf.Descriptors.ServiceDescriptor; 27 import io.grpc.BindableService; 28 import io.grpc.ExperimentalApi; 29 import io.grpc.InternalServer; 30 import io.grpc.Server; 31 import io.grpc.ServerServiceDefinition; 32 import io.grpc.Status; 33 import io.grpc.protobuf.ProtoFileDescriptorSupplier; 34 import io.grpc.reflection.v1alpha.ErrorResponse; 35 import io.grpc.reflection.v1alpha.ExtensionNumberResponse; 36 import io.grpc.reflection.v1alpha.ExtensionRequest; 37 import io.grpc.reflection.v1alpha.FileDescriptorResponse; 38 import io.grpc.reflection.v1alpha.ListServiceResponse; 39 import io.grpc.reflection.v1alpha.ServerReflectionGrpc; 40 import io.grpc.reflection.v1alpha.ServerReflectionRequest; 41 import io.grpc.reflection.v1alpha.ServerReflectionResponse; 42 import io.grpc.reflection.v1alpha.ServiceResponse; 43 import io.grpc.stub.ServerCallStreamObserver; 44 import io.grpc.stub.StreamObserver; 45 import java.util.ArrayDeque; 46 import java.util.Collections; 47 import java.util.HashMap; 48 import java.util.HashSet; 49 import java.util.List; 50 import java.util.Map; 51 import java.util.Queue; 52 import java.util.Set; 53 import java.util.WeakHashMap; 54 import javax.annotation.Nullable; 55 import javax.annotation.concurrent.GuardedBy; 56 57 /** 58 * Provides a reflection service for Protobuf services (including the reflection service itself). 59 * 60 * <p>Separately tracks mutable and immutable services. Throws an exception if either group of 61 * services contains multiple Protobuf files with declarations of the same service, method, type, or 62 * extension. 63 */ 64 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222") 65 public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase { 66 67 private final Object lock = new Object(); 68 69 @GuardedBy("lock") 70 private final Map<Server, ServerReflectionIndex> serverReflectionIndexes = new WeakHashMap<>(); 71 ProtoReflectionService()72 private ProtoReflectionService() {} 73 74 /** 75 * Creates a instance of {@link ProtoReflectionService}. 76 */ newInstance()77 public static BindableService newInstance() { 78 return new ProtoReflectionService(); 79 } 80 81 /** 82 * Retrieves the index for services of the server that dispatches the current call. Computes 83 * one if not exist. The index is updated if any changes to the server's mutable services are 84 * detected. A change is any addition or removal in the set of file descriptors attached to the 85 * mutable services or a change in the service names. 86 */ getRefreshedIndex()87 private ServerReflectionIndex getRefreshedIndex() { 88 synchronized (lock) { 89 Server server = InternalServer.SERVER_CONTEXT_KEY.get(); 90 ServerReflectionIndex index = serverReflectionIndexes.get(server); 91 if (index == null) { 92 index = 93 new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices()); 94 serverReflectionIndexes.put(server, index); 95 return index; 96 } 97 98 Set<FileDescriptor> serverFileDescriptors = new HashSet<>(); 99 Set<String> serverServiceNames = new HashSet<>(); 100 List<ServerServiceDefinition> serverMutableServices = server.getMutableServices(); 101 for (ServerServiceDefinition mutableService : serverMutableServices) { 102 io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor(); 103 if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { 104 String serviceName = serviceDescriptor.getName(); 105 FileDescriptor fileDescriptor = 106 ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) 107 .getFileDescriptor(); 108 serverFileDescriptors.add(fileDescriptor); 109 serverServiceNames.add(serviceName); 110 } 111 } 112 113 // Replace the index if the underlying mutable services have changed. Check both the file 114 // descriptors and the service names, because one file descriptor can define multiple 115 // services. 116 FileDescriptorIndex mutableServicesIndex = index.getMutableServicesIndex(); 117 if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors) 118 || !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) { 119 index = 120 new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices); 121 serverReflectionIndexes.put(server, index); 122 } 123 124 return index; 125 } 126 } 127 128 @Override serverReflectionInfo( final StreamObserver<ServerReflectionResponse> responseObserver)129 public StreamObserver<ServerReflectionRequest> serverReflectionInfo( 130 final StreamObserver<ServerReflectionResponse> responseObserver) { 131 final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver = 132 (ServerCallStreamObserver<ServerReflectionResponse>) responseObserver; 133 ProtoReflectionStreamObserver requestObserver = 134 new ProtoReflectionStreamObserver(getRefreshedIndex(), serverCallStreamObserver); 135 serverCallStreamObserver.setOnReadyHandler(requestObserver); 136 serverCallStreamObserver.disableAutoRequest(); 137 serverCallStreamObserver.request(1); 138 return requestObserver; 139 } 140 141 private static class ProtoReflectionStreamObserver 142 implements Runnable, StreamObserver<ServerReflectionRequest> { 143 private final ServerReflectionIndex serverReflectionIndex; 144 private final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver; 145 146 private boolean closeAfterSend = false; 147 private ServerReflectionRequest request; 148 ProtoReflectionStreamObserver( ServerReflectionIndex serverReflectionIndex, ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver)149 ProtoReflectionStreamObserver( 150 ServerReflectionIndex serverReflectionIndex, 151 ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver) { 152 this.serverReflectionIndex = serverReflectionIndex; 153 this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer"); 154 } 155 156 @Override run()157 public void run() { 158 if (request != null) { 159 handleReflectionRequest(); 160 } 161 } 162 163 @Override onNext(ServerReflectionRequest request)164 public void onNext(ServerReflectionRequest request) { 165 checkState(this.request == null); 166 this.request = checkNotNull(request); 167 handleReflectionRequest(); 168 } 169 handleReflectionRequest()170 private void handleReflectionRequest() { 171 if (serverCallStreamObserver.isReady()) { 172 switch (request.getMessageRequestCase()) { 173 case FILE_BY_FILENAME: 174 getFileByName(request); 175 break; 176 case FILE_CONTAINING_SYMBOL: 177 getFileContainingSymbol(request); 178 break; 179 case FILE_CONTAINING_EXTENSION: 180 getFileByExtension(request); 181 break; 182 case ALL_EXTENSION_NUMBERS_OF_TYPE: 183 getAllExtensions(request); 184 break; 185 case LIST_SERVICES: 186 listServices(request); 187 break; 188 default: 189 sendErrorResponse( 190 request, 191 Status.Code.UNIMPLEMENTED, 192 "not implemented " + request.getMessageRequestCase()); 193 } 194 request = null; 195 if (closeAfterSend) { 196 serverCallStreamObserver.onCompleted(); 197 } else { 198 serverCallStreamObserver.request(1); 199 } 200 } 201 } 202 203 @Override onCompleted()204 public void onCompleted() { 205 if (request != null) { 206 closeAfterSend = true; 207 } else { 208 serverCallStreamObserver.onCompleted(); 209 } 210 } 211 212 @Override onError(Throwable cause)213 public void onError(Throwable cause) { 214 serverCallStreamObserver.onError(cause); 215 } 216 getFileByName(ServerReflectionRequest request)217 private void getFileByName(ServerReflectionRequest request) { 218 String name = request.getFileByFilename(); 219 FileDescriptor fd = serverReflectionIndex.getFileDescriptorByName(name); 220 if (fd != null) { 221 serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); 222 } else { 223 sendErrorResponse(request, Status.Code.NOT_FOUND, "File not found."); 224 } 225 } 226 getFileContainingSymbol(ServerReflectionRequest request)227 private void getFileContainingSymbol(ServerReflectionRequest request) { 228 String symbol = request.getFileContainingSymbol(); 229 FileDescriptor fd = serverReflectionIndex.getFileDescriptorBySymbol(symbol); 230 if (fd != null) { 231 serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); 232 } else { 233 sendErrorResponse(request, Status.Code.NOT_FOUND, "Symbol not found."); 234 } 235 } 236 getFileByExtension(ServerReflectionRequest request)237 private void getFileByExtension(ServerReflectionRequest request) { 238 ExtensionRequest extensionRequest = request.getFileContainingExtension(); 239 String type = extensionRequest.getContainingType(); 240 int extension = extensionRequest.getExtensionNumber(); 241 FileDescriptor fd = 242 serverReflectionIndex.getFileDescriptorByExtensionAndNumber(type, extension); 243 if (fd != null) { 244 serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); 245 } else { 246 sendErrorResponse(request, Status.Code.NOT_FOUND, "Extension not found."); 247 } 248 } 249 getAllExtensions(ServerReflectionRequest request)250 private void getAllExtensions(ServerReflectionRequest request) { 251 String type = request.getAllExtensionNumbersOfType(); 252 Set<Integer> extensions = serverReflectionIndex.getExtensionNumbersOfType(type); 253 if (extensions != null) { 254 ExtensionNumberResponse.Builder builder = 255 ExtensionNumberResponse.newBuilder() 256 .setBaseTypeName(type) 257 .addAllExtensionNumber(extensions); 258 serverCallStreamObserver.onNext( 259 ServerReflectionResponse.newBuilder() 260 .setValidHost(request.getHost()) 261 .setOriginalRequest(request) 262 .setAllExtensionNumbersResponse(builder) 263 .build()); 264 } else { 265 sendErrorResponse(request, Status.Code.NOT_FOUND, "Type not found."); 266 } 267 } 268 listServices(ServerReflectionRequest request)269 private void listServices(ServerReflectionRequest request) { 270 ListServiceResponse.Builder builder = ListServiceResponse.newBuilder(); 271 for (String serviceName : serverReflectionIndex.getServiceNames()) { 272 builder.addService(ServiceResponse.newBuilder().setName(serviceName)); 273 } 274 serverCallStreamObserver.onNext( 275 ServerReflectionResponse.newBuilder() 276 .setValidHost(request.getHost()) 277 .setOriginalRequest(request) 278 .setListServicesResponse(builder) 279 .build()); 280 } 281 sendErrorResponse( ServerReflectionRequest request, Status.Code code, String message)282 private void sendErrorResponse( 283 ServerReflectionRequest request, Status.Code code, String message) { 284 ServerReflectionResponse response = 285 ServerReflectionResponse.newBuilder() 286 .setValidHost(request.getHost()) 287 .setOriginalRequest(request) 288 .setErrorResponse( 289 ErrorResponse.newBuilder() 290 .setErrorCode(code.value()) 291 .setErrorMessage(message)) 292 .build(); 293 serverCallStreamObserver.onNext(response); 294 } 295 createServerReflectionResponse( ServerReflectionRequest request, FileDescriptor fd)296 private ServerReflectionResponse createServerReflectionResponse( 297 ServerReflectionRequest request, FileDescriptor fd) { 298 FileDescriptorResponse.Builder fdRBuilder = FileDescriptorResponse.newBuilder(); 299 300 Set<String> seenFiles = new HashSet<>(); 301 Queue<FileDescriptor> frontier = new ArrayDeque<>(); 302 seenFiles.add(fd.getName()); 303 frontier.add(fd); 304 while (!frontier.isEmpty()) { 305 FileDescriptor nextFd = frontier.remove(); 306 fdRBuilder.addFileDescriptorProto(nextFd.toProto().toByteString()); 307 for (FileDescriptor dependencyFd : nextFd.getDependencies()) { 308 if (!seenFiles.contains(dependencyFd.getName())) { 309 seenFiles.add(dependencyFd.getName()); 310 frontier.add(dependencyFd); 311 } 312 } 313 } 314 return ServerReflectionResponse.newBuilder() 315 .setValidHost(request.getHost()) 316 .setOriginalRequest(request) 317 .setFileDescriptorResponse(fdRBuilder) 318 .build(); 319 } 320 } 321 322 /** 323 * Indexes the server's services and allows lookups of file descriptors by filename, symbol, type, 324 * and extension number. 325 * 326 * <p>Internally, this stores separate indices for the immutable and mutable services. When 327 * queried, the immutable service index is checked for a matching value. Only if there is no match 328 * in the immutable service index are the mutable services checked. 329 */ 330 private static final class ServerReflectionIndex { 331 private final FileDescriptorIndex immutableServicesIndex; 332 private final FileDescriptorIndex mutableServicesIndex; 333 ServerReflectionIndex( List<ServerServiceDefinition> immutableServices, List<ServerServiceDefinition> mutableServices)334 public ServerReflectionIndex( 335 List<ServerServiceDefinition> immutableServices, 336 List<ServerServiceDefinition> mutableServices) { 337 immutableServicesIndex = new FileDescriptorIndex(immutableServices); 338 mutableServicesIndex = new FileDescriptorIndex(mutableServices); 339 } 340 getMutableServicesIndex()341 private FileDescriptorIndex getMutableServicesIndex() { 342 return mutableServicesIndex; 343 } 344 getServiceNames()345 private Set<String> getServiceNames() { 346 Set<String> immutableServiceNames = immutableServicesIndex.getServiceNames(); 347 Set<String> mutableServiceNames = mutableServicesIndex.getServiceNames(); 348 Set<String> serviceNames = 349 new HashSet<>(immutableServiceNames.size() + mutableServiceNames.size()); 350 serviceNames.addAll(immutableServiceNames); 351 serviceNames.addAll(mutableServiceNames); 352 return serviceNames; 353 } 354 355 @Nullable getFileDescriptorByName(String name)356 private FileDescriptor getFileDescriptorByName(String name) { 357 FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name); 358 if (fd == null) { 359 fd = mutableServicesIndex.getFileDescriptorByName(name); 360 } 361 return fd; 362 } 363 364 @Nullable getFileDescriptorBySymbol(String symbol)365 private FileDescriptor getFileDescriptorBySymbol(String symbol) { 366 FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol); 367 if (fd == null) { 368 fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol); 369 } 370 return fd; 371 } 372 373 @Nullable getFileDescriptorByExtensionAndNumber(String type, int extension)374 private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) { 375 FileDescriptor fd = 376 immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); 377 if (fd == null) { 378 fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); 379 } 380 return fd; 381 } 382 383 @Nullable getExtensionNumbersOfType(String type)384 private Set<Integer> getExtensionNumbersOfType(String type) { 385 Set<Integer> extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type); 386 if (extensionNumbers == null) { 387 extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type); 388 } 389 return extensionNumbers; 390 } 391 } 392 393 /** 394 * Provides a set of methods for answering reflection queries for the file descriptors underlying 395 * a set of services. Used by {@link ServerReflectionIndex} to separately index immutable and 396 * mutable services. 397 */ 398 private static final class FileDescriptorIndex { 399 private final Set<String> serviceNames = new HashSet<>(); 400 private final Set<FileDescriptor> serviceFileDescriptors = new HashSet<>(); 401 private final Map<String, FileDescriptor> fileDescriptorsByName = 402 new HashMap<>(); 403 private final Map<String, FileDescriptor> fileDescriptorsBySymbol = 404 new HashMap<>(); 405 private final Map<String, Map<Integer, FileDescriptor>> fileDescriptorsByExtensionAndNumber = 406 new HashMap<>(); 407 FileDescriptorIndex(List<ServerServiceDefinition> services)408 FileDescriptorIndex(List<ServerServiceDefinition> services) { 409 Queue<FileDescriptor> fileDescriptorsToProcess = new ArrayDeque<>(); 410 Set<String> seenFiles = new HashSet<>(); 411 for (ServerServiceDefinition service : services) { 412 io.grpc.ServiceDescriptor serviceDescriptor = service.getServiceDescriptor(); 413 if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { 414 FileDescriptor fileDescriptor = 415 ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) 416 .getFileDescriptor(); 417 String serviceName = serviceDescriptor.getName(); 418 checkState( 419 !serviceNames.contains(serviceName), "Service already defined: %s", serviceName); 420 serviceFileDescriptors.add(fileDescriptor); 421 serviceNames.add(serviceName); 422 if (!seenFiles.contains(fileDescriptor.getName())) { 423 seenFiles.add(fileDescriptor.getName()); 424 fileDescriptorsToProcess.add(fileDescriptor); 425 } 426 } 427 } 428 429 while (!fileDescriptorsToProcess.isEmpty()) { 430 FileDescriptor currentFd = fileDescriptorsToProcess.remove(); 431 processFileDescriptor(currentFd); 432 for (FileDescriptor dependencyFd : currentFd.getDependencies()) { 433 if (!seenFiles.contains(dependencyFd.getName())) { 434 seenFiles.add(dependencyFd.getName()); 435 fileDescriptorsToProcess.add(dependencyFd); 436 } 437 } 438 } 439 } 440 441 /** 442 * Returns the file descriptors for the indexed services, but not their dependencies. This is 443 * used to check if the server's mutable services have changed. 444 */ getServiceFileDescriptors()445 private Set<FileDescriptor> getServiceFileDescriptors() { 446 return Collections.unmodifiableSet(serviceFileDescriptors); 447 } 448 getServiceNames()449 private Set<String> getServiceNames() { 450 return Collections.unmodifiableSet(serviceNames); 451 } 452 453 @Nullable getFileDescriptorByName(String name)454 private FileDescriptor getFileDescriptorByName(String name) { 455 return fileDescriptorsByName.get(name); 456 } 457 458 @Nullable getFileDescriptorBySymbol(String symbol)459 private FileDescriptor getFileDescriptorBySymbol(String symbol) { 460 return fileDescriptorsBySymbol.get(symbol); 461 } 462 463 @Nullable getFileDescriptorByExtensionAndNumber(String type, int number)464 private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int number) { 465 if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { 466 return fileDescriptorsByExtensionAndNumber.get(type).get(number); 467 } 468 return null; 469 } 470 471 @Nullable getExtensionNumbersOfType(String type)472 private Set<Integer> getExtensionNumbersOfType(String type) { 473 if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { 474 return Collections.unmodifiableSet(fileDescriptorsByExtensionAndNumber.get(type).keySet()); 475 } 476 return null; 477 } 478 processFileDescriptor(FileDescriptor fd)479 private void processFileDescriptor(FileDescriptor fd) { 480 String fdName = fd.getName(); 481 checkState(!fileDescriptorsByName.containsKey(fdName), "File name already used: %s", fdName); 482 fileDescriptorsByName.put(fdName, fd); 483 for (ServiceDescriptor service : fd.getServices()) { 484 processService(service, fd); 485 } 486 for (Descriptor type : fd.getMessageTypes()) { 487 processType(type, fd); 488 } 489 for (FieldDescriptor extension : fd.getExtensions()) { 490 processExtension(extension, fd); 491 } 492 } 493 processService(ServiceDescriptor service, FileDescriptor fd)494 private void processService(ServiceDescriptor service, FileDescriptor fd) { 495 String serviceName = service.getFullName(); 496 checkState( 497 !fileDescriptorsBySymbol.containsKey(serviceName), 498 "Service already defined: %s", 499 serviceName); 500 fileDescriptorsBySymbol.put(serviceName, fd); 501 for (MethodDescriptor method : service.getMethods()) { 502 String methodName = method.getFullName(); 503 checkState( 504 !fileDescriptorsBySymbol.containsKey(methodName), 505 "Method already defined: %s", 506 methodName); 507 fileDescriptorsBySymbol.put(methodName, fd); 508 } 509 } 510 processType(Descriptor type, FileDescriptor fd)511 private void processType(Descriptor type, FileDescriptor fd) { 512 String typeName = type.getFullName(); 513 checkState( 514 !fileDescriptorsBySymbol.containsKey(typeName), "Type already defined: %s", typeName); 515 fileDescriptorsBySymbol.put(typeName, fd); 516 for (FieldDescriptor extension : type.getExtensions()) { 517 processExtension(extension, fd); 518 } 519 for (Descriptor nestedType : type.getNestedTypes()) { 520 processType(nestedType, fd); 521 } 522 } 523 processExtension(FieldDescriptor extension, FileDescriptor fd)524 private void processExtension(FieldDescriptor extension, FileDescriptor fd) { 525 String extensionName = extension.getContainingType().getFullName(); 526 int extensionNumber = extension.getNumber(); 527 if (!fileDescriptorsByExtensionAndNumber.containsKey(extensionName)) { 528 fileDescriptorsByExtensionAndNumber.put( 529 extensionName, new HashMap<Integer, FileDescriptor>()); 530 } 531 checkState( 532 !fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber), 533 "Extension name and number already defined: %s, %s", 534 extensionName, 535 extensionNumber); 536 fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd); 537 } 538 } 539 } 540