• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2015 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import static com.google.common.truth.Truth.assertAbout;
20 import static io.grpc.testing.DeadlineSubject.deadline;
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertNotSame;
23 import static org.junit.Assert.assertTrue;
24 import static org.junit.Assert.fail;
25 
26 import com.google.common.util.concurrent.SettableFuture;
27 import io.grpc.Context;
28 import io.grpc.Context.CancellableContext;
29 import io.grpc.Deadline;
30 import io.grpc.ManagedChannel;
31 import io.grpc.Metadata;
32 import io.grpc.Server;
33 import io.grpc.ServerCall;
34 import io.grpc.ServerCallHandler;
35 import io.grpc.ServerInterceptor;
36 import io.grpc.ServerInterceptors;
37 import io.grpc.Status;
38 import io.grpc.StatusRuntimeException;
39 import io.grpc.inprocess.InProcessChannelBuilder;
40 import io.grpc.inprocess.InProcessServerBuilder;
41 import io.grpc.stub.ServerCallStreamObserver;
42 import io.grpc.stub.StreamObserver;
43 import io.grpc.testing.integration.Messages.SimpleRequest;
44 import io.grpc.testing.integration.Messages.SimpleResponse;
45 import java.io.IOException;
46 import java.util.concurrent.CountDownLatch;
47 import java.util.concurrent.ExecutionException;
48 import java.util.concurrent.ExecutorService;
49 import java.util.concurrent.Executors;
50 import java.util.concurrent.Future;
51 import java.util.concurrent.TimeUnit;
52 import java.util.concurrent.atomic.AtomicInteger;
53 import org.junit.After;
54 import org.junit.Before;
55 import org.junit.Rule;
56 import org.junit.Test;
57 import org.junit.runner.RunWith;
58 import org.junit.runners.JUnit4;
59 import org.mockito.Mock;
60 import org.mockito.junit.MockitoJUnit;
61 import org.mockito.junit.MockitoRule;
62 
63 /**
64  * Integration test for various forms of cancellation and deadline propagation.
65  */
66 @RunWith(JUnit4.class)
67 public class CascadingTest {
68   @Rule public final MockitoRule mocks = MockitoJUnit.rule();
69 
70   @Mock
71   TestServiceGrpc.TestServiceImplBase service;
72   private ManagedChannel channel;
73   private Server server;
74   private CountDownLatch observedCancellations;
75   private CountDownLatch receivedCancellations;
76   private TestServiceGrpc.TestServiceBlockingStub blockingStub;
77   private TestServiceGrpc.TestServiceStub asyncStub;
78   private TestServiceGrpc.TestServiceFutureStub futureStub;
79   private ExecutorService otherWork;
80 
81   @Before
setUp()82   public void setUp() throws Exception {
83     // Use a cached thread pool as we need a thread for each blocked call
84     otherWork = Executors.newCachedThreadPool();
85     channel = InProcessChannelBuilder.forName("channel").executor(otherWork).build();
86     blockingStub = TestServiceGrpc.newBlockingStub(channel);
87     asyncStub = TestServiceGrpc.newStub(channel);
88     futureStub = TestServiceGrpc.newFutureStub(channel);
89   }
90 
91   @After
tearDown()92   public void tearDown() {
93     channel.shutdownNow();
94     server.shutdownNow();
95     otherWork.shutdownNow();
96   }
97 
98   /**
99    * Test {@link Context} cancellation propagates from the first node in the call chain all the way
100    * to the last.
101    */
102   @Test
testCascadingCancellationViaOuterContextCancellation()103   public void testCascadingCancellationViaOuterContextCancellation() throws Exception {
104     observedCancellations = new CountDownLatch(2);
105     receivedCancellations = new CountDownLatch(3);
106     Future<?> chainReady = startChainingServer(3);
107     CancellableContext context = Context.current().withCancellation();
108     Future<SimpleResponse> future;
109     Context prevContext = context.attach();
110     try {
111       future = futureStub.unaryCall(SimpleRequest.getDefaultInstance());
112     } finally {
113       context.detach(prevContext);
114     }
115     chainReady.get(5, TimeUnit.SECONDS);
116 
117     context.cancel(null);
118     try {
119       future.get(5, TimeUnit.SECONDS);
120       fail("Expected cancellation");
121     } catch (ExecutionException ex) {
122       Status status = Status.fromThrowable(ex);
123       assertEquals(Status.Code.CANCELLED, status.getCode());
124 
125       // Should have observed 2 cancellations responses from downstream servers
126       if (!observedCancellations.await(5, TimeUnit.SECONDS)) {
127         fail("Expected number of cancellations not observed by clients");
128       }
129       if (!receivedCancellations.await(5, TimeUnit.SECONDS)) {
130         fail("Expected number of cancellations to be received by servers not observed");
131       }
132     }
133   }
134 
135   /**
136    * Test that cancellation via call cancellation propagates down the call.
137    */
138   @Test
testCascadingCancellationViaRpcCancel()139   public void testCascadingCancellationViaRpcCancel() throws Exception {
140     observedCancellations = new CountDownLatch(2);
141     receivedCancellations = new CountDownLatch(3);
142     Future<?> chainReady = startChainingServer(3);
143     Future<SimpleResponse> future = futureStub.unaryCall(SimpleRequest.getDefaultInstance());
144     chainReady.get(5, TimeUnit.SECONDS);
145 
146     future.cancel(true);
147     assertTrue(future.isCancelled());
148     if (!observedCancellations.await(5, TimeUnit.SECONDS)) {
149       fail("Expected number of cancellations not observed by clients");
150     }
151     if (!receivedCancellations.await(5, TimeUnit.SECONDS)) {
152       fail("Expected number of cancellations to be received by servers not observed");
153     }
154   }
155 
156   /**
157    * Test that when RPC cancellation propagates up a call chain, the cancellation of the parent
158    * RPC triggers cancellation of all of its children.
159    */
160   @Test
testCascadingCancellationViaLeafFailure()161   public void testCascadingCancellationViaLeafFailure() throws Exception {
162     // All nodes (15) except one edge of the tree (4) will be cancelled.
163     observedCancellations = new CountDownLatch(11);
164     receivedCancellations = new CountDownLatch(11);
165     startCallTreeServer(3);
166     try {
167       // Use response size limit to control tree nodeCount.
168       blockingStub.unaryCall(Messages.SimpleRequest.newBuilder().setResponseSize(3).build());
169       fail("Expected abort");
170     } catch (StatusRuntimeException sre) {
171       // Wait for the workers to finish
172       Status status = sre.getStatus();
173       // Outermost caller observes ABORTED propagating up from the failing leaf,
174       // The descendant RPCs are cancelled so they receive CANCELLED.
175       assertEquals(Status.Code.ABORTED, status.getCode());
176 
177       if (!observedCancellations.await(5, TimeUnit.SECONDS)) {
178         fail("Expected number of cancellations not observed by clients");
179       }
180       if (!receivedCancellations.await(5, TimeUnit.SECONDS)) {
181         fail("Expected number of cancellations to be received by servers not observed");
182       }
183     }
184   }
185 
186   @Test
testDeadlinePropagation()187   public void testDeadlinePropagation() throws Exception {
188     final AtomicInteger recursionDepthRemaining = new AtomicInteger(3);
189     final SettableFuture<Deadline> finalDeadline = SettableFuture.create();
190     class DeadlineSaver extends TestServiceGrpc.TestServiceImplBase {
191       @Override
192       public void unaryCall(final SimpleRequest request,
193           final StreamObserver<SimpleResponse> responseObserver) {
194         Context.currentContextExecutor(otherWork).execute(new Runnable() {
195           @Override
196           public void run() {
197             try {
198               if (recursionDepthRemaining.decrementAndGet() == 0) {
199                 finalDeadline.set(Context.current().getDeadline());
200                 responseObserver.onNext(SimpleResponse.getDefaultInstance());
201               } else {
202                 responseObserver.onNext(blockingStub.unaryCall(request));
203               }
204               responseObserver.onCompleted();
205             } catch (Exception ex) {
206               responseObserver.onError(ex);
207             }
208           }
209         });
210       }
211     }
212 
213     server = InProcessServerBuilder.forName("channel").executor(otherWork)
214         .addService(new DeadlineSaver())
215         .build().start();
216 
217     Deadline initialDeadline = Deadline.after(1, TimeUnit.MINUTES);
218     blockingStub.withDeadline(initialDeadline).unaryCall(SimpleRequest.getDefaultInstance());
219     assertNotSame(initialDeadline, finalDeadline);
220     // Since deadline is re-calculated at each hop, some variance is acceptable and expected.
221     assertAbout(deadline())
222         .that(finalDeadline.get()).isWithin(1, TimeUnit.SECONDS).of(initialDeadline);
223   }
224 
225   /**
226    * Create a chain of client to server calls which can be cancelled top down.
227    *
228    * @return a Future that completes when call chain is created
229    */
startChainingServer(final int depthThreshold)230   private Future<?> startChainingServer(final int depthThreshold) throws IOException {
231     final AtomicInteger serversReady = new AtomicInteger();
232     final SettableFuture<Void> chainReady = SettableFuture.create();
233     class ChainingService extends TestServiceGrpc.TestServiceImplBase {
234       @Override
235       public void unaryCall(final SimpleRequest request,
236           final StreamObserver<SimpleResponse> responseObserver) {
237         ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(new Runnable() {
238           @Override
239           public void run() {
240             receivedCancellations.countDown();
241           }
242         });
243         if (serversReady.incrementAndGet() == depthThreshold) {
244           // Stop recursion
245           chainReady.set(null);
246           return;
247         }
248 
249         Context.currentContextExecutor(otherWork).execute(new Runnable() {
250           @Override
251           public void run() {
252             try {
253               blockingStub.unaryCall(request);
254             } catch (StatusRuntimeException e) {
255               Status status = e.getStatus();
256               if (status.getCode() == Status.Code.CANCELLED) {
257                 observedCancellations.countDown();
258               } else {
259                 responseObserver.onError(e);
260               }
261             }
262           }
263         });
264       }
265     }
266 
267     server = InProcessServerBuilder.forName("channel").executor(otherWork)
268         .addService(new ChainingService())
269         .build().start();
270     return chainReady;
271   }
272 
273   /**
274    * Create a tree of client to server calls where each received call on the server
275    * fans out to two downstream calls. Uses SimpleRequest.response_size to limit the nodeCount
276    * of the tree. One of the leaves will ABORT to trigger cancellation back up to tree.
277    */
startCallTreeServer(int depthThreshold)278   private void startCallTreeServer(int depthThreshold) throws IOException {
279     final AtomicInteger nodeCount = new AtomicInteger((2 << depthThreshold) - 1);
280     server = InProcessServerBuilder.forName("channel").executor(otherWork).addService(
281         ServerInterceptors.intercept(service,
282             new ServerInterceptor() {
283               @Override
284               public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
285                   final ServerCall<ReqT, RespT> call,
286                   Metadata headers,
287                   ServerCallHandler<ReqT, RespT> next) {
288                 // Respond with the headers but nothing else.
289                 call.sendHeaders(new Metadata());
290                 call.request(1);
291                 return new ServerCall.Listener<ReqT>() {
292                   @Override
293                   public void onMessage(final ReqT message) {
294                     Messages.SimpleRequest req = (Messages.SimpleRequest) message;
295                     if (nodeCount.decrementAndGet() == 0) {
296                       // we are in the final leaf node so trigger an ABORT upwards
297                       Context.currentContextExecutor(otherWork).execute(new Runnable() {
298                         @Override
299                         public void run() {
300                           synchronized (call) {
301                             call.close(Status.ABORTED, new Metadata());
302                           }
303                         }
304                       });
305                     } else if (req.getResponseSize() != 0) {
306                       // We are in a non leaf node so fire off two requests
307                       req = req.toBuilder().setResponseSize(req.getResponseSize() - 1).build();
308                       for (int i = 0; i < 2; i++) {
309                         asyncStub.unaryCall(req,
310                             new StreamObserver<Messages.SimpleResponse>() {
311                               @Override
312                               public void onNext(Messages.SimpleResponse value) {
313                               }
314 
315                               @Override
316                               public void onError(Throwable t) {
317                                 Status status = Status.fromThrowable(t);
318                                 if (status.getCode() == Status.Code.CANCELLED) {
319                                   observedCancellations.countDown();
320                                 }
321                                 // Propagate closure upwards.
322                                 try {
323                                   synchronized (call) {
324                                     call.close(status, new Metadata());
325                                   }
326                                 } catch (IllegalStateException t2) {
327                                   // Ignore error if already closed.
328                                 }
329                               }
330 
331                               @Override
332                               public void onCompleted() {
333                               }
334                             });
335                       }
336                     }
337                   }
338 
339                   @Override
340                   public void onCancel() {
341                     receivedCancellations.countDown();
342                   }
343                 };
344               }
345             })
346     ).build();
347     server.start();
348   }
349 }
350