1 /* 2 * Copyright 2015 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.testing.integration; 18 19 import static com.google.common.truth.Truth.assertAbout; 20 import static io.grpc.testing.DeadlineSubject.deadline; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertNotSame; 23 import static org.junit.Assert.assertTrue; 24 import static org.junit.Assert.fail; 25 26 import com.google.common.util.concurrent.SettableFuture; 27 import io.grpc.Context; 28 import io.grpc.Context.CancellableContext; 29 import io.grpc.Deadline; 30 import io.grpc.ManagedChannel; 31 import io.grpc.Metadata; 32 import io.grpc.Server; 33 import io.grpc.ServerCall; 34 import io.grpc.ServerCallHandler; 35 import io.grpc.ServerInterceptor; 36 import io.grpc.ServerInterceptors; 37 import io.grpc.Status; 38 import io.grpc.StatusRuntimeException; 39 import io.grpc.inprocess.InProcessChannelBuilder; 40 import io.grpc.inprocess.InProcessServerBuilder; 41 import io.grpc.stub.ServerCallStreamObserver; 42 import io.grpc.stub.StreamObserver; 43 import io.grpc.testing.integration.Messages.SimpleRequest; 44 import io.grpc.testing.integration.Messages.SimpleResponse; 45 import java.io.IOException; 46 import java.util.concurrent.CountDownLatch; 47 import java.util.concurrent.ExecutionException; 48 import java.util.concurrent.ExecutorService; 49 import java.util.concurrent.Executors; 50 import java.util.concurrent.Future; 51 import java.util.concurrent.TimeUnit; 52 import java.util.concurrent.atomic.AtomicInteger; 53 import org.junit.After; 54 import org.junit.Before; 55 import org.junit.Rule; 56 import org.junit.Test; 57 import org.junit.runner.RunWith; 58 import org.junit.runners.JUnit4; 59 import org.mockito.Mock; 60 import org.mockito.junit.MockitoJUnit; 61 import org.mockito.junit.MockitoRule; 62 63 /** 64 * Integration test for various forms of cancellation and deadline propagation. 65 */ 66 @RunWith(JUnit4.class) 67 public class CascadingTest { 68 @Rule public final MockitoRule mocks = MockitoJUnit.rule(); 69 70 @Mock 71 TestServiceGrpc.TestServiceImplBase service; 72 private ManagedChannel channel; 73 private Server server; 74 private CountDownLatch observedCancellations; 75 private CountDownLatch receivedCancellations; 76 private TestServiceGrpc.TestServiceBlockingStub blockingStub; 77 private TestServiceGrpc.TestServiceStub asyncStub; 78 private TestServiceGrpc.TestServiceFutureStub futureStub; 79 private ExecutorService otherWork; 80 81 @Before setUp()82 public void setUp() throws Exception { 83 // Use a cached thread pool as we need a thread for each blocked call 84 otherWork = Executors.newCachedThreadPool(); 85 channel = InProcessChannelBuilder.forName("channel").executor(otherWork).build(); 86 blockingStub = TestServiceGrpc.newBlockingStub(channel); 87 asyncStub = TestServiceGrpc.newStub(channel); 88 futureStub = TestServiceGrpc.newFutureStub(channel); 89 } 90 91 @After tearDown()92 public void tearDown() { 93 channel.shutdownNow(); 94 server.shutdownNow(); 95 otherWork.shutdownNow(); 96 } 97 98 /** 99 * Test {@link Context} cancellation propagates from the first node in the call chain all the way 100 * to the last. 101 */ 102 @Test testCascadingCancellationViaOuterContextCancellation()103 public void testCascadingCancellationViaOuterContextCancellation() throws Exception { 104 observedCancellations = new CountDownLatch(2); 105 receivedCancellations = new CountDownLatch(3); 106 Future<?> chainReady = startChainingServer(3); 107 CancellableContext context = Context.current().withCancellation(); 108 Future<SimpleResponse> future; 109 Context prevContext = context.attach(); 110 try { 111 future = futureStub.unaryCall(SimpleRequest.getDefaultInstance()); 112 } finally { 113 context.detach(prevContext); 114 } 115 chainReady.get(5, TimeUnit.SECONDS); 116 117 context.cancel(null); 118 try { 119 future.get(5, TimeUnit.SECONDS); 120 fail("Expected cancellation"); 121 } catch (ExecutionException ex) { 122 Status status = Status.fromThrowable(ex); 123 assertEquals(Status.Code.CANCELLED, status.getCode()); 124 125 // Should have observed 2 cancellations responses from downstream servers 126 if (!observedCancellations.await(5, TimeUnit.SECONDS)) { 127 fail("Expected number of cancellations not observed by clients"); 128 } 129 if (!receivedCancellations.await(5, TimeUnit.SECONDS)) { 130 fail("Expected number of cancellations to be received by servers not observed"); 131 } 132 } 133 } 134 135 /** 136 * Test that cancellation via call cancellation propagates down the call. 137 */ 138 @Test testCascadingCancellationViaRpcCancel()139 public void testCascadingCancellationViaRpcCancel() throws Exception { 140 observedCancellations = new CountDownLatch(2); 141 receivedCancellations = new CountDownLatch(3); 142 Future<?> chainReady = startChainingServer(3); 143 Future<SimpleResponse> future = futureStub.unaryCall(SimpleRequest.getDefaultInstance()); 144 chainReady.get(5, TimeUnit.SECONDS); 145 146 future.cancel(true); 147 assertTrue(future.isCancelled()); 148 if (!observedCancellations.await(5, TimeUnit.SECONDS)) { 149 fail("Expected number of cancellations not observed by clients"); 150 } 151 if (!receivedCancellations.await(5, TimeUnit.SECONDS)) { 152 fail("Expected number of cancellations to be received by servers not observed"); 153 } 154 } 155 156 /** 157 * Test that when RPC cancellation propagates up a call chain, the cancellation of the parent 158 * RPC triggers cancellation of all of its children. 159 */ 160 @Test testCascadingCancellationViaLeafFailure()161 public void testCascadingCancellationViaLeafFailure() throws Exception { 162 // All nodes (15) except one edge of the tree (4) will be cancelled. 163 observedCancellations = new CountDownLatch(11); 164 receivedCancellations = new CountDownLatch(11); 165 startCallTreeServer(3); 166 try { 167 // Use response size limit to control tree nodeCount. 168 blockingStub.unaryCall(Messages.SimpleRequest.newBuilder().setResponseSize(3).build()); 169 fail("Expected abort"); 170 } catch (StatusRuntimeException sre) { 171 // Wait for the workers to finish 172 Status status = sre.getStatus(); 173 // Outermost caller observes ABORTED propagating up from the failing leaf, 174 // The descendant RPCs are cancelled so they receive CANCELLED. 175 assertEquals(Status.Code.ABORTED, status.getCode()); 176 177 if (!observedCancellations.await(5, TimeUnit.SECONDS)) { 178 fail("Expected number of cancellations not observed by clients"); 179 } 180 if (!receivedCancellations.await(5, TimeUnit.SECONDS)) { 181 fail("Expected number of cancellations to be received by servers not observed"); 182 } 183 } 184 } 185 186 @Test testDeadlinePropagation()187 public void testDeadlinePropagation() throws Exception { 188 final AtomicInteger recursionDepthRemaining = new AtomicInteger(3); 189 final SettableFuture<Deadline> finalDeadline = SettableFuture.create(); 190 class DeadlineSaver extends TestServiceGrpc.TestServiceImplBase { 191 @Override 192 public void unaryCall(final SimpleRequest request, 193 final StreamObserver<SimpleResponse> responseObserver) { 194 Context.currentContextExecutor(otherWork).execute(new Runnable() { 195 @Override 196 public void run() { 197 try { 198 if (recursionDepthRemaining.decrementAndGet() == 0) { 199 finalDeadline.set(Context.current().getDeadline()); 200 responseObserver.onNext(SimpleResponse.getDefaultInstance()); 201 } else { 202 responseObserver.onNext(blockingStub.unaryCall(request)); 203 } 204 responseObserver.onCompleted(); 205 } catch (Exception ex) { 206 responseObserver.onError(ex); 207 } 208 } 209 }); 210 } 211 } 212 213 server = InProcessServerBuilder.forName("channel").executor(otherWork) 214 .addService(new DeadlineSaver()) 215 .build().start(); 216 217 Deadline initialDeadline = Deadline.after(1, TimeUnit.MINUTES); 218 blockingStub.withDeadline(initialDeadline).unaryCall(SimpleRequest.getDefaultInstance()); 219 assertNotSame(initialDeadline, finalDeadline); 220 // Since deadline is re-calculated at each hop, some variance is acceptable and expected. 221 assertAbout(deadline()) 222 .that(finalDeadline.get()).isWithin(1, TimeUnit.SECONDS).of(initialDeadline); 223 } 224 225 /** 226 * Create a chain of client to server calls which can be cancelled top down. 227 * 228 * @return a Future that completes when call chain is created 229 */ startChainingServer(final int depthThreshold)230 private Future<?> startChainingServer(final int depthThreshold) throws IOException { 231 final AtomicInteger serversReady = new AtomicInteger(); 232 final SettableFuture<Void> chainReady = SettableFuture.create(); 233 class ChainingService extends TestServiceGrpc.TestServiceImplBase { 234 @Override 235 public void unaryCall(final SimpleRequest request, 236 final StreamObserver<SimpleResponse> responseObserver) { 237 ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(new Runnable() { 238 @Override 239 public void run() { 240 receivedCancellations.countDown(); 241 } 242 }); 243 if (serversReady.incrementAndGet() == depthThreshold) { 244 // Stop recursion 245 chainReady.set(null); 246 return; 247 } 248 249 Context.currentContextExecutor(otherWork).execute(new Runnable() { 250 @Override 251 public void run() { 252 try { 253 blockingStub.unaryCall(request); 254 } catch (StatusRuntimeException e) { 255 Status status = e.getStatus(); 256 if (status.getCode() == Status.Code.CANCELLED) { 257 observedCancellations.countDown(); 258 } else { 259 responseObserver.onError(e); 260 } 261 } 262 } 263 }); 264 } 265 } 266 267 server = InProcessServerBuilder.forName("channel").executor(otherWork) 268 .addService(new ChainingService()) 269 .build().start(); 270 return chainReady; 271 } 272 273 /** 274 * Create a tree of client to server calls where each received call on the server 275 * fans out to two downstream calls. Uses SimpleRequest.response_size to limit the nodeCount 276 * of the tree. One of the leaves will ABORT to trigger cancellation back up to tree. 277 */ startCallTreeServer(int depthThreshold)278 private void startCallTreeServer(int depthThreshold) throws IOException { 279 final AtomicInteger nodeCount = new AtomicInteger((2 << depthThreshold) - 1); 280 server = InProcessServerBuilder.forName("channel").executor(otherWork).addService( 281 ServerInterceptors.intercept(service, 282 new ServerInterceptor() { 283 @Override 284 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 285 final ServerCall<ReqT, RespT> call, 286 Metadata headers, 287 ServerCallHandler<ReqT, RespT> next) { 288 // Respond with the headers but nothing else. 289 call.sendHeaders(new Metadata()); 290 call.request(1); 291 return new ServerCall.Listener<ReqT>() { 292 @Override 293 public void onMessage(final ReqT message) { 294 Messages.SimpleRequest req = (Messages.SimpleRequest) message; 295 if (nodeCount.decrementAndGet() == 0) { 296 // we are in the final leaf node so trigger an ABORT upwards 297 Context.currentContextExecutor(otherWork).execute(new Runnable() { 298 @Override 299 public void run() { 300 synchronized (call) { 301 call.close(Status.ABORTED, new Metadata()); 302 } 303 } 304 }); 305 } else if (req.getResponseSize() != 0) { 306 // We are in a non leaf node so fire off two requests 307 req = req.toBuilder().setResponseSize(req.getResponseSize() - 1).build(); 308 for (int i = 0; i < 2; i++) { 309 asyncStub.unaryCall(req, 310 new StreamObserver<Messages.SimpleResponse>() { 311 @Override 312 public void onNext(Messages.SimpleResponse value) { 313 } 314 315 @Override 316 public void onError(Throwable t) { 317 Status status = Status.fromThrowable(t); 318 if (status.getCode() == Status.Code.CANCELLED) { 319 observedCancellations.countDown(); 320 } 321 // Propagate closure upwards. 322 try { 323 synchronized (call) { 324 call.close(status, new Metadata()); 325 } 326 } catch (IllegalStateException t2) { 327 // Ignore error if already closed. 328 } 329 } 330 331 @Override 332 public void onCompleted() { 333 } 334 }); 335 } 336 } 337 } 338 339 @Override 340 public void onCancel() { 341 receivedCancellations.countDown(); 342 } 343 }; 344 } 345 }) 346 ).build(); 347 server.start(); 348 } 349 } 350